[1] | 1 | package geniusweb.blingbling.Ranknet4j;
|
---|
| 2 |
|
---|
| 3 | /**
|
---|
| 4 | * Copyright 2010 Neuroph Project http://neuroph.sourceforge.net
|
---|
| 5 | *
|
---|
| 6 | * Licensed under the Apache License, Version 2.0 (the "License");
|
---|
| 7 | * you may not use this file except in compliance with the License.
|
---|
| 8 | * You may obtain a copy of the License at
|
---|
| 9 | *
|
---|
| 10 | * http://www.apache.org/licenses/LICENSE-2.0
|
---|
| 11 | *
|
---|
| 12 | * Unless required by applicable law or agreed to in writing, software
|
---|
| 13 | * distributed under the License is distributed on an "AS IS" BASIS,
|
---|
| 14 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
---|
| 15 | * See the License for the specific language governing permissions and
|
---|
| 16 | * limitations under the License.
|
---|
| 17 | */
|
---|
| 18 |
|
---|
| 19 |
|
---|
| 20 | import java.io.Serializable;
|
---|
| 21 | import java.util.Iterator;
|
---|
| 22 | import java.util.List;
|
---|
| 23 |
|
---|
| 24 | import org.neuroph.core.Connection;
|
---|
| 25 | import org.neuroph.core.Layer;
|
---|
| 26 | import org.neuroph.core.Neuron;
|
---|
| 27 | import org.neuroph.core.Weight;
|
---|
| 28 | import org.neuroph.core.data.DataSet;
|
---|
| 29 | import org.neuroph.core.data.DataSetRow;
|
---|
| 30 | import org.neuroph.core.learning.error.ErrorFunction;
|
---|
| 31 | //import org.neuroph.core.learning.error.MeanSquaredError;
|
---|
| 32 | import org.neuroph.core.learning.stop.MaxErrorStop;
|
---|
| 33 | import org.neuroph.core.learning.IterativeLearning;
|
---|
| 34 |
|
---|
| 35 | import org.neuroph.core.transfer.Sigmoid;
|
---|
| 36 |
|
---|
| 37 |
|
---|
| 38 | /**
|
---|
| 39 | * Base class for all supervised learning algorithms.
|
---|
| 40 | * It extends IterativeLearning, and provides general supervised learning principles.
|
---|
| 41 | * Based on Template Method Pattern with abstract method calculateWeightChanges
|
---|
| 42 | *
|
---|
| 43 | * TODO: random pattern order
|
---|
| 44 | *
|
---|
| 45 | * @author Zoran Sevarac <sevarac@gmail.com>
|
---|
| 46 | */
|
---|
| 47 | abstract public class SupervisedLearning extends IterativeLearning implements
|
---|
| 48 | Serializable {
|
---|
| 49 |
|
---|
| 50 | /**
|
---|
| 51 | * The class fingerprint that is set to indicate serialization
|
---|
| 52 | * compatibility with a previous version of the class
|
---|
| 53 | */
|
---|
| 54 | private static final long serialVersionUID = 3L;
|
---|
| 55 |
|
---|
| 56 | /**
|
---|
| 57 | * Total network error in previous epoch
|
---|
| 58 | */
|
---|
| 59 | protected transient double previousEpochError;
|
---|
| 60 |
|
---|
| 61 | /**
|
---|
| 62 | * Max allowed network error (condition to stop learning)
|
---|
| 63 | */
|
---|
| 64 | protected double maxError = 0.01d;
|
---|
| 65 |
|
---|
| 66 | /**
|
---|
| 67 | * Stopping condition: training stops if total network error change is smaller than minErrorChange
|
---|
| 68 | * for minErrorChangeIterationsLimit number of iterations
|
---|
| 69 | */
|
---|
| 70 | private double minErrorChange = Double.POSITIVE_INFINITY;
|
---|
| 71 |
|
---|
| 72 | /**
|
---|
| 73 | * Stopping condition: training stops if total network error change is smaller than minErrorChange
|
---|
| 74 | * for minErrorChangeStopIterations number of iterations
|
---|
| 75 | */
|
---|
| 76 | private int minErrorChangeIterationsLimit = Integer.MAX_VALUE;
|
---|
| 77 |
|
---|
| 78 | /**
|
---|
| 79 | * Count iterations where error change is smaller then minErrorChange.
|
---|
| 80 | */
|
---|
| 81 | private transient int minErrorChangeIterationsCount;
|
---|
| 82 |
|
---|
| 83 | /**
|
---|
| 84 | * Setting to determine if learning (weights update) is in batch mode.
|
---|
| 85 | * False by default.
|
---|
| 86 | */
|
---|
| 87 | private boolean batchMode = false;
|
---|
| 88 |
|
---|
| 89 | private ErrorFunction errorFunction;
|
---|
| 90 |
|
---|
| 91 | /**
|
---|
| 92 | * Creates new supervised learning rule
|
---|
| 93 | */
|
---|
| 94 | public SupervisedLearning() {
|
---|
| 95 | super();
|
---|
| 96 | errorFunction = new CrossEntropyError();
|
---|
| 97 | // stopConditions.add(new MaxErrorStop(this));
|
---|
| 98 | }
|
---|
| 99 |
|
---|
| 100 | /**
|
---|
| 101 | * This method should implement the weights update procedure for the whole network
|
---|
| 102 | * for the given output error vector.
|
---|
| 103 | *
|
---|
| 104 | * @param outputError output error vector for some network input (aka. patternError, network error)
|
---|
| 105 | * usually the difference between desired and actual output
|
---|
| 106 | */
|
---|
| 107 | abstract protected void calculateWeightChanges(double[] outputError);
|
---|
| 108 |
|
---|
| 109 | /**
|
---|
| 110 | * Trains network for the specified training set and maxError
|
---|
| 111 | *
|
---|
| 112 | * @param trainingSet training set to learn
|
---|
| 113 | * @param maxError learning stop condition. If maxError is reached learning stops
|
---|
| 114 | */
|
---|
| 115 | public final void learn(DataSet trainingSet, double maxError) {
|
---|
| 116 | this.maxError = maxError;
|
---|
| 117 | learn(trainingSet);
|
---|
| 118 | }
|
---|
| 119 |
|
---|
| 120 | /**
|
---|
| 121 | * Trains network for the specified training set, maxError and number of iterations
|
---|
| 122 | *
|
---|
| 123 | * @param trainingSet training set to learn
|
---|
| 124 | * @param maxError learning stop condition. if maxError is reached learning stops
|
---|
| 125 | * @param maxIterations maximum number of learning iterations
|
---|
| 126 | */
|
---|
| 127 | public final void learn(DataSet trainingSet, double maxError, int maxIterations) {
|
---|
| 128 | this.trainingSet = trainingSet;
|
---|
| 129 | this.maxError = maxError;
|
---|
| 130 | setMaxIterations(maxIterations);
|
---|
| 131 | learn(trainingSet);
|
---|
| 132 | }
|
---|
| 133 |
|
---|
| 134 | @Override
|
---|
| 135 | protected void onStart() {
|
---|
| 136 | super.onStart(); // reset iteration counter
|
---|
| 137 | minErrorChangeIterationsCount = 0;
|
---|
| 138 | previousEpochError = 0d;
|
---|
| 139 | }
|
---|
| 140 |
|
---|
| 141 | @Override
|
---|
| 142 | protected void beforeEpoch() {
|
---|
| 143 | previousEpochError = errorFunction.getTotalError();
|
---|
| 144 | errorFunction.reset();
|
---|
| 145 | }
|
---|
| 146 |
|
---|
| 147 | @Override
|
---|
| 148 | protected void afterEpoch() {
|
---|
| 149 | // calculate abs error change and count iterations if its below specified min error change (used for stop condition)
|
---|
| 150 | double absErrorChange = Math.abs(previousEpochError - errorFunction.getTotalError());
|
---|
| 151 | if (absErrorChange <= this.minErrorChange) {
|
---|
| 152 | minErrorChangeIterationsCount++;
|
---|
| 153 | } else {
|
---|
| 154 | minErrorChangeIterationsCount = 0;
|
---|
| 155 | }
|
---|
| 156 |
|
---|
| 157 | // if learning is performed in batch mode, apply accumulated weight changes from this epoch
|
---|
| 158 | if (batchMode == true) {
|
---|
| 159 | doBatchWeightsUpdate();
|
---|
| 160 | }
|
---|
| 161 | }
|
---|
| 162 |
|
---|
| 163 | /**
|
---|
| 164 | * This method implements basic logic for one learning epoch for the
|
---|
| 165 | * supervised learning algorithms. Epoch is the one pass through the
|
---|
| 166 | * training set. This method iterates through the training set
|
---|
| 167 | * and trains network for each element. It also sets flag if conditions
|
---|
| 168 | * to stop learning has been reached: network error below some allowed
|
---|
| 169 | * value, or maximum iteration count
|
---|
| 170 | *
|
---|
| 171 | * @param trainingSet training set for training network
|
---|
| 172 | */
|
---|
| 173 | @Override
|
---|
| 174 | public void doLearningEpoch(DataSet trainingSet) {
|
---|
| 175 | Iterator<DataSetRow> iterator = trainingSet.iterator();
|
---|
| 176 | while (iterator.hasNext() && !isStopped()) { // iterate all elements from training set - maybe remove isStopped from here
|
---|
| 177 | DataSetRow dataSetRow = iterator.next();
|
---|
| 178 | learnPattern(dataSetRow); // learn current input/output pattern defined by SupervisedTrainingElement
|
---|
| 179 | }
|
---|
| 180 | }
|
---|
| 181 |
|
---|
| 182 | /**
|
---|
| 183 | * Trains network with the input and desired output pattern from the specified training element
|
---|
| 184 | *
|
---|
| 185 | * @param trainingElement supervised training element which contains input and desired output
|
---|
| 186 | */
|
---|
| 187 | protected final void learnPattern(DataSetRow trainingElement) {
|
---|
| 188 | int size = trainingElement.getInput().length/2;
|
---|
| 189 | double[] input1 = new double[size];
|
---|
| 190 | double[] input2 = new double[size];
|
---|
| 191 | for (int i=0; i<size; i++) {
|
---|
| 192 | input1[i] = trainingElement.getInput()[i];
|
---|
| 193 | }
|
---|
| 194 | for (int i=0; i<size; i++) {
|
---|
| 195 | input2[i] = trainingElement.getInput()[size+i];
|
---|
| 196 | }
|
---|
| 197 | neuralNetwork.setInput(input1);
|
---|
| 198 | neuralNetwork.calculate();
|
---|
| 199 | double[] output1 = neuralNetwork.getOutput();
|
---|
| 200 | neuralNetwork.setInput(input2);
|
---|
| 201 | neuralNetwork.calculate();
|
---|
| 202 | double[] output2 = neuralNetwork.getOutput();
|
---|
| 203 | double[] output = new double[1];
|
---|
| 204 | output[0] = output1[0]-output2[0];
|
---|
| 205 | //update the under input2
|
---|
| 206 | double[] patternError2 = ((CrossEntropyError) errorFunction).addPatternError2(output, trainingElement.getDesiredOutput());
|
---|
| 207 | calculateWeightChanges(patternError2);
|
---|
| 208 | if (!batchMode) applyWeightChanges(); // batch mode updates are done i doBatchWeightsUpdate
|
---|
| 209 | //update under input1
|
---|
| 210 | neuralNetwork.setInput(input1);
|
---|
| 211 | neuralNetwork.calculate();
|
---|
| 212 | double[] patternError1 = ((CrossEntropyError) errorFunction).addPatternError(output, trainingElement.getDesiredOutput());
|
---|
| 213 | calculateWeightChanges(patternError1);
|
---|
| 214 | if (!batchMode) applyWeightChanges(); // batch mode updates are done i doBatchWeightsUpdate
|
---|
| 215 | }
|
---|
| 216 |
|
---|
| 217 | /**
|
---|
| 218 | * This method updates network weights in batch mode - use accumulated weights change stored in Weight.deltaWeight
|
---|
| 219 | * It is executed after each learning epoch, only if learning is done in batch mode.
|
---|
| 220 | *
|
---|
| 221 | * @see SupervisedLearning#doLearningEpoch(org.neuroph.core.data.DataSet)
|
---|
| 222 | */
|
---|
| 223 | protected void doBatchWeightsUpdate() {
|
---|
| 224 | // iterate layers from output to input
|
---|
| 225 | List<Layer> layers = neuralNetwork.getLayers();
|
---|
| 226 | for (int i = neuralNetwork.getLayersCount() - 1; i > 0; i--) {
|
---|
| 227 | // iterate neurons at each layer
|
---|
| 228 | for (Neuron neuron : layers.get(i).getNeurons()) {
|
---|
| 229 | // iterate connections/weights for each neuron
|
---|
| 230 | for (Connection connection : neuron.getInputConnections()) {
|
---|
| 231 | // for each connection weight apply accumulated weight change
|
---|
| 232 | Weight weight = connection.getWeight();
|
---|
| 233 | weight.value += weight.weightChange / getTrainingSet().size(); // apply delta weight which is the sum of delta weights in batch mode - TODO: add mini batch
|
---|
| 234 | weight.weightChange = 0; // reset deltaWeight
|
---|
| 235 | }
|
---|
| 236 | }
|
---|
| 237 | }
|
---|
| 238 | }
|
---|
| 239 |
|
---|
| 240 | /**
|
---|
| 241 | * Returns true if learning is performed in batch mode, false otherwise
|
---|
| 242 | *
|
---|
| 243 | * @return true if learning is performed in batch mode, false otherwise
|
---|
| 244 | */
|
---|
| 245 | public boolean isBatchMode() {
|
---|
| 246 | return batchMode;
|
---|
| 247 | }
|
---|
| 248 |
|
---|
| 249 | /**
|
---|
| 250 | * Sets batch mode on/off (true/false)
|
---|
| 251 | *
|
---|
| 252 | * @param batchMode batch mode setting
|
---|
| 253 | */
|
---|
| 254 | public void setBatchMode(boolean batchMode) {
|
---|
| 255 | this.batchMode = batchMode;
|
---|
| 256 | }
|
---|
| 257 |
|
---|
| 258 | /**
|
---|
| 259 | * Sets allowed network error, which indicates when to stopLearning training
|
---|
| 260 | *
|
---|
| 261 | * @param maxError network error
|
---|
| 262 | */
|
---|
| 263 | public void setMaxError(double maxError) {
|
---|
| 264 | this.maxError = maxError;
|
---|
| 265 | }
|
---|
| 266 |
|
---|
| 267 | /**
|
---|
| 268 | * Returns learning error tolerance - the value of total network error to stop learning.
|
---|
| 269 | *
|
---|
| 270 | * @return learning error tolerance
|
---|
| 271 | */
|
---|
| 272 | public double getMaxError() {
|
---|
| 273 | return maxError;
|
---|
| 274 | }
|
---|
| 275 |
|
---|
| 276 | /**
|
---|
| 277 | * Returns total network error in previous learning epoch
|
---|
| 278 | *
|
---|
| 279 | * @return total network error in previous learning epoch
|
---|
| 280 | */
|
---|
| 281 | public double getPreviousEpochError() {
|
---|
| 282 | return previousEpochError;
|
---|
| 283 | }
|
---|
| 284 |
|
---|
| 285 | /**
|
---|
| 286 | * Returns min error change stopping criteria
|
---|
| 287 | *
|
---|
| 288 | * @return min error change stopping criteria
|
---|
| 289 | */
|
---|
| 290 | public double getMinErrorChange() {
|
---|
| 291 | return minErrorChange;
|
---|
| 292 | }
|
---|
| 293 |
|
---|
| 294 | /**
|
---|
| 295 | * Sets min error change stopping criteria
|
---|
| 296 | *
|
---|
| 297 | * @param minErrorChange value for min error change stopping criteria
|
---|
| 298 | */
|
---|
| 299 | public void setMinErrorChange(double minErrorChange) {
|
---|
| 300 | this.minErrorChange = minErrorChange;
|
---|
| 301 | }
|
---|
| 302 |
|
---|
| 303 | /**
|
---|
| 304 | * Returns number of iterations for min error change stopping criteria
|
---|
| 305 | *
|
---|
| 306 | * @return number of iterations for min error change stopping criteria
|
---|
| 307 | */
|
---|
| 308 | public int getMinErrorChangeIterationsLimit() {
|
---|
| 309 | return minErrorChangeIterationsLimit;
|
---|
| 310 | }
|
---|
| 311 |
|
---|
| 312 | /**
|
---|
| 313 | * Sets number of iterations for min error change stopping criteria
|
---|
| 314 | *
|
---|
| 315 | * @param minErrorChangeIterationsLimit number of iterations for min error change stopping criteria
|
---|
| 316 | */
|
---|
| 317 | public void setMinErrorChangeIterationsLimit(int minErrorChangeIterationsLimit) {
|
---|
| 318 | this.minErrorChangeIterationsLimit = minErrorChangeIterationsLimit;
|
---|
| 319 | }
|
---|
| 320 |
|
---|
| 321 | /**
|
---|
| 322 | * Returns number of iterations count for for min error change stopping criteria
|
---|
| 323 | *
|
---|
| 324 | * @return number of iterations count for for min error change stopping criteria
|
---|
| 325 | */
|
---|
| 326 | public int getMinErrorChangeIterationsCount() {
|
---|
| 327 | return minErrorChangeIterationsCount;
|
---|
| 328 | }
|
---|
| 329 |
|
---|
| 330 | public ErrorFunction getErrorFunction() {
|
---|
| 331 | return errorFunction;
|
---|
| 332 | }
|
---|
| 333 |
|
---|
| 334 | public void setErrorFunction(ErrorFunction errorFunction) {
|
---|
| 335 | this.errorFunction = errorFunction;
|
---|
| 336 | }
|
---|
| 337 |
|
---|
| 338 |
|
---|
| 339 | public double getTotalNetworkError() {
|
---|
| 340 | return errorFunction.getTotalError();
|
---|
| 341 | }
|
---|
| 342 |
|
---|
| 343 | private void applyWeightChanges() {
|
---|
| 344 | List<Layer> layers = neuralNetwork.getLayers();
|
---|
| 345 | for (int i = neuralNetwork.getLayersCount() - 1; i > 0; i--) {
|
---|
| 346 | // iterate neurons at each layer
|
---|
| 347 | for (Neuron neuron : layers.get(i)) {
|
---|
| 348 | // iterate connections/weights for each neuron
|
---|
| 349 | for (Connection connection : neuron.getInputConnections()) {
|
---|
| 350 | // for each connection weight apply accumulated weight change
|
---|
| 351 | Weight weight = connection.getWeight();
|
---|
| 352 | if (!isBatchMode()) {
|
---|
| 353 | weight.value += weight.weightChange;
|
---|
| 354 | } else {
|
---|
| 355 | weight.value += (weight.weightChange / getTrainingSet().size());
|
---|
| 356 | }
|
---|
| 357 |
|
---|
| 358 | weight.weightChange = 0; // reset deltaWeight
|
---|
| 359 | }
|
---|
| 360 | }
|
---|
| 361 | }
|
---|
| 362 | }
|
---|
| 363 | } |
---|