source: src/main/java/agents/uk/ac/soton/ecs/gp4j/bmc/GaussianProcessRegressionBMC.java

Last change on this file was 1, checked in by Wouter Pasman, 6 years ago

Initial import : Genius 9.0.0

File size: 11.3 KB
Line 
1package agents.uk.ac.soton.ecs.gp4j.bmc;
2
3import java.util.ArrayList;
4import java.util.Arrays;
5import java.util.HashMap;
6import java.util.List;
7import java.util.Map;
8
9import agents.Jama.Matrix;
10import agents.org.apache.commons.lang.NotImplementedException;
11import agents.org.apache.commons.lang.Validate;
12import agents.org.apache.commons.math.stat.StatUtils;
13import agents.uk.ac.soton.ecs.gp4j.gp.GaussianPredictor;
14import agents.uk.ac.soton.ecs.gp4j.gp.GaussianProcess;
15import agents.uk.ac.soton.ecs.gp4j.gp.GaussianProcessRegression;
16import agents.uk.ac.soton.ecs.gp4j.gp.GaussianRegression;
17import agents.uk.ac.soton.ecs.gp4j.gp.covariancefunctions.CovarianceFunction;
18import agents.uk.ac.soton.ecs.gp4j.util.ArrayUtils;
19import agents.uk.ac.soton.ecs.gp4j.util.MathUtils;
20import agents.uk.ac.soton.ecs.gp4j.util.MatrixUtils;
21
22public class GaussianProcessRegressionBMC implements
23 GaussianRegression<GaussianProcessMixture> {
24
25 // private static Log log =
26 // LogFactory.getLog(GaussianProcessRegressionBMC.class);
27
28 private CovarianceFunction function;
29
30 private List<BasicPrior> priors;
31
32 private List<GaussianProcessRegression> gpRegressions;
33
34 private List<Double> weights;
35
36 private Matrix KSinv_NS_KSinv;
37
38 private boolean initialized;
39
40 private int dataPointsProcessed = 0;
41
42 private GaussianProcessMixture currentPredictor;
43
44 public GaussianProcessRegressionBMC() {
45
46 }
47
48 public void reset() {
49 throw new NotImplementedException();
50 }
51
52 public GaussianProcessRegressionBMC(CovarianceFunction function,
53 List<BasicPrior> priors) {
54 this.function = function;
55 this.priors = priors;
56 initialize();
57 }
58
59 // copy constructor
60 private GaussianProcessRegressionBMC(GaussianProcessRegressionBMC toCopy) {
61 this.function = toCopy.function;
62 this.priors = new ArrayList<BasicPrior>(toCopy.priors);
63
64 this.weights = new ArrayList<Double>(toCopy.weights);
65
66 this.KSinv_NS_KSinv = toCopy.KSinv_NS_KSinv.copy();
67
68 this.currentPredictor = toCopy.currentPredictor;
69 initialize();
70 this.gpRegressions = new ArrayList<GaussianProcessRegression>();
71 }
72
73 public void initialize() {
74 if (!initialized) {
75 initializeSamples();
76 calculateKSinv_NS_KSinv();
77 }
78
79 initialized = true;
80 }
81
82 private void initializeSamples() {
83 int i = 0;
84 double[][] independentSamples = new double[priors.size()][0];
85 for (BasicPrior prior : priors)
86 independentSamples[i++] = prior.getLogSamples();
87
88 double[][] samples = ArrayUtils.allCombinations(independentSamples);
89
90 gpRegressions = new ArrayList<GaussianProcessRegression>(samples.length);
91
92 for (i = 0; i < samples.length; i++) {
93 gpRegressions.add(new GaussianProcessRegression(samples[i],
94 function));
95 }
96
97 // for (int j = 0; j < samples.length; j++) {
98 // log.debug("Sample : " + ArrayUtils.toString(samples[j]));
99 // }
100 }
101
102 /**
103 * Calculate the first three terms of equation 3.8.13, which is the product
104 * of the inverse of KS, NS, and the inverse of KS
105 */
106 private void calculateKSinv_NS_KSinv() {
107 KSinv_NS_KSinv = new Matrix(1, 1, 1.0);
108
109 for (BasicPrior prior : priors) {
110 Matrix KS = calculateKS(prior.getWidth(), prior.getLogSamples());
111 Matrix NS = calculateNS(prior.getWidth(), prior.getLogSamples(),
112 prior.getLogMean(), prior.getStandardDeviation());
113
114 Matrix cholKS = KS.chol().getL();
115 Matrix result = MatrixUtils.solveChol(cholKS, NS).transpose();
116 result = cholKS.transpose().solve(result);
117 result = cholKS.solve(result);
118
119 KSinv_NS_KSinv = MatrixUtils.kronecker(KSinv_NS_KSinv, result);
120 }
121 }
122
123 /**
124 * Calculate the NS Matrix using equation 3.8.10
125 */
126 private Matrix calculateNS(double width, double[] samples, double mean,
127 double standardDeviation) {
128 Matrix NS = new Matrix(samples.length, samples.length);
129
130 double variance = standardDeviation * standardDeviation;
131 double lambda = variance + width * width;
132 double precX = lambda - variance * variance / lambda;
133 double precY = 1 / (variance - lambda * lambda / variance);
134 double multConst = 1 / Math.sqrt(Math.pow(2 * Math.PI, 2) * lambda
135 * precX);
136
137 for (int i = 0; i < samples.length; i++) {
138 for (int j = 0; j < samples.length; j++) {
139 double xDev = samples[i] - mean;
140 double yDev = samples[j] - mean;
141
142 NS.set(i, j, multConst
143 * Math.exp(-0.5 / precX * (xDev * xDev + yDev * yDev)
144 - precY * xDev * yDev));
145 }
146 }
147
148 return NS;
149 }
150
151 /**
152 * Calculate the KS matrix
153 */
154 private Matrix calculateKS(double width, double[] samples) {
155 Matrix KS = new Matrix(samples.length, samples.length);
156
157 for (int i = 0; i < samples.length; i++) {
158 for (int j = 0; j < samples.length; j++) {
159 KS.set(i, j, MathUtils.normPDF(samples[i], samples[j], width));
160 }
161 }
162
163 return KS;
164 }
165
166 public List<GaussianProcessRegression> getGpRegressions() {
167 return gpRegressions;
168 }
169
170 public GaussianProcessMixture calculateRegression(Matrix trainX,
171 Matrix trainY) {
172 initialize();
173 List<GaussianProcess> gaussianProcesses = new ArrayList<GaussianProcess>();
174
175 for (GaussianProcessRegression gpRegression : gpRegressions) {
176 GaussianProcess gp = gpRegression.calculateRegression(trainX,
177 trainY);
178 gaussianProcesses.add(gp);
179 }
180
181 calculateWeights();
182 Validate.isTrue(gpRegressions.size() == weights.size());
183
184 currentPredictor = new GaussianProcessMixture(gaussianProcesses,
185 weights);
186 return currentPredictor;
187 }
188
189 public GaussianProcessMixture downdateRegression(int i) {
190 List<GaussianProcess> gaussianProcesses = new ArrayList<GaussianProcess>();
191
192 for (GaussianProcessRegression gpRegression : gpRegressions) {
193 GaussianProcess gp = gpRegression.downdateRegression(i);
194 gaussianProcesses.add(gp);
195 }
196
197 // log-likelihoods will not change during a downdate. Therefore, weights
198 // need not be recalculated
199 // calculateWeights();
200
201 currentPredictor = new GaussianProcessMixture(gaussianProcesses,
202 weights);
203 return currentPredictor;
204
205 }
206
207 public GaussianProcessMixture downdateRegression() {
208 return downdateRegression(1);
209 }
210
211 public GaussianProcessMixture updateRegression(Matrix addedTrainX,
212 Matrix addedTrainY, boolean downDate) {
213 return updateRegression(addedTrainX, addedTrainY);
214 }
215
216 public GaussianProcessMixture updateRegression(Matrix addedTrainX,
217 Matrix addedTrainY) {
218 initialize();
219
220 dataPointsProcessed += addedTrainX.getRowDimension();
221
222 List<GaussianProcess> gaussianProcesses = new ArrayList<GaussianProcess>();
223
224 for (GaussianProcessRegression gpRegression : gpRegressions) {
225 GaussianProcess gp = gpRegression.updateRegression(addedTrainX,
226 addedTrainY);
227 gaussianProcesses.add(gp);
228 }
229
230 calculateWeights();
231
232 Validate.isTrue(gpRegressions.size() == weights.size());
233
234 currentPredictor = new GaussianProcessMixture(gaussianProcesses,
235 weights);
236
237 recalculateSamples();
238
239 return currentPredictor;
240 }
241
242 private void calculateWeights() {
243 int size = gpRegressions.size();
244
245 double[] logLikelihoods = new double[size];
246
247 // calculate the weights using equation 3.8.16
248 for (int i = 0; i < size; i++)
249 logLikelihoods[i] = gpRegressions.get(i).getLogLikelihood();
250
251 // scale log-likelihoods for numerical stability
252 double maxLogLikelihood = StatUtils.max(logLikelihoods);
253 Matrix rs = new Matrix(size, 1);
254 for (int i = 0; i < size; i++)
255 rs.set(i, 0, Math.exp(logLikelihoods[i] - maxLogLikelihood));
256
257 Matrix numerator = KSinv_NS_KSinv.times(rs);
258 double denominator = MatrixUtils.sum(numerator).get(0, 0);
259 Matrix weightsMatrix = numerator.times(1 / denominator);
260
261 weights = Arrays.asList(ArrayUtils.toObject(weightsMatrix
262 .getColumnPackedCopy()));
263 }
264
265 private void recalculateSamples() {
266 if (dataPointsProcessed % 50 == 0) {
267 double threshold = 1e-3;
268
269 for (int j = 0; j < priors.size(); j++) {
270 if (priors.get(j).getSampleCount() <= 5)
271 continue;
272
273 double[][] marginalizedWeights = getMarginalizedHyperParameterWeights(j);
274 double maxWeight = Double.NEGATIVE_INFINITY;
275 double maxWeightedParam = Double.NEGATIVE_INFINITY;
276 int underThreshold = 0;
277
278 for (int i = 0; i < marginalizedWeights.length; i++) {
279 if (marginalizedWeights[i][1] < threshold)
280 underThreshold++;
281
282 if (marginalizedWeights[i][1] > maxWeight) {
283 maxWeight = marginalizedWeights[i][1];
284 maxWeightedParam = marginalizedWeights[i][0];
285 }
286 }
287
288 BasicPrior oldPrior = priors.get(j);
289 int newSampleCount = Math.max(5, oldPrior.getSampleCount()
290 - underThreshold / 2);
291 double newStandardDeviation = oldPrior.getStandardDeviation() * 0.8;
292 BasicPrior newPrior = new BasicPrior(newSampleCount, Math
293 .exp(maxWeightedParam), newStandardDeviation);
294
295 priors.set(j, newPrior);
296 }
297
298 initialized = false;
299 calculateRegression(getTrainX(), getTrainY());
300 }
301 }
302
303 protected Matrix getKSinv_NS_KSinv() {
304 return KSinv_NS_KSinv;
305 }
306
307 protected List<Double> getWeights() {
308 return weights;
309 }
310
311 public Map<Double[], Double> getHyperParameterWeights() {
312 HashMap<Double[], Double> weighing = new HashMap<Double[], Double>();
313
314 for (int i = 0; i < weights.size(); i++) {
315 weighing.put(gpRegressions.get(i).getHyperParameters(), weights
316 .get(i));
317 }
318
319 return weighing;
320 }
321
322 /**
323 * Returns a 3D matrix of. First dimension specifies hyperparameter index.
324 * Second and third dimensions form a 2D matrix with (param value, weight)
325 * tuples
326 *
327 * @return
328 */
329 public double[][] getMarginalizedHyperParameterWeights(int paramIndex) {
330 int n = priors.get(paramIndex).getSampleCount();
331 double[] samples = priors.get(paramIndex).getLogSamples();
332 double[][] result = new double[n][];
333
334 for (int i = 0; i < samples.length; i++) {
335 result[i] = new double[2];
336 result[i][0] = samples[i];
337
338 for (int j = 0; j < gpRegressions.size(); j++) {
339 GaussianProcessRegression gpr = gpRegressions.get(j);
340 if (gpr.getLogHyperParameters()[paramIndex] == samples[i])
341 result[i][1] += weights.get(j);
342 }
343 }
344
345 return result;
346 }
347
348 public GaussianProcessMixture calculateRegression(double[] trainX,
349 double[] trainY) {
350 return calculateRegression(new Matrix(trainX, 1).transpose(),
351 new Matrix(trainY, 1).transpose());
352 }
353
354 public int getTrainingSampleCount() {
355 return gpRegressions.get(0).getTrainingSampleCount();
356 }
357
358 public Matrix getTrainX() {
359 return gpRegressions.get(0).getTrainX();
360 }
361
362 public Matrix getTrainY() {
363 return gpRegressions.get(0).getTrainY();
364 }
365
366 public GaussianProcessRegressionBMC copy() {
367 Validate.isTrue(initialized, "Cannot copy before initialized");
368
369 GaussianProcessRegressionBMC regressionBMC = new GaussianProcessRegressionBMC(
370 this);
371
372 for (GaussianProcessRegression regression : gpRegressions) {
373 regressionBMC.gpRegressions.add(regression.copy());
374 }
375
376 Validate.isTrue(regressionBMC.gpRegressions.size() == KSinv_NS_KSinv
377 .getColumnDimension());
378
379 return regressionBMC;
380 }
381
382 public GaussianProcessRegressionBMC shallowCopy() {
383 GaussianProcessRegressionBMC regressionBMC = new GaussianProcessRegressionBMC(
384 this);
385
386 for (GaussianProcessRegression regression : gpRegressions) {
387 regressionBMC.gpRegressions.add(regression.shallowCopy());
388 }
389
390 Validate.isTrue(gpRegressions.size() == KSinv_NS_KSinv
391 .getColumnDimension());
392
393 return regressionBMC;
394 }
395
396 public void setCovarianceFunction(CovarianceFunction function) {
397 this.function = function;
398 }
399
400 public void setPriors(List<BasicPrior> priors) {
401 this.priors = priors;
402 }
403
404 public GaussianPredictor<?> getCurrentPredictor() {
405 return currentPredictor;
406 }
407
408 public void setPriors(BasicPrior... priors) {
409 this.priors = Arrays.asList(priors);
410 }
411}
Note: See TracBrowser for help on using the repository browser.