[1] | 1 | package agents.anac.y2015.Phoenix.GP;/* This file is part of the jgpml Project.
|
---|
| 2 | * http://github.com/renzodenardi/jgpml
|
---|
| 3 | *
|
---|
| 4 | * Copyright (c) 2011 Renzo De Nardi and Hugo Gravato-Marques
|
---|
| 5 | *
|
---|
| 6 | * Permission is hereby granted, free of charge, to any person
|
---|
| 7 | * obtaining a copy of this software and associated documentation
|
---|
| 8 | * files (the "Software"), to deal in the Software without
|
---|
| 9 | * restriction, including without limitation the rights to use,
|
---|
| 10 | * copy, modify, merge, publish, distribute, sublicense, and/or sell
|
---|
| 11 | * copies of the Software, and to permit persons to whom the
|
---|
| 12 | * Software is furnished to do so, subject to the following
|
---|
| 13 | * conditions:
|
---|
| 14 | *
|
---|
| 15 | * The above copyright notice and this permission notice shall be
|
---|
| 16 | * included in all copies or substantial portions of the Software.
|
---|
| 17 | *
|
---|
| 18 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
---|
| 19 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
---|
| 20 | * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
---|
| 21 | * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
|
---|
| 22 | * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
---|
| 23 | * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
---|
| 24 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
---|
| 25 | * OTHER DEALINGS IN THE SOFTWARE.
|
---|
| 26 | */
|
---|
| 27 |
|
---|
| 28 | import agents.Jama.CholeskyDecomposition;
|
---|
| 29 | import agents.Jama.Matrix;
|
---|
| 30 |
|
---|
| 31 | /**
|
---|
| 32 | * Main class of the package, contains the objects that constitutes a Gaussian
|
---|
| 33 | * Process as well as the algorithm to train the Hyperparameters and to do
|
---|
| 34 | * predictions.
|
---|
| 35 | */
|
---|
| 36 | public class GaussianProcess {
|
---|
| 37 |
|
---|
| 38 | /**
|
---|
| 39 | * hyperparameters
|
---|
| 40 | */
|
---|
| 41 | public Matrix logtheta;
|
---|
| 42 |
|
---|
| 43 | /**
|
---|
| 44 | * input data points
|
---|
| 45 | */
|
---|
| 46 | public Matrix X;
|
---|
| 47 |
|
---|
| 48 | /**
|
---|
| 49 | * Cholesky decomposition of the input
|
---|
| 50 | */
|
---|
| 51 | public Matrix L;
|
---|
| 52 |
|
---|
| 53 | /**
|
---|
| 54 | * partial factor
|
---|
| 55 | */
|
---|
| 56 | public Matrix alpha;
|
---|
| 57 |
|
---|
| 58 | /**
|
---|
| 59 | * covariance function
|
---|
| 60 | */
|
---|
| 61 | CovarianceFunction covFunction;
|
---|
| 62 |
|
---|
| 63 | /**
|
---|
| 64 | * Creates a new GP object.
|
---|
| 65 | *
|
---|
| 66 | * @param covFunction
|
---|
| 67 | * - the covariance function
|
---|
| 68 | */
|
---|
| 69 | public GaussianProcess(CovarianceFunction covFunction) {
|
---|
| 70 | this.covFunction = covFunction;
|
---|
| 71 | }
|
---|
| 72 |
|
---|
| 73 | /**
|
---|
| 74 | * Trains the GP Hyperparameters maximizing the marginal likelihood. By
|
---|
| 75 | * default the minimisation algorithm performs 100 iterations.
|
---|
| 76 | *
|
---|
| 77 | * @param X
|
---|
| 78 | * - the input data points
|
---|
| 79 | * @param y
|
---|
| 80 | * - the target data points
|
---|
| 81 | * @param logtheta0
|
---|
| 82 | * - the initial hyperparameters of the covariance function
|
---|
| 83 | */
|
---|
| 84 | public void train(Matrix X, Matrix y, Matrix logtheta0) {
|
---|
| 85 | train(X, y, logtheta0, -100);
|
---|
| 86 | }
|
---|
| 87 |
|
---|
| 88 | /**
|
---|
| 89 | * Trains the GP Hyperparameters maximizing the marginal likelihood. By
|
---|
| 90 | * default the algorithm performs 100 iterations.
|
---|
| 91 | *
|
---|
| 92 | * @param X
|
---|
| 93 | * - the input data points
|
---|
| 94 | * @param y
|
---|
| 95 | * - the target data points
|
---|
| 96 | * @param logtheta0
|
---|
| 97 | * - the initial hyperparameters of the covariance function
|
---|
| 98 | * @param iterations
|
---|
| 99 | * - number of iterations performed by the minimization algorithm
|
---|
| 100 | */
|
---|
| 101 | public void train(Matrix X, Matrix y, Matrix logtheta0, int iterations) {
|
---|
| 102 | // System.out.println("X Dimension:"+X.getRowDimension()+"x"+X.getColumnDimension());
|
---|
| 103 | // System.out.println("y Dimension:"+y.getRowDimension()+"x"+y.getColumnDimension());
|
---|
| 104 | // System.out.println("training started...");
|
---|
| 105 | this.X = X;
|
---|
| 106 | logtheta = minimize(logtheta0, iterations, X, y);
|
---|
| 107 | }
|
---|
| 108 |
|
---|
| 109 | /**
|
---|
| 110 | * Computes minus the log likelihood and its partial derivatives with
|
---|
| 111 | * respect to the hyperparameters; this mode is used to fit the
|
---|
| 112 | * hyperparameters.
|
---|
| 113 | *
|
---|
| 114 | * @param logtheta
|
---|
| 115 | * column <code>Matrix</code> of hyperparameters
|
---|
| 116 | * @param y
|
---|
| 117 | * output dataset
|
---|
| 118 | * @param df0
|
---|
| 119 | * returned partial derivatives with respect to the
|
---|
| 120 | * hyperparameters
|
---|
| 121 | * @return lml minus log marginal likelihood
|
---|
| 122 | */
|
---|
| 123 | public double negativeLogLikelihood(Matrix logtheta, Matrix x, Matrix y,
|
---|
| 124 | Matrix df0) {
|
---|
| 125 |
|
---|
| 126 | int n = x.getRowDimension();
|
---|
| 127 |
|
---|
| 128 | Matrix K = covFunction.compute(logtheta, x); // compute training set
|
---|
| 129 | // covariance matrix
|
---|
| 130 |
|
---|
| 131 | CholeskyDecomposition cd = K.chol();
|
---|
| 132 | if (!cd.isSPD()) {
|
---|
| 133 | throw new RuntimeException(
|
---|
| 134 | "The covariance Matrix is not SDP, check your covariance function (maybe you mess the noise term..)");
|
---|
| 135 | } else {
|
---|
| 136 | L = cd.getL(); // cholesky factorization of the covariance
|
---|
| 137 |
|
---|
| 138 | // alpha = L'\(L\y);
|
---|
| 139 | alpha = bSubstitutionWithTranspose(L, fSubstitution(L, y));
|
---|
| 140 |
|
---|
| 141 | // double[][] yarr = y.getArray();
|
---|
| 142 | // double[][] alphaarr = alpha.getArray();
|
---|
| 143 | // double lml =0;
|
---|
| 144 | // for(int i=0; i<n; i++){
|
---|
| 145 | // lml+= yarr[i][0]*alphaarr[i][0];
|
---|
| 146 | // }
|
---|
| 147 | // lml*=0.5;
|
---|
| 148 | //
|
---|
| 149 |
|
---|
| 150 | // compute the negative log marginal likelihood
|
---|
| 151 | double lml = (y.transpose().times(alpha).times(0.5)).get(0, 0);
|
---|
| 152 |
|
---|
| 153 | for (int i = 0; i < L.getRowDimension(); i++)
|
---|
| 154 | lml += Math.log(L.get(i, i));
|
---|
| 155 | lml += 0.5 * n * Math.log(2 * Math.PI);
|
---|
| 156 |
|
---|
| 157 | Matrix W = bSubstitutionWithTranspose(L,
|
---|
| 158 | (fSubstitution(L, Matrix.identity(n, n)))).minus(
|
---|
| 159 | alpha.times(alpha.transpose())); // precompute for
|
---|
| 160 | // convenience
|
---|
| 161 | for (int i = 0; i < df0.getRowDimension(); i++) {
|
---|
| 162 | df0.set(i, 0, sum(W.arrayTimes(covFunction.computeDerivatives(
|
---|
| 163 | logtheta, x, i))) / 2);
|
---|
| 164 | }
|
---|
| 165 |
|
---|
| 166 | return lml;
|
---|
| 167 | }
|
---|
| 168 | }
|
---|
| 169 |
|
---|
| 170 | /**
|
---|
| 171 | * Computes Gaussian predictions, whose mean and variance are returned. Note
|
---|
| 172 | * that in cases where the covariance function has noise contributions, the
|
---|
| 173 | * variance returned in S2 is for noisy test targets; if you want the
|
---|
| 174 | * variance of the noise-free latent function, you must subtract the noise
|
---|
| 175 | * variance.
|
---|
| 176 | *
|
---|
| 177 | * @param xstar
|
---|
| 178 | * test dataset
|
---|
| 179 | * @return [ystar Sstar] predicted mean and covariance
|
---|
| 180 | */
|
---|
| 181 |
|
---|
| 182 | public Matrix[] predict(Matrix xstar) {
|
---|
| 183 |
|
---|
| 184 | if (alpha == null || L == null) {
|
---|
| 185 | System.out.println("GP needs to be trained first..");
|
---|
| 186 | throw new IllegalStateException();
|
---|
| 187 | }
|
---|
| 188 | if (xstar.getColumnDimension() != X.getColumnDimension())
|
---|
| 189 | throw new IllegalArgumentException("Wrong size of the input "
|
---|
| 190 | + xstar.getColumnDimension() + " instead of "
|
---|
| 191 | + X.getColumnDimension());
|
---|
| 192 | Matrix[] star = covFunction.compute(logtheta, X, xstar);
|
---|
| 193 |
|
---|
| 194 | Matrix Kstar = star[1];
|
---|
| 195 | Matrix Kss = star[0];
|
---|
| 196 |
|
---|
| 197 | Matrix ystar = Kstar.transpose().times(alpha);
|
---|
| 198 |
|
---|
| 199 | Matrix v = fSubstitution(L, Kstar);
|
---|
| 200 |
|
---|
| 201 | v.arrayTimesEquals(v);
|
---|
| 202 |
|
---|
| 203 | Matrix Sstar = Kss.minus(sumColumns(v).transpose());
|
---|
| 204 |
|
---|
| 205 | // System.out.println("predict: "+ystar.get(0, 0));
|
---|
| 206 | return new Matrix[] { ystar, Sstar };
|
---|
| 207 | }
|
---|
| 208 |
|
---|
| 209 | /**
|
---|
| 210 | * Computes Gaussian predictions, whose mean is returned. Note that in cases
|
---|
| 211 | * where the covariance function has noise contributions, the variance
|
---|
| 212 | * returned in S2 is for noisy test targets; if you want the variance of the
|
---|
| 213 | * noise-free latent function, you must substract the noise variance.
|
---|
| 214 | *
|
---|
| 215 | * @param xstar
|
---|
| 216 | * test dataset
|
---|
| 217 | * @return [ystar Sstar] predicted mean and covariance
|
---|
| 218 | */
|
---|
| 219 |
|
---|
| 220 | public Matrix predictMean(Matrix xstar) {
|
---|
| 221 |
|
---|
| 222 | if (alpha == null || L == null) {
|
---|
| 223 | System.out.println("GP needs to be trained first..");
|
---|
| 224 | throw new IllegalStateException();
|
---|
| 225 | }
|
---|
| 226 | if (xstar.getColumnDimension() != X.getColumnDimension())
|
---|
| 227 | throw new IllegalArgumentException("Wrong size of the input"
|
---|
| 228 | + xstar.getColumnDimension() + " instead of "
|
---|
| 229 | + X.getColumnDimension());
|
---|
| 230 |
|
---|
| 231 | Matrix[] star = covFunction.compute(logtheta, X, xstar);
|
---|
| 232 |
|
---|
| 233 | Matrix Kstar = star[1];
|
---|
| 234 |
|
---|
| 235 | Matrix ystar = Kstar.transpose().times(alpha);
|
---|
| 236 |
|
---|
| 237 | return ystar;
|
---|
| 238 | }
|
---|
| 239 |
|
---|
| 240 | private static Matrix sumColumns(Matrix a) {
|
---|
| 241 | Matrix sum = new Matrix(1, a.getColumnDimension());
|
---|
| 242 | for (int i = 0; i < a.getRowDimension(); i++)
|
---|
| 243 | sum.plusEquals(a.getMatrix(i, i, 0, a.getColumnDimension() - 1));
|
---|
| 244 | return sum;
|
---|
| 245 | }
|
---|
| 246 |
|
---|
| 247 | private static double sum(Matrix a) {
|
---|
| 248 | double sum = 0;
|
---|
| 249 | for (int i = 0; i < a.getRowDimension(); i++)
|
---|
| 250 | for (int j = 0; j < a.getColumnDimension(); j++)
|
---|
| 251 | sum += a.get(i, j);
|
---|
| 252 | return sum;
|
---|
| 253 | }
|
---|
| 254 |
|
---|
| 255 | private static Matrix fSubstitution(Matrix L, Matrix B) {
|
---|
| 256 |
|
---|
| 257 | final double[][] l = L.getArray();
|
---|
| 258 | final double[][] b = B.getArray();
|
---|
| 259 | final double[][] x = new double[B.getRowDimension()][B
|
---|
| 260 | .getColumnDimension()];
|
---|
| 261 |
|
---|
| 262 | final int n = x.length;
|
---|
| 263 |
|
---|
| 264 | for (int i = 0; i < B.getColumnDimension(); i++) {
|
---|
| 265 | for (int k = 0; k < n; k++) {
|
---|
| 266 | x[k][i] = b[k][i];
|
---|
| 267 | for (int j = 0; j < k; j++) {
|
---|
| 268 | x[k][i] -= l[k][j] * x[j][i];
|
---|
| 269 | }
|
---|
| 270 | x[k][i] /= l[k][k];
|
---|
| 271 | }
|
---|
| 272 | }
|
---|
| 273 | return new Matrix(x);
|
---|
| 274 | }
|
---|
| 275 |
|
---|
| 276 | private static Matrix bSubstitution(Matrix L, Matrix B) {
|
---|
| 277 |
|
---|
| 278 | final double[][] l = L.getArray();
|
---|
| 279 | final double[][] b = B.getArray();
|
---|
| 280 | final double[][] x = new double[B.getRowDimension()][B
|
---|
| 281 | .getColumnDimension()];
|
---|
| 282 |
|
---|
| 283 | final int n = x.length - 1;
|
---|
| 284 |
|
---|
| 285 | for (int i = 0; i < B.getColumnDimension(); i++) {
|
---|
| 286 | for (int k = n; k > -1; k--) {
|
---|
| 287 | x[k][i] = b[k][i];
|
---|
| 288 | for (int j = n; j > k; j--) {
|
---|
| 289 | x[k][i] -= l[k][j] * x[j][i];
|
---|
| 290 | }
|
---|
| 291 | x[k][i] /= l[k][k];
|
---|
| 292 | }
|
---|
| 293 | }
|
---|
| 294 | return new Matrix(x);
|
---|
| 295 |
|
---|
| 296 | }
|
---|
| 297 |
|
---|
| 298 | private static Matrix bSubstitutionWithTranspose(Matrix L, Matrix B) {
|
---|
| 299 |
|
---|
| 300 | final double[][] l = L.getArray();
|
---|
| 301 | final double[][] b = B.getArray();
|
---|
| 302 | final double[][] x = new double[B.getRowDimension()][B
|
---|
| 303 | .getColumnDimension()];
|
---|
| 304 |
|
---|
| 305 | final int n = x.length - 1;
|
---|
| 306 |
|
---|
| 307 | for (int i = 0; i < B.getColumnDimension(); i++) {
|
---|
| 308 | for (int k = n; k > -1; k--) {
|
---|
| 309 | x[k][i] = b[k][i];
|
---|
| 310 | for (int j = n; j > k; j--) {
|
---|
| 311 | x[k][i] -= l[j][k] * x[j][i];
|
---|
| 312 | }
|
---|
| 313 | x[k][i] /= l[k][k];
|
---|
| 314 | }
|
---|
| 315 | }
|
---|
| 316 | return new Matrix(x);
|
---|
| 317 |
|
---|
| 318 | }
|
---|
| 319 |
|
---|
| 320 | private final static double INT = 0.1; // don't reevaluate within 0.1 of the
|
---|
| 321 | // limit of the current bracket
|
---|
| 322 |
|
---|
| 323 | private final static double EXT = 3.0; // extrapolate maximum 3 times the
|
---|
| 324 | // current step-size
|
---|
| 325 |
|
---|
| 326 | private final static int MAX = 20; // max 20 function evaluations per line
|
---|
| 327 | // search
|
---|
| 328 |
|
---|
| 329 | private final static double RATIO = 10; // maximum allowed slope ratio
|
---|
| 330 |
|
---|
| 331 | private final static double SIG = 0.1, RHO = SIG / 2; // SIG and RHO are the
|
---|
| 332 | // constants
|
---|
| 333 | // controlling the
|
---|
| 334 | // Wolfe-
|
---|
| 335 |
|
---|
| 336 | // Powell conditions. SIG is the maximum allowed absolute ratio between
|
---|
| 337 | // previous and new slopes (derivatives in the search direction), thus
|
---|
| 338 | // setting
|
---|
| 339 | // SIG to low (positive) values forces higher precision in the
|
---|
| 340 | // line-searches.
|
---|
| 341 | // RHO is the minimum allowed fraction of the expected (from the slope at
|
---|
| 342 | // the
|
---|
| 343 | // initial point in the linesearch). Constants must satisfy 0 < RHO < SIG <
|
---|
| 344 | // 1.
|
---|
| 345 | // Tuning of SIG (depending on the nature of the function to be optimized)
|
---|
| 346 | // may
|
---|
| 347 | // speed up the minimization; it is probably not worth playing much with
|
---|
| 348 | // RHO.
|
---|
| 349 |
|
---|
| 350 | private Matrix minimize(Matrix params, int length, Matrix in, Matrix out) {
|
---|
| 351 |
|
---|
| 352 | double A, B;
|
---|
| 353 | double x1, x2, x3, x4;
|
---|
| 354 | double f0, f1, f2, f3, f4;
|
---|
| 355 | double d0, d1, d2, d3, d4;
|
---|
| 356 | Matrix df0, df3;
|
---|
| 357 | Matrix fX;
|
---|
| 358 |
|
---|
| 359 | double red = 1.0;
|
---|
| 360 |
|
---|
| 361 | int i = 0;
|
---|
| 362 | int ls_failed = 0;
|
---|
| 363 |
|
---|
| 364 | int sizeX = params.getRowDimension();
|
---|
| 365 |
|
---|
| 366 | df0 = new Matrix(sizeX, 1);
|
---|
| 367 | f0 = negativeLogLikelihood(params, in, out, df0);
|
---|
| 368 | // f0 = f.evaluate(params,cf, in, out, df0);
|
---|
| 369 |
|
---|
| 370 | fX = new Matrix(new double[] { f0 }, 1);
|
---|
| 371 |
|
---|
| 372 | i = (length < 0) ? i + 1 : i;
|
---|
| 373 |
|
---|
| 374 | Matrix s = df0.times(-1);
|
---|
| 375 |
|
---|
| 376 | // initial search direction (steepest) and slope
|
---|
| 377 | d0 = s.times(-1).transpose().times(s).get(0, 0);
|
---|
| 378 | x3 = red / (1 - d0); // initial step is red/(|s|+1)
|
---|
| 379 |
|
---|
| 380 | final int nCycles = Math.abs(length);
|
---|
| 381 |
|
---|
| 382 | int success;
|
---|
| 383 |
|
---|
| 384 | double M;
|
---|
| 385 | while (i < nCycles) {
|
---|
| 386 | // System.out.println("-");
|
---|
| 387 | i = (length > 0) ? i + 1 : i; // count iterations?!
|
---|
| 388 |
|
---|
| 389 | // make a copy of current values
|
---|
| 390 | double F0 = f0;
|
---|
| 391 | Matrix X0 = params.copy();
|
---|
| 392 | Matrix dF0 = df0.copy();
|
---|
| 393 |
|
---|
| 394 | M = (length > 0) ? MAX : Math.min(MAX, -length - i);
|
---|
| 395 |
|
---|
| 396 | while (true) { // keep extrapolating as long as necessary
|
---|
| 397 |
|
---|
| 398 | x2 = 0;
|
---|
| 399 | f2 = f0;
|
---|
| 400 | d2 = d0;
|
---|
| 401 | f3 = f0;
|
---|
| 402 | df3 = df0.copy();
|
---|
| 403 |
|
---|
| 404 | success = 0;
|
---|
| 405 |
|
---|
| 406 | while (success == 0 && M > 0) {
|
---|
| 407 | // try
|
---|
| 408 | M = M - 1;
|
---|
| 409 | i = (length < 0) ? i + 1 : i; // count iterations?!
|
---|
| 410 |
|
---|
| 411 | Matrix m1 = params.plus(s.times(x3));
|
---|
| 412 | // f3 = f.evaluate(m1,cf, in, out, df3);
|
---|
| 413 | f3 = negativeLogLikelihood(m1, in, out, df3);
|
---|
| 414 |
|
---|
| 415 | if (Double.isNaN(f3) || Double.isInfinite(f3)
|
---|
| 416 | || hasInvalidNumbers(df3.getRowPackedCopy())) {
|
---|
| 417 | x3 = (x2 + x3) / 2; // catch any error which occured in
|
---|
| 418 | // f
|
---|
| 419 | } else {
|
---|
| 420 | success = 1;
|
---|
| 421 | }
|
---|
| 422 |
|
---|
| 423 | }
|
---|
| 424 |
|
---|
| 425 | if (f3 < F0) { // keep best values
|
---|
| 426 | X0 = s.times(x3).plus(params);
|
---|
| 427 | F0 = f3;
|
---|
| 428 | dF0 = df3;
|
---|
| 429 | }
|
---|
| 430 |
|
---|
| 431 | d3 = df3.transpose().times(s).get(0, 0); // new slope
|
---|
| 432 |
|
---|
| 433 | if (d3 > SIG * d0 || f3 > f0 + x3 * RHO * d0 || M == 0) { // are
|
---|
| 434 | // we
|
---|
| 435 | // done
|
---|
| 436 | // extrapolating?
|
---|
| 437 | break;
|
---|
| 438 | }
|
---|
| 439 |
|
---|
| 440 | x1 = x2;
|
---|
| 441 | f1 = f2;
|
---|
| 442 | d1 = d2; // move point 2 to point 1
|
---|
| 443 | x2 = x3;
|
---|
| 444 | f2 = f3;
|
---|
| 445 | d2 = d3; // move point 3 to point 2
|
---|
| 446 |
|
---|
| 447 | A = 6 * (f1 - f2) + 3 * (d2 + d1) * (x2 - x1); // make cubic
|
---|
| 448 | // extrapolation
|
---|
| 449 | B = 3 * (f2 - f1) - (2 * d1 + d2) * (x2 - x1);
|
---|
| 450 |
|
---|
| 451 | x3 = x1 - d1 * (x2 - x1) * (x2 - x1)
|
---|
| 452 | / (B + Math.sqrt(B * B - A * d1 * (x2 - x1))); // num.
|
---|
| 453 | // error
|
---|
| 454 | // possible,
|
---|
| 455 | // ok!
|
---|
| 456 |
|
---|
| 457 | if (Double.isNaN(x3) || Double.isInfinite(x3) || x3 < 0) // num
|
---|
| 458 | // prob
|
---|
| 459 | // |
|
---|
| 460 | // wrong
|
---|
| 461 | // sign?
|
---|
| 462 | x3 = x2 * EXT; // extrapolate maximum amount
|
---|
| 463 | else if (x3 > x2 * EXT) // new point beyond extrapolation limit?
|
---|
| 464 | x3 = x2 * EXT; // extrapolate maximum amount
|
---|
| 465 | else if (x3 < x2 + INT * (x2 - x1)) // new point too close to
|
---|
| 466 | // previous point?
|
---|
| 467 | x3 = x2 + INT * (x2 - x1);
|
---|
| 468 |
|
---|
| 469 | }
|
---|
| 470 |
|
---|
| 471 | f4 = 0;
|
---|
| 472 | x4 = 0;
|
---|
| 473 | d4 = 0;
|
---|
| 474 |
|
---|
| 475 | while ((Math.abs(d3) > -SIG * d0 || f3 > f0 + x3 * RHO * d0)
|
---|
| 476 | && M > 0) { // keep interpolating
|
---|
| 477 |
|
---|
| 478 | if (d3 > 0 || f3 > f0 + x3 * RHO * d0) { // choose subinterval
|
---|
| 479 | x4 = x3;
|
---|
| 480 | f4 = f3;
|
---|
| 481 | d4 = d3; // move point 3 to point 4
|
---|
| 482 | } else {
|
---|
| 483 | x2 = x3;
|
---|
| 484 | f2 = f3;
|
---|
| 485 | d2 = d3; // move point 3 to point 2
|
---|
| 486 | }
|
---|
| 487 |
|
---|
| 488 | if (f4 > f0) {
|
---|
| 489 | x3 = x2 - (0.5 * d2 * (x4 - x2) * (x4 - x2))
|
---|
| 490 | / (f4 - f2 - d2 * (x4 - x2)); // quadratic
|
---|
| 491 | // interpolation
|
---|
| 492 | } else {
|
---|
| 493 | A = 6 * (f2 - f4) / (x4 - x2) + 3 * (d4 + d2); // cubic
|
---|
| 494 | // interpolation
|
---|
| 495 | B = 3 * (f4 - f2) - (2 * d2 + d4) * (x4 - x2);
|
---|
| 496 | x3 = x2
|
---|
| 497 | + (Math.sqrt(B * B - A * d2 * (x4 - x2) * (x4 - x2)) - B)
|
---|
| 498 | / A; // num. error possible, ok!
|
---|
| 499 | }
|
---|
| 500 |
|
---|
| 501 | if (Double.isNaN(x3) || Double.isInfinite(x3)) {
|
---|
| 502 | x3 = (x2 + x4) / 2; // if we had a numerical problem then
|
---|
| 503 | // bisect
|
---|
| 504 | }
|
---|
| 505 |
|
---|
| 506 | x3 = Math.max(Math.min(x3, x4 - INT * (x4 - x2)), x2 + INT
|
---|
| 507 | * (x4 - x2)); // don't accept too close
|
---|
| 508 |
|
---|
| 509 | Matrix m1 = s.times(x3).plus(params);
|
---|
| 510 | // f3 = f.evaluate(m1,cf, in, out, df3);
|
---|
| 511 | f3 = negativeLogLikelihood(m1, in, out, df3);
|
---|
| 512 |
|
---|
| 513 | if (f3 < F0) {
|
---|
| 514 | X0 = m1.copy();
|
---|
| 515 | F0 = f3;
|
---|
| 516 | dF0 = df3.copy(); // keep best values
|
---|
| 517 | }
|
---|
| 518 |
|
---|
| 519 | M = M - 1;
|
---|
| 520 | i = (length < 0) ? i + 1 : i; // count iterations?!
|
---|
| 521 |
|
---|
| 522 | d3 = df3.transpose().times(s).get(0, 0); // new slope
|
---|
| 523 |
|
---|
| 524 | } // end interpolation
|
---|
| 525 |
|
---|
| 526 | if (Math.abs(d3) < -SIG * d0 && f3 < f0 + x3 * RHO * d0) { // if
|
---|
| 527 | // line
|
---|
| 528 | // search
|
---|
| 529 | // succeeded
|
---|
| 530 | params = s.times(x3).plus(params);
|
---|
| 531 | f0 = f3;
|
---|
| 532 |
|
---|
| 533 | double[] elem = fX.getColumnPackedCopy();
|
---|
| 534 | double[] newfX = new double[elem.length + 1];
|
---|
| 535 |
|
---|
| 536 | System.arraycopy(elem, 0, newfX, 0, elem.length);
|
---|
| 537 | newfX[elem.length - 1] = f0;
|
---|
| 538 | fX = new Matrix(newfX, newfX.length); // update variables
|
---|
| 539 |
|
---|
| 540 | // System.out.println("Function evaluation "+i+" Value "+f0);
|
---|
| 541 |
|
---|
| 542 | double tmp1 = df3.transpose().times(df3)
|
---|
| 543 | .minus(df0.transpose().times(df3)).get(0, 0);
|
---|
| 544 | double tmp2 = df0.transpose().times(df0).get(0, 0);
|
---|
| 545 |
|
---|
| 546 | s = s.times(tmp1 / tmp2).minus(df3);
|
---|
| 547 |
|
---|
| 548 | df0 = df3; // swap derivatives
|
---|
| 549 | d3 = d0;
|
---|
| 550 | d0 = df0.transpose().times(s).get(0, 0);
|
---|
| 551 |
|
---|
| 552 | if (d0 > 0) { // new slope must be negative
|
---|
| 553 | s = df0.times(-1); // otherwise use steepest direction
|
---|
| 554 | d0 = s.times(-1).transpose().times(s).get(0, 0);
|
---|
| 555 | }
|
---|
| 556 |
|
---|
| 557 | x3 = x3 * Math.min(RATIO, d3 / (d0 - Double.MIN_VALUE)); // slope
|
---|
| 558 | // ratio
|
---|
| 559 | // but
|
---|
| 560 | // max
|
---|
| 561 | // RATIO
|
---|
| 562 | ls_failed = 0; // this line search did not fail
|
---|
| 563 |
|
---|
| 564 | } else {
|
---|
| 565 |
|
---|
| 566 | params = X0;
|
---|
| 567 | f0 = F0;
|
---|
| 568 | df0 = dF0; // restore best point so far
|
---|
| 569 |
|
---|
| 570 | if (ls_failed == 1 || i > Math.abs(length)) { // line search
|
---|
| 571 | // failed twice
|
---|
| 572 | // in a row
|
---|
| 573 | break; // or we ran out of time, so we give up
|
---|
| 574 | }
|
---|
| 575 |
|
---|
| 576 | s = df0.times(-1);
|
---|
| 577 | d0 = s.times(-1).transpose().times(s).get(0, 0); // try steepest
|
---|
| 578 | x3 = 1 / (1 - d0);
|
---|
| 579 | ls_failed = 1; // this line search failed
|
---|
| 580 |
|
---|
| 581 | }
|
---|
| 582 |
|
---|
| 583 | }
|
---|
| 584 |
|
---|
| 585 | return params;
|
---|
| 586 | }
|
---|
| 587 |
|
---|
| 588 | private static boolean hasInvalidNumbers(double[] array) {
|
---|
| 589 |
|
---|
| 590 | for (double a : array) {
|
---|
| 591 | if (Double.isInfinite(a) || Double.isNaN(a)) {
|
---|
| 592 | return true;
|
---|
| 593 | }
|
---|
| 594 | }
|
---|
| 595 |
|
---|
| 596 | return false;
|
---|
| 597 | }
|
---|
| 598 |
|
---|
| 599 | /**
|
---|
| 600 | * A simple test
|
---|
| 601 | *
|
---|
| 602 | * @param args
|
---|
| 603 | */
|
---|
| 604 | public static void main(String[] args) {
|
---|
| 605 |
|
---|
| 606 | CovarianceFunction covFunc = new CovSum(6, new CovLINone(),
|
---|
| 607 | new CovNoise());
|
---|
| 608 | GaussianProcess gp = new GaussianProcess(covFunc);
|
---|
| 609 |
|
---|
| 610 | double[][] logtheta0 = new double[][] { { 0.1 }, { Math.log(0.1) } };
|
---|
| 611 | /*
|
---|
| 612 | * double[][] logtheta0 = new double[][]{ {0.1}, {0.2}, {0.3}, {0.4},
|
---|
| 613 | * {0.5}, {0.6}, {0.7}, {Math.log(0.1)}};
|
---|
| 614 | */
|
---|
| 615 |
|
---|
| 616 | Matrix params0 = new Matrix(logtheta0);
|
---|
| 617 |
|
---|
| 618 | Matrix[] data = CSVtoMatrix
|
---|
| 619 | .load("/Users/MaxLam/Desktop/Researches/ANAC2015/DragonAgent/DragonAgentBeta/src/armdata.csv",
|
---|
| 620 | 6, 1);
|
---|
| 621 | Matrix X = data[0];
|
---|
| 622 | Matrix Y = data[1];
|
---|
| 623 |
|
---|
| 624 | gp.train(X, Y, params0, -20);
|
---|
| 625 |
|
---|
| 626 | // int size = 100;
|
---|
| 627 | // Matrix Xtrain = new Matrix(size, 1);
|
---|
| 628 | // Matrix Ytrain = new Matrix(size, 1);
|
---|
| 629 | //
|
---|
| 630 | // Matrix Xtest = new Matrix(size, 1);
|
---|
| 631 | // Matrix Ytest = new Matrix(size, 1);
|
---|
| 632 |
|
---|
| 633 | // half of the sinusoid uses points very close to each other and the
|
---|
| 634 | // other half uses
|
---|
| 635 | // more sparse data
|
---|
| 636 |
|
---|
| 637 | // double inc = 2 * Math.PI / 1000;
|
---|
| 638 | //
|
---|
| 639 | // double[][] data = new double[1000][2];
|
---|
| 640 | //
|
---|
| 641 | // Random random = new Random();
|
---|
| 642 | //
|
---|
| 643 | // for(int i=0; i<1000; i++){
|
---|
| 644 | // data[i][0] = i*inc;
|
---|
| 645 | // data[i][1] = Math.sin(i*+inc);
|
---|
| 646 | // }
|
---|
| 647 | //
|
---|
| 648 | //
|
---|
| 649 | // // NEED TO FILL Xtrain, Ytrain and Xtest, Ytest
|
---|
| 650 | //
|
---|
| 651 | //
|
---|
| 652 | // gp.train(Xtrain,Ytest,params0);
|
---|
| 653 | //
|
---|
| 654 | // // SimpleRealTimePlotter plot = new
|
---|
| 655 | // SimpleRealTimePlotter("","","",false,false,true,200);
|
---|
| 656 | //
|
---|
| 657 | // final Matrix[] out = gp.predict(Xtest);
|
---|
| 658 | //
|
---|
| 659 | // for(int i=0; i<Xtest.getRowDimension(); i++){
|
---|
| 660 | //
|
---|
| 661 | // final double mean = out[0].get(i,0);
|
---|
| 662 | // final double var = 3 * Math.sqrt(out[1].get(i,0));
|
---|
| 663 | //
|
---|
| 664 | // plot.addPoint(i, mean, 0, true);
|
---|
| 665 | // plot.addPoint(i, mean-var, 1, true);
|
---|
| 666 | // plot.addPoint(i, mean+var, 2, true);
|
---|
| 667 | // plot.addPoint(i, Ytest.get(i,0), 3, true);
|
---|
| 668 | //
|
---|
| 669 | // plot.fillPlot();
|
---|
| 670 | // }
|
---|
| 671 |
|
---|
| 672 | Matrix[] datastar = CSVtoMatrix
|
---|
| 673 | .load("/Users/MaxLam/Desktop/Researches/ANAC2015/DragonAgent/DragonAgentBeta/src/armdatastar.csv",
|
---|
| 674 | 6, 1);
|
---|
| 675 | Matrix Xstar = datastar[0];
|
---|
| 676 | Matrix Ystar = datastar[1];
|
---|
| 677 |
|
---|
| 678 | Matrix[] res = gp.predict(Xstar);
|
---|
| 679 |
|
---|
| 680 | res[0].print(res[0].getColumnDimension(), 16);
|
---|
| 681 | System.err.println("");
|
---|
| 682 | res[1].print(res[1].getColumnDimension(), 16);
|
---|
| 683 | }
|
---|
| 684 |
|
---|
| 685 | }
|
---|