1 | package agents.uk.ac.soton.ecs.gp4j.util;
|
---|
2 |
|
---|
3 | import java.io.BufferedReader;
|
---|
4 | import java.io.File;
|
---|
5 | import java.io.FileReader;
|
---|
6 | import java.io.IOException;
|
---|
7 | import java.util.Locale;
|
---|
8 | import java.util.Map;
|
---|
9 |
|
---|
10 | import agents.Jama.Matrix;
|
---|
11 | import agents.uk.ac.soton.ecs.gp4j.bmc.BasicPrior;
|
---|
12 | import agents.uk.ac.soton.ecs.gp4j.bmc.GaussianProcessRegressionBMC;
|
---|
13 | import agents.uk.ac.soton.ecs.gp4j.gp.covariancefunctions.CovarianceFunctionFactory;
|
---|
14 |
|
---|
15 | public class DataLearner {
|
---|
16 | public static void main(String[] args) throws IOException {
|
---|
17 | BufferedReader reader = new BufferedReader(new FileReader(new File(
|
---|
18 | "/mnt/data/berkeley-dataset/sensor1_times_0.1.txt")));
|
---|
19 | Matrix trainX = Matrix.read(reader);
|
---|
20 |
|
---|
21 | reader = new BufferedReader(new FileReader(new File(
|
---|
22 | "/mnt/data/berkeley-dataset/sensor1_temps_0.1.txt")));
|
---|
23 | Matrix trainY = Matrix.read(reader);
|
---|
24 |
|
---|
25 | GaussianProcessRegressionBMC regression = new GaussianProcessRegressionBMC();
|
---|
26 | regression.setCovarianceFunction(CovarianceFunctionFactory
|
---|
27 | .getNoisySquaredExponentialARDCovarianceFunction());
|
---|
28 | // .getNoisy2DTimeSquaredExponentialCovarianceFunction());
|
---|
29 |
|
---|
30 | // BasicPrior lengthScalePrior = new BasicPrior(5, 5.0, 0.5);
|
---|
31 | BasicPrior timeScalePrior = new BasicPrior(10, 5000, 0.15);
|
---|
32 | BasicPrior signalVariance = new BasicPrior(10, 10, 0.15);
|
---|
33 | BasicPrior noise = new BasicPrior(1, 0.4, 0.3);
|
---|
34 |
|
---|
35 | regression.setPriors(timeScalePrior, signalVariance, noise);
|
---|
36 |
|
---|
37 | // int batchSize = 1;
|
---|
38 |
|
---|
39 | regression.updateRegression(trainX, trainY);
|
---|
40 |
|
---|
41 | printHyperParamWeights(regression.getHyperParameterWeights(), 1);
|
---|
42 |
|
---|
43 | // for (int i = 0; i < trainX.getRowDimension(); i++) {
|
---|
44 | //
|
---|
45 | // System.out.println(i);
|
---|
46 | //
|
---|
47 | // regression.updateRegression(trainX.getMatrix(i, i, 0, 0), trainY
|
---|
48 | // .getMatrix(i, i, 0, 0));
|
---|
49 | //
|
---|
50 | // // if (regression.getTrainingSampleCount() > 120)
|
---|
51 | // // regression.downdateRegression();
|
---|
52 | //
|
---|
53 | // printHyperParamWeights(regression.getHyperParameterWeights(), i);
|
---|
54 | // }
|
---|
55 | }
|
---|
56 |
|
---|
57 | private static void printHyperParamWeights(
|
---|
58 | Map<Double[], Double> hyperParameterWeights, int round)
|
---|
59 | throws IOException {
|
---|
60 |
|
---|
61 | StringBuffer buffer = new StringBuffer();
|
---|
62 |
|
---|
63 | for (Double[] hyper : hyperParameterWeights.keySet()) {
|
---|
64 | Double weight = hyperParameterWeights.get(hyper);
|
---|
65 |
|
---|
66 | buffer.append(String.format(Locale.US, " %6d", round));
|
---|
67 |
|
---|
68 | for (int j = 0; j < hyper.length; j++) {
|
---|
69 | buffer.append(String.format(Locale.US, " %15.5f", hyper[j]));
|
---|
70 | }
|
---|
71 |
|
---|
72 | buffer.append(String.format(Locale.US, " %15.5f\n", weight));
|
---|
73 | }
|
---|
74 |
|
---|
75 | // FileUtils.writeStringToFile(new File("params", "params-" + round
|
---|
76 | // + ".txt"), buffer.toString());
|
---|
77 | }
|
---|
78 | }
|
---|