source: anac2020/BlingBling/src/main/java/geniusweb/blingbling/Ranknet4j/BackPropagation.java@ 21

Last change on this file since 21 was 1, checked in by wouter, 4 years ago

#1910 added anac2020 parties

File size: 3.8 KB
Line 
1package geniusweb.blingbling.Ranknet4j;
2
3import java.util.List;
4import org.neuroph.core.Connection;
5import org.neuroph.core.Layer;
6import org.neuroph.core.Neuron;
7import org.neuroph.core.transfer.TransferFunction;
8
9/**
10 * Back Propagation learning rule for Multi Layer Perceptron neural networks.
11 *
12 * @author Zoran Sevarac <sevarac@gmail.com>
13 */
14public class BackPropagation extends LMS {
15
16 /**
17 * The class fingerprint that is set to indicate serialization
18 * compatibility with a previous version of the class.
19 */
20 private static final long serialVersionUID = 1L;
21
22 /**
23 * Creates new instance of BackPropagation learning
24 */
25 public BackPropagation() {
26 super();
27 }
28
29
30 /**
31 * This method implements weight update procedure for the whole network
32 * for the specified output error vector.
33 *
34 * @param outputError output error vector
35 */
36 @Override
37 protected void calculateWeightChanges(double[] outputError) {
38 calculateErrorAndUpdateOutputNeurons(outputError);
39 calculateErrorAndUpdateHiddenNeurons();
40 }
41
42
43 /**
44 * This method implements weights update procedure for the output neurons
45 * Calculates delta/error and calls updateNeuronWeights to update neuron's weights
46 * for each output neuron
47 *
48 * @param outputError error vector for output neurons
49 */
50 protected void calculateErrorAndUpdateOutputNeurons(double[] outputError) {
51 int i = 0;
52
53 // for all output neurons
54 final List<Neuron> outputNeurons = neuralNetwork.getOutputNeurons();
55 for (Neuron neuron : outputNeurons) {
56 // if error is zero, just set zero error and continue to next neuron
57 if (outputError[i] == 0) {
58 neuron.setDelta(0);
59 i++;
60 continue;
61 }
62
63 // otherwise calculate and set error/delta for the current neuron
64 final TransferFunction transferFunction = neuron.getTransferFunction();
65 final double neuronInput = neuron.getNetInput();
66 final double delta = outputError[i] * transferFunction.getDerivative(neuronInput); // delta = (y-d)*df(net)
67 neuron.setDelta(delta);
68
69 // and update weights of the current neuron
70 calculateWeightChanges(neuron);
71 i++;
72 } // for
73 }
74
75 /**
76 * This method implements weights adjustment for the hidden layers
77 */
78 protected void calculateErrorAndUpdateHiddenNeurons() {
79 List<Layer> layers = neuralNetwork.getLayers();
80 for (int layerIdx = layers.size() - 2; layerIdx > 0; layerIdx--) {
81 for (Neuron neuron : layers.get(layerIdx).getNeurons()) {
82 // calculate the neuron's error (delta)
83 final double delta = calculateHiddenNeuronError(neuron);
84 neuron.setDelta(delta);
85 calculateWeightChanges(neuron);
86 } // for
87 } // for
88 }
89
90 /**
91 * Calculates and returns the neuron's error (neuron's delta) for the given neuron param
92 *
93 * @param neuron neuron to calculate error for
94 * @return neuron error (delta) for the specified neuron
95 */
96 protected double calculateHiddenNeuronError(Neuron neuron) {
97 double deltaSum = 0d;
98 for (Connection connection : neuron.getOutConnections()) {
99 double delta = connection.getToNeuron().getDelta() * connection.getWeight().value;
100 deltaSum += delta; // weighted delta sum from the next layer
101 } // for
102
103 TransferFunction transferFunction = neuron.getTransferFunction();
104 double netInput = neuron.getNetInput();
105 double f1 = transferFunction.getDerivative(netInput); // does this use netInput or cached output in order to avoid double caluclation?
106 double delta = f1 * deltaSum;
107 return delta;
108 }
109
110}
Note: See TracBrowser for help on using the repository browser.