source: src/main/java/agents/anac/y2019/harddealer/math3/analysis/function/Logit.java

Last change on this file was 204, checked in by Katsuhide Fujita, 5 years ago

Fixed errors of ANAC2019 agents

  • Property svn:executable set to *
File size: 7.5 KB
Line 
1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package agents.anac.y2019.harddealer.math3.analysis.function;
19
20import agents.anac.y2019.harddealer.math3.analysis.DifferentiableUnivariateFunction;
21import agents.anac.y2019.harddealer.math3.analysis.FunctionUtils;
22import agents.anac.y2019.harddealer.math3.analysis.ParametricUnivariateFunction;
23import agents.anac.y2019.harddealer.math3.analysis.UnivariateFunction;
24import agents.anac.y2019.harddealer.math3.analysis.differentiation.DerivativeStructure;
25import agents.anac.y2019.harddealer.math3.analysis.differentiation.UnivariateDifferentiableFunction;
26import agents.anac.y2019.harddealer.math3.exception.DimensionMismatchException;
27import agents.anac.y2019.harddealer.math3.exception.NullArgumentException;
28import agents.anac.y2019.harddealer.math3.exception.OutOfRangeException;
29import agents.anac.y2019.harddealer.math3.util.FastMath;
30
31/**
32 * <a href="http://en.wikipedia.org/wiki/Logit">
33 * Logit</a> function.
34 * It is the inverse of the {@link Sigmoid sigmoid} function.
35 *
36 * @since 3.0
37 */
38public class Logit implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
39 /** Lower bound. */
40 private final double lo;
41 /** Higher bound. */
42 private final double hi;
43
44 /**
45 * Usual logit function, where the lower bound is 0 and the higher
46 * bound is 1.
47 */
48 public Logit() {
49 this(0, 1);
50 }
51
52 /**
53 * Logit function.
54 *
55 * @param lo Lower bound of the function domain.
56 * @param hi Higher bound of the function domain.
57 */
58 public Logit(double lo,
59 double hi) {
60 this.lo = lo;
61 this.hi = hi;
62 }
63
64 /** {@inheritDoc} */
65 public double value(double x)
66 throws OutOfRangeException {
67 return value(x, lo, hi);
68 }
69
70 /** {@inheritDoc}
71 * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
72 */
73 @Deprecated
74 public UnivariateFunction derivative() {
75 return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
76 }
77
78 /**
79 * Parametric function where the input array contains the parameters of
80 * the logit function, ordered as follows:
81 * <ul>
82 * <li>Lower bound</li>
83 * <li>Higher bound</li>
84 * </ul>
85 */
86 public static class Parametric implements ParametricUnivariateFunction {
87 /**
88 * Computes the value of the logit at {@code x}.
89 *
90 * @param x Value for which the function must be computed.
91 * @param param Values of lower bound and higher bounds.
92 * @return the value of the function.
93 * @throws NullArgumentException if {@code param} is {@code null}.
94 * @throws DimensionMismatchException if the size of {@code param} is
95 * not 2.
96 */
97 public double value(double x, double ... param)
98 throws NullArgumentException,
99 DimensionMismatchException {
100 validateParameters(param);
101 return Logit.value(x, param[0], param[1]);
102 }
103
104 /**
105 * Computes the value of the gradient at {@code x}.
106 * The components of the gradient vector are the partial
107 * derivatives of the function with respect to each of the
108 * <em>parameters</em> (lower bound and higher bound).
109 *
110 * @param x Value at which the gradient must be computed.
111 * @param param Values for lower and higher bounds.
112 * @return the gradient vector at {@code x}.
113 * @throws NullArgumentException if {@code param} is {@code null}.
114 * @throws DimensionMismatchException if the size of {@code param} is
115 * not 2.
116 */
117 public double[] gradient(double x, double ... param)
118 throws NullArgumentException,
119 DimensionMismatchException {
120 validateParameters(param);
121
122 final double lo = param[0];
123 final double hi = param[1];
124
125 return new double[] { 1 / (lo - x), 1 / (hi - x) };
126 }
127
128 /**
129 * Validates parameters to ensure they are appropriate for the evaluation of
130 * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
131 * methods.
132 *
133 * @param param Values for lower and higher bounds.
134 * @throws NullArgumentException if {@code param} is {@code null}.
135 * @throws DimensionMismatchException if the size of {@code param} is
136 * not 2.
137 */
138 private void validateParameters(double[] param)
139 throws NullArgumentException,
140 DimensionMismatchException {
141 if (param == null) {
142 throw new NullArgumentException();
143 }
144 if (param.length != 2) {
145 throw new DimensionMismatchException(param.length, 2);
146 }
147 }
148 }
149
150 /**
151 * @param x Value at which to compute the logit.
152 * @param lo Lower bound.
153 * @param hi Higher bound.
154 * @return the value of the logit function at {@code x}.
155 * @throws OutOfRangeException if {@code x < lo} or {@code x > hi}.
156 */
157 private static double value(double x,
158 double lo,
159 double hi)
160 throws OutOfRangeException {
161 if (x < lo || x > hi) {
162 throw new OutOfRangeException(x, lo, hi);
163 }
164 return FastMath.log((x - lo) / (hi - x));
165 }
166
167 /** {@inheritDoc}
168 * @since 3.1
169 * @exception OutOfRangeException if parameter is outside of function domain
170 */
171 public DerivativeStructure value(final DerivativeStructure t)
172 throws OutOfRangeException {
173 final double x = t.getValue();
174 if (x < lo || x > hi) {
175 throw new OutOfRangeException(x, lo, hi);
176 }
177 double[] f = new double[t.getOrder() + 1];
178
179 // function value
180 f[0] = FastMath.log((x - lo) / (hi - x));
181
182 if (Double.isInfinite(f[0])) {
183
184 if (f.length > 1) {
185 f[1] = Double.POSITIVE_INFINITY;
186 }
187 // fill the array with infinities
188 // (for x close to lo the signs will flip between -inf and +inf,
189 // for x close to hi the signs will always be +inf)
190 // this is probably overkill, since the call to compose at the end
191 // of the method will transform most infinities into NaN ...
192 for (int i = 2; i < f.length; ++i) {
193 f[i] = f[i - 2];
194 }
195
196 } else {
197
198 // function derivatives
199 final double invL = 1.0 / (x - lo);
200 double xL = invL;
201 final double invH = 1.0 / (hi - x);
202 double xH = invH;
203 for (int i = 1; i < f.length; ++i) {
204 f[i] = xL + xH;
205 xL *= -i * invL;
206 xH *= i * invH;
207 }
208 }
209
210 return t.compose(f);
211 }
212}
Note: See TracBrowser for help on using the repository browser.