source: src/main/java/agents/anac/y2019/harddealer/math3/fitting/SimpleCurveFitter.java

Last change on this file was 204, checked in by Katsuhide Fujita, 5 years ago

Fixed errors of ANAC2019 agents

  • Property svn:executable set to *
File size: 4.8 KB
Line 
1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17package agents.anac.y2019.harddealer.math3.fitting;
18
19import java.util.Collection;
20
21import agents.anac.y2019.harddealer.math3.analysis.ParametricUnivariateFunction;
22import agents.anac.y2019.harddealer.math3.fitting.leastsquares.LeastSquaresBuilder;
23import agents.anac.y2019.harddealer.math3.fitting.leastsquares.LeastSquaresProblem;
24import agents.anac.y2019.harddealer.math3.linear.DiagonalMatrix;
25
26/**
27 * Fits points to a user-defined {@link ParametricUnivariateFunction function}.
28 *
29 * @since 3.4
30 */
31public class SimpleCurveFitter extends AbstractCurveFitter {
32 /** Function to fit. */
33 private final ParametricUnivariateFunction function;
34 /** Initial guess for the parameters. */
35 private final double[] initialGuess;
36 /** Maximum number of iterations of the optimization algorithm. */
37 private final int maxIter;
38
39 /**
40 * Contructor used by the factory methods.
41 *
42 * @param function Function to fit.
43 * @param initialGuess Initial guess. Cannot be {@code null}. Its length must
44 * be consistent with the number of parameters of the {@code function} to fit.
45 * @param maxIter Maximum number of iterations of the optimization algorithm.
46 */
47 private SimpleCurveFitter(ParametricUnivariateFunction function,
48 double[] initialGuess,
49 int maxIter) {
50 this.function = function;
51 this.initialGuess = initialGuess;
52 this.maxIter = maxIter;
53 }
54
55 /**
56 * Creates a curve fitter.
57 * The maximum number of iterations of the optimization algorithm is set
58 * to {@link Integer#MAX_VALUE}.
59 *
60 * @param f Function to fit.
61 * @param start Initial guess for the parameters. Cannot be {@code null}.
62 * Its length must be consistent with the number of parameters of the
63 * function to fit.
64 * @return a curve fitter.
65 *
66 * @see #withStartPoint(double[])
67 * @see #withMaxIterations(int)
68 */
69 public static SimpleCurveFitter create(ParametricUnivariateFunction f,
70 double[] start) {
71 return new SimpleCurveFitter(f, start, Integer.MAX_VALUE);
72 }
73
74 /**
75 * Configure the start point (initial guess).
76 * @param newStart new start point (initial guess)
77 * @return a new instance.
78 */
79 public SimpleCurveFitter withStartPoint(double[] newStart) {
80 return new SimpleCurveFitter(function,
81 newStart.clone(),
82 maxIter);
83 }
84
85 /**
86 * Configure the maximum number of iterations.
87 * @param newMaxIter maximum number of iterations
88 * @return a new instance.
89 */
90 public SimpleCurveFitter withMaxIterations(int newMaxIter) {
91 return new SimpleCurveFitter(function,
92 initialGuess,
93 newMaxIter);
94 }
95
96 /** {@inheritDoc} */
97 @Override
98 protected LeastSquaresProblem getProblem(Collection<WeightedObservedPoint> observations) {
99 // Prepare least-squares problem.
100 final int len = observations.size();
101 final double[] target = new double[len];
102 final double[] weights = new double[len];
103
104 int count = 0;
105 for (WeightedObservedPoint obs : observations) {
106 target[count] = obs.getY();
107 weights[count] = obs.getWeight();
108 ++count;
109 }
110
111 final AbstractCurveFitter.TheoreticalValuesFunction model
112 = new AbstractCurveFitter.TheoreticalValuesFunction(function,
113 observations);
114
115 // Create an optimizer for fitting the curve to the observed points.
116 return new LeastSquaresBuilder().
117 maxEvaluations(Integer.MAX_VALUE).
118 maxIterations(maxIter).
119 start(initialGuess).
120 target(target).
121 weight(new DiagonalMatrix(weights)).
122 model(model.getModelFunction(), model.getModelFunctionJacobian()).
123 build();
124 }
125}
Note: See TracBrowser for help on using the repository browser.