1 | package geniusweb.blingbling.Ranknet4j;
|
---|
2 |
|
---|
3 | import java.io.Serializable;
|
---|
4 | import java.util.List;
|
---|
5 | import org.neuroph.core.Connection;
|
---|
6 | import org.neuroph.core.Neuron;
|
---|
7 | import org.neuroph.core.Weight;
|
---|
8 |
|
---|
9 | /**
|
---|
10 | * LMS learning rule for neural networks. This learning rule is used to train
|
---|
11 | * Adaline neural network, and this class is base for all LMS based learning
|
---|
12 | * rules like PerceptronLearning, DeltaRule, SigmoidDeltaRule, Backpropagation
|
---|
13 | * etc.
|
---|
14 | *
|
---|
15 | * @author Zoran Sevarac <sevarac@gmail.com>
|
---|
16 | */
|
---|
17 | public class LMS extends SupervisedLearning implements Serializable {
|
---|
18 |
|
---|
19 | /**
|
---|
20 | * The class fingerprint that is set to indicate serialization
|
---|
21 | * compatibility with a previous version of the class.
|
---|
22 | */
|
---|
23 | private static final long serialVersionUID = 2L;
|
---|
24 |
|
---|
25 |
|
---|
26 | /**
|
---|
27 | * Creates a new LMS learning rule
|
---|
28 | */
|
---|
29 | public LMS() {
|
---|
30 |
|
---|
31 | }
|
---|
32 |
|
---|
33 |
|
---|
34 | /**
|
---|
35 | * This method implements the weights update procedure for the whole network
|
---|
36 | * for the given output error vector.
|
---|
37 | *
|
---|
38 | * @param outputError
|
---|
39 | * output error vector for some network input- the difference between desired and actual output
|
---|
40 | * @see SupervisedLearning#learnPattern(org.neuroph.core.data.DataSetRow) learnPattern
|
---|
41 | */
|
---|
42 | @Override
|
---|
43 | protected void calculateWeightChanges(final double[] outputError) {
|
---|
44 | int i = 0;
|
---|
45 | // for each neuron in output layer
|
---|
46 | List<Neuron> outputNeurons = neuralNetwork.getOutputNeurons();
|
---|
47 | for (Neuron neuron : outputNeurons) {
|
---|
48 | neuron.setDelta(outputError[i]); // set the neuron error, as difference between desired and actual output
|
---|
49 | calculateWeightChanges(neuron); // and update neuron weights -- this should be renamed to calculate weight changes
|
---|
50 | i++;
|
---|
51 | }
|
---|
52 |
|
---|
53 | // outputNeurons.forEach( neuron -> updateNeuronWeights(neuron));
|
---|
54 |
|
---|
55 | }
|
---|
56 |
|
---|
57 | /**
|
---|
58 | * This method calculates weights changes for the single neuron.
|
---|
59 | * It iterates through all neuron's input connections, and calculates/set weight change for each weight
|
---|
60 | * using formula
|
---|
61 | * deltaWeight = -learningRate * delta * input
|
---|
62 | *
|
---|
63 | * where delta is a neuron error, a difference between desired/target and actual output for specific neuron
|
---|
64 | * neuronError = desiredOutput[i] - actualOutput[i] (see method SuprevisedLearning.calculateOutputError)
|
---|
65 | *
|
---|
66 | * @param neuron
|
---|
67 | * neuron to update weights
|
---|
68 | *
|
---|
69 | * @see LMS#calculateWeightChanges(double[])
|
---|
70 | */
|
---|
71 | protected void calculateWeightChanges(Neuron neuron) {
|
---|
72 | // get the error(delta) for specified neuron,
|
---|
73 | double delta = neuron.getDelta();
|
---|
74 |
|
---|
75 | // tanh can be used to minimise the impact of big error values, which can cause network instability
|
---|
76 | // suggested at https://sourceforge.net/tracker/?func=detail&atid=1107579&aid=3130561&group_id=238532
|
---|
77 | // double neuronError = Math.tanh(neuron.getError());
|
---|
78 |
|
---|
79 | // iterate through all neuron's input connections
|
---|
80 | for (Connection connection : neuron.getInputConnections()) {
|
---|
81 | // get the input from current connection
|
---|
82 | final double input = connection.getInput();
|
---|
83 | // calculate the weight change
|
---|
84 | final double weightChange = -learningRate * delta * input;
|
---|
85 |
|
---|
86 | // get the connection weight
|
---|
87 | final Weight weight = connection.getWeight();
|
---|
88 | // if the learning is in online mode (not batch) apply the weight change immediately
|
---|
89 | if (!this.isBatchMode()) {
|
---|
90 | weight.weightChange = weightChange;
|
---|
91 | } else { // otherwise if its in batch mode, accumulate weight changes and apply them after the current epoch (see SupervisedLearning.doLearningEpoch method)
|
---|
92 | weight.weightChange += weightChange;
|
---|
93 | }
|
---|
94 | }
|
---|
95 | }
|
---|
96 |
|
---|
97 | } |
---|