source: anac2020/BlingBling/src/main/java/geniusweb/blingbling/Ranknet4j/SupervisedLearning.java@ 31

Last change on this file since 31 was 1, checked in by wouter, 4 years ago

#1910 added anac2020 parties

File size: 12.8 KB
Line 
1package geniusweb.blingbling.Ranknet4j;
2
3/**
4 * Copyright 2010 Neuroph Project http://neuroph.sourceforge.net
5 *
6 * Licensed under the Apache License, Version 2.0 (the "License");
7 * you may not use this file except in compliance with the License.
8 * You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing, software
13 * distributed under the License is distributed on an "AS IS" BASIS,
14 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15 * See the License for the specific language governing permissions and
16 * limitations under the License.
17 */
18
19
20import java.io.Serializable;
21import java.util.Iterator;
22import java.util.List;
23
24import org.neuroph.core.Connection;
25import org.neuroph.core.Layer;
26import org.neuroph.core.Neuron;
27import org.neuroph.core.Weight;
28import org.neuroph.core.data.DataSet;
29import org.neuroph.core.data.DataSetRow;
30import org.neuroph.core.learning.error.ErrorFunction;
31//import org.neuroph.core.learning.error.MeanSquaredError;
32import org.neuroph.core.learning.stop.MaxErrorStop;
33import org.neuroph.core.learning.IterativeLearning;
34
35import org.neuroph.core.transfer.Sigmoid;
36
37
38/**
39 * Base class for all supervised learning algorithms.
40 * It extends IterativeLearning, and provides general supervised learning principles.
41 * Based on Template Method Pattern with abstract method calculateWeightChanges
42 *
43 * TODO: random pattern order
44 *
45 * @author Zoran Sevarac <sevarac@gmail.com>
46 */
47abstract public class SupervisedLearning extends IterativeLearning implements
48 Serializable {
49
50 /**
51 * The class fingerprint that is set to indicate serialization
52 * compatibility with a previous version of the class
53 */
54 private static final long serialVersionUID = 3L;
55
56 /**
57 * Total network error in previous epoch
58 */
59 protected transient double previousEpochError;
60
61 /**
62 * Max allowed network error (condition to stop learning)
63 */
64 protected double maxError = 0.01d;
65
66 /**
67 * Stopping condition: training stops if total network error change is smaller than minErrorChange
68 * for minErrorChangeIterationsLimit number of iterations
69 */
70 private double minErrorChange = Double.POSITIVE_INFINITY;
71
72 /**
73 * Stopping condition: training stops if total network error change is smaller than minErrorChange
74 * for minErrorChangeStopIterations number of iterations
75 */
76 private int minErrorChangeIterationsLimit = Integer.MAX_VALUE;
77
78 /**
79 * Count iterations where error change is smaller then minErrorChange.
80 */
81 private transient int minErrorChangeIterationsCount;
82
83 /**
84 * Setting to determine if learning (weights update) is in batch mode.
85 * False by default.
86 */
87 private boolean batchMode = false;
88
89 private ErrorFunction errorFunction;
90
91 /**
92 * Creates new supervised learning rule
93 */
94 public SupervisedLearning() {
95 super();
96 errorFunction = new CrossEntropyError();
97// stopConditions.add(new MaxErrorStop(this));
98 }
99
100 /**
101 * This method should implement the weights update procedure for the whole network
102 * for the given output error vector.
103 *
104 * @param outputError output error vector for some network input (aka. patternError, network error)
105 * usually the difference between desired and actual output
106 */
107 abstract protected void calculateWeightChanges(double[] outputError);
108
109 /**
110 * Trains network for the specified training set and maxError
111 *
112 * @param trainingSet training set to learn
113 * @param maxError learning stop condition. If maxError is reached learning stops
114 */
115 public final void learn(DataSet trainingSet, double maxError) {
116 this.maxError = maxError;
117 learn(trainingSet);
118 }
119
120 /**
121 * Trains network for the specified training set, maxError and number of iterations
122 *
123 * @param trainingSet training set to learn
124 * @param maxError learning stop condition. if maxError is reached learning stops
125 * @param maxIterations maximum number of learning iterations
126 */
127 public final void learn(DataSet trainingSet, double maxError, int maxIterations) {
128 this.trainingSet = trainingSet;
129 this.maxError = maxError;
130 setMaxIterations(maxIterations);
131 learn(trainingSet);
132 }
133
134 @Override
135 protected void onStart() {
136 super.onStart(); // reset iteration counter
137 minErrorChangeIterationsCount = 0;
138 previousEpochError = 0d;
139 }
140
141 @Override
142 protected void beforeEpoch() {
143 previousEpochError = errorFunction.getTotalError();
144 errorFunction.reset();
145 }
146
147 @Override
148 protected void afterEpoch() {
149 // calculate abs error change and count iterations if its below specified min error change (used for stop condition)
150 double absErrorChange = Math.abs(previousEpochError - errorFunction.getTotalError());
151 if (absErrorChange <= this.minErrorChange) {
152 minErrorChangeIterationsCount++;
153 } else {
154 minErrorChangeIterationsCount = 0;
155 }
156
157 // if learning is performed in batch mode, apply accumulated weight changes from this epoch
158 if (batchMode == true) {
159 doBatchWeightsUpdate();
160 }
161 }
162
163 /**
164 * This method implements basic logic for one learning epoch for the
165 * supervised learning algorithms. Epoch is the one pass through the
166 * training set. This method iterates through the training set
167 * and trains network for each element. It also sets flag if conditions
168 * to stop learning has been reached: network error below some allowed
169 * value, or maximum iteration count
170 *
171 * @param trainingSet training set for training network
172 */
173 @Override
174 public void doLearningEpoch(DataSet trainingSet) {
175 Iterator<DataSetRow> iterator = trainingSet.iterator();
176 while (iterator.hasNext() && !isStopped()) { // iterate all elements from training set - maybe remove isStopped from here
177 DataSetRow dataSetRow = iterator.next();
178 learnPattern(dataSetRow); // learn current input/output pattern defined by SupervisedTrainingElement
179 }
180 }
181
182 /**
183 * Trains network with the input and desired output pattern from the specified training element
184 *
185 * @param trainingElement supervised training element which contains input and desired output
186 */
187 protected final void learnPattern(DataSetRow trainingElement) {
188 int size = trainingElement.getInput().length/2;
189 double[] input1 = new double[size];
190 double[] input2 = new double[size];
191 for (int i=0; i<size; i++) {
192 input1[i] = trainingElement.getInput()[i];
193 }
194 for (int i=0; i<size; i++) {
195 input2[i] = trainingElement.getInput()[size+i];
196 }
197 neuralNetwork.setInput(input1);
198 neuralNetwork.calculate();
199 double[] output1 = neuralNetwork.getOutput();
200 neuralNetwork.setInput(input2);
201 neuralNetwork.calculate();
202 double[] output2 = neuralNetwork.getOutput();
203 double[] output = new double[1];
204 output[0] = output1[0]-output2[0];
205 //update the under input2
206 double[] patternError2 = ((CrossEntropyError) errorFunction).addPatternError2(output, trainingElement.getDesiredOutput());
207 calculateWeightChanges(patternError2);
208 if (!batchMode) applyWeightChanges(); // batch mode updates are done i doBatchWeightsUpdate
209 //update under input1
210 neuralNetwork.setInput(input1);
211 neuralNetwork.calculate();
212 double[] patternError1 = ((CrossEntropyError) errorFunction).addPatternError(output, trainingElement.getDesiredOutput());
213 calculateWeightChanges(patternError1);
214 if (!batchMode) applyWeightChanges(); // batch mode updates are done i doBatchWeightsUpdate
215 }
216
217 /**
218 * This method updates network weights in batch mode - use accumulated weights change stored in Weight.deltaWeight
219 * It is executed after each learning epoch, only if learning is done in batch mode.
220 *
221 * @see SupervisedLearning#doLearningEpoch(org.neuroph.core.data.DataSet)
222 */
223 protected void doBatchWeightsUpdate() {
224 // iterate layers from output to input
225 List<Layer> layers = neuralNetwork.getLayers();
226 for (int i = neuralNetwork.getLayersCount() - 1; i > 0; i--) {
227 // iterate neurons at each layer
228 for (Neuron neuron : layers.get(i).getNeurons()) {
229 // iterate connections/weights for each neuron
230 for (Connection connection : neuron.getInputConnections()) {
231 // for each connection weight apply accumulated weight change
232 Weight weight = connection.getWeight();
233 weight.value += weight.weightChange / getTrainingSet().size(); // apply delta weight which is the sum of delta weights in batch mode - TODO: add mini batch
234 weight.weightChange = 0; // reset deltaWeight
235 }
236 }
237 }
238 }
239
240 /**
241 * Returns true if learning is performed in batch mode, false otherwise
242 *
243 * @return true if learning is performed in batch mode, false otherwise
244 */
245 public boolean isBatchMode() {
246 return batchMode;
247 }
248
249 /**
250 * Sets batch mode on/off (true/false)
251 *
252 * @param batchMode batch mode setting
253 */
254 public void setBatchMode(boolean batchMode) {
255 this.batchMode = batchMode;
256 }
257
258 /**
259 * Sets allowed network error, which indicates when to stopLearning training
260 *
261 * @param maxError network error
262 */
263 public void setMaxError(double maxError) {
264 this.maxError = maxError;
265 }
266
267 /**
268 * Returns learning error tolerance - the value of total network error to stop learning.
269 *
270 * @return learning error tolerance
271 */
272 public double getMaxError() {
273 return maxError;
274 }
275
276 /**
277 * Returns total network error in previous learning epoch
278 *
279 * @return total network error in previous learning epoch
280 */
281 public double getPreviousEpochError() {
282 return previousEpochError;
283 }
284
285 /**
286 * Returns min error change stopping criteria
287 *
288 * @return min error change stopping criteria
289 */
290 public double getMinErrorChange() {
291 return minErrorChange;
292 }
293
294 /**
295 * Sets min error change stopping criteria
296 *
297 * @param minErrorChange value for min error change stopping criteria
298 */
299 public void setMinErrorChange(double minErrorChange) {
300 this.minErrorChange = minErrorChange;
301 }
302
303 /**
304 * Returns number of iterations for min error change stopping criteria
305 *
306 * @return number of iterations for min error change stopping criteria
307 */
308 public int getMinErrorChangeIterationsLimit() {
309 return minErrorChangeIterationsLimit;
310 }
311
312 /**
313 * Sets number of iterations for min error change stopping criteria
314 *
315 * @param minErrorChangeIterationsLimit number of iterations for min error change stopping criteria
316 */
317 public void setMinErrorChangeIterationsLimit(int minErrorChangeIterationsLimit) {
318 this.minErrorChangeIterationsLimit = minErrorChangeIterationsLimit;
319 }
320
321 /**
322 * Returns number of iterations count for for min error change stopping criteria
323 *
324 * @return number of iterations count for for min error change stopping criteria
325 */
326 public int getMinErrorChangeIterationsCount() {
327 return minErrorChangeIterationsCount;
328 }
329
330 public ErrorFunction getErrorFunction() {
331 return errorFunction;
332 }
333
334 public void setErrorFunction(ErrorFunction errorFunction) {
335 this.errorFunction = errorFunction;
336 }
337
338
339 public double getTotalNetworkError() {
340 return errorFunction.getTotalError();
341 }
342
343 private void applyWeightChanges() {
344 List<Layer> layers = neuralNetwork.getLayers();
345 for (int i = neuralNetwork.getLayersCount() - 1; i > 0; i--) {
346 // iterate neurons at each layer
347 for (Neuron neuron : layers.get(i)) {
348 // iterate connections/weights for each neuron
349 for (Connection connection : neuron.getInputConnections()) {
350 // for each connection weight apply accumulated weight change
351 Weight weight = connection.getWeight();
352 if (!isBatchMode()) {
353 weight.value += weight.weightChange;
354 } else {
355 weight.value += (weight.weightChange / getTrainingSet().size());
356 }
357
358 weight.weightChange = 0; // reset deltaWeight
359 }
360 }
361 }
362 }
363}
Note: See TracBrowser for help on using the repository browser.