source: src/main/java/agents/anac/y2019/harddealer/math3/distribution/MixtureMultivariateRealDistribution.java

Last change on this file was 204, checked in by Katsuhide Fujita, 5 years ago

Fixed errors of ANAC2019 agents

  • Property svn:executable set to *
File size: 6.3 KB
Line 
1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements. See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License. You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17package agents.anac.y2019.harddealer.math3.distribution;
18
19import java.util.ArrayList;
20import java.util.List;
21
22import agents.anac.y2019.harddealer.math3.exception.DimensionMismatchException;
23import agents.anac.y2019.harddealer.math3.exception.MathArithmeticException;
24import agents.anac.y2019.harddealer.math3.exception.NotPositiveException;
25import agents.anac.y2019.harddealer.math3.exception.util.LocalizedFormats;
26import agents.anac.y2019.harddealer.math3.random.RandomGenerator;
27import agents.anac.y2019.harddealer.math3.random.Well19937c;
28import agents.anac.y2019.harddealer.math3.util.Pair;
29
30/**
31 * Class for representing <a href="http://en.wikipedia.org/wiki/Mixture_model">
32 * mixture model</a> distributions.
33 *
34 * @param <T> Type of the mixture components.
35 *
36 * @since 3.1
37 */
38public class MixtureMultivariateRealDistribution<T extends MultivariateRealDistribution>
39 extends AbstractMultivariateRealDistribution {
40 /** Normalized weight of each mixture component. */
41 private final double[] weight;
42 /** Mixture components. */
43 private final List<T> distribution;
44
45 /**
46 * Creates a mixture model from a list of distributions and their
47 * associated weights.
48 * <p>
49 * <b>Note:</b> this constructor will implicitly create an instance of
50 * {@link Well19937c} as random generator to be used for sampling only (see
51 * {@link #sample()} and {@link #sample(int)}). In case no sampling is
52 * needed for the created distribution, it is advised to pass {@code null}
53 * as random generator via the appropriate constructors to avoid the
54 * additional initialisation overhead.
55 *
56 * @param components List of (weight, distribution) pairs from which to sample.
57 */
58 public MixtureMultivariateRealDistribution(List<Pair<Double, T>> components) {
59 this(new Well19937c(), components);
60 }
61
62 /**
63 * Creates a mixture model from a list of distributions and their
64 * associated weights.
65 *
66 * @param rng Random number generator.
67 * @param components Distributions from which to sample.
68 * @throws NotPositiveException if any of the weights is negative.
69 * @throws DimensionMismatchException if not all components have the same
70 * number of variables.
71 */
72 public MixtureMultivariateRealDistribution(RandomGenerator rng,
73 List<Pair<Double, T>> components) {
74 super(rng, components.get(0).getSecond().getDimension());
75
76 final int numComp = components.size();
77 final int dim = getDimension();
78 double weightSum = 0;
79 for (int i = 0; i < numComp; i++) {
80 final Pair<Double, T> comp = components.get(i);
81 if (comp.getSecond().getDimension() != dim) {
82 throw new DimensionMismatchException(comp.getSecond().getDimension(), dim);
83 }
84 if (comp.getFirst() < 0) {
85 throw new NotPositiveException(comp.getFirst());
86 }
87 weightSum += comp.getFirst();
88 }
89
90 // Check for overflow.
91 if (Double.isInfinite(weightSum)) {
92 throw new MathArithmeticException(LocalizedFormats.OVERFLOW);
93 }
94
95 // Store each distribution and its normalized weight.
96 distribution = new ArrayList<T>();
97 weight = new double[numComp];
98 for (int i = 0; i < numComp; i++) {
99 final Pair<Double, T> comp = components.get(i);
100 weight[i] = comp.getFirst() / weightSum;
101 distribution.add(comp.getSecond());
102 }
103 }
104
105 /** {@inheritDoc} */
106 public double density(final double[] values) {
107 double p = 0;
108 for (int i = 0; i < weight.length; i++) {
109 p += weight[i] * distribution.get(i).density(values);
110 }
111 return p;
112 }
113
114 /** {@inheritDoc} */
115 @Override
116 public double[] sample() {
117 // Sampled values.
118 double[] vals = null;
119
120 // Determine which component to sample from.
121 final double randomValue = random.nextDouble();
122 double sum = 0;
123
124 for (int i = 0; i < weight.length; i++) {
125 sum += weight[i];
126 if (randomValue <= sum) {
127 // pick model i
128 vals = distribution.get(i).sample();
129 break;
130 }
131 }
132
133 if (vals == null) {
134 // This should never happen, but it ensures we won't return a null in
135 // case the loop above has some floating point inequality problem on
136 // the final iteration.
137 vals = distribution.get(weight.length - 1).sample();
138 }
139
140 return vals;
141 }
142
143 /** {@inheritDoc} */
144 @Override
145 public void reseedRandomGenerator(long seed) {
146 // Seed needs to be propagated to underlying components
147 // in order to maintain consistency between runs.
148 super.reseedRandomGenerator(seed);
149
150 for (int i = 0; i < distribution.size(); i++) {
151 // Make each component's seed different in order to avoid
152 // using the same sequence of random numbers.
153 distribution.get(i).reseedRandomGenerator(i + 1 + seed);
154 }
155 }
156
157 /**
158 * Gets the distributions that make up the mixture model.
159 *
160 * @return the component distributions and associated weights.
161 */
162 public List<Pair<Double, T>> getComponents() {
163 final List<Pair<Double, T>> list = new ArrayList<Pair<Double, T>>(weight.length);
164
165 for (int i = 0; i < weight.length; i++) {
166 list.add(new Pair<Double, T>(weight[i], distribution.get(i)));
167 }
168
169 return list;
170 }
171}
Note: See TracBrowser for help on using the repository browser.