1 | package agents.anac.y2015.Phoenix.GP;/* This file is part of the jgpml Project.
|
---|
2 | * http://github.com/renzodenardi/jgpml
|
---|
3 | *
|
---|
4 | * Copyright (c) 2011 Renzo De Nardi and Hugo Gravato-Marques
|
---|
5 | *
|
---|
6 | * Permission is hereby granted, free of charge, to any person
|
---|
7 | * obtaining a copy of this software and associated documentation
|
---|
8 | * files (the "Software"), to deal in the Software without
|
---|
9 | * restriction, including without limitation the rights to use,
|
---|
10 | * copy, modify, merge, publish, distribute, sublicense, and/or sell
|
---|
11 | * copies of the Software, and to permit persons to whom the
|
---|
12 | * Software is furnished to do so, subject to the following
|
---|
13 | * conditions:
|
---|
14 | *
|
---|
15 | * The above copyright notice and this permission notice shall be
|
---|
16 | * included in all copies or substantial portions of the Software.
|
---|
17 | *
|
---|
18 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
---|
19 | * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
---|
20 | * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
---|
21 | * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
|
---|
22 | * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
---|
23 | * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
---|
24 | * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
---|
25 | * OTHER DEALINGS IN THE SOFTWARE.
|
---|
26 | */
|
---|
27 |
|
---|
28 | import agents.Jama.CholeskyDecomposition;
|
---|
29 | import agents.Jama.Matrix;
|
---|
30 |
|
---|
31 | /**
|
---|
32 | * Main class of the package, contains the objects that constitutes a Gaussian
|
---|
33 | * Process as well as the algorithm to train the Hyperparameters and to do
|
---|
34 | * predictions.
|
---|
35 | */
|
---|
36 | public class GaussianProcess {
|
---|
37 |
|
---|
38 | /**
|
---|
39 | * hyperparameters
|
---|
40 | */
|
---|
41 | public Matrix logtheta;
|
---|
42 |
|
---|
43 | /**
|
---|
44 | * input data points
|
---|
45 | */
|
---|
46 | public Matrix X;
|
---|
47 |
|
---|
48 | /**
|
---|
49 | * Cholesky decomposition of the input
|
---|
50 | */
|
---|
51 | public Matrix L;
|
---|
52 |
|
---|
53 | /**
|
---|
54 | * partial factor
|
---|
55 | */
|
---|
56 | public Matrix alpha;
|
---|
57 |
|
---|
58 | /**
|
---|
59 | * covariance function
|
---|
60 | */
|
---|
61 | CovarianceFunction covFunction;
|
---|
62 |
|
---|
63 | /**
|
---|
64 | * Creates a new GP object.
|
---|
65 | *
|
---|
66 | * @param covFunction
|
---|
67 | * - the covariance function
|
---|
68 | */
|
---|
69 | public GaussianProcess(CovarianceFunction covFunction) {
|
---|
70 | this.covFunction = covFunction;
|
---|
71 | }
|
---|
72 |
|
---|
73 | /**
|
---|
74 | * Trains the GP Hyperparameters maximizing the marginal likelihood. By
|
---|
75 | * default the minimisation algorithm performs 100 iterations.
|
---|
76 | *
|
---|
77 | * @param X
|
---|
78 | * - the input data points
|
---|
79 | * @param y
|
---|
80 | * - the target data points
|
---|
81 | * @param logtheta0
|
---|
82 | * - the initial hyperparameters of the covariance function
|
---|
83 | */
|
---|
84 | public void train(Matrix X, Matrix y, Matrix logtheta0) {
|
---|
85 | train(X, y, logtheta0, -100);
|
---|
86 | }
|
---|
87 |
|
---|
88 | /**
|
---|
89 | * Trains the GP Hyperparameters maximizing the marginal likelihood. By
|
---|
90 | * default the algorithm performs 100 iterations.
|
---|
91 | *
|
---|
92 | * @param X
|
---|
93 | * - the input data points
|
---|
94 | * @param y
|
---|
95 | * - the target data points
|
---|
96 | * @param logtheta0
|
---|
97 | * - the initial hyperparameters of the covariance function
|
---|
98 | * @param iterations
|
---|
99 | * - number of iterations performed by the minimization algorithm
|
---|
100 | */
|
---|
101 | public void train(Matrix X, Matrix y, Matrix logtheta0, int iterations) {
|
---|
102 | // System.out.println("X Dimension:"+X.getRowDimension()+"x"+X.getColumnDimension());
|
---|
103 | // System.out.println("y Dimension:"+y.getRowDimension()+"x"+y.getColumnDimension());
|
---|
104 | // System.out.println("training started...");
|
---|
105 | this.X = X;
|
---|
106 | logtheta = minimize(logtheta0, iterations, X, y);
|
---|
107 | }
|
---|
108 |
|
---|
109 | /**
|
---|
110 | * Computes minus the log likelihood and its partial derivatives with
|
---|
111 | * respect to the hyperparameters; this mode is used to fit the
|
---|
112 | * hyperparameters.
|
---|
113 | *
|
---|
114 | * @param logtheta
|
---|
115 | * column <code>Matrix</code> of hyperparameters
|
---|
116 | * @param y
|
---|
117 | * output dataset
|
---|
118 | * @param df0
|
---|
119 | * returned partial derivatives with respect to the
|
---|
120 | * hyperparameters
|
---|
121 | * @return lml minus log marginal likelihood
|
---|
122 | */
|
---|
123 | public double negativeLogLikelihood(Matrix logtheta, Matrix x, Matrix y,
|
---|
124 | Matrix df0) {
|
---|
125 |
|
---|
126 | int n = x.getRowDimension();
|
---|
127 |
|
---|
128 | Matrix K = covFunction.compute(logtheta, x); // compute training set
|
---|
129 | // covariance matrix
|
---|
130 |
|
---|
131 | CholeskyDecomposition cd = K.chol();
|
---|
132 | if (!cd.isSPD()) {
|
---|
133 | throw new RuntimeException(
|
---|
134 | "The covariance Matrix is not SDP, check your covariance function (maybe you mess the noise term..)");
|
---|
135 | } else {
|
---|
136 | L = cd.getL(); // cholesky factorization of the covariance
|
---|
137 |
|
---|
138 | // alpha = L'\(L\y);
|
---|
139 | alpha = bSubstitutionWithTranspose(L, fSubstitution(L, y));
|
---|
140 |
|
---|
141 | // double[][] yarr = y.getArray();
|
---|
142 | // double[][] alphaarr = alpha.getArray();
|
---|
143 | // double lml =0;
|
---|
144 | // for(int i=0; i<n; i++){
|
---|
145 | // lml+= yarr[i][0]*alphaarr[i][0];
|
---|
146 | // }
|
---|
147 | // lml*=0.5;
|
---|
148 | //
|
---|
149 |
|
---|
150 | // compute the negative log marginal likelihood
|
---|
151 | double lml = (y.transpose().times(alpha).times(0.5)).get(0, 0);
|
---|
152 |
|
---|
153 | for (int i = 0; i < L.getRowDimension(); i++)
|
---|
154 | lml += Math.log(L.get(i, i));
|
---|
155 | lml += 0.5 * n * Math.log(2 * Math.PI);
|
---|
156 |
|
---|
157 | Matrix W = bSubstitutionWithTranspose(L,
|
---|
158 | (fSubstitution(L, Matrix.identity(n, n)))).minus(
|
---|
159 | alpha.times(alpha.transpose())); // precompute for
|
---|
160 | // convenience
|
---|
161 | for (int i = 0; i < df0.getRowDimension(); i++) {
|
---|
162 | df0.set(i, 0, sum(W.arrayTimes(covFunction.computeDerivatives(
|
---|
163 | logtheta, x, i))) / 2);
|
---|
164 | }
|
---|
165 |
|
---|
166 | return lml;
|
---|
167 | }
|
---|
168 | }
|
---|
169 |
|
---|
170 | /**
|
---|
171 | * Computes Gaussian predictions, whose mean and variance are returned. Note
|
---|
172 | * that in cases where the covariance function has noise contributions, the
|
---|
173 | * variance returned in S2 is for noisy test targets; if you want the
|
---|
174 | * variance of the noise-free latent function, you must subtract the noise
|
---|
175 | * variance.
|
---|
176 | *
|
---|
177 | * @param xstar
|
---|
178 | * test dataset
|
---|
179 | * @return [ystar Sstar] predicted mean and covariance
|
---|
180 | */
|
---|
181 |
|
---|
182 | public Matrix[] predict(Matrix xstar) {
|
---|
183 |
|
---|
184 | if (alpha == null || L == null) {
|
---|
185 | System.out.println("GP needs to be trained first..");
|
---|
186 | throw new IllegalStateException();
|
---|
187 | }
|
---|
188 | if (xstar.getColumnDimension() != X.getColumnDimension())
|
---|
189 | throw new IllegalArgumentException("Wrong size of the input "
|
---|
190 | + xstar.getColumnDimension() + " instead of "
|
---|
191 | + X.getColumnDimension());
|
---|
192 | Matrix[] star = covFunction.compute(logtheta, X, xstar);
|
---|
193 |
|
---|
194 | Matrix Kstar = star[1];
|
---|
195 | Matrix Kss = star[0];
|
---|
196 |
|
---|
197 | Matrix ystar = Kstar.transpose().times(alpha);
|
---|
198 |
|
---|
199 | Matrix v = fSubstitution(L, Kstar);
|
---|
200 |
|
---|
201 | v.arrayTimesEquals(v);
|
---|
202 |
|
---|
203 | Matrix Sstar = Kss.minus(sumColumns(v).transpose());
|
---|
204 |
|
---|
205 | // System.out.println("predict: "+ystar.get(0, 0));
|
---|
206 | return new Matrix[] { ystar, Sstar };
|
---|
207 | }
|
---|
208 |
|
---|
209 | /**
|
---|
210 | * Computes Gaussian predictions, whose mean is returned. Note that in cases
|
---|
211 | * where the covariance function has noise contributions, the variance
|
---|
212 | * returned in S2 is for noisy test targets; if you want the variance of the
|
---|
213 | * noise-free latent function, you must substract the noise variance.
|
---|
214 | *
|
---|
215 | * @param xstar
|
---|
216 | * test dataset
|
---|
217 | * @return [ystar Sstar] predicted mean and covariance
|
---|
218 | */
|
---|
219 |
|
---|
220 | public Matrix predictMean(Matrix xstar) {
|
---|
221 |
|
---|
222 | if (alpha == null || L == null) {
|
---|
223 | System.out.println("GP needs to be trained first..");
|
---|
224 | throw new IllegalStateException();
|
---|
225 | }
|
---|
226 | if (xstar.getColumnDimension() != X.getColumnDimension())
|
---|
227 | throw new IllegalArgumentException("Wrong size of the input"
|
---|
228 | + xstar.getColumnDimension() + " instead of "
|
---|
229 | + X.getColumnDimension());
|
---|
230 |
|
---|
231 | Matrix[] star = covFunction.compute(logtheta, X, xstar);
|
---|
232 |
|
---|
233 | Matrix Kstar = star[1];
|
---|
234 |
|
---|
235 | Matrix ystar = Kstar.transpose().times(alpha);
|
---|
236 |
|
---|
237 | return ystar;
|
---|
238 | }
|
---|
239 |
|
---|
240 | private static Matrix sumColumns(Matrix a) {
|
---|
241 | Matrix sum = new Matrix(1, a.getColumnDimension());
|
---|
242 | for (int i = 0; i < a.getRowDimension(); i++)
|
---|
243 | sum.plusEquals(a.getMatrix(i, i, 0, a.getColumnDimension() - 1));
|
---|
244 | return sum;
|
---|
245 | }
|
---|
246 |
|
---|
247 | private static double sum(Matrix a) {
|
---|
248 | double sum = 0;
|
---|
249 | for (int i = 0; i < a.getRowDimension(); i++)
|
---|
250 | for (int j = 0; j < a.getColumnDimension(); j++)
|
---|
251 | sum += a.get(i, j);
|
---|
252 | return sum;
|
---|
253 | }
|
---|
254 |
|
---|
255 | private static Matrix fSubstitution(Matrix L, Matrix B) {
|
---|
256 |
|
---|
257 | final double[][] l = L.getArray();
|
---|
258 | final double[][] b = B.getArray();
|
---|
259 | final double[][] x = new double[B.getRowDimension()][B
|
---|
260 | .getColumnDimension()];
|
---|
261 |
|
---|
262 | final int n = x.length;
|
---|
263 |
|
---|
264 | for (int i = 0; i < B.getColumnDimension(); i++) {
|
---|
265 | for (int k = 0; k < n; k++) {
|
---|
266 | x[k][i] = b[k][i];
|
---|
267 | for (int j = 0; j < k; j++) {
|
---|
268 | x[k][i] -= l[k][j] * x[j][i];
|
---|
269 | }
|
---|
270 | x[k][i] /= l[k][k];
|
---|
271 | }
|
---|
272 | }
|
---|
273 | return new Matrix(x);
|
---|
274 | }
|
---|
275 |
|
---|
276 | private static Matrix bSubstitution(Matrix L, Matrix B) {
|
---|
277 |
|
---|
278 | final double[][] l = L.getArray();
|
---|
279 | final double[][] b = B.getArray();
|
---|
280 | final double[][] x = new double[B.getRowDimension()][B
|
---|
281 | .getColumnDimension()];
|
---|
282 |
|
---|
283 | final int n = x.length - 1;
|
---|
284 |
|
---|
285 | for (int i = 0; i < B.getColumnDimension(); i++) {
|
---|
286 | for (int k = n; k > -1; k--) {
|
---|
287 | x[k][i] = b[k][i];
|
---|
288 | for (int j = n; j > k; j--) {
|
---|
289 | x[k][i] -= l[k][j] * x[j][i];
|
---|
290 | }
|
---|
291 | x[k][i] /= l[k][k];
|
---|
292 | }
|
---|
293 | }
|
---|
294 | return new Matrix(x);
|
---|
295 |
|
---|
296 | }
|
---|
297 |
|
---|
298 | private static Matrix bSubstitutionWithTranspose(Matrix L, Matrix B) {
|
---|
299 |
|
---|
300 | final double[][] l = L.getArray();
|
---|
301 | final double[][] b = B.getArray();
|
---|
302 | final double[][] x = new double[B.getRowDimension()][B
|
---|
303 | .getColumnDimension()];
|
---|
304 |
|
---|
305 | final int n = x.length - 1;
|
---|
306 |
|
---|
307 | for (int i = 0; i < B.getColumnDimension(); i++) {
|
---|
308 | for (int k = n; k > -1; k--) {
|
---|
309 | x[k][i] = b[k][i];
|
---|
310 | for (int j = n; j > k; j--) {
|
---|
311 | x[k][i] -= l[j][k] * x[j][i];
|
---|
312 | }
|
---|
313 | x[k][i] /= l[k][k];
|
---|
314 | }
|
---|
315 | }
|
---|
316 | return new Matrix(x);
|
---|
317 |
|
---|
318 | }
|
---|
319 |
|
---|
320 | private final static double INT = 0.1; // don't reevaluate within 0.1 of the
|
---|
321 | // limit of the current bracket
|
---|
322 |
|
---|
323 | private final static double EXT = 3.0; // extrapolate maximum 3 times the
|
---|
324 | // current step-size
|
---|
325 |
|
---|
326 | private final static int MAX = 20; // max 20 function evaluations per line
|
---|
327 | // search
|
---|
328 |
|
---|
329 | private final static double RATIO = 10; // maximum allowed slope ratio
|
---|
330 |
|
---|
331 | private final static double SIG = 0.1, RHO = SIG / 2; // SIG and RHO are the
|
---|
332 | // constants
|
---|
333 | // controlling the
|
---|
334 | // Wolfe-
|
---|
335 |
|
---|
336 | // Powell conditions. SIG is the maximum allowed absolute ratio between
|
---|
337 | // previous and new slopes (derivatives in the search direction), thus
|
---|
338 | // setting
|
---|
339 | // SIG to low (positive) values forces higher precision in the
|
---|
340 | // line-searches.
|
---|
341 | // RHO is the minimum allowed fraction of the expected (from the slope at
|
---|
342 | // the
|
---|
343 | // initial point in the linesearch). Constants must satisfy 0 < RHO < SIG <
|
---|
344 | // 1.
|
---|
345 | // Tuning of SIG (depending on the nature of the function to be optimized)
|
---|
346 | // may
|
---|
347 | // speed up the minimization; it is probably not worth playing much with
|
---|
348 | // RHO.
|
---|
349 |
|
---|
350 | private Matrix minimize(Matrix params, int length, Matrix in, Matrix out) {
|
---|
351 |
|
---|
352 | double A, B;
|
---|
353 | double x1, x2, x3, x4;
|
---|
354 | double f0, f1, f2, f3, f4;
|
---|
355 | double d0, d1, d2, d3, d4;
|
---|
356 | Matrix df0, df3;
|
---|
357 | Matrix fX;
|
---|
358 |
|
---|
359 | double red = 1.0;
|
---|
360 |
|
---|
361 | int i = 0;
|
---|
362 | int ls_failed = 0;
|
---|
363 |
|
---|
364 | int sizeX = params.getRowDimension();
|
---|
365 |
|
---|
366 | df0 = new Matrix(sizeX, 1);
|
---|
367 | f0 = negativeLogLikelihood(params, in, out, df0);
|
---|
368 | // f0 = f.evaluate(params,cf, in, out, df0);
|
---|
369 |
|
---|
370 | fX = new Matrix(new double[] { f0 }, 1);
|
---|
371 |
|
---|
372 | i = (length < 0) ? i + 1 : i;
|
---|
373 |
|
---|
374 | Matrix s = df0.times(-1);
|
---|
375 |
|
---|
376 | // initial search direction (steepest) and slope
|
---|
377 | d0 = s.times(-1).transpose().times(s).get(0, 0);
|
---|
378 | x3 = red / (1 - d0); // initial step is red/(|s|+1)
|
---|
379 |
|
---|
380 | final int nCycles = Math.abs(length);
|
---|
381 |
|
---|
382 | int success;
|
---|
383 |
|
---|
384 | double M;
|
---|
385 | while (i < nCycles) {
|
---|
386 | // System.out.println("-");
|
---|
387 | i = (length > 0) ? i + 1 : i; // count iterations?!
|
---|
388 |
|
---|
389 | // make a copy of current values
|
---|
390 | double F0 = f0;
|
---|
391 | Matrix X0 = params.copy();
|
---|
392 | Matrix dF0 = df0.copy();
|
---|
393 |
|
---|
394 | M = (length > 0) ? MAX : Math.min(MAX, -length - i);
|
---|
395 |
|
---|
396 | while (true) { // keep extrapolating as long as necessary
|
---|
397 |
|
---|
398 | x2 = 0;
|
---|
399 | f2 = f0;
|
---|
400 | d2 = d0;
|
---|
401 | f3 = f0;
|
---|
402 | df3 = df0.copy();
|
---|
403 |
|
---|
404 | success = 0;
|
---|
405 |
|
---|
406 | while (success == 0 && M > 0) {
|
---|
407 | // try
|
---|
408 | M = M - 1;
|
---|
409 | i = (length < 0) ? i + 1 : i; // count iterations?!
|
---|
410 |
|
---|
411 | Matrix m1 = params.plus(s.times(x3));
|
---|
412 | // f3 = f.evaluate(m1,cf, in, out, df3);
|
---|
413 | f3 = negativeLogLikelihood(m1, in, out, df3);
|
---|
414 |
|
---|
415 | if (Double.isNaN(f3) || Double.isInfinite(f3)
|
---|
416 | || hasInvalidNumbers(df3.getRowPackedCopy())) {
|
---|
417 | x3 = (x2 + x3) / 2; // catch any error which occured in
|
---|
418 | // f
|
---|
419 | } else {
|
---|
420 | success = 1;
|
---|
421 | }
|
---|
422 |
|
---|
423 | }
|
---|
424 |
|
---|
425 | if (f3 < F0) { // keep best values
|
---|
426 | X0 = s.times(x3).plus(params);
|
---|
427 | F0 = f3;
|
---|
428 | dF0 = df3;
|
---|
429 | }
|
---|
430 |
|
---|
431 | d3 = df3.transpose().times(s).get(0, 0); // new slope
|
---|
432 |
|
---|
433 | if (d3 > SIG * d0 || f3 > f0 + x3 * RHO * d0 || M == 0) { // are
|
---|
434 | // we
|
---|
435 | // done
|
---|
436 | // extrapolating?
|
---|
437 | break;
|
---|
438 | }
|
---|
439 |
|
---|
440 | x1 = x2;
|
---|
441 | f1 = f2;
|
---|
442 | d1 = d2; // move point 2 to point 1
|
---|
443 | x2 = x3;
|
---|
444 | f2 = f3;
|
---|
445 | d2 = d3; // move point 3 to point 2
|
---|
446 |
|
---|
447 | A = 6 * (f1 - f2) + 3 * (d2 + d1) * (x2 - x1); // make cubic
|
---|
448 | // extrapolation
|
---|
449 | B = 3 * (f2 - f1) - (2 * d1 + d2) * (x2 - x1);
|
---|
450 |
|
---|
451 | x3 = x1 - d1 * (x2 - x1) * (x2 - x1)
|
---|
452 | / (B + Math.sqrt(B * B - A * d1 * (x2 - x1))); // num.
|
---|
453 | // error
|
---|
454 | // possible,
|
---|
455 | // ok!
|
---|
456 |
|
---|
457 | if (Double.isNaN(x3) || Double.isInfinite(x3) || x3 < 0) // num
|
---|
458 | // prob
|
---|
459 | // |
|
---|
460 | // wrong
|
---|
461 | // sign?
|
---|
462 | x3 = x2 * EXT; // extrapolate maximum amount
|
---|
463 | else if (x3 > x2 * EXT) // new point beyond extrapolation limit?
|
---|
464 | x3 = x2 * EXT; // extrapolate maximum amount
|
---|
465 | else if (x3 < x2 + INT * (x2 - x1)) // new point too close to
|
---|
466 | // previous point?
|
---|
467 | x3 = x2 + INT * (x2 - x1);
|
---|
468 |
|
---|
469 | }
|
---|
470 |
|
---|
471 | f4 = 0;
|
---|
472 | x4 = 0;
|
---|
473 | d4 = 0;
|
---|
474 |
|
---|
475 | while ((Math.abs(d3) > -SIG * d0 || f3 > f0 + x3 * RHO * d0)
|
---|
476 | && M > 0) { // keep interpolating
|
---|
477 |
|
---|
478 | if (d3 > 0 || f3 > f0 + x3 * RHO * d0) { // choose subinterval
|
---|
479 | x4 = x3;
|
---|
480 | f4 = f3;
|
---|
481 | d4 = d3; // move point 3 to point 4
|
---|
482 | } else {
|
---|
483 | x2 = x3;
|
---|
484 | f2 = f3;
|
---|
485 | d2 = d3; // move point 3 to point 2
|
---|
486 | }
|
---|
487 |
|
---|
488 | if (f4 > f0) {
|
---|
489 | x3 = x2 - (0.5 * d2 * (x4 - x2) * (x4 - x2))
|
---|
490 | / (f4 - f2 - d2 * (x4 - x2)); // quadratic
|
---|
491 | // interpolation
|
---|
492 | } else {
|
---|
493 | A = 6 * (f2 - f4) / (x4 - x2) + 3 * (d4 + d2); // cubic
|
---|
494 | // interpolation
|
---|
495 | B = 3 * (f4 - f2) - (2 * d2 + d4) * (x4 - x2);
|
---|
496 | x3 = x2
|
---|
497 | + (Math.sqrt(B * B - A * d2 * (x4 - x2) * (x4 - x2)) - B)
|
---|
498 | / A; // num. error possible, ok!
|
---|
499 | }
|
---|
500 |
|
---|
501 | if (Double.isNaN(x3) || Double.isInfinite(x3)) {
|
---|
502 | x3 = (x2 + x4) / 2; // if we had a numerical problem then
|
---|
503 | // bisect
|
---|
504 | }
|
---|
505 |
|
---|
506 | x3 = Math.max(Math.min(x3, x4 - INT * (x4 - x2)), x2 + INT
|
---|
507 | * (x4 - x2)); // don't accept too close
|
---|
508 |
|
---|
509 | Matrix m1 = s.times(x3).plus(params);
|
---|
510 | // f3 = f.evaluate(m1,cf, in, out, df3);
|
---|
511 | f3 = negativeLogLikelihood(m1, in, out, df3);
|
---|
512 |
|
---|
513 | if (f3 < F0) {
|
---|
514 | X0 = m1.copy();
|
---|
515 | F0 = f3;
|
---|
516 | dF0 = df3.copy(); // keep best values
|
---|
517 | }
|
---|
518 |
|
---|
519 | M = M - 1;
|
---|
520 | i = (length < 0) ? i + 1 : i; // count iterations?!
|
---|
521 |
|
---|
522 | d3 = df3.transpose().times(s).get(0, 0); // new slope
|
---|
523 |
|
---|
524 | } // end interpolation
|
---|
525 |
|
---|
526 | if (Math.abs(d3) < -SIG * d0 && f3 < f0 + x3 * RHO * d0) { // if
|
---|
527 | // line
|
---|
528 | // search
|
---|
529 | // succeeded
|
---|
530 | params = s.times(x3).plus(params);
|
---|
531 | f0 = f3;
|
---|
532 |
|
---|
533 | double[] elem = fX.getColumnPackedCopy();
|
---|
534 | double[] newfX = new double[elem.length + 1];
|
---|
535 |
|
---|
536 | System.arraycopy(elem, 0, newfX, 0, elem.length);
|
---|
537 | newfX[elem.length - 1] = f0;
|
---|
538 | fX = new Matrix(newfX, newfX.length); // update variables
|
---|
539 |
|
---|
540 | // System.out.println("Function evaluation "+i+" Value "+f0);
|
---|
541 |
|
---|
542 | double tmp1 = df3.transpose().times(df3)
|
---|
543 | .minus(df0.transpose().times(df3)).get(0, 0);
|
---|
544 | double tmp2 = df0.transpose().times(df0).get(0, 0);
|
---|
545 |
|
---|
546 | s = s.times(tmp1 / tmp2).minus(df3);
|
---|
547 |
|
---|
548 | df0 = df3; // swap derivatives
|
---|
549 | d3 = d0;
|
---|
550 | d0 = df0.transpose().times(s).get(0, 0);
|
---|
551 |
|
---|
552 | if (d0 > 0) { // new slope must be negative
|
---|
553 | s = df0.times(-1); // otherwise use steepest direction
|
---|
554 | d0 = s.times(-1).transpose().times(s).get(0, 0);
|
---|
555 | }
|
---|
556 |
|
---|
557 | x3 = x3 * Math.min(RATIO, d3 / (d0 - Double.MIN_VALUE)); // slope
|
---|
558 | // ratio
|
---|
559 | // but
|
---|
560 | // max
|
---|
561 | // RATIO
|
---|
562 | ls_failed = 0; // this line search did not fail
|
---|
563 |
|
---|
564 | } else {
|
---|
565 |
|
---|
566 | params = X0;
|
---|
567 | f0 = F0;
|
---|
568 | df0 = dF0; // restore best point so far
|
---|
569 |
|
---|
570 | if (ls_failed == 1 || i > Math.abs(length)) { // line search
|
---|
571 | // failed twice
|
---|
572 | // in a row
|
---|
573 | break; // or we ran out of time, so we give up
|
---|
574 | }
|
---|
575 |
|
---|
576 | s = df0.times(-1);
|
---|
577 | d0 = s.times(-1).transpose().times(s).get(0, 0); // try steepest
|
---|
578 | x3 = 1 / (1 - d0);
|
---|
579 | ls_failed = 1; // this line search failed
|
---|
580 |
|
---|
581 | }
|
---|
582 |
|
---|
583 | }
|
---|
584 |
|
---|
585 | return params;
|
---|
586 | }
|
---|
587 |
|
---|
588 | private static boolean hasInvalidNumbers(double[] array) {
|
---|
589 |
|
---|
590 | for (double a : array) {
|
---|
591 | if (Double.isInfinite(a) || Double.isNaN(a)) {
|
---|
592 | return true;
|
---|
593 | }
|
---|
594 | }
|
---|
595 |
|
---|
596 | return false;
|
---|
597 | }
|
---|
598 |
|
---|
599 | /**
|
---|
600 | * A simple test
|
---|
601 | *
|
---|
602 | * @param args
|
---|
603 | */
|
---|
604 | public static void main(String[] args) {
|
---|
605 |
|
---|
606 | CovarianceFunction covFunc = new CovSum(6, new CovLINone(),
|
---|
607 | new CovNoise());
|
---|
608 | GaussianProcess gp = new GaussianProcess(covFunc);
|
---|
609 |
|
---|
610 | double[][] logtheta0 = new double[][] { { 0.1 }, { Math.log(0.1) } };
|
---|
611 | /*
|
---|
612 | * double[][] logtheta0 = new double[][]{ {0.1}, {0.2}, {0.3}, {0.4},
|
---|
613 | * {0.5}, {0.6}, {0.7}, {Math.log(0.1)}};
|
---|
614 | */
|
---|
615 |
|
---|
616 | Matrix params0 = new Matrix(logtheta0);
|
---|
617 |
|
---|
618 | Matrix[] data = CSVtoMatrix
|
---|
619 | .load("/Users/MaxLam/Desktop/Researches/ANAC2015/DragonAgent/DragonAgentBeta/src/armdata.csv",
|
---|
620 | 6, 1);
|
---|
621 | Matrix X = data[0];
|
---|
622 | Matrix Y = data[1];
|
---|
623 |
|
---|
624 | gp.train(X, Y, params0, -20);
|
---|
625 |
|
---|
626 | // int size = 100;
|
---|
627 | // Matrix Xtrain = new Matrix(size, 1);
|
---|
628 | // Matrix Ytrain = new Matrix(size, 1);
|
---|
629 | //
|
---|
630 | // Matrix Xtest = new Matrix(size, 1);
|
---|
631 | // Matrix Ytest = new Matrix(size, 1);
|
---|
632 |
|
---|
633 | // half of the sinusoid uses points very close to each other and the
|
---|
634 | // other half uses
|
---|
635 | // more sparse data
|
---|
636 |
|
---|
637 | // double inc = 2 * Math.PI / 1000;
|
---|
638 | //
|
---|
639 | // double[][] data = new double[1000][2];
|
---|
640 | //
|
---|
641 | // Random random = new Random();
|
---|
642 | //
|
---|
643 | // for(int i=0; i<1000; i++){
|
---|
644 | // data[i][0] = i*inc;
|
---|
645 | // data[i][1] = Math.sin(i*+inc);
|
---|
646 | // }
|
---|
647 | //
|
---|
648 | //
|
---|
649 | // // NEED TO FILL Xtrain, Ytrain and Xtest, Ytest
|
---|
650 | //
|
---|
651 | //
|
---|
652 | // gp.train(Xtrain,Ytest,params0);
|
---|
653 | //
|
---|
654 | // // SimpleRealTimePlotter plot = new
|
---|
655 | // SimpleRealTimePlotter("","","",false,false,true,200);
|
---|
656 | //
|
---|
657 | // final Matrix[] out = gp.predict(Xtest);
|
---|
658 | //
|
---|
659 | // for(int i=0; i<Xtest.getRowDimension(); i++){
|
---|
660 | //
|
---|
661 | // final double mean = out[0].get(i,0);
|
---|
662 | // final double var = 3 * Math.sqrt(out[1].get(i,0));
|
---|
663 | //
|
---|
664 | // plot.addPoint(i, mean, 0, true);
|
---|
665 | // plot.addPoint(i, mean-var, 1, true);
|
---|
666 | // plot.addPoint(i, mean+var, 2, true);
|
---|
667 | // plot.addPoint(i, Ytest.get(i,0), 3, true);
|
---|
668 | //
|
---|
669 | // plot.fillPlot();
|
---|
670 | // }
|
---|
671 |
|
---|
672 | Matrix[] datastar = CSVtoMatrix
|
---|
673 | .load("/Users/MaxLam/Desktop/Researches/ANAC2015/DragonAgent/DragonAgentBeta/src/armdatastar.csv",
|
---|
674 | 6, 1);
|
---|
675 | Matrix Xstar = datastar[0];
|
---|
676 | Matrix Ystar = datastar[1];
|
---|
677 |
|
---|
678 | Matrix[] res = gp.predict(Xstar);
|
---|
679 |
|
---|
680 | res[0].print(res[0].getColumnDimension(), 16);
|
---|
681 | System.err.println("");
|
---|
682 | res[1].print(res[1].getColumnDimension(), 16);
|
---|
683 | }
|
---|
684 |
|
---|
685 | }
|
---|