1 | /*
|
---|
2 | * Licensed to the Apache Software Foundation (ASF) under one or more
|
---|
3 | * contributor license agreements. See the NOTICE file distributed with
|
---|
4 | * this work for additional information regarding copyright ownership.
|
---|
5 | * The ASF licenses this file to You under the Apache License, Version 2.0
|
---|
6 | * (the "License"); you may not use this file except in compliance with
|
---|
7 | * the License. You may obtain a copy of the License at
|
---|
8 | *
|
---|
9 | * http://www.apache.org/licenses/LICENSE-2.0
|
---|
10 | *
|
---|
11 | * Unless required by applicable law or agreed to in writing, software
|
---|
12 | * distributed under the License is distributed on an "AS IS" BASIS,
|
---|
13 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
---|
14 | * See the License for the specific language governing permissions and
|
---|
15 | * limitations under the License.
|
---|
16 | */
|
---|
17 |
|
---|
18 | package agents.anac.y2019.harddealer.math3.analysis.function;
|
---|
19 |
|
---|
20 | import java.util.Arrays;
|
---|
21 |
|
---|
22 | import agents.anac.y2019.harddealer.math3.analysis.FunctionUtils;
|
---|
23 | import agents.anac.y2019.harddealer.math3.analysis.UnivariateFunction;
|
---|
24 | import agents.anac.y2019.harddealer.math3.analysis.DifferentiableUnivariateFunction;
|
---|
25 | import agents.anac.y2019.harddealer.math3.analysis.ParametricUnivariateFunction;
|
---|
26 | import agents.anac.y2019.harddealer.math3.analysis.differentiation.DerivativeStructure;
|
---|
27 | import agents.anac.y2019.harddealer.math3.analysis.differentiation.UnivariateDifferentiableFunction;
|
---|
28 | import agents.anac.y2019.harddealer.math3.exception.NotStrictlyPositiveException;
|
---|
29 | import agents.anac.y2019.harddealer.math3.exception.NullArgumentException;
|
---|
30 | import agents.anac.y2019.harddealer.math3.exception.DimensionMismatchException;
|
---|
31 | import agents.anac.y2019.harddealer.math3.util.FastMath;
|
---|
32 | import agents.anac.y2019.harddealer.math3.util.Precision;
|
---|
33 |
|
---|
34 | /**
|
---|
35 | * <a href="http://en.wikipedia.org/wiki/Gaussian_function">
|
---|
36 | * Gaussian</a> function.
|
---|
37 | *
|
---|
38 | * @since 3.0
|
---|
39 | */
|
---|
40 | public class Gaussian implements UnivariateDifferentiableFunction, DifferentiableUnivariateFunction {
|
---|
41 | /** Mean. */
|
---|
42 | private final double mean;
|
---|
43 | /** Inverse of the standard deviation. */
|
---|
44 | private final double is;
|
---|
45 | /** Inverse of twice the square of the standard deviation. */
|
---|
46 | private final double i2s2;
|
---|
47 | /** Normalization factor. */
|
---|
48 | private final double norm;
|
---|
49 |
|
---|
50 | /**
|
---|
51 | * Gaussian with given normalization factor, mean and standard deviation.
|
---|
52 | *
|
---|
53 | * @param norm Normalization factor.
|
---|
54 | * @param mean Mean.
|
---|
55 | * @param sigma Standard deviation.
|
---|
56 | * @throws NotStrictlyPositiveException if {@code sigma <= 0}.
|
---|
57 | */
|
---|
58 | public Gaussian(double norm,
|
---|
59 | double mean,
|
---|
60 | double sigma)
|
---|
61 | throws NotStrictlyPositiveException {
|
---|
62 | if (sigma <= 0) {
|
---|
63 | throw new NotStrictlyPositiveException(sigma);
|
---|
64 | }
|
---|
65 |
|
---|
66 | this.norm = norm;
|
---|
67 | this.mean = mean;
|
---|
68 | this.is = 1 / sigma;
|
---|
69 | this.i2s2 = 0.5 * is * is;
|
---|
70 | }
|
---|
71 |
|
---|
72 | /**
|
---|
73 | * Normalized gaussian with given mean and standard deviation.
|
---|
74 | *
|
---|
75 | * @param mean Mean.
|
---|
76 | * @param sigma Standard deviation.
|
---|
77 | * @throws NotStrictlyPositiveException if {@code sigma <= 0}.
|
---|
78 | */
|
---|
79 | public Gaussian(double mean,
|
---|
80 | double sigma)
|
---|
81 | throws NotStrictlyPositiveException {
|
---|
82 | this(1 / (sigma * FastMath.sqrt(2 * Math.PI)), mean, sigma);
|
---|
83 | }
|
---|
84 |
|
---|
85 | /**
|
---|
86 | * Normalized gaussian with zero mean and unit standard deviation.
|
---|
87 | */
|
---|
88 | public Gaussian() {
|
---|
89 | this(0, 1);
|
---|
90 | }
|
---|
91 |
|
---|
92 | /** {@inheritDoc} */
|
---|
93 | public double value(double x) {
|
---|
94 | return value(x - mean, norm, i2s2);
|
---|
95 | }
|
---|
96 |
|
---|
97 | /** {@inheritDoc}
|
---|
98 | * @deprecated as of 3.1, replaced by {@link #value(DerivativeStructure)}
|
---|
99 | */
|
---|
100 | @Deprecated
|
---|
101 | public UnivariateFunction derivative() {
|
---|
102 | return FunctionUtils.toDifferentiableUnivariateFunction(this).derivative();
|
---|
103 | }
|
---|
104 |
|
---|
105 | /**
|
---|
106 | * Parametric function where the input array contains the parameters of
|
---|
107 | * the Gaussian, ordered as follows:
|
---|
108 | * <ul>
|
---|
109 | * <li>Norm</li>
|
---|
110 | * <li>Mean</li>
|
---|
111 | * <li>Standard deviation</li>
|
---|
112 | * </ul>
|
---|
113 | */
|
---|
114 | public static class Parametric implements ParametricUnivariateFunction {
|
---|
115 | /**
|
---|
116 | * Computes the value of the Gaussian at {@code x}.
|
---|
117 | *
|
---|
118 | * @param x Value for which the function must be computed.
|
---|
119 | * @param param Values of norm, mean and standard deviation.
|
---|
120 | * @return the value of the function.
|
---|
121 | * @throws NullArgumentException if {@code param} is {@code null}.
|
---|
122 | * @throws DimensionMismatchException if the size of {@code param} is
|
---|
123 | * not 3.
|
---|
124 | * @throws NotStrictlyPositiveException if {@code param[2]} is negative.
|
---|
125 | */
|
---|
126 | public double value(double x, double ... param)
|
---|
127 | throws NullArgumentException,
|
---|
128 | DimensionMismatchException,
|
---|
129 | NotStrictlyPositiveException {
|
---|
130 | validateParameters(param);
|
---|
131 |
|
---|
132 | final double diff = x - param[1];
|
---|
133 | final double i2s2 = 1 / (2 * param[2] * param[2]);
|
---|
134 | return Gaussian.value(diff, param[0], i2s2);
|
---|
135 | }
|
---|
136 |
|
---|
137 | /**
|
---|
138 | * Computes the value of the gradient at {@code x}.
|
---|
139 | * The components of the gradient vector are the partial
|
---|
140 | * derivatives of the function with respect to each of the
|
---|
141 | * <em>parameters</em> (norm, mean and standard deviation).
|
---|
142 | *
|
---|
143 | * @param x Value at which the gradient must be computed.
|
---|
144 | * @param param Values of norm, mean and standard deviation.
|
---|
145 | * @return the gradient vector at {@code x}.
|
---|
146 | * @throws NullArgumentException if {@code param} is {@code null}.
|
---|
147 | * @throws DimensionMismatchException if the size of {@code param} is
|
---|
148 | * not 3.
|
---|
149 | * @throws NotStrictlyPositiveException if {@code param[2]} is negative.
|
---|
150 | */
|
---|
151 | public double[] gradient(double x, double ... param)
|
---|
152 | throws NullArgumentException,
|
---|
153 | DimensionMismatchException,
|
---|
154 | NotStrictlyPositiveException {
|
---|
155 | validateParameters(param);
|
---|
156 |
|
---|
157 | final double norm = param[0];
|
---|
158 | final double diff = x - param[1];
|
---|
159 | final double sigma = param[2];
|
---|
160 | final double i2s2 = 1 / (2 * sigma * sigma);
|
---|
161 |
|
---|
162 | final double n = Gaussian.value(diff, 1, i2s2);
|
---|
163 | final double m = norm * n * 2 * i2s2 * diff;
|
---|
164 | final double s = m * diff / sigma;
|
---|
165 |
|
---|
166 | return new double[] { n, m, s };
|
---|
167 | }
|
---|
168 |
|
---|
169 | /**
|
---|
170 | * Validates parameters to ensure they are appropriate for the evaluation of
|
---|
171 | * the {@link #value(double,double[])} and {@link #gradient(double,double[])}
|
---|
172 | * methods.
|
---|
173 | *
|
---|
174 | * @param param Values of norm, mean and standard deviation.
|
---|
175 | * @throws NullArgumentException if {@code param} is {@code null}.
|
---|
176 | * @throws DimensionMismatchException if the size of {@code param} is
|
---|
177 | * not 3.
|
---|
178 | * @throws NotStrictlyPositiveException if {@code param[2]} is negative.
|
---|
179 | */
|
---|
180 | private void validateParameters(double[] param)
|
---|
181 | throws NullArgumentException,
|
---|
182 | DimensionMismatchException,
|
---|
183 | NotStrictlyPositiveException {
|
---|
184 | if (param == null) {
|
---|
185 | throw new NullArgumentException();
|
---|
186 | }
|
---|
187 | if (param.length != 3) {
|
---|
188 | throw new DimensionMismatchException(param.length, 3);
|
---|
189 | }
|
---|
190 | if (param[2] <= 0) {
|
---|
191 | throw new NotStrictlyPositiveException(param[2]);
|
---|
192 | }
|
---|
193 | }
|
---|
194 | }
|
---|
195 |
|
---|
196 | /**
|
---|
197 | * @param xMinusMean {@code x - mean}.
|
---|
198 | * @param norm Normalization factor.
|
---|
199 | * @param i2s2 Inverse of twice the square of the standard deviation.
|
---|
200 | * @return the value of the Gaussian at {@code x}.
|
---|
201 | */
|
---|
202 | private static double value(double xMinusMean,
|
---|
203 | double norm,
|
---|
204 | double i2s2) {
|
---|
205 | return norm * FastMath.exp(-xMinusMean * xMinusMean * i2s2);
|
---|
206 | }
|
---|
207 |
|
---|
208 | /** {@inheritDoc}
|
---|
209 | * @since 3.1
|
---|
210 | */
|
---|
211 | public DerivativeStructure value(final DerivativeStructure t)
|
---|
212 | throws DimensionMismatchException {
|
---|
213 |
|
---|
214 | final double u = is * (t.getValue() - mean);
|
---|
215 | double[] f = new double[t.getOrder() + 1];
|
---|
216 |
|
---|
217 | // the nth order derivative of the Gaussian has the form:
|
---|
218 | // dn(g(x)/dxn = (norm / s^n) P_n(u) exp(-u^2/2) with u=(x-m)/s
|
---|
219 | // where P_n(u) is a degree n polynomial with same parity as n
|
---|
220 | // P_0(u) = 1, P_1(u) = -u, P_2(u) = u^2 - 1, P_3(u) = -u^3 + 3 u...
|
---|
221 | // the general recurrence relation for P_n is:
|
---|
222 | // P_n(u) = P_(n-1)'(u) - u P_(n-1)(u)
|
---|
223 | // as per polynomial parity, we can store coefficients of both P_(n-1) and P_n in the same array
|
---|
224 | final double[] p = new double[f.length];
|
---|
225 | p[0] = 1;
|
---|
226 | final double u2 = u * u;
|
---|
227 | double coeff = norm * FastMath.exp(-0.5 * u2);
|
---|
228 | if (coeff <= Precision.SAFE_MIN) {
|
---|
229 | Arrays.fill(f, 0.0);
|
---|
230 | } else {
|
---|
231 | f[0] = coeff;
|
---|
232 | for (int n = 1; n < f.length; ++n) {
|
---|
233 |
|
---|
234 | // update and evaluate polynomial P_n(x)
|
---|
235 | double v = 0;
|
---|
236 | p[n] = -p[n - 1];
|
---|
237 | for (int k = n; k >= 0; k -= 2) {
|
---|
238 | v = v * u2 + p[k];
|
---|
239 | if (k > 2) {
|
---|
240 | p[k - 2] = (k - 1) * p[k - 1] - p[k - 3];
|
---|
241 | } else if (k == 2) {
|
---|
242 | p[0] = p[1];
|
---|
243 | }
|
---|
244 | }
|
---|
245 | if ((n & 0x1) == 1) {
|
---|
246 | v *= u;
|
---|
247 | }
|
---|
248 |
|
---|
249 | coeff *= is;
|
---|
250 | f[n] = coeff * v;
|
---|
251 |
|
---|
252 | }
|
---|
253 | }
|
---|
254 |
|
---|
255 | return t.compose(f);
|
---|
256 |
|
---|
257 | }
|
---|
258 |
|
---|
259 | }
|
---|