source: src/main/java/agents/anac/y2019/harddealer/math3/fitting/GaussianCurveFitter.java

Last change on this file was 204, checked in by Katsuhide Fujita, 5 years ago

Fixed errors of ANAC2019 agents

  • Property svn:executable set to *
File size: 16.8 KB
Line 
1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17package agents.anac.y2019.harddealer.math3.fitting;
18
19import java.util.ArrayList;
20import java.util.Collection;
21import java.util.Collections;
22import java.util.Comparator;
23import java.util.List;
24
25import agents.anac.y2019.harddealer.math3.analysis.function.Gaussian;
26import agents.anac.y2019.harddealer.math3.exception.NotStrictlyPositiveException;
27import agents.anac.y2019.harddealer.math3.exception.NullArgumentException;
28import agents.anac.y2019.harddealer.math3.exception.NumberIsTooSmallException;
29import agents.anac.y2019.harddealer.math3.exception.OutOfRangeException;
30import agents.anac.y2019.harddealer.math3.exception.ZeroException;
31import agents.anac.y2019.harddealer.math3.exception.util.LocalizedFormats;
32import agents.anac.y2019.harddealer.math3.fitting.leastsquares.LeastSquaresBuilder;
33import agents.anac.y2019.harddealer.math3.fitting.leastsquares.LeastSquaresProblem;
34import agents.anac.y2019.harddealer.math3.linear.DiagonalMatrix;
35import agents.anac.y2019.harddealer.math3.util.FastMath;
36
37/**
38 * Fits points to a {@link
39 * agents.anac.y2019.harddealer.math3.analysis.function.Gaussian.Parametric Gaussian}
40 * function.
41 * <br/>
42 * The {@link #withStartPoint(double[]) initial guess values} must be passed
43 * in the following order:
44 * <ul>
45 * <li>Normalization</li>
46 * <li>Mean</li>
47 * <li>Sigma</li>
48 * </ul>
49 * The optimal values will be returned in the same order.
50 *
51 * <p>
52 * Usage example:
53 * <pre>
54 * WeightedObservedPoints obs = new WeightedObservedPoints();
55 * obs.add(4.0254623, 531026.0);
56 * obs.add(4.03128248, 984167.0);
57 * obs.add(4.03839603, 1887233.0);
58 * obs.add(4.04421621, 2687152.0);
59 * obs.add(4.05132976, 3461228.0);
60 * obs.add(4.05326982, 3580526.0);
61 * obs.add(4.05779662, 3439750.0);
62 * obs.add(4.0636168, 2877648.0);
63 * obs.add(4.06943698, 2175960.0);
64 * obs.add(4.07525716, 1447024.0);
65 * obs.add(4.08237071, 717104.0);
66 * obs.add(4.08366408, 620014.0);
67 * double[] parameters = GaussianCurveFitter.create().fit(obs.toList());
68 * </pre>
69 *
70 * @since 3.3
71 */
72public class GaussianCurveFitter extends AbstractCurveFitter {
73 /** Parametric function to be fitted. */
74 private static final Gaussian.Parametric FUNCTION = new Gaussian.Parametric() {
75 /** {@inheritDoc} */
76 @Override
77 public double value(double x, double ... p) {
78 double v = Double.POSITIVE_INFINITY;
79 try {
80 v = super.value(x, p);
81 } catch (NotStrictlyPositiveException e) { // NOPMD
82 // Do nothing.
83 }
84 return v;
85 }
86
87 /** {@inheritDoc} */
88 @Override
89 public double[] gradient(double x, double ... p) {
90 double[] v = { Double.POSITIVE_INFINITY,
91 Double.POSITIVE_INFINITY,
92 Double.POSITIVE_INFINITY };
93 try {
94 v = super.gradient(x, p);
95 } catch (NotStrictlyPositiveException e) { // NOPMD
96 // Do nothing.
97 }
98 return v;
99 }
100 };
101 /** Initial guess. */
102 private final double[] initialGuess;
103 /** Maximum number of iterations of the optimization algorithm. */
104 private final int maxIter;
105
106 /**
107 * Contructor used by the factory methods.
108 *
109 * @param initialGuess Initial guess. If set to {@code null}, the initial guess
110 * will be estimated using the {@link ParameterGuesser}.
111 * @param maxIter Maximum number of iterations of the optimization algorithm.
112 */
113 private GaussianCurveFitter(double[] initialGuess,
114 int maxIter) {
115 this.initialGuess = initialGuess;
116 this.maxIter = maxIter;
117 }
118
119 /**
120 * Creates a default curve fitter.
121 * The initial guess for the parameters will be {@link ParameterGuesser}
122 * computed automatically, and the maximum number of iterations of the
123 * optimization algorithm is set to {@link Integer#MAX_VALUE}.
124 *
125 * @return a curve fitter.
126 *
127 * @see #withStartPoint(double[])
128 * @see #withMaxIterations(int)
129 */
130 public static GaussianCurveFitter create() {
131 return new GaussianCurveFitter(null, Integer.MAX_VALUE);
132 }
133
134 /**
135 * Configure the start point (initial guess).
136 * @param newStart new start point (initial guess)
137 * @return a new instance.
138 */
139 public GaussianCurveFitter withStartPoint(double[] newStart) {
140 return new GaussianCurveFitter(newStart.clone(),
141 maxIter);
142 }
143
144 /**
145 * Configure the maximum number of iterations.
146 * @param newMaxIter maximum number of iterations
147 * @return a new instance.
148 */
149 public GaussianCurveFitter withMaxIterations(int newMaxIter) {
150 return new GaussianCurveFitter(initialGuess,
151 newMaxIter);
152 }
153
154 /** {@inheritDoc} */
155 @Override
156 protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
157
158 // Prepare least-squares problem.
159 final int len = observations.size();
160 final double[] target = new double[len];
161 final double[] weights = new double[len];
162
163 int i = 0;
164 for (WeightedObservedPoint obs : observations) {
165 target[i] = obs.getY();
166 weights[i] = obs.getWeight();
167 ++i;
168 }
169
170 final AbstractCurveFitter.TheoreticalValuesFunction model =
171 new AbstractCurveFitter.TheoreticalValuesFunction(FUNCTION, observations);
172
173 final double[] startPoint = initialGuess != null ?
174 initialGuess :
175 // Compute estimation.
176 new ParameterGuesser(observations).guess();
177
178 // Return a new least squares problem set up to fit a Gaussian curve to the
179 // observed points.
180 return new LeastSquaresBuilder().
181 maxEvaluations(Integer.MAX_VALUE).
182 maxIterations(maxIter).
183 start(startPoint).
184 target(target).
185 weight(new DiagonalMatrix(weights)).
186 model(model.getModelFunction(), model.getModelFunctionJacobian()).
187 build();
188
189 }
190
191 /**
192 * Guesses the parameters {@code norm}, {@code mean}, and {@code sigma}
193 * of a {@link agents.anac.y2019.harddealer.math3.analysis.function.Gaussian.Parametric}
194 * based on the specified observed points.
195 */
196 public static class ParameterGuesser {
197 /** Normalization factor. */
198 private final double norm;
199 /** Mean. */
200 private final double mean;
201 /** Standard deviation. */
202 private final double sigma;
203
204 /**
205 * Constructs instance with the specified observed points.
206 *
207 * @param observations Observed points from which to guess the
208 * parameters of the Gaussian.
209 * @throws NullArgumentException if {@code observations} is
210 * {@code null}.
211 * @throws NumberIsTooSmallException if there are less than 3
212 * observations.
213 */
214 public ParameterGuesser(Collection<WeightedObservedPoint> observations) {
215 if (observations == null) {
216 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
217 }
218 if (observations.size() < 3) {
219 throw new NumberIsTooSmallException(observations.size(), 3, true);
220 }
221
222 final List<WeightedObservedPoint> sorted = sortObservations(observations);
223 final double[] params = basicGuess(sorted.toArray(new WeightedObservedPoint[0]));
224
225 norm = params[0];
226 mean = params[1];
227 sigma = params[2];
228 }
229
230 /**
231 * Gets an estimation of the parameters.
232 *
233 * @return the guessed parameters, in the following order:
234 * <ul>
235 * <li>Normalization factor</li>
236 * <li>Mean</li>
237 * <li>Standard deviation</li>
238 * </ul>
239 */
240 public double[] guess() {
241 return new double[] { norm, mean, sigma };
242 }
243
244 /**
245 * Sort the observations.
246 *
247 * @param unsorted Input observations.
248 * @return the input observations, sorted.
249 */
250 private List<WeightedObservedPoint> sortObservations(Collection<WeightedObservedPoint> unsorted) {
251 final List<WeightedObservedPoint> observations = new ArrayList<WeightedObservedPoint>(unsorted);
252
253 final Comparator<WeightedObservedPoint> cmp = new Comparator<WeightedObservedPoint>() {
254 /** {@inheritDoc} */
255 public int compare(WeightedObservedPoint p1,
256 WeightedObservedPoint p2) {
257 if (p1 == null && p2 == null) {
258 return 0;
259 }
260 if (p1 == null) {
261 return -1;
262 }
263 if (p2 == null) {
264 return 1;
265 }
266 final int cmpX = Double.compare(p1.getX(), p2.getX());
267 if (cmpX < 0) {
268 return -1;
269 }
270 if (cmpX > 0) {
271 return 1;
272 }
273 final int cmpY = Double.compare(p1.getY(), p2.getY());
274 if (cmpY < 0) {
275 return -1;
276 }
277 if (cmpY > 0) {
278 return 1;
279 }
280 final int cmpW = Double.compare(p1.getWeight(), p2.getWeight());
281 if (cmpW < 0) {
282 return -1;
283 }
284 if (cmpW > 0) {
285 return 1;
286 }
287 return 0;
288 }
289 };
290
291 Collections.sort(observations, cmp);
292 return observations;
293 }
294
295 /**
296 * Guesses the parameters based on the specified observed points.
297 *
298 * @param points Observed points, sorted.
299 * @return the guessed parameters (normalization factor, mean and
300 * sigma).
301 */
302 private double[] basicGuess(WeightedObservedPoint[] points) {
303 final int maxYIdx = findMaxY(points);
304 final double n = points[maxYIdx].getY();
305 final double m = points[maxYIdx].getX();
306
307 double fwhmApprox;
308 try {
309 final double halfY = n + ((m - n) / 2);
310 final double fwhmX1 = interpolateXAtY(points, maxYIdx, -1, halfY);
311 final double fwhmX2 = interpolateXAtY(points, maxYIdx, 1, halfY);
312 fwhmApprox = fwhmX2 - fwhmX1;
313 } catch (OutOfRangeException e) {
314 // TODO: Exceptions should not be used for flow control.
315 fwhmApprox = points[points.length - 1].getX() - points[0].getX();
316 }
317 final double s = fwhmApprox / (2 * FastMath.sqrt(2 * FastMath.log(2)));
318
319 return new double[] { n, m, s };
320 }
321
322 /**
323 * Finds index of point in specified points with the largest Y.
324 *
325 * @param points Points to search.
326 * @return the index in specified points array.
327 */
328 private int findMaxY(WeightedObservedPoint[] points) {
329 int maxYIdx = 0;
330 for (int i = 1; i < points.length; i++) {
331 if (points[i].getY() > points[maxYIdx].getY()) {
332 maxYIdx = i;
333 }
334 }
335 return maxYIdx;
336 }
337
338 /**
339 * Interpolates using the specified points to determine X at the
340 * specified Y.
341 *
342 * @param points Points to use for interpolation.
343 * @param startIdx Index within points from which to start the search for
344 * interpolation bounds points.
345 * @param idxStep Index step for searching interpolation bounds points.
346 * @param y Y value for which X should be determined.
347 * @return the value of X for the specified Y.
348 * @throws ZeroException if {@code idxStep} is 0.
349 * @throws OutOfRangeException if specified {@code y} is not within the
350 * range of the specified {@code points}.
351 */
352 private double interpolateXAtY(WeightedObservedPoint[] points,
353 int startIdx,
354 int idxStep,
355 double y)
356 throws OutOfRangeException {
357 if (idxStep == 0) {
358 throw new ZeroException();
359 }
360 final WeightedObservedPoint[] twoPoints
361 = getInterpolationPointsForY(points, startIdx, idxStep, y);
362 final WeightedObservedPoint p1 = twoPoints[0];
363 final WeightedObservedPoint p2 = twoPoints[1];
364 if (p1.getY() == y) {
365 return p1.getX();
366 }
367 if (p2.getY() == y) {
368 return p2.getX();
369 }
370 return p1.getX() + (((y - p1.getY()) * (p2.getX() - p1.getX())) /
371 (p2.getY() - p1.getY()));
372 }
373
374 /**
375 * Gets the two bounding interpolation points from the specified points
376 * suitable for determining X at the specified Y.
377 *
378 * @param points Points to use for interpolation.
379 * @param startIdx Index within points from which to start search for
380 * interpolation bounds points.
381 * @param idxStep Index step for search for interpolation bounds points.
382 * @param y Y value for which X should be determined.
383 * @return the array containing two points suitable for determining X at
384 * the specified Y.
385 * @throws ZeroException if {@code idxStep} is 0.
386 * @throws OutOfRangeException if specified {@code y} is not within the
387 * range of the specified {@code points}.
388 */
389 private WeightedObservedPoint[] getInterpolationPointsForY(WeightedObservedPoint[] points,
390 int startIdx,
391 int idxStep,
392 double y)
393 throws OutOfRangeException {
394 if (idxStep == 0) {
395 throw new ZeroException();
396 }
397 for (int i = startIdx;
398 idxStep < 0 ? i + idxStep >= 0 : i + idxStep < points.length;
399 i += idxStep) {
400 final WeightedObservedPoint p1 = points[i];
401 final WeightedObservedPoint p2 = points[i + idxStep];
402 if (isBetween(y, p1.getY(), p2.getY())) {
403 if (idxStep < 0) {
404 return new WeightedObservedPoint[] { p2, p1 };
405 } else {
406 return new WeightedObservedPoint[] { p1, p2 };
407 }
408 }
409 }
410
411 // Boundaries are replaced by dummy values because the raised
412 // exception is caught and the message never displayed.
413 // TODO: Exceptions should not be used for flow control.
414 throw new OutOfRangeException(y,
415 Double.NEGATIVE_INFINITY,
416 Double.POSITIVE_INFINITY);
417 }
418
419 /**
420 * Determines whether a value is between two other values.
421 *
422 * @param value Value to test whether it is between {@code boundary1}
423 * and {@code boundary2}.
424 * @param boundary1 One end of the range.
425 * @param boundary2 Other end of the range.
426 * @return {@code true} if {@code value} is between {@code boundary1} and
427 * {@code boundary2} (inclusive), {@code false} otherwise.
428 */
429 private boolean isBetween(double value,
430 double boundary1,
431 double boundary2) {
432 return (value >= boundary1 && value <= boundary2) ||
433 (value >= boundary2 && value <= boundary1);
434 }
435 }
436}
Note: See TracBrowser for help on using the repository browser.