source: src/main/java/agents/anac/y2019/harddealer/math3/distribution/EnumeratedDistribution.java

Last change on this file was 204, checked in by Katsuhide Fujita, 5 years ago

Fixed errors of ANAC2019 agents

  • Property svn:executable set to *
File size: 11.2 KB
Line 
1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17package agents.anac.y2019.harddealer.math3.distribution;
18
19import java.io.Serializable;
20import java.lang.reflect.Array;
21import java.util.ArrayList;
22import java.util.Arrays;
23import java.util.List;
24
25import agents.anac.y2019.harddealer.math3.exception.MathArithmeticException;
26import agents.anac.y2019.harddealer.math3.exception.NotANumberException;
27import agents.anac.y2019.harddealer.math3.exception.NotFiniteNumberException;
28import agents.anac.y2019.harddealer.math3.exception.NotPositiveException;
29import agents.anac.y2019.harddealer.math3.exception.NotStrictlyPositiveException;
30import agents.anac.y2019.harddealer.math3.exception.NullArgumentException;
31import agents.anac.y2019.harddealer.math3.exception.util.LocalizedFormats;
32import agents.anac.y2019.harddealer.math3.random.RandomGenerator;
33import agents.anac.y2019.harddealer.math3.random.Well19937c;
34import agents.anac.y2019.harddealer.math3.util.MathArrays;
35import agents.anac.y2019.harddealer.math3.util.Pair;
36
37/**
38 * <p>A generic implementation of a
39 * <a href="http://en.wikipedia.org/wiki/Probability_distribution#Discrete_probability_distribution">
40 * discrete probability distribution (Wikipedia)</a> over a finite sample space,
41 * based on an enumerated list of &lt;value, probability&gt; pairs. Input probabilities must all be non-negative,
42 * but zero values are allowed and their sum does not have to equal one. Constructors will normalize input
43 * probabilities to make them sum to one.</p>
44 *
45 * <p>The list of <value, probability> pairs does not, strictly speaking, have to be a function and it can
46 * contain null values. The pmf created by the constructor will combine probabilities of equal values and
47 * will treat null values as equal. For example, if the list of pairs &lt;"dog", 0.2&gt;, &lt;null, 0.1&gt;,
48 * &lt;"pig", 0.2&gt;, &lt;"dog", 0.1&gt;, &lt;null, 0.4&gt; is provided to the constructor, the resulting
49 * pmf will assign mass of 0.5 to null, 0.3 to "dog" and 0.2 to null.</p>
50 *
51 * @param <T> type of the elements in the sample space.
52 * @since 3.2
53 */
54public class EnumeratedDistribution<T> implements Serializable {
55
56 /** Serializable UID. */
57 private static final long serialVersionUID = 20123308L;
58
59 /**
60 * RNG instance used to generate samples from the distribution.
61 */
62 protected final RandomGenerator random;
63
64 /**
65 * List of random variable values.
66 */
67 private final List<T> singletons;
68
69 /**
70 * Probabilities of respective random variable values. For i = 0, ..., singletons.size() - 1,
71 * probability[i] is the probability that a random variable following this distribution takes
72 * the value singletons[i].
73 */
74 private final double[] probabilities;
75
76 /**
77 * Cumulative probabilities, cached to speed up sampling.
78 */
79 private final double[] cumulativeProbabilities;
80
81 /**
82 * Create an enumerated distribution using the given probability mass function
83 * enumeration.
84 * <p>
85 * <b>Note:</b> this constructor will implicitly create an instance of
86 * {@link Well19937c} as random generator to be used for sampling only (see
87 * {@link #sample()} and {@link #sample(int)}). In case no sampling is
88 * needed for the created distribution, it is advised to pass {@code null}
89 * as random generator via the appropriate constructors to avoid the
90 * additional initialisation overhead.
91 *
92 * @param pmf probability mass function enumerated as a list of <T, probability>
93 * pairs.
94 * @throws NotPositiveException if any of the probabilities are negative.
95 * @throws NotFiniteNumberException if any of the probabilities are infinite.
96 * @throws NotANumberException if any of the probabilities are NaN.
97 * @throws MathArithmeticException all of the probabilities are 0.
98 */
99 public EnumeratedDistribution(final List<Pair<T, Double>> pmf)
100 throws NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException {
101 this(new Well19937c(), pmf);
102 }
103
104 /**
105 * Create an enumerated distribution using the given random number generator
106 * and probability mass function enumeration.
107 *
108 * @param rng random number generator.
109 * @param pmf probability mass function enumerated as a list of <T, probability>
110 * pairs.
111 * @throws NotPositiveException if any of the probabilities are negative.
112 * @throws NotFiniteNumberException if any of the probabilities are infinite.
113 * @throws NotANumberException if any of the probabilities are NaN.
114 * @throws MathArithmeticException all of the probabilities are 0.
115 */
116 public EnumeratedDistribution(final RandomGenerator rng, final List<Pair<T, Double>> pmf)
117 throws NotPositiveException, MathArithmeticException, NotFiniteNumberException, NotANumberException {
118 random = rng;
119
120 singletons = new ArrayList<T>(pmf.size());
121 final double[] probs = new double[pmf.size()];
122
123 for (int i = 0; i < pmf.size(); i++) {
124 final Pair<T, Double> sample = pmf.get(i);
125 singletons.add(sample.getKey());
126 final double p = sample.getValue();
127 if (p < 0) {
128 throw new NotPositiveException(sample.getValue());
129 }
130 if (Double.isInfinite(p)) {
131 throw new NotFiniteNumberException(p);
132 }
133 if (Double.isNaN(p)) {
134 throw new NotANumberException();
135 }
136 probs[i] = p;
137 }
138
139 probabilities = MathArrays.normalizeArray(probs, 1.0);
140
141 cumulativeProbabilities = new double[probabilities.length];
142 double sum = 0;
143 for (int i = 0; i < probabilities.length; i++) {
144 sum += probabilities[i];
145 cumulativeProbabilities[i] = sum;
146 }
147 }
148
149 /**
150 * Reseed the random generator used to generate samples.
151 *
152 * @param seed the new seed
153 */
154 public void reseedRandomGenerator(long seed) {
155 random.setSeed(seed);
156 }
157
158 /**
159 * <p>For a random variable {@code X} whose values are distributed according to
160 * this distribution, this method returns {@code P(X = x)}. In other words,
161 * this method represents the probability mass function (PMF) for the
162 * distribution.</p>
163 *
164 * <p>Note that if {@code x1} and {@code x2} satisfy {@code x1.equals(x2)},
165 * or both are null, then {@code probability(x1) = probability(x2)}.</p>
166 *
167 * @param x the point at which the PMF is evaluated
168 * @return the value of the probability mass function at {@code x}
169 */
170 double probability(final T x) {
171 double probability = 0;
172
173 for (int i = 0; i < probabilities.length; i++) {
174 if ((x == null && singletons.get(i) == null) ||
175 (x != null && x.equals(singletons.get(i)))) {
176 probability += probabilities[i];
177 }
178 }
179
180 return probability;
181 }
182
183 /**
184 * <p>Return the probability mass function as a list of <value, probability> pairs.</p>
185 *
186 * <p>Note that if duplicate and / or null values were provided to the constructor
187 * when creating this EnumeratedDistribution, the returned list will contain these
188 * values. If duplicates values exist, what is returned will not represent
189 * a pmf (i.e., it is up to the caller to consolidate duplicate mass points).</p>
190 *
191 * @return the probability mass function.
192 */
193 public List<Pair<T, Double>> getPmf() {
194 final List<Pair<T, Double>> samples = new ArrayList<Pair<T, Double>>(probabilities.length);
195
196 for (int i = 0; i < probabilities.length; i++) {
197 samples.add(new Pair<T, Double>(singletons.get(i), probabilities[i]));
198 }
199
200 return samples;
201 }
202
203 /**
204 * Generate a random value sampled from this distribution.
205 *
206 * @return a random value.
207 */
208 public T sample() {
209 final double randomValue = random.nextDouble();
210
211 int index = Arrays.binarySearch(cumulativeProbabilities, randomValue);
212 if (index < 0) {
213 index = -index-1;
214 }
215
216 if (index >= 0 &&
217 index < probabilities.length &&
218 randomValue < cumulativeProbabilities[index]) {
219 return singletons.get(index);
220 }
221
222 /* This should never happen, but it ensures we will return a correct
223 * object in case there is some floating point inequality problem
224 * wrt the cumulative probabilities. */
225 return singletons.get(singletons.size() - 1);
226 }
227
228 /**
229 * Generate a random sample from the distribution.
230 *
231 * @param sampleSize the number of random values to generate.
232 * @return an array representing the random sample.
233 * @throws NotStrictlyPositiveException if {@code sampleSize} is not
234 * positive.
235 */
236 public Object[] sample(int sampleSize) throws NotStrictlyPositiveException {
237 if (sampleSize <= 0) {
238 throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES,
239 sampleSize);
240 }
241
242 final Object[] out = new Object[sampleSize];
243
244 for (int i = 0; i < sampleSize; i++) {
245 out[i] = sample();
246 }
247
248 return out;
249
250 }
251
252 /**
253 * Generate a random sample from the distribution.
254 * <p>
255 * If the requested samples fit in the specified array, it is returned
256 * therein. Otherwise, a new array is allocated with the runtime type of
257 * the specified array and the size of this collection.
258 *
259 * @param sampleSize the number of random values to generate.
260 * @param array the array to populate.
261 * @return an array representing the random sample.
262 * @throws NotStrictlyPositiveException if {@code sampleSize} is not positive.
263 * @throws NullArgumentException if {@code array} is null
264 */
265 public T[] sample(int sampleSize, final T[] array) throws NotStrictlyPositiveException {
266 if (sampleSize <= 0) {
267 throw new NotStrictlyPositiveException(LocalizedFormats.NUMBER_OF_SAMPLES, sampleSize);
268 }
269
270 if (array == null) {
271 throw new NullArgumentException(LocalizedFormats.INPUT_ARRAY);
272 }
273
274 T[] out;
275 if (array.length < sampleSize) {
276 @SuppressWarnings("unchecked") // safe as both are of type T
277 final T[] unchecked = (T[]) Array.newInstance(array.getClass().getComponentType(), sampleSize);
278 out = unchecked;
279 } else {
280 out = array;
281 }
282
283 for (int i = 0; i < sampleSize; i++) {
284 out[i] = sample();
285 }
286
287 return out;
288
289 }
290
291}
Note: See TracBrowser for help on using the repository browser.