source: anac2020/BlingBling/src/main/java/geniusweb/blingbling/Ranknet/NeuralRankNet.java@ 31

Last change on this file since 31 was 1, checked in by wouter, 4 years ago

#1910 added anac2020 parties

File size: 5.4 KB
Line 
1package geniusweb.blingbling.Ranknet;
2
3import org.nd4j.linalg.activations.impl.ActivationIdentity;
4import org.nd4j.linalg.api.ndarray.INDArray;
5import org.nd4j.linalg.lossfunctions.ILossFunction;
6import org.nd4j.linalg.lossfunctions.impl.LossMSE;
7import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
8import org.nd4j.linalg.ops.transforms.Transforms;
9
10import java.util.LinkedList;
11import java.util.List;
12import java.util.Stack;
13
14public class NeuralRankNet {
15 protected List<Layer> layers;
16 protected double learningRate;
17
18 public static RanknetBuilder Builder() {
19 return new RanknetBuilder();
20 }
21
22 protected void setLayers(List<Layer> layers) {
23 this.layers = layers;
24 }
25
26 protected void setLearningRate(double learningRate) {
27 this.learningRate = learningRate;
28 }
29
30 public List<INDArray> feedForward(INDArray input) {
31 //input the
32 List<INDArray> activations = new LinkedList<INDArray>();
33 activations.add(input);
34
35 INDArray layerInput = input;
36 for (int i = 0; i < layers.size(); i++) {
37 Layer layer = layers.get(i);
38 layer.activate(layerInput);
39 INDArray output = layer.getActivation();
40
41 activations.add(output);
42 layerInput = output;
43 }
44
45 return activations;
46 }
47
48 public void train(INDArray inputI, INDArray inputJ, INDArray expected) {
49 Layer outputLayer = layers.get(layers.size() - 1);
50
51 Pair<List<INDArray>, List<INDArray>> pairI = getActivationZPair(inputI);
52 Pair<List<INDArray>, List<INDArray>> pairJ = getActivationZPair(inputJ);
53
54 List<INDArray> activationsI = pairI.fst;
55 List<INDArray> zI = pairI.snd;
56
57 List<INDArray> activationsJ = pairJ.fst;
58 List<INDArray> zJ = pairJ.snd;
59
60 INDArray oI = activationsI.get(activationsI.size() - 1);
61 INDArray oJ = activationsJ.get(activationsJ.size() - 1);
62 INDArray oIJ = oI.sub(oJ);
63 INDArray pIJ = getProbability(oIJ);
64// System.out.println(pIJ);
65
66 INDArray deltaC = getOutputCostDerivative(pIJ, expected);
67 INDArray outputErrorGradientI = deltaC.mul(outputLayer.calculateActivationDerivative(oI));
68 INDArray outputErrorGradientJ = deltaC.mul(outputLayer.calculateActivationDerivative(oJ));
69
70 List<INDArray> errorGradientsI = backpropagateError(outputErrorGradientI, activationsI);
71 List<INDArray> errorGradientsJ = backpropagateError(outputErrorGradientJ, activationsJ);
72
73 updateParams(
74 new Pair<List<INDArray>, List<INDArray>>(errorGradientsI, errorGradientsJ),
75 new Pair<List<INDArray>, List<INDArray>>(activationsI, activationsJ));
76 }
77
78 private List<INDArray> backpropagateError(INDArray outputErrorGradient, List<INDArray> activations) {
79 Stack<INDArray> reverseGradients = new Stack<INDArray>();
80 INDArray errorGradient = outputErrorGradient;
81 reverseGradients.add(errorGradient);
82
83 for (int i = layers.size() - 2; i >= 0; i--) {
84 Layer layer = layers.get(i);
85 Layer forwardLayer = layers.get(i + 1);
86 INDArray activationDerivative = layer.calculateActivationDerivative(activations.get(i + 1));
87 errorGradient = forwardLayer.getErrorGradient(errorGradient, activationDerivative);
88 reverseGradients.add(errorGradient);
89 }
90
91 List<INDArray> errorGradients = new LinkedList<INDArray>();
92 while (!reverseGradients.isEmpty()) {
93 errorGradients.add(reverseGradients.pop());
94 }
95
96 return errorGradients;
97 }
98
99 private void updateParams(Pair<List<INDArray>, List<INDArray>> errorGradients, Pair<List<INDArray>, List<INDArray>> activations) {
100 List<INDArray> errorGradientsI = errorGradients.fst;
101 List<INDArray> errorGradientsJ = errorGradients.snd;
102
103 List<INDArray> activationsI = activations.fst;
104 List<INDArray> activationsJ = activations.snd;
105
106 for (int i = layers.size() - 1; i >= 0; i--) {
107 Layer layer = layers.get(i);
108 INDArray errorGradientI = errorGradientsI.get(i);
109 INDArray errorGradientJ = errorGradientsJ.get(i);
110
111 INDArray activationI = activationsI.get(i);
112 INDArray activationJ = activationsJ.get(i);
113
114 INDArray weightDeltaI = errorGradientI.transpose().mmul(activationI).mul(learningRate);
115 INDArray weightDeltaJ = errorGradientJ.transpose().mmul(activationJ).mul(learningRate);
116
117 layer.updateWeights(weightDeltaI.sub(weightDeltaJ));
118 layer.updateBiases(errorGradientI.sub(errorGradientJ));
119 }
120 }
121
122 private static INDArray getProbability(INDArray data) {
123 return Transforms.sigmoid(data);
124 }
125
126 private static INDArray getOutputCostDerivative(INDArray output, INDArray expected) {
127 // Naive MSE.
128 // return expected.sub(output);
129 ILossFunction lf = new LossBinaryXENT();
130 // XE
131// ILossFunction lf = new LossMSE();
132 return lf.computeGradient(expected, output, new ActivationIdentity(), null);
133 }
134
135 public Pair<List<INDArray>, List<INDArray>> getActivationZPair(INDArray input) {
136 List<INDArray> activations = feedForward(input);
137 List<INDArray> zs = new LinkedList<INDArray>();
138 zs.add(input);
139 for (Layer layer : layers) {
140 zs.add(layer.getZ());
141 }
142
143 return new Pair<List<INDArray>, List<INDArray>>(activations, zs);
144 }
145}
Note: See TracBrowser for help on using the repository browser.