1 | package geniusweb.blingbling.Ranknet4j;
|
---|
2 |
|
---|
3 | import java.util.List;
|
---|
4 | import org.neuroph.core.Connection;
|
---|
5 | import org.neuroph.core.Layer;
|
---|
6 | import org.neuroph.core.Neuron;
|
---|
7 | import org.neuroph.core.transfer.TransferFunction;
|
---|
8 |
|
---|
9 | /**
|
---|
10 | * Back Propagation learning rule for Multi Layer Perceptron neural networks.
|
---|
11 | *
|
---|
12 | * @author Zoran Sevarac <sevarac@gmail.com>
|
---|
13 | */
|
---|
14 | public class BackPropagation extends LMS {
|
---|
15 |
|
---|
16 | /**
|
---|
17 | * The class fingerprint that is set to indicate serialization
|
---|
18 | * compatibility with a previous version of the class.
|
---|
19 | */
|
---|
20 | private static final long serialVersionUID = 1L;
|
---|
21 |
|
---|
22 | /**
|
---|
23 | * Creates new instance of BackPropagation learning
|
---|
24 | */
|
---|
25 | public BackPropagation() {
|
---|
26 | super();
|
---|
27 | }
|
---|
28 |
|
---|
29 |
|
---|
30 | /**
|
---|
31 | * This method implements weight update procedure for the whole network
|
---|
32 | * for the specified output error vector.
|
---|
33 | *
|
---|
34 | * @param outputError output error vector
|
---|
35 | */
|
---|
36 | @Override
|
---|
37 | protected void calculateWeightChanges(double[] outputError) {
|
---|
38 | calculateErrorAndUpdateOutputNeurons(outputError);
|
---|
39 | calculateErrorAndUpdateHiddenNeurons();
|
---|
40 | }
|
---|
41 |
|
---|
42 |
|
---|
43 | /**
|
---|
44 | * This method implements weights update procedure for the output neurons
|
---|
45 | * Calculates delta/error and calls updateNeuronWeights to update neuron's weights
|
---|
46 | * for each output neuron
|
---|
47 | *
|
---|
48 | * @param outputError error vector for output neurons
|
---|
49 | */
|
---|
50 | protected void calculateErrorAndUpdateOutputNeurons(double[] outputError) {
|
---|
51 | int i = 0;
|
---|
52 |
|
---|
53 | // for all output neurons
|
---|
54 | final List<Neuron> outputNeurons = neuralNetwork.getOutputNeurons();
|
---|
55 | for (Neuron neuron : outputNeurons) {
|
---|
56 | // if error is zero, just set zero error and continue to next neuron
|
---|
57 | if (outputError[i] == 0) {
|
---|
58 | neuron.setDelta(0);
|
---|
59 | i++;
|
---|
60 | continue;
|
---|
61 | }
|
---|
62 |
|
---|
63 | // otherwise calculate and set error/delta for the current neuron
|
---|
64 | final TransferFunction transferFunction = neuron.getTransferFunction();
|
---|
65 | final double neuronInput = neuron.getNetInput();
|
---|
66 | final double delta = outputError[i] * transferFunction.getDerivative(neuronInput); // delta = (y-d)*df(net)
|
---|
67 | neuron.setDelta(delta);
|
---|
68 |
|
---|
69 | // and update weights of the current neuron
|
---|
70 | calculateWeightChanges(neuron);
|
---|
71 | i++;
|
---|
72 | } // for
|
---|
73 | }
|
---|
74 |
|
---|
75 | /**
|
---|
76 | * This method implements weights adjustment for the hidden layers
|
---|
77 | */
|
---|
78 | protected void calculateErrorAndUpdateHiddenNeurons() {
|
---|
79 | List<Layer> layers = neuralNetwork.getLayers();
|
---|
80 | for (int layerIdx = layers.size() - 2; layerIdx > 0; layerIdx--) {
|
---|
81 | for (Neuron neuron : layers.get(layerIdx).getNeurons()) {
|
---|
82 | // calculate the neuron's error (delta)
|
---|
83 | final double delta = calculateHiddenNeuronError(neuron);
|
---|
84 | neuron.setDelta(delta);
|
---|
85 | calculateWeightChanges(neuron);
|
---|
86 | } // for
|
---|
87 | } // for
|
---|
88 | }
|
---|
89 |
|
---|
90 | /**
|
---|
91 | * Calculates and returns the neuron's error (neuron's delta) for the given neuron param
|
---|
92 | *
|
---|
93 | * @param neuron neuron to calculate error for
|
---|
94 | * @return neuron error (delta) for the specified neuron
|
---|
95 | */
|
---|
96 | protected double calculateHiddenNeuronError(Neuron neuron) {
|
---|
97 | double deltaSum = 0d;
|
---|
98 | for (Connection connection : neuron.getOutConnections()) {
|
---|
99 | double delta = connection.getToNeuron().getDelta() * connection.getWeight().value;
|
---|
100 | deltaSum += delta; // weighted delta sum from the next layer
|
---|
101 | } // for
|
---|
102 |
|
---|
103 | TransferFunction transferFunction = neuron.getTransferFunction();
|
---|
104 | double netInput = neuron.getNetInput();
|
---|
105 | double f1 = transferFunction.getDerivative(netInput); // does this use netInput or cached output in order to avoid double caluclation?
|
---|
106 | double delta = f1 * deltaSum;
|
---|
107 | return delta;
|
---|
108 | }
|
---|
109 |
|
---|
110 | } |
---|