source: src/main/java/agents/uk/ac/soton/ecs/gp4j/util/DataLearner.java

Last change on this file was 1, checked in by Wouter Pasman, 6 years ago

Initial import : Genius 9.0.0

File size: 2.5 KB
Line 
1package agents.uk.ac.soton.ecs.gp4j.util;
2
3import java.io.BufferedReader;
4import java.io.File;
5import java.io.FileReader;
6import java.io.IOException;
7import java.util.Locale;
8import java.util.Map;
9
10import agents.Jama.Matrix;
11import agents.uk.ac.soton.ecs.gp4j.bmc.BasicPrior;
12import agents.uk.ac.soton.ecs.gp4j.bmc.GaussianProcessRegressionBMC;
13import agents.uk.ac.soton.ecs.gp4j.gp.covariancefunctions.CovarianceFunctionFactory;
14
15public class DataLearner {
16 public static void main(String[] args) throws IOException {
17 BufferedReader reader = new BufferedReader(new FileReader(new File(
18 "/mnt/data/berkeley-dataset/sensor1_times_0.1.txt")));
19 Matrix trainX = Matrix.read(reader);
20
21 reader = new BufferedReader(new FileReader(new File(
22 "/mnt/data/berkeley-dataset/sensor1_temps_0.1.txt")));
23 Matrix trainY = Matrix.read(reader);
24
25 GaussianProcessRegressionBMC regression = new GaussianProcessRegressionBMC();
26 regression.setCovarianceFunction(CovarianceFunctionFactory
27 .getNoisySquaredExponentialARDCovarianceFunction());
28 // .getNoisy2DTimeSquaredExponentialCovarianceFunction());
29
30 // BasicPrior lengthScalePrior = new BasicPrior(5, 5.0, 0.5);
31 BasicPrior timeScalePrior = new BasicPrior(10, 5000, 0.15);
32 BasicPrior signalVariance = new BasicPrior(10, 10, 0.15);
33 BasicPrior noise = new BasicPrior(1, 0.4, 0.3);
34
35 regression.setPriors(timeScalePrior, signalVariance, noise);
36
37 // int batchSize = 1;
38
39 regression.updateRegression(trainX, trainY);
40
41 printHyperParamWeights(regression.getHyperParameterWeights(), 1);
42
43 // for (int i = 0; i < trainX.getRowDimension(); i++) {
44 //
45 // System.out.println(i);
46 //
47 // regression.updateRegression(trainX.getMatrix(i, i, 0, 0), trainY
48 // .getMatrix(i, i, 0, 0));
49 //
50 // // if (regression.getTrainingSampleCount() > 120)
51 // // regression.downdateRegression();
52 //
53 // printHyperParamWeights(regression.getHyperParameterWeights(), i);
54 // }
55 }
56
57 private static void printHyperParamWeights(
58 Map<Double[], Double> hyperParameterWeights, int round)
59 throws IOException {
60
61 StringBuffer buffer = new StringBuffer();
62
63 for (Double[] hyper : hyperParameterWeights.keySet()) {
64 Double weight = hyperParameterWeights.get(hyper);
65
66 buffer.append(String.format(Locale.US, " %6d", round));
67
68 for (int j = 0; j < hyper.length; j++) {
69 buffer.append(String.format(Locale.US, " %15.5f", hyper[j]));
70 }
71
72 buffer.append(String.format(Locale.US, " %15.5f\n", weight));
73 }
74
75 // FileUtils.writeStringToFile(new File("params", "params-" + round
76 // + ".txt"), buffer.toString());
77 }
78}
Note: See TracBrowser for help on using the repository browser.