1 | package geniusweb.blingbling.Ranknet4j;
|
---|
2 |
|
---|
3 | /**
|
---|
4 | * Copyright 2010 Neuroph Project http://neuroph.sourceforge.net
|
---|
5 | *
|
---|
6 | * Licensed under the Apache License, Version 2.0 (the "License");
|
---|
7 | * you may not use this file except in compliance with the License.
|
---|
8 | * You may obtain a copy of the License at
|
---|
9 | *
|
---|
10 | * http://www.apache.org/licenses/LICENSE-2.0
|
---|
11 | *
|
---|
12 | * Unless required by applicable law or agreed to in writing, software
|
---|
13 | * distributed under the License is distributed on an "AS IS" BASIS,
|
---|
14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
---|
15 | * See the License for the specific language governing permissions and
|
---|
16 | * limitations under the License.
|
---|
17 | */
|
---|
18 |
|
---|
19 |
|
---|
20 | import java.io.Serializable;
|
---|
21 | import java.util.Iterator;
|
---|
22 | import java.util.List;
|
---|
23 |
|
---|
24 | import org.neuroph.core.Connection;
|
---|
25 | import org.neuroph.core.Layer;
|
---|
26 | import org.neuroph.core.Neuron;
|
---|
27 | import org.neuroph.core.Weight;
|
---|
28 | import org.neuroph.core.data.DataSet;
|
---|
29 | import org.neuroph.core.data.DataSetRow;
|
---|
30 | import org.neuroph.core.learning.error.ErrorFunction;
|
---|
31 | //import org.neuroph.core.learning.error.MeanSquaredError;
|
---|
32 | import org.neuroph.core.learning.stop.MaxErrorStop;
|
---|
33 | import org.neuroph.core.learning.IterativeLearning;
|
---|
34 |
|
---|
35 | import org.neuroph.core.transfer.Sigmoid;
|
---|
36 |
|
---|
37 |
|
---|
38 | /**
|
---|
39 | * Base class for all supervised learning algorithms.
|
---|
40 | * It extends IterativeLearning, and provides general supervised learning principles.
|
---|
41 | * Based on Template Method Pattern with abstract method calculateWeightChanges
|
---|
42 | *
|
---|
43 | * TODO: random pattern order
|
---|
44 | *
|
---|
45 | * @author Zoran Sevarac <sevarac@gmail.com>
|
---|
46 | */
|
---|
47 | abstract public class SupervisedLearning extends IterativeLearning implements
|
---|
48 | Serializable {
|
---|
49 |
|
---|
50 | /**
|
---|
51 | * The class fingerprint that is set to indicate serialization
|
---|
52 | * compatibility with a previous version of the class
|
---|
53 | */
|
---|
54 | private static final long serialVersionUID = 3L;
|
---|
55 |
|
---|
56 | /**
|
---|
57 | * Total network error in previous epoch
|
---|
58 | */
|
---|
59 | protected transient double previousEpochError;
|
---|
60 |
|
---|
61 | /**
|
---|
62 | * Max allowed network error (condition to stop learning)
|
---|
63 | */
|
---|
64 | protected double maxError = 0.01d;
|
---|
65 |
|
---|
66 | /**
|
---|
67 | * Stopping condition: training stops if total network error change is smaller than minErrorChange
|
---|
68 | * for minErrorChangeIterationsLimit number of iterations
|
---|
69 | */
|
---|
70 | private double minErrorChange = Double.POSITIVE_INFINITY;
|
---|
71 |
|
---|
72 | /**
|
---|
73 | * Stopping condition: training stops if total network error change is smaller than minErrorChange
|
---|
74 | * for minErrorChangeStopIterations number of iterations
|
---|
75 | */
|
---|
76 | private int minErrorChangeIterationsLimit = Integer.MAX_VALUE;
|
---|
77 |
|
---|
78 | /**
|
---|
79 | * Count iterations where error change is smaller then minErrorChange.
|
---|
80 | */
|
---|
81 | private transient int minErrorChangeIterationsCount;
|
---|
82 |
|
---|
83 | /**
|
---|
84 | * Setting to determine if learning (weights update) is in batch mode.
|
---|
85 | * False by default.
|
---|
86 | */
|
---|
87 | private boolean batchMode = false;
|
---|
88 |
|
---|
89 | private ErrorFunction errorFunction;
|
---|
90 |
|
---|
91 | /**
|
---|
92 | * Creates new supervised learning rule
|
---|
93 | */
|
---|
94 | public SupervisedLearning() {
|
---|
95 | super();
|
---|
96 | errorFunction = new CrossEntropyError();
|
---|
97 | // stopConditions.add(new MaxErrorStop(this));
|
---|
98 | }
|
---|
99 |
|
---|
100 | /**
|
---|
101 | * This method should implement the weights update procedure for the whole network
|
---|
102 | * for the given output error vector.
|
---|
103 | *
|
---|
104 | * @param outputError output error vector for some network input (aka. patternError, network error)
|
---|
105 | * usually the difference between desired and actual output
|
---|
106 | */
|
---|
107 | abstract protected void calculateWeightChanges(double[] outputError);
|
---|
108 |
|
---|
109 | /**
|
---|
110 | * Trains network for the specified training set and maxError
|
---|
111 | *
|
---|
112 | * @param trainingSet training set to learn
|
---|
113 | * @param maxError learning stop condition. If maxError is reached learning stops
|
---|
114 | */
|
---|
115 | public final void learn(DataSet trainingSet, double maxError) {
|
---|
116 | this.maxError = maxError;
|
---|
117 | learn(trainingSet);
|
---|
118 | }
|
---|
119 |
|
---|
120 | /**
|
---|
121 | * Trains network for the specified training set, maxError and number of iterations
|
---|
122 | *
|
---|
123 | * @param trainingSet training set to learn
|
---|
124 | * @param maxError learning stop condition. if maxError is reached learning stops
|
---|
125 | * @param maxIterations maximum number of learning iterations
|
---|
126 | */
|
---|
127 | public final void learn(DataSet trainingSet, double maxError, int maxIterations) {
|
---|
128 | this.trainingSet = trainingSet;
|
---|
129 | this.maxError = maxError;
|
---|
130 | setMaxIterations(maxIterations);
|
---|
131 | learn(trainingSet);
|
---|
132 | }
|
---|
133 |
|
---|
134 | @Override
|
---|
135 | protected void onStart() {
|
---|
136 | super.onStart(); // reset iteration counter
|
---|
137 | minErrorChangeIterationsCount = 0;
|
---|
138 | previousEpochError = 0d;
|
---|
139 | }
|
---|
140 |
|
---|
141 | @Override
|
---|
142 | protected void beforeEpoch() {
|
---|
143 | previousEpochError = errorFunction.getTotalError();
|
---|
144 | errorFunction.reset();
|
---|
145 | }
|
---|
146 |
|
---|
147 | @Override
|
---|
148 | protected void afterEpoch() {
|
---|
149 | // calculate abs error change and count iterations if its below specified min error change (used for stop condition)
|
---|
150 | double absErrorChange = Math.abs(previousEpochError - errorFunction.getTotalError());
|
---|
151 | if (absErrorChange <= this.minErrorChange) {
|
---|
152 | minErrorChangeIterationsCount++;
|
---|
153 | } else {
|
---|
154 | minErrorChangeIterationsCount = 0;
|
---|
155 | }
|
---|
156 |
|
---|
157 | // if learning is performed in batch mode, apply accumulated weight changes from this epoch
|
---|
158 | if (batchMode == true) {
|
---|
159 | doBatchWeightsUpdate();
|
---|
160 | }
|
---|
161 | }
|
---|
162 |
|
---|
163 | /**
|
---|
164 | * This method implements basic logic for one learning epoch for the
|
---|
165 | * supervised learning algorithms. Epoch is the one pass through the
|
---|
166 | * training set. This method iterates through the training set
|
---|
167 | * and trains network for each element. It also sets flag if conditions
|
---|
168 | * to stop learning has been reached: network error below some allowed
|
---|
169 | * value, or maximum iteration count
|
---|
170 | *
|
---|
171 | * @param trainingSet training set for training network
|
---|
172 | */
|
---|
173 | @Override
|
---|
174 | public void doLearningEpoch(DataSet trainingSet) {
|
---|
175 | Iterator<DataSetRow> iterator = trainingSet.iterator();
|
---|
176 | while (iterator.hasNext() && !isStopped()) { // iterate all elements from training set - maybe remove isStopped from here
|
---|
177 | DataSetRow dataSetRow = iterator.next();
|
---|
178 | learnPattern(dataSetRow); // learn current input/output pattern defined by SupervisedTrainingElement
|
---|
179 | }
|
---|
180 | }
|
---|
181 |
|
---|
182 | /**
|
---|
183 | * Trains network with the input and desired output pattern from the specified training element
|
---|
184 | *
|
---|
185 | * @param trainingElement supervised training element which contains input and desired output
|
---|
186 | */
|
---|
187 | protected final void learnPattern(DataSetRow trainingElement) {
|
---|
188 | int size = trainingElement.getInput().length/2;
|
---|
189 | double[] input1 = new double[size];
|
---|
190 | double[] input2 = new double[size];
|
---|
191 | for (int i=0; i<size; i++) {
|
---|
192 | input1[i] = trainingElement.getInput()[i];
|
---|
193 | }
|
---|
194 | for (int i=0; i<size; i++) {
|
---|
195 | input2[i] = trainingElement.getInput()[size+i];
|
---|
196 | }
|
---|
197 | neuralNetwork.setInput(input1);
|
---|
198 | neuralNetwork.calculate();
|
---|
199 | double[] output1 = neuralNetwork.getOutput();
|
---|
200 | neuralNetwork.setInput(input2);
|
---|
201 | neuralNetwork.calculate();
|
---|
202 | double[] output2 = neuralNetwork.getOutput();
|
---|
203 | double[] output = new double[1];
|
---|
204 | output[0] = output1[0]-output2[0];
|
---|
205 | //update the under input2
|
---|
206 | double[] patternError2 = ((CrossEntropyError) errorFunction).addPatternError2(output, trainingElement.getDesiredOutput());
|
---|
207 | calculateWeightChanges(patternError2);
|
---|
208 | if (!batchMode) applyWeightChanges(); // batch mode updates are done i doBatchWeightsUpdate
|
---|
209 | //update under input1
|
---|
210 | neuralNetwork.setInput(input1);
|
---|
211 | neuralNetwork.calculate();
|
---|
212 | double[] patternError1 = ((CrossEntropyError) errorFunction).addPatternError(output, trainingElement.getDesiredOutput());
|
---|
213 | calculateWeightChanges(patternError1);
|
---|
214 | if (!batchMode) applyWeightChanges(); // batch mode updates are done i doBatchWeightsUpdate
|
---|
215 | }
|
---|
216 |
|
---|
217 | /**
|
---|
218 | * This method updates network weights in batch mode - use accumulated weights change stored in Weight.deltaWeight
|
---|
219 | * It is executed after each learning epoch, only if learning is done in batch mode.
|
---|
220 | *
|
---|
221 | * @see SupervisedLearning#doLearningEpoch(org.neuroph.core.data.DataSet)
|
---|
222 | */
|
---|
223 | protected void doBatchWeightsUpdate() {
|
---|
224 | // iterate layers from output to input
|
---|
225 | List<Layer> layers = neuralNetwork.getLayers();
|
---|
226 | for (int i = neuralNetwork.getLayersCount() - 1; i > 0; i--) {
|
---|
227 | // iterate neurons at each layer
|
---|
228 | for (Neuron neuron : layers.get(i).getNeurons()) {
|
---|
229 | // iterate connections/weights for each neuron
|
---|
230 | for (Connection connection : neuron.getInputConnections()) {
|
---|
231 | // for each connection weight apply accumulated weight change
|
---|
232 | Weight weight = connection.getWeight();
|
---|
233 | weight.value += weight.weightChange / getTrainingSet().size(); // apply delta weight which is the sum of delta weights in batch mode - TODO: add mini batch
|
---|
234 | weight.weightChange = 0; // reset deltaWeight
|
---|
235 | }
|
---|
236 | }
|
---|
237 | }
|
---|
238 | }
|
---|
239 |
|
---|
240 | /**
|
---|
241 | * Returns true if learning is performed in batch mode, false otherwise
|
---|
242 | *
|
---|
243 | * @return true if learning is performed in batch mode, false otherwise
|
---|
244 | */
|
---|
245 | public boolean isBatchMode() {
|
---|
246 | return batchMode;
|
---|
247 | }
|
---|
248 |
|
---|
249 | /**
|
---|
250 | * Sets batch mode on/off (true/false)
|
---|
251 | *
|
---|
252 | * @param batchMode batch mode setting
|
---|
253 | */
|
---|
254 | public void setBatchMode(boolean batchMode) {
|
---|
255 | this.batchMode = batchMode;
|
---|
256 | }
|
---|
257 |
|
---|
258 | /**
|
---|
259 | * Sets allowed network error, which indicates when to stopLearning training
|
---|
260 | *
|
---|
261 | * @param maxError network error
|
---|
262 | */
|
---|
263 | public void setMaxError(double maxError) {
|
---|
264 | this.maxError = maxError;
|
---|
265 | }
|
---|
266 |
|
---|
267 | /**
|
---|
268 | * Returns learning error tolerance - the value of total network error to stop learning.
|
---|
269 | *
|
---|
270 | * @return learning error tolerance
|
---|
271 | */
|
---|
272 | public double getMaxError() {
|
---|
273 | return maxError;
|
---|
274 | }
|
---|
275 |
|
---|
276 | /**
|
---|
277 | * Returns total network error in previous learning epoch
|
---|
278 | *
|
---|
279 | * @return total network error in previous learning epoch
|
---|
280 | */
|
---|
281 | public double getPreviousEpochError() {
|
---|
282 | return previousEpochError;
|
---|
283 | }
|
---|
284 |
|
---|
285 | /**
|
---|
286 | * Returns min error change stopping criteria
|
---|
287 | *
|
---|
288 | * @return min error change stopping criteria
|
---|
289 | */
|
---|
290 | public double getMinErrorChange() {
|
---|
291 | return minErrorChange;
|
---|
292 | }
|
---|
293 |
|
---|
294 | /**
|
---|
295 | * Sets min error change stopping criteria
|
---|
296 | *
|
---|
297 | * @param minErrorChange value for min error change stopping criteria
|
---|
298 | */
|
---|
299 | public void setMinErrorChange(double minErrorChange) {
|
---|
300 | this.minErrorChange = minErrorChange;
|
---|
301 | }
|
---|
302 |
|
---|
303 | /**
|
---|
304 | * Returns number of iterations for min error change stopping criteria
|
---|
305 | *
|
---|
306 | * @return number of iterations for min error change stopping criteria
|
---|
307 | */
|
---|
308 | public int getMinErrorChangeIterationsLimit() {
|
---|
309 | return minErrorChangeIterationsLimit;
|
---|
310 | }
|
---|
311 |
|
---|
312 | /**
|
---|
313 | * Sets number of iterations for min error change stopping criteria
|
---|
314 | *
|
---|
315 | * @param minErrorChangeIterationsLimit number of iterations for min error change stopping criteria
|
---|
316 | */
|
---|
317 | public void setMinErrorChangeIterationsLimit(int minErrorChangeIterationsLimit) {
|
---|
318 | this.minErrorChangeIterationsLimit = minErrorChangeIterationsLimit;
|
---|
319 | }
|
---|
320 |
|
---|
321 | /**
|
---|
322 | * Returns number of iterations count for for min error change stopping criteria
|
---|
323 | *
|
---|
324 | * @return number of iterations count for for min error change stopping criteria
|
---|
325 | */
|
---|
326 | public int getMinErrorChangeIterationsCount() {
|
---|
327 | return minErrorChangeIterationsCount;
|
---|
328 | }
|
---|
329 |
|
---|
330 | public ErrorFunction getErrorFunction() {
|
---|
331 | return errorFunction;
|
---|
332 | }
|
---|
333 |
|
---|
334 | public void setErrorFunction(ErrorFunction errorFunction) {
|
---|
335 | this.errorFunction = errorFunction;
|
---|
336 | }
|
---|
337 |
|
---|
338 |
|
---|
339 | public double getTotalNetworkError() {
|
---|
340 | return errorFunction.getTotalError();
|
---|
341 | }
|
---|
342 |
|
---|
343 | private void applyWeightChanges() {
|
---|
344 | List<Layer> layers = neuralNetwork.getLayers();
|
---|
345 | for (int i = neuralNetwork.getLayersCount() - 1; i > 0; i--) {
|
---|
346 | // iterate neurons at each layer
|
---|
347 | for (Neuron neuron : layers.get(i)) {
|
---|
348 | // iterate connections/weights for each neuron
|
---|
349 | for (Connection connection : neuron.getInputConnections()) {
|
---|
350 | // for each connection weight apply accumulated weight change
|
---|
351 | Weight weight = connection.getWeight();
|
---|
352 | if (!isBatchMode()) {
|
---|
353 | weight.value += weight.weightChange;
|
---|
354 | } else {
|
---|
355 | weight.value += (weight.weightChange / getTrainingSet().size());
|
---|
356 | }
|
---|
357 |
|
---|
358 | weight.weightChange = 0; // reset deltaWeight
|
---|
359 | }
|
---|
360 | }
|
---|
361 | }
|
---|
362 | }
|
---|
363 | } |
---|