1 | package agents.anac.y2015.Phoenix.GP;/* This file is part of the jgpml Project.
|
---|
2 | * http://github.com/renzodenardi/jgpml
|
---|
3 | *
|
---|
4 | * Copyright (c) 2011 Renzo De Nardi and Hugo Gravato-Marques
|
---|
5 | *
|
---|
6 | * Permission is hereby granted, free of charge, to any person
|
---|
7 | * obtaining a copy of this software and associated documentation
|
---|
8 | * files (the "Software"), to deal in the Software without
|
---|
9 | * restriction, including without limitation the rights to use,
|
---|
10 | * copy, modify, merge, publish, distribute, sublicense, and/or sell
|
---|
11 | * copies of the Software, and to permit persons to whom the
|
---|
12 | * Software is furnished to do so, subject to the following
|
---|
13 | * conditions:
|
---|
14 | *
|
---|
15 | * The above copyright notice and this permission notice shall be
|
---|
16 | * included in all copies or substantial portions of the Software.
|
---|
17 | *
|
---|
18 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
---|
19 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
---|
20 | * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
---|
21 | * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
|
---|
22 | * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
---|
23 | * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
---|
24 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
---|
25 | * OTHER DEALINGS IN THE SOFTWARE.
|
---|
26 | */
|
---|
27 |
|
---|
28 | import agents.Jama.Matrix;
|
---|
29 |
|
---|
30 | /**
|
---|
31 | * Neural network covariance function with a single parameter for the distance
|
---|
32 | * measure. The covariance function is parameterized as:
|
---|
33 | * <P>
|
---|
34 | * k(x^p,x^q) = sf2 * asin(x^p'*P*x^q / sqrt[(1+x^p'*P*x^p)*(1+x^q'*P*x^q)])
|
---|
35 | * <P>
|
---|
36 | * where the x^p and x^q vectors on the right hand side have an added extra bias
|
---|
37 | * entry with unit value. P is ell^-2 times the unit matrix and sf2 controls the
|
---|
38 | * signal variance. The hyperparameters are:
|
---|
39 | * <P>
|
---|
40 | * [ log(ell)
|
---|
41 | * log(sqrt(sf2) ]
|
---|
42 | */
|
---|
43 |
|
---|
44 | public class CovNNone implements CovarianceFunction{
|
---|
45 |
|
---|
46 | double[][] k;
|
---|
47 | double[][] q;
|
---|
48 |
|
---|
49 | public CovNNone(){}
|
---|
50 |
|
---|
51 |
|
---|
52 | /**
|
---|
53 | * Returns the number of hyperparameters of this<code>PhoenixAlpha.CovarianceFunction</code>
|
---|
54 | *
|
---|
55 | * @return number of hyperparameters
|
---|
56 | */
|
---|
57 | public int numParameters() {
|
---|
58 | return 2;
|
---|
59 | }
|
---|
60 |
|
---|
61 | /**
|
---|
62 | * Compute covariance matrix of a dataset X
|
---|
63 | *
|
---|
64 | * @param loghyper column <code>Matrix</code> of hyperparameters
|
---|
65 | * @param X input dataset
|
---|
66 | * @return K covariance <code>Matrix</code>
|
---|
67 | */
|
---|
68 | public Matrix compute(Matrix loghyper, Matrix X) {
|
---|
69 |
|
---|
70 | if(loghyper.getColumnDimension()!=1 || loghyper.getRowDimension()!=numParameters())
|
---|
71 | throw new IllegalArgumentException("Wrong number of hyperparameters, "+loghyper.getRowDimension()+" instead of "+numParameters());
|
---|
72 |
|
---|
73 | final double ell = Math.exp(loghyper.get(0,0));
|
---|
74 | final double em2 = 1/(ell*ell);
|
---|
75 | final double oneplusem2 = 1+em2;
|
---|
76 | final double sf2 = Math.exp(2*loghyper.get(1,0));
|
---|
77 |
|
---|
78 |
|
---|
79 | final int m = X.getRowDimension();
|
---|
80 | final int n = X.getColumnDimension();
|
---|
81 | double[][] x= X.getArray();
|
---|
82 |
|
---|
83 | // Matrix Xc= X.times(1/ell);
|
---|
84 | //
|
---|
85 | // Q = Xc.times(Xc.transpose());
|
---|
86 | // System.out.print("Q=");Q.print(Q.getColumnDimension(), 8);
|
---|
87 |
|
---|
88 | // Q = new Matrix(m,m);
|
---|
89 | // double[][] q = Q.getArray();
|
---|
90 | q = new double[m][m];
|
---|
91 |
|
---|
92 | for(int i=0;i<m;i++){
|
---|
93 | for(int j=0;j<m;j++){
|
---|
94 | double t = 0;
|
---|
95 | for(int k=0;k<n;k++){
|
---|
96 | t+=x[i][k]*x[j][k]*em2;
|
---|
97 | }
|
---|
98 | q[i][j]=t;
|
---|
99 | }
|
---|
100 | }
|
---|
101 | // System.out.print("q=");Q.print(Q.getColumnDimension(), 8);
|
---|
102 |
|
---|
103 | // Matrix dQ = diag(Q);
|
---|
104 | // Matrix dQT = dQ.transpose();
|
---|
105 | // Matrix Qc = Q.copy();
|
---|
106 | // K = addValue(Qc,em2).arrayRightDivide(sqrt(addValue(dQ,1+em2)).times(sqrt(addValue(dQT,1+em2))));
|
---|
107 | // System.out.print("K=");K.print(K.getColumnDimension(), 8);
|
---|
108 |
|
---|
109 | double[] dq = new double[m];
|
---|
110 | for(int i=0;i<m;i++){
|
---|
111 | dq[i]=Math.sqrt(oneplusem2+q[i][i]);
|
---|
112 | }
|
---|
113 |
|
---|
114 | //K = new Matrix(m,m);
|
---|
115 | Matrix A = new Matrix(m,m);
|
---|
116 | double[][] k = new double[m][m];//K.getArray();
|
---|
117 | double[][] a =A.getArray();
|
---|
118 | for(int i=0;i<m;i++){
|
---|
119 | final double dqi = dq[i];
|
---|
120 | for(int j=0;j<m;j++){
|
---|
121 | final double t = (em2+q[i][j])/(dqi*dq[j]);
|
---|
122 | k[i][j]=t;
|
---|
123 | a[i][j]=sf2*Math.asin(t);
|
---|
124 | }
|
---|
125 | }
|
---|
126 | // System.out.print("k=");K.print(K.getColumnDimension(), 8);
|
---|
127 | // System.out.println("");
|
---|
128 |
|
---|
129 | // Matrix A = asin(K).times(sf2);
|
---|
130 | return A;
|
---|
131 | }
|
---|
132 |
|
---|
133 | /**
|
---|
134 | * Compute compute test set covariances
|
---|
135 | *
|
---|
136 | * @param loghyper column <code>Matrix</code> of hyperparameters
|
---|
137 | * @param X input dataset
|
---|
138 | * @param Xstar test set
|
---|
139 | * @return [K(Xstar,Xstar) K(X,Xstar)]
|
---|
140 | */
|
---|
141 | public Matrix[] compute(Matrix loghyper, Matrix X, Matrix Xstar) {
|
---|
142 |
|
---|
143 | if(loghyper.getColumnDimension()!=1 || loghyper.getRowDimension()!=numParameters())
|
---|
144 | throw new IllegalArgumentException("Wrong number of hyperparameters, "+loghyper.getRowDimension()+" instead of "+numParameters());
|
---|
145 |
|
---|
146 | final double ell = Math.exp(loghyper.get(0,0));
|
---|
147 | final double em2 = 1/(ell*ell);
|
---|
148 | final double oneplusem2 = 1+em2;
|
---|
149 | final double sf2 = Math.exp(2*loghyper.get(1,0));
|
---|
150 |
|
---|
151 |
|
---|
152 |
|
---|
153 | final int m = X.getRowDimension();
|
---|
154 | final int n = X.getColumnDimension();
|
---|
155 | double[][] x= X.getArray();
|
---|
156 | final int mstar = Xstar.getRowDimension();
|
---|
157 | final int nstar = Xstar.getColumnDimension();
|
---|
158 | double[][] xstar= Xstar.getArray();
|
---|
159 |
|
---|
160 |
|
---|
161 | double[] sumxstardotTimesxstar = new double[mstar];
|
---|
162 | for(int i=0; i<mstar; i++){
|
---|
163 | double t =0;
|
---|
164 | for(int j=0; j<nstar; j++){
|
---|
165 | final double tt = xstar[i][j];
|
---|
166 | t+=tt*tt*em2;
|
---|
167 | }
|
---|
168 | sumxstardotTimesxstar[i]=t;
|
---|
169 | }
|
---|
170 |
|
---|
171 | Matrix A = new Matrix(mstar,1);
|
---|
172 | double[][] a = A.getArray();
|
---|
173 | for(int i=0; i<mstar; i++){
|
---|
174 | a[i][0]=sf2*Math.asin((em2+sumxstardotTimesxstar[i])/(oneplusem2+sumxstardotTimesxstar[i]));
|
---|
175 | }
|
---|
176 |
|
---|
177 |
|
---|
178 |
|
---|
179 | // X = X.times(1/ell);
|
---|
180 | // Xstar = Xstar.times(1/ell);
|
---|
181 | // Matrix tmp = sumRows(Xstar.arrayTimes(Xstar));
|
---|
182 | //
|
---|
183 | // Matrix tmp2 = tmp.copy();
|
---|
184 | // addValue(tmp,em2);
|
---|
185 | // addValue(tmp2,oneplusem2);
|
---|
186 | // Matrix A = asin(tmp.arrayRightDivide(tmp2)).times(sf2);
|
---|
187 |
|
---|
188 |
|
---|
189 | double[] sumxdotTimesx = new double[m];
|
---|
190 | for(int i=0; i<m; i++){
|
---|
191 | double t =0;
|
---|
192 | for(int j=0; j<n; j++){
|
---|
193 | final double tt = x[i][j];
|
---|
194 | t+=tt*tt*em2;
|
---|
195 | }
|
---|
196 | sumxdotTimesx[i]=t+oneplusem2;
|
---|
197 | }
|
---|
198 |
|
---|
199 | Matrix B = new Matrix(m,mstar);
|
---|
200 | double[][] b = B.getArray();
|
---|
201 | for(int i=0; i<m; i++){
|
---|
202 | final double[] xi = x[i];
|
---|
203 | for(int j=0; j<mstar; j++){
|
---|
204 | double t=0;
|
---|
205 | final double[] xstarj = xstar[j];
|
---|
206 | for(int k=0; k<n; k++){
|
---|
207 | t+=xi[k]*xstarj[k]*em2;
|
---|
208 | }
|
---|
209 | b[i][j]=t+em2;
|
---|
210 | }
|
---|
211 | }
|
---|
212 |
|
---|
213 | for(int i=0; i<m; i++){
|
---|
214 | for(int j=0; j<mstar; j++){
|
---|
215 | b[i][j] = sf2*Math.asin(b[i][j]/Math.sqrt((sumxstardotTimesxstar[j]+oneplusem2)*sumxdotTimesx[i]));
|
---|
216 | }
|
---|
217 | }
|
---|
218 |
|
---|
219 |
|
---|
220 |
|
---|
221 | // tmp = sumRows(X.arrayTimes(X));
|
---|
222 | // addValue(tmp,oneplusem2);
|
---|
223 | //
|
---|
224 | // tmp2=tmp2.transpose();
|
---|
225 | //
|
---|
226 | // tmp = addValue(X.times(Xstar.transpose()),em2).arrayRightDivide(sqrt(tmp.times(tmp2)));
|
---|
227 | // Matrix B = asin(tmp).times(sf2);
|
---|
228 |
|
---|
229 | //System.out.println("");
|
---|
230 | return new Matrix[]{A,B};
|
---|
231 | }
|
---|
232 |
|
---|
233 | /**
|
---|
234 | * Coompute the derivatives of this <code>PhoenixAlpha.CovarianceFunction</code> with respect
|
---|
235 | * to the hyperparameter with index <code>idx</code>
|
---|
236 | *
|
---|
237 | * @param loghyper hyperparameters
|
---|
238 | * @param X input dataset
|
---|
239 | * @param index hyperparameter index
|
---|
240 | * @return <code>Matrix</code> of derivatives
|
---|
241 | */
|
---|
242 | public Matrix computeDerivatives(Matrix loghyper, Matrix X, int index) {
|
---|
243 |
|
---|
244 | if(loghyper.getColumnDimension()!=1 || loghyper.getRowDimension()!=numParameters())
|
---|
245 | throw new IllegalArgumentException("Wrong number of hyperparameters, "+loghyper.getRowDimension()+" instead of "+numParameters());
|
---|
246 | if(index>numParameters()-1)
|
---|
247 | throw new IllegalArgumentException("Wrong hyperparameters index "+index+" it should be smaller or equal to "+(numParameters()-1));
|
---|
248 |
|
---|
249 | final double ell = Math.exp(loghyper.get(0,0));
|
---|
250 | final double em2 = 1/(ell*ell);
|
---|
251 | final double oneplusem2 = 1+em2;
|
---|
252 | final double twosf2 = 2*Math.exp(2*loghyper.get(1,0));
|
---|
253 |
|
---|
254 | final int m = X.getRowDimension();
|
---|
255 | final int n = X.getColumnDimension();
|
---|
256 | double[][] x= X.getArray();
|
---|
257 |
|
---|
258 | // Matrix X = XX.times(1/ell);
|
---|
259 |
|
---|
260 | if(q==null || q.length!=m || q[0].length!=m) {
|
---|
261 | q = new double[m][m];
|
---|
262 |
|
---|
263 | for(int i=0;i<m;i++){
|
---|
264 | for(int j=0;j<m;j++){
|
---|
265 | double t = 0;
|
---|
266 | for(int k=0;k<n;k++){
|
---|
267 | t+=x[i][k]*x[j][k]*em2;
|
---|
268 | }
|
---|
269 | q[i][j]=t;
|
---|
270 | }
|
---|
271 | }
|
---|
272 | }
|
---|
273 |
|
---|
274 | double[] dq = new double[m];
|
---|
275 | for(int i=0;i<m;i++){
|
---|
276 | dq[i]=Math.sqrt(oneplusem2+q[i][i]);
|
---|
277 | }
|
---|
278 |
|
---|
279 | if(k==null || k.length!=m || k[0].length!=m) {
|
---|
280 | k = new double[m][m];
|
---|
281 | for(int i=0;i<m;i++){
|
---|
282 | final double dqi = dq[i];
|
---|
283 | for(int j=0;j<m;j++){
|
---|
284 | final double t = (em2+q[i][j])/(dqi*dq[j]);
|
---|
285 | k[i][j]=t;
|
---|
286 | }
|
---|
287 | }
|
---|
288 | }
|
---|
289 |
|
---|
290 | // Matrix Xc= XX.times(1/ell);
|
---|
291 | // Matrix Q = Xc.times(Xc.transpose());
|
---|
292 | //
|
---|
293 | // Matrix dQ = diag(Q);
|
---|
294 | // Matrix dQT = dQ.transpose();
|
---|
295 | // Matrix K = addValue(Q.copy(),em2).arrayRightDivide(sqrt(addValue(dQ.copy(),1+em2)).times(sqrt(addValue(dQT,1+em2))));
|
---|
296 | // Matrix dQc = dQ.copy();
|
---|
297 |
|
---|
298 | Matrix A;
|
---|
299 | if(index==0){
|
---|
300 | for(int i=0;i<m;i++){
|
---|
301 | dq[i]=oneplusem2+q[i][i];
|
---|
302 | }
|
---|
303 | double[] v = new double[m];
|
---|
304 | for(int i=0; i<m; i++){
|
---|
305 | double t =0;
|
---|
306 | for(int j=0; j<n; j++){
|
---|
307 | final double xij = x[i][j];
|
---|
308 | t+=xij*xij*em2;
|
---|
309 | }
|
---|
310 | v[i]=(t+em2)/(dq[i]);
|
---|
311 | }
|
---|
312 |
|
---|
313 | // Matrix test = addValue(sumRows(X.arrayTimes(X)),em2);
|
---|
314 | // Matrix tmp = addValue(dQc,1+em2);
|
---|
315 | // Matrix V = addValue(sumRows(X.arrayTimes(X)),em2).arrayRightDivide(tmp);
|
---|
316 | //
|
---|
317 | // tmp = sqrt(tmp);
|
---|
318 | // tmp = addValue(Q.copy(),em2).arrayRightDivide(tmp.times(tmp.transpose()));
|
---|
319 |
|
---|
320 | for(int i=0; i<m; i++){
|
---|
321 | final double vi = v[i];
|
---|
322 | for(int j=0; j<m; j++){
|
---|
323 | double t =(q[i][j]+em2)/(Math.sqrt(dq[i])*Math.sqrt(dq[j]));
|
---|
324 | final double kij = k[i][j];
|
---|
325 | q[i][j]=-twosf2*((t-(0.5*kij*(vi+v[j])))/Math.sqrt(1-kij*kij));
|
---|
326 | }
|
---|
327 | }
|
---|
328 |
|
---|
329 | // Matrix tmp2 = new Matrix(m,m);
|
---|
330 | // for(int j=0; j<m; j++)
|
---|
331 | // tmp2.setMatrix(0,m-1,j,j,V);
|
---|
332 | //
|
---|
333 | // tmp = tmp.minus(K.arrayTimes(tmp2.plus(tmp2.transpose())).times(0.5));
|
---|
334 | //
|
---|
335 | // A = tmp.arrayRightDivide(sqrtOneMinusSqr(K)).times(-twosf2);
|
---|
336 |
|
---|
337 | A = new Matrix(q);
|
---|
338 | // System.out.println("");
|
---|
339 | q=null;
|
---|
340 | } else{
|
---|
341 | for(int i=0; i<m; i++){
|
---|
342 | for(int j=0; j<m; j++){
|
---|
343 | k[i][j]=Math.asin(k[i][j])*twosf2;
|
---|
344 | }
|
---|
345 | }
|
---|
346 | // A = asin(K).times(twosf2);
|
---|
347 | // K=null;
|
---|
348 | A = new Matrix(k);
|
---|
349 | k=null;
|
---|
350 | }
|
---|
351 |
|
---|
352 |
|
---|
353 | return A;
|
---|
354 | }
|
---|
355 |
|
---|
356 | // private static Matrix sqrtOneMinusSqr(Matrix in){
|
---|
357 | // Matrix out = new Matrix(in.getRowDimension(),in.getColumnDimension());
|
---|
358 | // for(int i=0; i<in.getRowDimension(); i++)
|
---|
359 | // for(int j=0; j<in.getColumnDimension(); j++) {
|
---|
360 | // final double tmp = in.get(i,j);
|
---|
361 | // out.set(i,j,Math.sqrt(1-tmp*tmp));
|
---|
362 | // }
|
---|
363 | // return out;
|
---|
364 | // }
|
---|
365 |
|
---|
366 | public static void main(String[] args) {
|
---|
367 |
|
---|
368 | CovNNone cf = new CovNNone();
|
---|
369 |
|
---|
370 | Matrix X = Matrix.identity(6,6);
|
---|
371 | Matrix logtheta = new Matrix(new double[][]{{0.1},{0.2}});
|
---|
372 |
|
---|
373 | Matrix z =new Matrix(new double[][]{{1,2,3,4,5,6},{1,2,3,4,5,6}});
|
---|
374 |
|
---|
375 | // System.out.println("")
|
---|
376 | //
|
---|
377 | // long start = System.currentTimeMillis()
|
---|
378 | //
|
---|
379 | // Matrix K = cf.compute(logtheta,X);
|
---|
380 | // long stop = System.currentTimeMillis();
|
---|
381 | // System.out.println(""+(stop-start));
|
---|
382 |
|
---|
383 | // K.print(K.getColumnDimension(), 15);
|
---|
384 |
|
---|
385 | // long start = System.currentTimeMillis();
|
---|
386 | // Matrix[] res = cf.compute(logtheta,X,z);
|
---|
387 | // long stop = System.currentTimeMillis();
|
---|
388 | // System.out.println(""+(stop-start));
|
---|
389 |
|
---|
390 | // res[0].print(res[0].getColumnDimension(), 8);
|
---|
391 | // res[1].print(res[1].getColumnDimension(), 8);
|
---|
392 |
|
---|
393 | Matrix d = cf.computeDerivatives(logtheta,X,1);
|
---|
394 |
|
---|
395 | d.print(d.getColumnDimension(), 8);
|
---|
396 |
|
---|
397 | }
|
---|
398 | }
|
---|