source: anac2020/BlingBling/src/main/java/geniusweb/blingbling/Ranknet4j/LMS.java

Last change on this file was 1, checked in by wouter, 4 years ago

#1910 added anac2020 parties

File size: 3.8 KB
Line 
1package geniusweb.blingbling.Ranknet4j;
2
3import java.io.Serializable;
4import java.util.List;
5import org.neuroph.core.Connection;
6import org.neuroph.core.Neuron;
7import org.neuroph.core.Weight;
8
9/**
10 * LMS learning rule for neural networks. This learning rule is used to train
11 * Adaline neural network, and this class is base for all LMS based learning
12 * rules like PerceptronLearning, DeltaRule, SigmoidDeltaRule, Backpropagation
13 * etc.
14 *
15 * @author Zoran Sevarac <sevarac@gmail.com>
16 */
17public class LMS extends SupervisedLearning implements Serializable {
18
19 /**
20 * The class fingerprint that is set to indicate serialization
21 * compatibility with a previous version of the class.
22 */
23 private static final long serialVersionUID = 2L;
24
25
26 /**
27 * Creates a new LMS learning rule
28 */
29 public LMS() {
30
31 }
32
33
34 /**
35 * This method implements the weights update procedure for the whole network
36 * for the given output error vector.
37 *
38 * @param outputError
39 * output error vector for some network input- the difference between desired and actual output
40 * @see SupervisedLearning#learnPattern(org.neuroph.core.data.DataSetRow) learnPattern
41 */
42 @Override
43 protected void calculateWeightChanges(final double[] outputError) {
44 int i = 0;
45 // for each neuron in output layer
46 List<Neuron> outputNeurons = neuralNetwork.getOutputNeurons();
47 for (Neuron neuron : outputNeurons) {
48 neuron.setDelta(outputError[i]); // set the neuron error, as difference between desired and actual output
49 calculateWeightChanges(neuron); // and update neuron weights -- this should be renamed to calculate weight changes
50 i++;
51 }
52
53 // outputNeurons.forEach( neuron -> updateNeuronWeights(neuron));
54
55 }
56
57 /**
58 * This method calculates weights changes for the single neuron.
59 * It iterates through all neuron's input connections, and calculates/set weight change for each weight
60 * using formula
61 * deltaWeight = -learningRate * delta * input
62 *
63 * where delta is a neuron error, a difference between desired/target and actual output for specific neuron
64 * neuronError = desiredOutput[i] - actualOutput[i] (see method SuprevisedLearning.calculateOutputError)
65 *
66 * @param neuron
67 * neuron to update weights
68 *
69 * @see LMS#calculateWeightChanges(double[])
70 */
71 protected void calculateWeightChanges(Neuron neuron) {
72 // get the error(delta) for specified neuron,
73 double delta = neuron.getDelta();
74
75 // tanh can be used to minimise the impact of big error values, which can cause network instability
76 // suggested at https://sourceforge.net/tracker/?func=detail&atid=1107579&aid=3130561&group_id=238532
77 // double neuronError = Math.tanh(neuron.getError());
78
79 // iterate through all neuron's input connections
80 for (Connection connection : neuron.getInputConnections()) {
81 // get the input from current connection
82 final double input = connection.getInput();
83 // calculate the weight change
84 final double weightChange = -learningRate * delta * input;
85
86 // get the connection weight
87 final Weight weight = connection.getWeight();
88 // if the learning is in online mode (not batch) apply the weight change immediately
89 if (!this.isBatchMode()) {
90 weight.weightChange = weightChange;
91 } else { // otherwise if its in batch mode, accumulate weight changes and apply them after the current epoch (see SupervisedLearning.doLearningEpoch method)
92 weight.weightChange += weightChange;
93 }
94 }
95 }
96
97}
Note: See TracBrowser for help on using the repository browser.