1 | /*
|
---|
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
|
---|
3 | * contributor license agreements. See the NOTICE file distributed with
|
---|
4 | * this work for additional information regarding copyright ownership.
|
---|
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
|
---|
6 | * (the "License"); you may not use this file except in compliance with
|
---|
7 | * the License. You may obtain a copy of the License at
|
---|
8 | *
|
---|
9 | * http://www.apache.org/licenses/LICENSE-2.0
|
---|
10 | *
|
---|
11 | * Unless required by applicable law or agreed to in writing, software
|
---|
12 | * distributed under the License is distributed on an "AS IS" BASIS,
|
---|
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
---|
14 | * See the License for the specific language governing permissions and
|
---|
15 | * limitations under the License.
|
---|
16 | */
|
---|
17 | package agents.anac.y2019.harddealer.math3.optim.univariate;
|
---|
18 |
|
---|
19 | import agents.anac.y2019.harddealer.math3.util.Precision;
|
---|
20 | import agents.anac.y2019.harddealer.math3.util.FastMath;
|
---|
21 | import agents.anac.y2019.harddealer.math3.exception.NumberIsTooSmallException;
|
---|
22 | import agents.anac.y2019.harddealer.math3.exception.NotStrictlyPositiveException;
|
---|
23 | import agents.anac.y2019.harddealer.math3.optim.ConvergenceChecker;
|
---|
24 | import agents.anac.y2019.harddealer.math3.optim.nonlinear.scalar.GoalType;
|
---|
25 |
|
---|
26 | /**
|
---|
27 | * For a function defined on some interval {@code (lo, hi)}, this class
|
---|
28 | * finds an approximation {@code x} to the point at which the function
|
---|
29 | * attains its minimum.
|
---|
30 | * It implements Richard Brent's algorithm (from his book "Algorithms for
|
---|
31 | * Minimization without Derivatives", p. 79) for finding minima of real
|
---|
32 | * univariate functions.
|
---|
33 | * <br/>
|
---|
34 | * This code is an adaptation, partly based on the Python code from SciPy
|
---|
35 | * (module "optimize.py" v0.5); the original algorithm is also modified
|
---|
36 | * <ul>
|
---|
37 | * <li>to use an initial guess provided by the user,</li>
|
---|
38 | * <li>to ensure that the best point encountered is the one returned.</li>
|
---|
39 | * </ul>
|
---|
40 | *
|
---|
41 | * @since 2.0
|
---|
42 | */
|
---|
43 | public class BrentOptimizer extends UnivariateOptimizer {
|
---|
44 | /**
|
---|
45 | * Golden section.
|
---|
46 | */
|
---|
47 | private static final double GOLDEN_SECTION = 0.5 * (3 - FastMath.sqrt(5));
|
---|
48 | /**
|
---|
49 | * Minimum relative tolerance.
|
---|
50 | */
|
---|
51 | private static final double MIN_RELATIVE_TOLERANCE = 2 * FastMath.ulp(1d);
|
---|
52 | /**
|
---|
53 | * Relative threshold.
|
---|
54 | */
|
---|
55 | private final double relativeThreshold;
|
---|
56 | /**
|
---|
57 | * Absolute threshold.
|
---|
58 | */
|
---|
59 | private final double absoluteThreshold;
|
---|
60 |
|
---|
61 | /**
|
---|
62 | * The arguments are used implement the original stopping criterion
|
---|
63 | * of Brent's algorithm.
|
---|
64 | * {@code abs} and {@code rel} define a tolerance
|
---|
65 | * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
|
---|
66 | * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
|
---|
67 | * where <em>macheps</em> is the relative machine precision. {@code abs} must
|
---|
68 | * be positive.
|
---|
69 | *
|
---|
70 | * @param rel Relative threshold.
|
---|
71 | * @param abs Absolute threshold.
|
---|
72 | * @param checker Additional, user-defined, convergence checking
|
---|
73 | * procedure.
|
---|
74 | * @throws NotStrictlyPositiveException if {@code abs <= 0}.
|
---|
75 | * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
|
---|
76 | */
|
---|
77 | public BrentOptimizer(double rel,
|
---|
78 | double abs,
|
---|
79 | ConvergenceChecker<UnivariatePointValuePair> checker) {
|
---|
80 | super(checker);
|
---|
81 |
|
---|
82 | if (rel < MIN_RELATIVE_TOLERANCE) {
|
---|
83 | throw new NumberIsTooSmallException(rel, MIN_RELATIVE_TOLERANCE, true);
|
---|
84 | }
|
---|
85 | if (abs <= 0) {
|
---|
86 | throw new NotStrictlyPositiveException(abs);
|
---|
87 | }
|
---|
88 |
|
---|
89 | relativeThreshold = rel;
|
---|
90 | absoluteThreshold = abs;
|
---|
91 | }
|
---|
92 |
|
---|
93 | /**
|
---|
94 | * The arguments are used for implementing the original stopping criterion
|
---|
95 | * of Brent's algorithm.
|
---|
96 | * {@code abs} and {@code rel} define a tolerance
|
---|
97 | * {@code tol = rel |x| + abs}. {@code rel} should be no smaller than
|
---|
98 | * <em>2 macheps</em> and preferably not much less than <em>sqrt(macheps)</em>,
|
---|
99 | * where <em>macheps</em> is the relative machine precision. {@code abs} must
|
---|
100 | * be positive.
|
---|
101 | *
|
---|
102 | * @param rel Relative threshold.
|
---|
103 | * @param abs Absolute threshold.
|
---|
104 | * @throws NotStrictlyPositiveException if {@code abs <= 0}.
|
---|
105 | * @throws NumberIsTooSmallException if {@code rel < 2 * Math.ulp(1d)}.
|
---|
106 | */
|
---|
107 | public BrentOptimizer(double rel,
|
---|
108 | double abs) {
|
---|
109 | this(rel, abs, null);
|
---|
110 | }
|
---|
111 |
|
---|
112 | /** {@inheritDoc} */
|
---|
113 | @Override
|
---|
114 | protected UnivariatePointValuePair doOptimize() {
|
---|
115 | final boolean isMinim = getGoalType() == GoalType.MINIMIZE;
|
---|
116 | final double lo = getMin();
|
---|
117 | final double mid = getStartValue();
|
---|
118 | final double hi = getMax();
|
---|
119 |
|
---|
120 | // Optional additional convergence criteria.
|
---|
121 | final ConvergenceChecker<UnivariatePointValuePair> checker
|
---|
122 | = getConvergenceChecker();
|
---|
123 |
|
---|
124 | double a;
|
---|
125 | double b;
|
---|
126 | if (lo < hi) {
|
---|
127 | a = lo;
|
---|
128 | b = hi;
|
---|
129 | } else {
|
---|
130 | a = hi;
|
---|
131 | b = lo;
|
---|
132 | }
|
---|
133 |
|
---|
134 | double x = mid;
|
---|
135 | double v = x;
|
---|
136 | double w = x;
|
---|
137 | double d = 0;
|
---|
138 | double e = 0;
|
---|
139 | double fx = computeObjectiveValue(x);
|
---|
140 | if (!isMinim) {
|
---|
141 | fx = -fx;
|
---|
142 | }
|
---|
143 | double fv = fx;
|
---|
144 | double fw = fx;
|
---|
145 |
|
---|
146 | UnivariatePointValuePair previous = null;
|
---|
147 | UnivariatePointValuePair current
|
---|
148 | = new UnivariatePointValuePair(x, isMinim ? fx : -fx);
|
---|
149 | // Best point encountered so far (which is the initial guess).
|
---|
150 | UnivariatePointValuePair best = current;
|
---|
151 |
|
---|
152 | while (true) {
|
---|
153 | final double m = 0.5 * (a + b);
|
---|
154 | final double tol1 = relativeThreshold * FastMath.abs(x) + absoluteThreshold;
|
---|
155 | final double tol2 = 2 * tol1;
|
---|
156 |
|
---|
157 | // Default stopping criterion.
|
---|
158 | final boolean stop = FastMath.abs(x - m) <= tol2 - 0.5 * (b - a);
|
---|
159 | if (!stop) {
|
---|
160 | double p = 0;
|
---|
161 | double q = 0;
|
---|
162 | double r = 0;
|
---|
163 | double u = 0;
|
---|
164 |
|
---|
165 | if (FastMath.abs(e) > tol1) { // Fit parabola.
|
---|
166 | r = (x - w) * (fx - fv);
|
---|
167 | q = (x - v) * (fx - fw);
|
---|
168 | p = (x - v) * q - (x - w) * r;
|
---|
169 | q = 2 * (q - r);
|
---|
170 |
|
---|
171 | if (q > 0) {
|
---|
172 | p = -p;
|
---|
173 | } else {
|
---|
174 | q = -q;
|
---|
175 | }
|
---|
176 |
|
---|
177 | r = e;
|
---|
178 | e = d;
|
---|
179 |
|
---|
180 | if (p > q * (a - x) &&
|
---|
181 | p < q * (b - x) &&
|
---|
182 | FastMath.abs(p) < FastMath.abs(0.5 * q * r)) {
|
---|
183 | // Parabolic interpolation step.
|
---|
184 | d = p / q;
|
---|
185 | u = x + d;
|
---|
186 |
|
---|
187 | // f must not be evaluated too close to a or b.
|
---|
188 | if (u - a < tol2 || b - u < tol2) {
|
---|
189 | if (x <= m) {
|
---|
190 | d = tol1;
|
---|
191 | } else {
|
---|
192 | d = -tol1;
|
---|
193 | }
|
---|
194 | }
|
---|
195 | } else {
|
---|
196 | // Golden section step.
|
---|
197 | if (x < m) {
|
---|
198 | e = b - x;
|
---|
199 | } else {
|
---|
200 | e = a - x;
|
---|
201 | }
|
---|
202 | d = GOLDEN_SECTION * e;
|
---|
203 | }
|
---|
204 | } else {
|
---|
205 | // Golden section step.
|
---|
206 | if (x < m) {
|
---|
207 | e = b - x;
|
---|
208 | } else {
|
---|
209 | e = a - x;
|
---|
210 | }
|
---|
211 | d = GOLDEN_SECTION * e;
|
---|
212 | }
|
---|
213 |
|
---|
214 | // Update by at least "tol1".
|
---|
215 | if (FastMath.abs(d) < tol1) {
|
---|
216 | if (d >= 0) {
|
---|
217 | u = x + tol1;
|
---|
218 | } else {
|
---|
219 | u = x - tol1;
|
---|
220 | }
|
---|
221 | } else {
|
---|
222 | u = x + d;
|
---|
223 | }
|
---|
224 |
|
---|
225 | double fu = computeObjectiveValue(u);
|
---|
226 | if (!isMinim) {
|
---|
227 | fu = -fu;
|
---|
228 | }
|
---|
229 |
|
---|
230 | // User-defined convergence checker.
|
---|
231 | previous = current;
|
---|
232 | current = new UnivariatePointValuePair(u, isMinim ? fu : -fu);
|
---|
233 | best = best(best,
|
---|
234 | best(previous,
|
---|
235 | current,
|
---|
236 | isMinim),
|
---|
237 | isMinim);
|
---|
238 |
|
---|
239 | if (checker != null && checker.converged(getIterations(), previous, current)) {
|
---|
240 | return best;
|
---|
241 | }
|
---|
242 |
|
---|
243 | // Update a, b, v, w and x.
|
---|
244 | if (fu <= fx) {
|
---|
245 | if (u < x) {
|
---|
246 | b = x;
|
---|
247 | } else {
|
---|
248 | a = x;
|
---|
249 | }
|
---|
250 | v = w;
|
---|
251 | fv = fw;
|
---|
252 | w = x;
|
---|
253 | fw = fx;
|
---|
254 | x = u;
|
---|
255 | fx = fu;
|
---|
256 | } else {
|
---|
257 | if (u < x) {
|
---|
258 | a = u;
|
---|
259 | } else {
|
---|
260 | b = u;
|
---|
261 | }
|
---|
262 | if (fu <= fw ||
|
---|
263 | Precision.equals(w, x)) {
|
---|
264 | v = w;
|
---|
265 | fv = fw;
|
---|
266 | w = u;
|
---|
267 | fw = fu;
|
---|
268 | } else if (fu <= fv ||
|
---|
269 | Precision.equals(v, x) ||
|
---|
270 | Precision.equals(v, w)) {
|
---|
271 | v = u;
|
---|
272 | fv = fu;
|
---|
273 | }
|
---|
274 | }
|
---|
275 | } else { // Default termination (Brent's criterion).
|
---|
276 | return best(best,
|
---|
277 | best(previous,
|
---|
278 | current,
|
---|
279 | isMinim),
|
---|
280 | isMinim);
|
---|
281 | }
|
---|
282 |
|
---|
283 | incrementIterationCount();
|
---|
284 | }
|
---|
285 | }
|
---|
286 |
|
---|
287 | /**
|
---|
288 | * Selects the best of two points.
|
---|
289 | *
|
---|
290 | * @param a Point and value.
|
---|
291 | * @param b Point and value.
|
---|
292 | * @param isMinim {@code true} if the selected point must be the one with
|
---|
293 | * the lowest value.
|
---|
294 | * @return the best point, or {@code null} if {@code a} and {@code b} are
|
---|
295 | * both {@code null}. When {@code a} and {@code b} have the same function
|
---|
296 | * value, {@code a} is returned.
|
---|
297 | */
|
---|
298 | private UnivariatePointValuePair best(UnivariatePointValuePair a,
|
---|
299 | UnivariatePointValuePair b,
|
---|
300 | boolean isMinim) {
|
---|
301 | if (a == null) {
|
---|
302 | return b;
|
---|
303 | }
|
---|
304 | if (b == null) {
|
---|
305 | return a;
|
---|
306 | }
|
---|
307 |
|
---|
308 | if (isMinim) {
|
---|
309 | return a.getValue() <= b.getValue() ? a : b;
|
---|
310 | } else {
|
---|
311 | return a.getValue() >= b.getValue() ? a : b;
|
---|
312 | }
|
---|
313 | }
|
---|
314 | }
|
---|