[1] | 1 | package agents.anac.y2015.Phoenix.GP;/* This file is part of the jgpml Project.
|
---|
| 2 | * http://github.com/renzodenardi/jgpml
|
---|
| 3 | *
|
---|
| 4 | * Copyright (c) 2011 Renzo De Nardi and Hugo Gravato-Marques
|
---|
| 5 | *
|
---|
| 6 | * Permission is hereby granted, free of charge, to any person
|
---|
| 7 | * obtaining a copy of this software and associated documentation
|
---|
| 8 | * files (the "Software"), to deal in the Software without
|
---|
| 9 | * restriction, including without limitation the rights to use,
|
---|
| 10 | * copy, modify, merge, publish, distribute, sublicense, and/or sell
|
---|
| 11 | * copies of the Software, and to permit persons to whom the
|
---|
| 12 | * Software is furnished to do so, subject to the following
|
---|
| 13 | * conditions:
|
---|
| 14 | *
|
---|
| 15 | * The above copyright notice and this permission notice shall be
|
---|
| 16 | * included in all copies or substantial portions of the Software.
|
---|
| 17 | *
|
---|
| 18 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
---|
| 19 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
---|
| 20 | * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
---|
| 21 | * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
|
---|
| 22 | * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
---|
| 23 | * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
---|
| 24 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
---|
| 25 | * OTHER DEALINGS IN THE SOFTWARE.
|
---|
| 26 | */
|
---|
| 27 |
|
---|
| 28 | import agents.Jama.Matrix;
|
---|
| 29 |
|
---|
| 30 | /**
|
---|
| 31 | * Neural network covariance function with a single parameter for the distance
|
---|
| 32 | * measure. The covariance function is parameterized as:
|
---|
| 33 | * <P>
|
---|
| 34 | * k(x^p,x^q) = sf2 * asin(x^p'*P*x^q / sqrt[(1+x^p'*P*x^p)*(1+x^q'*P*x^q)])
|
---|
| 35 | * <P>
|
---|
| 36 | * where the x^p and x^q vectors on the right hand side have an added extra bias
|
---|
| 37 | * entry with unit value. P is ell^-2 times the unit matrix and sf2 controls the
|
---|
| 38 | * signal variance. The hyperparameters are:
|
---|
| 39 | * <P>
|
---|
| 40 | * [ log(ell)
|
---|
| 41 | * log(sqrt(sf2) ]
|
---|
| 42 | */
|
---|
| 43 |
|
---|
| 44 | public class CovNNone implements CovarianceFunction{
|
---|
| 45 |
|
---|
| 46 | double[][] k;
|
---|
| 47 | double[][] q;
|
---|
| 48 |
|
---|
| 49 | public CovNNone(){}
|
---|
| 50 |
|
---|
| 51 |
|
---|
| 52 | /**
|
---|
| 53 | * Returns the number of hyperparameters of this<code>PhoenixAlpha.CovarianceFunction</code>
|
---|
| 54 | *
|
---|
| 55 | * @return number of hyperparameters
|
---|
| 56 | */
|
---|
| 57 | public int numParameters() {
|
---|
| 58 | return 2;
|
---|
| 59 | }
|
---|
| 60 |
|
---|
| 61 | /**
|
---|
| 62 | * Compute covariance matrix of a dataset X
|
---|
| 63 | *
|
---|
| 64 | * @param loghyper column <code>Matrix</code> of hyperparameters
|
---|
| 65 | * @param X input dataset
|
---|
| 66 | * @return K covariance <code>Matrix</code>
|
---|
| 67 | */
|
---|
| 68 | public Matrix compute(Matrix loghyper, Matrix X) {
|
---|
| 69 |
|
---|
| 70 | if(loghyper.getColumnDimension()!=1 || loghyper.getRowDimension()!=numParameters())
|
---|
| 71 | throw new IllegalArgumentException("Wrong number of hyperparameters, "+loghyper.getRowDimension()+" instead of "+numParameters());
|
---|
| 72 |
|
---|
| 73 | final double ell = Math.exp(loghyper.get(0,0));
|
---|
| 74 | final double em2 = 1/(ell*ell);
|
---|
| 75 | final double oneplusem2 = 1+em2;
|
---|
| 76 | final double sf2 = Math.exp(2*loghyper.get(1,0));
|
---|
| 77 |
|
---|
| 78 |
|
---|
| 79 | final int m = X.getRowDimension();
|
---|
| 80 | final int n = X.getColumnDimension();
|
---|
| 81 | double[][] x= X.getArray();
|
---|
| 82 |
|
---|
| 83 | // Matrix Xc= X.times(1/ell);
|
---|
| 84 | //
|
---|
| 85 | // Q = Xc.times(Xc.transpose());
|
---|
| 86 | // System.out.print("Q=");Q.print(Q.getColumnDimension(), 8);
|
---|
| 87 |
|
---|
| 88 | // Q = new Matrix(m,m);
|
---|
| 89 | // double[][] q = Q.getArray();
|
---|
| 90 | q = new double[m][m];
|
---|
| 91 |
|
---|
| 92 | for(int i=0;i<m;i++){
|
---|
| 93 | for(int j=0;j<m;j++){
|
---|
| 94 | double t = 0;
|
---|
| 95 | for(int k=0;k<n;k++){
|
---|
| 96 | t+=x[i][k]*x[j][k]*em2;
|
---|
| 97 | }
|
---|
| 98 | q[i][j]=t;
|
---|
| 99 | }
|
---|
| 100 | }
|
---|
| 101 | // System.out.print("q=");Q.print(Q.getColumnDimension(), 8);
|
---|
| 102 |
|
---|
| 103 | // Matrix dQ = diag(Q);
|
---|
| 104 | // Matrix dQT = dQ.transpose();
|
---|
| 105 | // Matrix Qc = Q.copy();
|
---|
| 106 | // K = addValue(Qc,em2).arrayRightDivide(sqrt(addValue(dQ,1+em2)).times(sqrt(addValue(dQT,1+em2))));
|
---|
| 107 | // System.out.print("K=");K.print(K.getColumnDimension(), 8);
|
---|
| 108 |
|
---|
| 109 | double[] dq = new double[m];
|
---|
| 110 | for(int i=0;i<m;i++){
|
---|
| 111 | dq[i]=Math.sqrt(oneplusem2+q[i][i]);
|
---|
| 112 | }
|
---|
| 113 |
|
---|
| 114 | //K = new Matrix(m,m);
|
---|
| 115 | Matrix A = new Matrix(m,m);
|
---|
| 116 | double[][] k = new double[m][m];//K.getArray();
|
---|
| 117 | double[][] a =A.getArray();
|
---|
| 118 | for(int i=0;i<m;i++){
|
---|
| 119 | final double dqi = dq[i];
|
---|
| 120 | for(int j=0;j<m;j++){
|
---|
| 121 | final double t = (em2+q[i][j])/(dqi*dq[j]);
|
---|
| 122 | k[i][j]=t;
|
---|
| 123 | a[i][j]=sf2*Math.asin(t);
|
---|
| 124 | }
|
---|
| 125 | }
|
---|
| 126 | // System.out.print("k=");K.print(K.getColumnDimension(), 8);
|
---|
| 127 | // System.out.println("");
|
---|
| 128 |
|
---|
| 129 | // Matrix A = asin(K).times(sf2);
|
---|
| 130 | return A;
|
---|
| 131 | }
|
---|
| 132 |
|
---|
| 133 | /**
|
---|
| 134 | * Compute compute test set covariances
|
---|
| 135 | *
|
---|
| 136 | * @param loghyper column <code>Matrix</code> of hyperparameters
|
---|
| 137 | * @param X input dataset
|
---|
| 138 | * @param Xstar test set
|
---|
| 139 | * @return [K(Xstar,Xstar) K(X,Xstar)]
|
---|
| 140 | */
|
---|
| 141 | public Matrix[] compute(Matrix loghyper, Matrix X, Matrix Xstar) {
|
---|
| 142 |
|
---|
| 143 | if(loghyper.getColumnDimension()!=1 || loghyper.getRowDimension()!=numParameters())
|
---|
| 144 | throw new IllegalArgumentException("Wrong number of hyperparameters, "+loghyper.getRowDimension()+" instead of "+numParameters());
|
---|
| 145 |
|
---|
| 146 | final double ell = Math.exp(loghyper.get(0,0));
|
---|
| 147 | final double em2 = 1/(ell*ell);
|
---|
| 148 | final double oneplusem2 = 1+em2;
|
---|
| 149 | final double sf2 = Math.exp(2*loghyper.get(1,0));
|
---|
| 150 |
|
---|
| 151 |
|
---|
| 152 |
|
---|
| 153 | final int m = X.getRowDimension();
|
---|
| 154 | final int n = X.getColumnDimension();
|
---|
| 155 | double[][] x= X.getArray();
|
---|
| 156 | final int mstar = Xstar.getRowDimension();
|
---|
| 157 | final int nstar = Xstar.getColumnDimension();
|
---|
| 158 | double[][] xstar= Xstar.getArray();
|
---|
| 159 |
|
---|
| 160 |
|
---|
| 161 | double[] sumxstardotTimesxstar = new double[mstar];
|
---|
| 162 | for(int i=0; i<mstar; i++){
|
---|
| 163 | double t =0;
|
---|
| 164 | for(int j=0; j<nstar; j++){
|
---|
| 165 | final double tt = xstar[i][j];
|
---|
| 166 | t+=tt*tt*em2;
|
---|
| 167 | }
|
---|
| 168 | sumxstardotTimesxstar[i]=t;
|
---|
| 169 | }
|
---|
| 170 |
|
---|
| 171 | Matrix A = new Matrix(mstar,1);
|
---|
| 172 | double[][] a = A.getArray();
|
---|
| 173 | for(int i=0; i<mstar; i++){
|
---|
| 174 | a[i][0]=sf2*Math.asin((em2+sumxstardotTimesxstar[i])/(oneplusem2+sumxstardotTimesxstar[i]));
|
---|
| 175 | }
|
---|
| 176 |
|
---|
| 177 |
|
---|
| 178 |
|
---|
| 179 | // X = X.times(1/ell);
|
---|
| 180 | // Xstar = Xstar.times(1/ell);
|
---|
| 181 | // Matrix tmp = sumRows(Xstar.arrayTimes(Xstar));
|
---|
| 182 | //
|
---|
| 183 | // Matrix tmp2 = tmp.copy();
|
---|
| 184 | // addValue(tmp,em2);
|
---|
| 185 | // addValue(tmp2,oneplusem2);
|
---|
| 186 | // Matrix A = asin(tmp.arrayRightDivide(tmp2)).times(sf2);
|
---|
| 187 |
|
---|
| 188 |
|
---|
| 189 | double[] sumxdotTimesx = new double[m];
|
---|
| 190 | for(int i=0; i<m; i++){
|
---|
| 191 | double t =0;
|
---|
| 192 | for(int j=0; j<n; j++){
|
---|
| 193 | final double tt = x[i][j];
|
---|
| 194 | t+=tt*tt*em2;
|
---|
| 195 | }
|
---|
| 196 | sumxdotTimesx[i]=t+oneplusem2;
|
---|
| 197 | }
|
---|
| 198 |
|
---|
| 199 | Matrix B = new Matrix(m,mstar);
|
---|
| 200 | double[][] b = B.getArray();
|
---|
| 201 | for(int i=0; i<m; i++){
|
---|
| 202 | final double[] xi = x[i];
|
---|
| 203 | for(int j=0; j<mstar; j++){
|
---|
| 204 | double t=0;
|
---|
| 205 | final double[] xstarj = xstar[j];
|
---|
| 206 | for(int k=0; k<n; k++){
|
---|
| 207 | t+=xi[k]*xstarj[k]*em2;
|
---|
| 208 | }
|
---|
| 209 | b[i][j]=t+em2;
|
---|
| 210 | }
|
---|
| 211 | }
|
---|
| 212 |
|
---|
| 213 | for(int i=0; i<m; i++){
|
---|
| 214 | for(int j=0; j<mstar; j++){
|
---|
| 215 | b[i][j] = sf2*Math.asin(b[i][j]/Math.sqrt((sumxstardotTimesxstar[j]+oneplusem2)*sumxdotTimesx[i]));
|
---|
| 216 | }
|
---|
| 217 | }
|
---|
| 218 |
|
---|
| 219 |
|
---|
| 220 |
|
---|
| 221 | // tmp = sumRows(X.arrayTimes(X));
|
---|
| 222 | // addValue(tmp,oneplusem2);
|
---|
| 223 | //
|
---|
| 224 | // tmp2=tmp2.transpose();
|
---|
| 225 | //
|
---|
| 226 | // tmp = addValue(X.times(Xstar.transpose()),em2).arrayRightDivide(sqrt(tmp.times(tmp2)));
|
---|
| 227 | // Matrix B = asin(tmp).times(sf2);
|
---|
| 228 |
|
---|
| 229 | //System.out.println("");
|
---|
| 230 | return new Matrix[]{A,B};
|
---|
| 231 | }
|
---|
| 232 |
|
---|
| 233 | /**
|
---|
| 234 | * Coompute the derivatives of this <code>PhoenixAlpha.CovarianceFunction</code> with respect
|
---|
| 235 | * to the hyperparameter with index <code>idx</code>
|
---|
| 236 | *
|
---|
| 237 | * @param loghyper hyperparameters
|
---|
| 238 | * @param X input dataset
|
---|
| 239 | * @param index hyperparameter index
|
---|
| 240 | * @return <code>Matrix</code> of derivatives
|
---|
| 241 | */
|
---|
| 242 | public Matrix computeDerivatives(Matrix loghyper, Matrix X, int index) {
|
---|
| 243 |
|
---|
| 244 | if(loghyper.getColumnDimension()!=1 || loghyper.getRowDimension()!=numParameters())
|
---|
| 245 | throw new IllegalArgumentException("Wrong number of hyperparameters, "+loghyper.getRowDimension()+" instead of "+numParameters());
|
---|
| 246 | if(index>numParameters()-1)
|
---|
| 247 | throw new IllegalArgumentException("Wrong hyperparameters index "+index+" it should be smaller or equal to "+(numParameters()-1));
|
---|
| 248 |
|
---|
| 249 | final double ell = Math.exp(loghyper.get(0,0));
|
---|
| 250 | final double em2 = 1/(ell*ell);
|
---|
| 251 | final double oneplusem2 = 1+em2;
|
---|
| 252 | final double twosf2 = 2*Math.exp(2*loghyper.get(1,0));
|
---|
| 253 |
|
---|
| 254 | final int m = X.getRowDimension();
|
---|
| 255 | final int n = X.getColumnDimension();
|
---|
| 256 | double[][] x= X.getArray();
|
---|
| 257 |
|
---|
| 258 | // Matrix X = XX.times(1/ell);
|
---|
| 259 |
|
---|
| 260 | if(q==null || q.length!=m || q[0].length!=m) {
|
---|
| 261 | q = new double[m][m];
|
---|
| 262 |
|
---|
| 263 | for(int i=0;i<m;i++){
|
---|
| 264 | for(int j=0;j<m;j++){
|
---|
| 265 | double t = 0;
|
---|
| 266 | for(int k=0;k<n;k++){
|
---|
| 267 | t+=x[i][k]*x[j][k]*em2;
|
---|
| 268 | }
|
---|
| 269 | q[i][j]=t;
|
---|
| 270 | }
|
---|
| 271 | }
|
---|
| 272 | }
|
---|
| 273 |
|
---|
| 274 | double[] dq = new double[m];
|
---|
| 275 | for(int i=0;i<m;i++){
|
---|
| 276 | dq[i]=Math.sqrt(oneplusem2+q[i][i]);
|
---|
| 277 | }
|
---|
| 278 |
|
---|
| 279 | if(k==null || k.length!=m || k[0].length!=m) {
|
---|
| 280 | k = new double[m][m];
|
---|
| 281 | for(int i=0;i<m;i++){
|
---|
| 282 | final double dqi = dq[i];
|
---|
| 283 | for(int j=0;j<m;j++){
|
---|
| 284 | final double t = (em2+q[i][j])/(dqi*dq[j]);
|
---|
| 285 | k[i][j]=t;
|
---|
| 286 | }
|
---|
| 287 | }
|
---|
| 288 | }
|
---|
| 289 |
|
---|
| 290 | // Matrix Xc= XX.times(1/ell);
|
---|
| 291 | // Matrix Q = Xc.times(Xc.transpose());
|
---|
| 292 | //
|
---|
| 293 | // Matrix dQ = diag(Q);
|
---|
| 294 | // Matrix dQT = dQ.transpose();
|
---|
| 295 | // Matrix K = addValue(Q.copy(),em2).arrayRightDivide(sqrt(addValue(dQ.copy(),1+em2)).times(sqrt(addValue(dQT,1+em2))));
|
---|
| 296 | // Matrix dQc = dQ.copy();
|
---|
| 297 |
|
---|
| 298 | Matrix A;
|
---|
| 299 | if(index==0){
|
---|
| 300 | for(int i=0;i<m;i++){
|
---|
| 301 | dq[i]=oneplusem2+q[i][i];
|
---|
| 302 | }
|
---|
| 303 | double[] v = new double[m];
|
---|
| 304 | for(int i=0; i<m; i++){
|
---|
| 305 | double t =0;
|
---|
| 306 | for(int j=0; j<n; j++){
|
---|
| 307 | final double xij = x[i][j];
|
---|
| 308 | t+=xij*xij*em2;
|
---|
| 309 | }
|
---|
| 310 | v[i]=(t+em2)/(dq[i]);
|
---|
| 311 | }
|
---|
| 312 |
|
---|
| 313 | // Matrix test = addValue(sumRows(X.arrayTimes(X)),em2);
|
---|
| 314 | // Matrix tmp = addValue(dQc,1+em2);
|
---|
| 315 | // Matrix V = addValue(sumRows(X.arrayTimes(X)),em2).arrayRightDivide(tmp);
|
---|
| 316 | //
|
---|
| 317 | // tmp = sqrt(tmp);
|
---|
| 318 | // tmp = addValue(Q.copy(),em2).arrayRightDivide(tmp.times(tmp.transpose()));
|
---|
| 319 |
|
---|
| 320 | for(int i=0; i<m; i++){
|
---|
| 321 | final double vi = v[i];
|
---|
| 322 | for(int j=0; j<m; j++){
|
---|
| 323 | double t =(q[i][j]+em2)/(Math.sqrt(dq[i])*Math.sqrt(dq[j]));
|
---|
| 324 | final double kij = k[i][j];
|
---|
| 325 | q[i][j]=-twosf2*((t-(0.5*kij*(vi+v[j])))/Math.sqrt(1-kij*kij));
|
---|
| 326 | }
|
---|
| 327 | }
|
---|
| 328 |
|
---|
| 329 | // Matrix tmp2 = new Matrix(m,m);
|
---|
| 330 | // for(int j=0; j<m; j++)
|
---|
| 331 | // tmp2.setMatrix(0,m-1,j,j,V);
|
---|
| 332 | //
|
---|
| 333 | // tmp = tmp.minus(K.arrayTimes(tmp2.plus(tmp2.transpose())).times(0.5));
|
---|
| 334 | //
|
---|
| 335 | // A = tmp.arrayRightDivide(sqrtOneMinusSqr(K)).times(-twosf2);
|
---|
| 336 |
|
---|
| 337 | A = new Matrix(q);
|
---|
| 338 | // System.out.println("");
|
---|
| 339 | q=null;
|
---|
| 340 | } else{
|
---|
| 341 | for(int i=0; i<m; i++){
|
---|
| 342 | for(int j=0; j<m; j++){
|
---|
| 343 | k[i][j]=Math.asin(k[i][j])*twosf2;
|
---|
| 344 | }
|
---|
| 345 | }
|
---|
| 346 | // A = asin(K).times(twosf2);
|
---|
| 347 | // K=null;
|
---|
| 348 | A = new Matrix(k);
|
---|
| 349 | k=null;
|
---|
| 350 | }
|
---|
| 351 |
|
---|
| 352 |
|
---|
| 353 | return A;
|
---|
| 354 | }
|
---|
| 355 |
|
---|
| 356 | // private static Matrix sqrtOneMinusSqr(Matrix in){
|
---|
| 357 | // Matrix out = new Matrix(in.getRowDimension(),in.getColumnDimension());
|
---|
| 358 | // for(int i=0; i<in.getRowDimension(); i++)
|
---|
| 359 | // for(int j=0; j<in.getColumnDimension(); j++) {
|
---|
| 360 | // final double tmp = in.get(i,j);
|
---|
| 361 | // out.set(i,j,Math.sqrt(1-tmp*tmp));
|
---|
| 362 | // }
|
---|
| 363 | // return out;
|
---|
| 364 | // }
|
---|
| 365 |
|
---|
| 366 | public static void main(String[] args) {
|
---|
| 367 |
|
---|
| 368 | CovNNone cf = new CovNNone();
|
---|
| 369 |
|
---|
| 370 | Matrix X = Matrix.identity(6,6);
|
---|
| 371 | Matrix logtheta = new Matrix(new double[][]{{0.1},{0.2}});
|
---|
| 372 |
|
---|
| 373 | Matrix z =new Matrix(new double[][]{{1,2,3,4,5,6},{1,2,3,4,5,6}});
|
---|
| 374 |
|
---|
| 375 | // System.out.println("")
|
---|
| 376 | //
|
---|
| 377 | // long start = System.currentTimeMillis()
|
---|
| 378 | //
|
---|
| 379 | // Matrix K = cf.compute(logtheta,X);
|
---|
| 380 | // long stop = System.currentTimeMillis();
|
---|
| 381 | // System.out.println(""+(stop-start));
|
---|
| 382 |
|
---|
| 383 | // K.print(K.getColumnDimension(), 15);
|
---|
| 384 |
|
---|
| 385 | // long start = System.currentTimeMillis();
|
---|
| 386 | // Matrix[] res = cf.compute(logtheta,X,z);
|
---|
| 387 | // long stop = System.currentTimeMillis();
|
---|
| 388 | // System.out.println(""+(stop-start));
|
---|
| 389 |
|
---|
| 390 | // res[0].print(res[0].getColumnDimension(), 8);
|
---|
| 391 | // res[1].print(res[1].getColumnDimension(), 8);
|
---|
| 392 |
|
---|
| 393 | Matrix d = cf.computeDerivatives(logtheta,X,1);
|
---|
| 394 |
|
---|
| 395 | d.print(d.getColumnDimension(), 8);
|
---|
| 396 |
|
---|
| 397 | }
|
---|
| 398 | }
|
---|