1 | /*
|
---|
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
|
---|
3 | * contributor license agreements. See the NOTICE file distributed with
|
---|
4 | * this work for additional information regarding copyright ownership.
|
---|
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
|
---|
6 | * (the "License"); you may not use this file except in compliance with
|
---|
7 | * the License. You may obtain a copy of the License at
|
---|
8 | *
|
---|
9 | * http://www.apache.org/licenses/LICENSE-2.0
|
---|
10 | *
|
---|
11 | * Unless required by applicable law or agreed to in writing, software
|
---|
12 | * distributed under the License is distributed on an "AS IS" BASIS,
|
---|
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
---|
14 | * See the License for the specific language governing permissions and
|
---|
15 | * limitations under the License.
|
---|
16 | */
|
---|
17 |
|
---|
18 | package agents.anac.y2019.harddealer.math3.analysis.function;
|
---|
19 |
|
---|
20 | import java.util.Arrays;
|
---|
21 |
|
---|
22 | import agents.anac.y2019.harddealer.math3.analysis.FunctionUtils;
|
---|
23 | import agents.anac.y2019.harddealer.math3.analysis.UnivariateFunction;
|
---|
24 | import agents.anac.y2019.harddealer.math3.analysis.DifferentiableUnivariateFunction;
|
---|
25 | import agents.anac.y2019.harddealer.math3.analysis.ParametricUnivariateFunction;
|
---|
26 | import agents.anac.y2019.harddealer.math3.analysis.differentiation.DerivativeStructure;
|
---|
27 | import agents.anac.y2019.harddealer.math3.analysis.differentiation.UnivariateDifferentiableFunction;
|
---|
28 | import agents.anac.y2019.harddealer.math3.exception.NullArgumentException;
|
---|
29 | import agents.anac.y2019.harddealer.math3.exception.DimensionMismatchException;
|
---|
30 | import agents.anac.y2019.harddealer.math3.util.FastMath;
|
---|
31 |
|
---|
32 | /**
|
---|
33 | * <a href="http://en.wikipedia.org/wiki/Sigmoid_function">
|
---|
34 | * Sigmoid</a> function.
|
---|
35 | * It is the inverse of the {@link Logit logit} function.
|
---|
36 | * A more flexible version, the generalised logistic, is implemented
|
---|
37 | * by the {@link Logistic} class.
|
---|
38 | *
|
---|
39 | * @since 3.0
|
---|
40 | */
|
---|
41 | public class Sigmoid implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
|
---|
42 | /** Lower asymptote. */
|
---|
43 | private final double lo;
|
---|
44 | /** Higher asymptote. */
|
---|
45 | private final double hi;
|
---|
46 |
|
---|
47 | /**
|
---|
48 | * Usual sigmoid function, where the lower asymptote is 0 and the higher
|
---|
49 | * asymptote is 1.
|
---|
50 | */
|
---|
51 | public Sigmoid() {
|
---|
52 | this(0, 1);
|
---|
53 | }
|
---|
54 |
|
---|
55 | /**
|
---|
56 | * Sigmoid function.
|
---|
57 | *
|
---|
58 | * @param lo Lower asymptote.
|
---|
59 | * @param hi Higher asymptote.
|
---|
60 | */
|
---|
61 | public Sigmoid(double lo,
|
---|
62 | double hi) {
|
---|
63 | this.lo = lo;
|
---|
64 | this.hi = hi;
|
---|
65 | }
|
---|
66 |
|
---|
67 | /** {@inheritDoc}
|
---|
68 | * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
|
---|
69 | */
|
---|
70 | @Deprecated
|
---|
71 | public UnivariateFunction derivative() {
|
---|
72 | return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
|
---|
73 | }
|
---|
74 |
|
---|
75 | /** {@inheritDoc} */
|
---|
76 | public double value(double x) {
|
---|
77 | return value(x, lo, hi);
|
---|
78 | }
|
---|
79 |
|
---|
80 | /**
|
---|
81 | * Parametric function where the input array contains the parameters of
|
---|
82 | * the {@link Sigmoid#Sigmoid(double,double) sigmoid function}, ordered
|
---|
83 | * as follows:
|
---|
84 | * <ul>
|
---|
85 | * <li>Lower asymptote</li>
|
---|
86 | * <li>Higher asymptote</li>
|
---|
87 | * </ul>
|
---|
88 | */
|
---|
89 | public static class Parametric implements ParametricUnivariateFunction {
|
---|
90 | /**
|
---|
91 | * Computes the value of the sigmoid at {@code x}.
|
---|
92 | *
|
---|
93 | * @param x Value for which the function must be computed.
|
---|
94 | * @param param Values of lower asymptote and higher asymptote.
|
---|
95 | * @return the value of the function.
|
---|
96 | * @throws NullArgumentException if {@code param} is {@code null}.
|
---|
97 | * @throws DimensionMismatchException if the size of {@code param} is
|
---|
98 | * not 2.
|
---|
99 | */
|
---|
100 | public double value(double x, double ... param)
|
---|
101 | throws NullArgumentException,
|
---|
102 | DimensionMismatchException {
|
---|
103 | validateParameters(param);
|
---|
104 | return Sigmoid.value(x, param[0], param[1]);
|
---|
105 | }
|
---|
106 |
|
---|
107 | /**
|
---|
108 | * Computes the value of the gradient at {@code x}.
|
---|
109 | * The components of the gradient vector are the partial
|
---|
110 | * derivatives of the function with respect to each of the
|
---|
111 | * <em>parameters</em> (lower asymptote and higher asymptote).
|
---|
112 | *
|
---|
113 | * @param x Value at which the gradient must be computed.
|
---|
114 | * @param param Values for lower asymptote and higher asymptote.
|
---|
115 | * @return the gradient vector at {@code x}.
|
---|
116 | * @throws NullArgumentException if {@code param} is {@code null}.
|
---|
117 | * @throws DimensionMismatchException if the size of {@code param} is
|
---|
118 | * not 2.
|
---|
119 | */
|
---|
120 | public double[] gradient(double x, double ... param)
|
---|
121 | throws NullArgumentException,
|
---|
122 | DimensionMismatchException {
|
---|
123 | validateParameters(param);
|
---|
124 |
|
---|
125 | final double invExp1 = 1 / (1 + FastMath.exp(-x));
|
---|
126 |
|
---|
127 | return new double[] { 1 - invExp1, invExp1 };
|
---|
128 | }
|
---|
129 |
|
---|
130 | /**
|
---|
131 | * Validates parameters to ensure they are appropriate for the evaluation of
|
---|
132 | * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
|
---|
133 | * methods.
|
---|
134 | *
|
---|
135 | * @param param Values for lower and higher asymptotes.
|
---|
136 | * @throws NullArgumentException if {@code param} is {@code null}.
|
---|
137 | * @throws DimensionMismatchException if the size of {@code param} is
|
---|
138 | * not 2.
|
---|
139 | */
|
---|
140 | private void validateParameters(double[] param)
|
---|
141 | throws NullArgumentException,
|
---|
142 | DimensionMismatchException {
|
---|
143 | if (param == null) {
|
---|
144 | throw new NullArgumentException();
|
---|
145 | }
|
---|
146 | if (param.length != 2) {
|
---|
147 | throw new DimensionMismatchException(param.length, 2);
|
---|
148 | }
|
---|
149 | }
|
---|
150 | }
|
---|
151 |
|
---|
152 | /**
|
---|
153 | * @param x Value at which to compute the sigmoid.
|
---|
154 | * @param lo Lower asymptote.
|
---|
155 | * @param hi Higher asymptote.
|
---|
156 | * @return the value of the sigmoid function at {@code x}.
|
---|
157 | */
|
---|
158 | private static double value(double x,
|
---|
159 | double lo,
|
---|
160 | double hi) {
|
---|
161 | return lo + (hi - lo) / (1 + FastMath.exp(-x));
|
---|
162 | }
|
---|
163 |
|
---|
164 | /** {@inheritDoc}
|
---|
165 | * @since 3.1
|
---|
166 | */
|
---|
167 | public DerivativeStructure value(final DerivativeStructure t)
|
---|
168 | throws DimensionMismatchException {
|
---|
169 |
|
---|
170 | double[] f = new double[t.getOrder() + 1];
|
---|
171 | final double exp = FastMath.exp(-t.getValue());
|
---|
172 | if (Double.isInfinite(exp)) {
|
---|
173 |
|
---|
174 | // special handling near lower boundary, to avoid NaN
|
---|
175 | f[0] = lo;
|
---|
176 | Arrays.fill(f, 1, f.length, 0.0);
|
---|
177 |
|
---|
178 | } else {
|
---|
179 |
|
---|
180 | // the nth order derivative of sigmoid has the form:
|
---|
181 | // dn(sigmoid(x)/dxn = P_n(exp(-x)) / (1+exp(-x))^(n+1)
|
---|
182 | // where P_n(t) is a degree n polynomial with normalized higher term
|
---|
183 | // P_0(t) = 1, P_1(t) = t, P_2(t) = t^2 - t, P_3(t) = t^3 - 4 t^2 + t...
|
---|
184 | // the general recurrence relation for P_n is:
|
---|
185 | // P_n(x) = n t P_(n-1)(t) - t (1 + t) P_(n-1)'(t)
|
---|
186 | final double[] p = new double[f.length];
|
---|
187 |
|
---|
188 | final double inv = 1 / (1 + exp);
|
---|
189 | double coeff = hi - lo;
|
---|
190 | for (int n = 0; n < f.length; ++n) {
|
---|
191 |
|
---|
192 | // update and evaluate polynomial P_n(t)
|
---|
193 | double v = 0;
|
---|
194 | p[n] = 1;
|
---|
195 | for (int k = n; k >= 0; --k) {
|
---|
196 | v = v * exp + p[k];
|
---|
197 | if (k > 1) {
|
---|
198 | p[k - 1] = (n - k + 2) * p[k - 2] - (k - 1) * p[k - 1];
|
---|
199 | } else {
|
---|
200 | p[0] = 0;
|
---|
201 | }
|
---|
202 | }
|
---|
203 |
|
---|
204 | coeff *= inv;
|
---|
205 | f[n] = coeff * v;
|
---|
206 |
|
---|
207 | }
|
---|
208 |
|
---|
209 | // fix function value
|
---|
210 | f[0] += lo;
|
---|
211 |
|
---|
212 | }
|
---|
213 |
|
---|
214 | return t.compose(f);
|
---|
215 |
|
---|
216 | }
|
---|
217 |
|
---|
218 | }
|
---|