1 | package agents.anac.y2015.Phoenix.GP;/* This file is part of the jgpml Project.
|
---|
2 | * http://github.com/renzodenardi/jgpml
|
---|
3 | *
|
---|
4 | * Copyright (c) 2011 Renzo De Nardi and Hugo Gravato-Marques
|
---|
5 | *
|
---|
6 | * Permission is hereby granted, free of charge, to any person
|
---|
7 | * obtaining a copy of this software and associated documentation
|
---|
8 | * files (the "Software"), to deal in the Software without
|
---|
9 | * restriction, including without limitation the rights to use,
|
---|
10 | * copy, modify, merge, publish, distribute, sublicense, and/or sell
|
---|
11 | * copies of the Software, and to permit persons to whom the
|
---|
12 | * Software is furnished to do so, subject to the following
|
---|
13 | * conditions:
|
---|
14 | *
|
---|
15 | * The above copyright notice and this permission notice shall be
|
---|
16 | * included in all copies or substantial portions of the Software.
|
---|
17 | *
|
---|
18 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
---|
19 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
---|
20 | * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
---|
21 | * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
|
---|
22 | * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
---|
23 | * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
---|
24 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
---|
25 | * OTHER DEALINGS IN THE SOFTWARE.
|
---|
26 | */
|
---|
27 |
|
---|
28 | import agents.Jama.Matrix;
|
---|
29 |
|
---|
30 | /**
|
---|
31 | * Neural network covariance function with a single parameter for the distance
|
---|
32 | * measure and white noise. The covariance function is parameterized as:
|
---|
33 | * <P>
|
---|
34 | * k(x^p,x^q) = sf2 * asin(x^p'*P*x^q / sqrt[(1+x^p'*P*x^p)*(1+x^q'*P*x^q)]) + s2 * \delta(p,q)
|
---|
35 | * <P>
|
---|
36 | * where the x^p and x^q vectors on the right hand side have an added extra bias
|
---|
37 | * entry with unit value. P is ell^-2 times the unit matrix and sf2 controls the
|
---|
38 | * signal variance. The hyperparameters are:
|
---|
39 | * <P>
|
---|
40 | * [ log(ell)
|
---|
41 | * log(sqrt(sf2)
|
---|
42 | * log(s2)]
|
---|
43 | *
|
---|
44 | * <P>
|
---|
45 | * For reson of speed consider to use this covariance function instead of <code>PhoenixAlpha.CovSum(PhoenixAlpha.CovNNone,PhoenixAlpha.CovNoise)</code>
|
---|
46 | *
|
---|
47 | */
|
---|
48 |
|
---|
49 | public class CovNNoneNoise implements CovarianceFunction{
|
---|
50 |
|
---|
51 | double[][] k;
|
---|
52 | double[][] q;
|
---|
53 |
|
---|
54 | public CovNNoneNoise(){}
|
---|
55 |
|
---|
56 |
|
---|
57 | /**
|
---|
58 | * Returns the number of hyperparameters of this<code>PhoenixAlpha.CovarianceFunction</code>
|
---|
59 | *
|
---|
60 | * @return number of hyperparameters
|
---|
61 | */
|
---|
62 | public int numParameters() {
|
---|
63 | return 3;
|
---|
64 | }
|
---|
65 |
|
---|
66 | /**
|
---|
67 | * Compute covariance matrix of a dataset X
|
---|
68 | *
|
---|
69 | * @param loghyper column <code>Matrix</code> of hyperparameters
|
---|
70 | * @param X input dataset
|
---|
71 | * @return K covariance <code>Matrix</code>
|
---|
72 | */
|
---|
73 | public Matrix compute(Matrix loghyper, Matrix X) {
|
---|
74 |
|
---|
75 | if(loghyper.getColumnDimension()!=1 || loghyper.getRowDimension()!=numParameters())
|
---|
76 | throw new IllegalArgumentException("Wrong number of hyperparameters, "+loghyper.getRowDimension()+" instead of "+numParameters());
|
---|
77 |
|
---|
78 | final double ell = Math.exp(loghyper.get(0,0));
|
---|
79 | final double em2 = 1/(ell*ell);
|
---|
80 | final double oneplusem2 = 1+em2;
|
---|
81 | final double sf2 = Math.exp(2*loghyper.get(1,0));
|
---|
82 | final double s2 = Math.exp(2*loghyper.get(2,0));
|
---|
83 |
|
---|
84 | final int m = X.getRowDimension();
|
---|
85 | final int n = X.getColumnDimension();
|
---|
86 | double[][] x= X.getArray();
|
---|
87 |
|
---|
88 | // Matrix Xc= X.times(1/ell);
|
---|
89 | //
|
---|
90 | // Q = Xc.times(Xc.transpose());
|
---|
91 | // System.out.print("Q=");Q.print(Q.getColumnDimension(), 8);
|
---|
92 |
|
---|
93 | // Q = new Matrix(m,m);
|
---|
94 | // double[][] q = Q.getArray();
|
---|
95 | // double[][] q;
|
---|
96 | q = new double[m][m];
|
---|
97 |
|
---|
98 | for(int i=0;i<m;i++){
|
---|
99 | for(int j=0;j<m;j++){
|
---|
100 | double t = 0;
|
---|
101 | for(int k=0;k<n;k++){
|
---|
102 | t+=x[i][k]*x[j][k]*em2;
|
---|
103 | }
|
---|
104 | q[i][j]=t;
|
---|
105 | }
|
---|
106 | }
|
---|
107 | // System.out.print("q=");Q.print(Q.getColumnDimension(), 8);
|
---|
108 |
|
---|
109 | // Matrix dQ = diag(Q);
|
---|
110 | // Matrix dQT = dQ.transpose();
|
---|
111 | // Matrix Qc = Q.copy();
|
---|
112 | // K = addValue(Qc,em2).arrayRightDivide(sqrt(addValue(dQ,1+em2)).times(sqrt(addValue(dQT,1+em2))));
|
---|
113 | // System.out.print("K=");K.print(K.getColumnDimension(), 8);
|
---|
114 |
|
---|
115 | double[] dq = new double[m];
|
---|
116 | for(int i=0;i<m;i++){
|
---|
117 | dq[i]=Math.sqrt(oneplusem2+q[i][i]);
|
---|
118 | }
|
---|
119 |
|
---|
120 | //K = new Matrix(m,m);
|
---|
121 | Matrix A = new Matrix(m,m);
|
---|
122 | //double[][] k;
|
---|
123 | k = new double[m][m];//K.getArray();
|
---|
124 | double[][] a =A.getArray();
|
---|
125 | for(int i=0;i<m;i++){
|
---|
126 | final double dqi = dq[i];
|
---|
127 | for(int j=0;j<m;j++){
|
---|
128 | final double t = (em2+q[i][j])/(dqi*dq[j]);
|
---|
129 | k[i][j]=t;
|
---|
130 | a[i][j]=sf2*Math.asin(t);
|
---|
131 | }
|
---|
132 | a[i][i]+=s2;
|
---|
133 | }
|
---|
134 | // System.out.print("k=");K.print(K.getColumnDimension(), 8);
|
---|
135 | // System.out.println("");
|
---|
136 |
|
---|
137 | // Matrix A = asin(K).times(sf2);
|
---|
138 | return A;
|
---|
139 | }
|
---|
140 |
|
---|
141 | /**
|
---|
142 | * Compute compute test set covariances
|
---|
143 | *
|
---|
144 | * @param loghyper column <code>Matrix</code> of hyperparameters
|
---|
145 | * @param X input dataset
|
---|
146 | * @param Xstar test set
|
---|
147 | * @return [K(Xstar,Xstar) K(X,Xstar)]
|
---|
148 | */
|
---|
149 | public Matrix[] compute(Matrix loghyper, Matrix X, Matrix Xstar) {
|
---|
150 |
|
---|
151 | if(loghyper.getColumnDimension()!=1 || loghyper.getRowDimension()!=numParameters())
|
---|
152 | throw new IllegalArgumentException("Wrong number of hyperparameters, "+loghyper.getRowDimension()+" instead of "+numParameters());
|
---|
153 |
|
---|
154 | final double ell = Math.exp(loghyper.get(0,0));
|
---|
155 | final double em2 = 1/(ell*ell);
|
---|
156 | final double oneplusem2 = 1+em2;
|
---|
157 | final double sf2 = Math.exp(2*loghyper.get(1,0));
|
---|
158 | final double s2 = Math.exp(2*loghyper.get(2,0));
|
---|
159 |
|
---|
160 |
|
---|
161 | final int m = X.getRowDimension();
|
---|
162 | final int n = X.getColumnDimension();
|
---|
163 | double[][] x= X.getArray();
|
---|
164 | final int mstar = Xstar.getRowDimension();
|
---|
165 | final int nstar = Xstar.getColumnDimension();
|
---|
166 | double[][] xstar= Xstar.getArray();
|
---|
167 |
|
---|
168 |
|
---|
169 | double[] sumxstardotTimesxstar = new double[mstar];
|
---|
170 | for(int i=0; i<mstar; i++){
|
---|
171 | double t =0;
|
---|
172 | for(int j=0; j<nstar; j++){
|
---|
173 | final double tt = xstar[i][j];
|
---|
174 | t+=tt*tt*em2;
|
---|
175 | }
|
---|
176 | sumxstardotTimesxstar[i]=t;
|
---|
177 | }
|
---|
178 |
|
---|
179 | Matrix A = new Matrix(mstar,1);
|
---|
180 | double[][] a = A.getArray();
|
---|
181 | for(int i=0; i<mstar; i++){
|
---|
182 | a[i][0]=sf2*Math.asin((em2+sumxstardotTimesxstar[i])/(oneplusem2+sumxstardotTimesxstar[i]))+s2;
|
---|
183 | }
|
---|
184 |
|
---|
185 |
|
---|
186 |
|
---|
187 | // X = X.times(1/ell);
|
---|
188 | // Xstar = Xstar.times(1/ell);
|
---|
189 | // Matrix tmp = sumRows(Xstar.arrayTimes(Xstar));
|
---|
190 | //
|
---|
191 | // Matrix tmp2 = tmp.copy();
|
---|
192 | // addValue(tmp,em2);
|
---|
193 | // addValue(tmp2,oneplusem2);
|
---|
194 | // Matrix A = asin(tmp.arrayRightDivide(tmp2)).times(sf2);
|
---|
195 |
|
---|
196 |
|
---|
197 | double[] sumxdotTimesx = new double[m];
|
---|
198 | for(int i=0; i<m; i++){
|
---|
199 | double t =0;
|
---|
200 | for(int j=0; j<n; j++){
|
---|
201 | final double tt = x[i][j];
|
---|
202 | t+=tt*tt*em2;
|
---|
203 | }
|
---|
204 | sumxdotTimesx[i]=t+oneplusem2;
|
---|
205 | }
|
---|
206 |
|
---|
207 | Matrix B = new Matrix(m,mstar);
|
---|
208 | double[][] b = B.getArray();
|
---|
209 | for(int i=0; i<m; i++){
|
---|
210 | final double[] xi = x[i];
|
---|
211 | for(int j=0; j<mstar; j++){
|
---|
212 | double t=0;
|
---|
213 | final double[] xstarj = xstar[j];
|
---|
214 | for(int k=0; k<n; k++){
|
---|
215 | t+=xi[k]*xstarj[k]*em2;
|
---|
216 | }
|
---|
217 | b[i][j]=t+em2;
|
---|
218 | }
|
---|
219 | }
|
---|
220 |
|
---|
221 | for(int i=0; i<m; i++){
|
---|
222 | for(int j=0; j<mstar; j++){
|
---|
223 | b[i][j] = sf2*Math.asin(b[i][j]/Math.sqrt((sumxstardotTimesxstar[j]+oneplusem2)*sumxdotTimesx[i]));
|
---|
224 | }
|
---|
225 | }
|
---|
226 |
|
---|
227 |
|
---|
228 |
|
---|
229 | // tmp = sumRows(X.arrayTimes(X));
|
---|
230 | // addValue(tmp,oneplusem2);
|
---|
231 | //
|
---|
232 | // tmp2=tmp2.transpose();
|
---|
233 | //
|
---|
234 | // tmp = addValue(X.times(Xstar.transpose()),em2).arrayRightDivide(sqrt(tmp.times(tmp2)));
|
---|
235 | // Matrix B = asin(tmp).times(sf2);
|
---|
236 |
|
---|
237 | //System.out.println("");
|
---|
238 | return new Matrix[]{A,B};
|
---|
239 | }
|
---|
240 |
|
---|
241 | /**
|
---|
242 | * Coompute the derivatives of this <code>PhoenixAlpha.CovarianceFunction</code> with respect
|
---|
243 | * to the hyperparameter with index <code>idx</code>
|
---|
244 | *
|
---|
245 | * @param loghyper hyperparameters
|
---|
246 | * @param X input dataset
|
---|
247 | * @param index hyperparameter index
|
---|
248 | * @return <code>Matrix</code> of derivatives
|
---|
249 | */
|
---|
250 | public Matrix computeDerivatives(Matrix loghyper, Matrix X, int index) {
|
---|
251 |
|
---|
252 | if(loghyper.getColumnDimension()!=1 || loghyper.getRowDimension()!=numParameters())
|
---|
253 | throw new IllegalArgumentException("Wrong number of hyperparameters, "+loghyper.getRowDimension()+" instead of "+numParameters());
|
---|
254 | if(index>numParameters()-1)
|
---|
255 | throw new IllegalArgumentException("Wrong hyperparameters index "+index+" it should be smaller or equal to "+(numParameters()-1));
|
---|
256 |
|
---|
257 | final double ell = Math.exp(loghyper.get(0,0));
|
---|
258 | final double em2 = 1/(ell*ell);
|
---|
259 | final double oneplusem2 = 1+em2;
|
---|
260 | final double twosf2 = 2*Math.exp(2*loghyper.get(1,0));
|
---|
261 | final double twos2 = 2*Math.exp(2*loghyper.get(2,0));
|
---|
262 |
|
---|
263 | final int m = X.getRowDimension();
|
---|
264 | final int n = X.getColumnDimension();
|
---|
265 | double[][] x= X.getArray();
|
---|
266 |
|
---|
267 | // Matrix X = XX.times(1/ell);
|
---|
268 | // double[][] q=null;
|
---|
269 | if(q==null || q.length!=m || q[0].length!=m) {
|
---|
270 | q = new double[m][m];
|
---|
271 |
|
---|
272 | for(int i=0;i<m;i++){
|
---|
273 | for(int j=0;j<m;j++){
|
---|
274 | double t = 0;
|
---|
275 | for(int k=0;k<n;k++){
|
---|
276 | t+=x[i][k]*x[j][k]*em2;
|
---|
277 | }
|
---|
278 | q[i][j]=t;
|
---|
279 | }
|
---|
280 | }
|
---|
281 | }
|
---|
282 |
|
---|
283 | double[] dq = new double[m];
|
---|
284 | for(int i=0;i<m;i++){
|
---|
285 | dq[i]=Math.sqrt(oneplusem2+q[i][i]);
|
---|
286 | }
|
---|
287 | // double[][] k=null;
|
---|
288 | if(k==null || k.length!=m || k[0].length!=m) {
|
---|
289 | k = new double[m][m];
|
---|
290 | for(int i=0;i<m;i++){
|
---|
291 | final double dqi = dq[i];
|
---|
292 | for(int j=0;j<m;j++){
|
---|
293 | final double t = (em2+q[i][j])/(dqi*dq[j]);
|
---|
294 | k[i][j]=t;
|
---|
295 | }
|
---|
296 | }
|
---|
297 | }
|
---|
298 |
|
---|
299 | // Matrix Xc= XX.times(1/ell);
|
---|
300 | // Matrix Q = Xc.times(Xc.transpose());
|
---|
301 | //
|
---|
302 | // Matrix dQ = diag(Q);
|
---|
303 | // Matrix dQT = dQ.transpose();
|
---|
304 | // Matrix K = addValue(Q.copy(),em2).arrayRightDivide(sqrt(addValue(dQ.copy(),1+em2)).times(sqrt(addValue(dQT,1+em2))));
|
---|
305 | // Matrix dQc = dQ.copy();
|
---|
306 |
|
---|
307 | Matrix A = null;
|
---|
308 | switch(index){
|
---|
309 | case 0:
|
---|
310 | for(int i=0;i<m;i++){
|
---|
311 | dq[i]=oneplusem2+q[i][i];
|
---|
312 | }
|
---|
313 | double[] v = new double[m];
|
---|
314 | for(int i=0; i<m; i++){
|
---|
315 | double t =0;
|
---|
316 | for(int j=0; j<n; j++){
|
---|
317 | final double xij = x[i][j];
|
---|
318 | t+=xij*xij*em2;
|
---|
319 | }
|
---|
320 | v[i]=(t+em2)/(dq[i]);
|
---|
321 | }
|
---|
322 |
|
---|
323 | // Matrix test = addValue(sumRows(X.arrayTimes(X)),em2);
|
---|
324 | // Matrix tmp = addValue(dQc,1+em2);
|
---|
325 | // Matrix V = addValue(sumRows(X.arrayTimes(X)),em2).arrayRightDivide(tmp);
|
---|
326 | //
|
---|
327 | // tmp = sqrt(tmp);
|
---|
328 | // tmp = addValue(Q.copy(),em2).arrayRightDivide(tmp.times(tmp.transpose()));
|
---|
329 |
|
---|
330 | for(int i=0; i<m; i++){
|
---|
331 | final double vi = v[i];
|
---|
332 | for(int j=0; j<m; j++){
|
---|
333 | double t =(q[i][j]+em2)/(Math.sqrt(dq[i])*Math.sqrt(dq[j]));
|
---|
334 | final double kij = k[i][j];
|
---|
335 | q[i][j]=-twosf2*((t-(0.5*kij*(vi+v[j])))/Math.sqrt(1-kij*kij));
|
---|
336 | }
|
---|
337 | }
|
---|
338 |
|
---|
339 | // Matrix tmp2 = new Matrix(m,m);
|
---|
340 | // for(int j=0; j<m; j++)
|
---|
341 | // tmp2.setMatrix(0,m-1,j,j,V);
|
---|
342 | //
|
---|
343 | // tmp = tmp.minus(K.arrayTimes(tmp2.plus(tmp2.transpose())).times(0.5));
|
---|
344 | //
|
---|
345 | // A = tmp.arrayRightDivide(sqrtOneMinusSqr(K)).times(-twosf2);
|
---|
346 |
|
---|
347 | A = new Matrix(q);
|
---|
348 | // System.out.println("");
|
---|
349 | q=null;
|
---|
350 | break;
|
---|
351 | case 1:
|
---|
352 | for(int i=0; i<m; i++){
|
---|
353 | for(int j=0; j<m; j++){
|
---|
354 | k[i][j]=Math.asin(k[i][j])*twosf2;
|
---|
355 | }
|
---|
356 | }
|
---|
357 | // A = asin(K).times(twosf2);
|
---|
358 | // K=null;
|
---|
359 | A = new Matrix(k);
|
---|
360 | k=null;
|
---|
361 |
|
---|
362 | break;
|
---|
363 |
|
---|
364 | case 2:
|
---|
365 | double[][] a = new double[m][m];
|
---|
366 | for(int i=0; i<m;i++) a[i][i]=twos2;
|
---|
367 |
|
---|
368 | A = new Matrix(a);
|
---|
369 |
|
---|
370 | break;
|
---|
371 | default:
|
---|
372 | throw new IllegalArgumentException("the covariance function PhoenixAlpha.CovNNoneNoise alllows for a maximum of 3 parameters!!");
|
---|
373 | }
|
---|
374 | return A;
|
---|
375 | }
|
---|
376 |
|
---|
377 | // private static Matrix sqrtOneMinusSqr(Matrix in){
|
---|
378 | // Matrix out = new Matrix(in.getRowDimension(),in.getColumnDimension());
|
---|
379 | // for(int i=0; i<in.getRowDimension(); i++)
|
---|
380 | // for(int j=0; j<in.getColumnDimension(); j++) {
|
---|
381 | // final double tmp = in.get(i,j);
|
---|
382 | // out.set(i,j,Math.sqrt(1-tmp*tmp));
|
---|
383 | // }
|
---|
384 | // return out;
|
---|
385 | // }
|
---|
386 |
|
---|
387 | public static void main(String[] args) {
|
---|
388 |
|
---|
389 | CovarianceFunction cf = new CovNNoneNoise();
|
---|
390 | CovarianceFunction cf2 = new CovSum(6,new CovNNone(), new CovNoise());
|
---|
391 |
|
---|
392 |
|
---|
393 | Matrix X = Matrix.identity(10,6);
|
---|
394 |
|
---|
395 | for(int i=0; i<X.getRowDimension(); i++)
|
---|
396 | for(int j=0; j<X.getColumnDimension(); j++)
|
---|
397 | X.set(i,j,Math.random());
|
---|
398 |
|
---|
399 | Matrix logtheta = new Matrix(new double[][]{{0.1},{0.2},{Math.log(0.1)}});
|
---|
400 |
|
---|
401 | Matrix z =new Matrix(new double[][]{{1,2,3,4,5,6},{1,2,3,4,5,6}});
|
---|
402 |
|
---|
403 | // long start = System.currentTimeMillis();
|
---|
404 | // Matrix K = cf.compute(logtheta,X);
|
---|
405 | // K.print(K.getColumnDimension(), 15);
|
---|
406 |
|
---|
407 | // long stop = System.currentTimeMillis();
|
---|
408 | // System.out.println(""+(stop-start));
|
---|
409 | //
|
---|
410 | // start = System.currentTimeMillis();
|
---|
411 | // K = cf2.compute(logtheta,X);
|
---|
412 | // K.print(K.getColumnDimension(), 15);
|
---|
413 | // stop = System.currentTimeMillis();
|
---|
414 | // System.out.println(""+(stop-start));
|
---|
415 |
|
---|
416 |
|
---|
417 |
|
---|
418 | // long start = System.currentTimeMillis();
|
---|
419 | // Matrix[] res = cf.compute(logtheta,X,z);
|
---|
420 | // res[0].print(res[0].getColumnDimension(), 8);
|
---|
421 | // res[1].print(res[1].getColumnDimension(), 8);
|
---|
422 |
|
---|
423 | // long stop = System.currentTimeMillis();
|
---|
424 | // System.out.println(""+(stop-start));
|
---|
425 |
|
---|
426 |
|
---|
427 |
|
---|
428 | // res = cf2.compute(logtheta,X,z);
|
---|
429 | // res[0].print(res[0].getColumnDimension(), 8);
|
---|
430 | // res[1].print(res[1].getColumnDimension(), 8);
|
---|
431 |
|
---|
432 | // long stop = System.currentTimeMillis();
|
---|
433 | // System.out.println(""+(stop-start));
|
---|
434 |
|
---|
435 |
|
---|
436 |
|
---|
437 |
|
---|
438 | // Matrix d = cf.computeDerivatives(logtheta,X,0);
|
---|
439 | // d.print(d.getColumnDimension(), 8);
|
---|
440 |
|
---|
441 | //d = cf2.computeDerivatives(logtheta,X,0);
|
---|
442 | //d.print(d.getColumnDimension(), 8);
|
---|
443 |
|
---|
444 |
|
---|
445 | }
|
---|
446 | }
|
---|