1 | package geniusweb.blingbling.Ranknet;
|
---|
2 |
|
---|
3 | import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
---|
4 | import org.nd4j.linalg.api.ndarray.INDArray;
|
---|
5 | import org.nd4j.linalg.lossfunctions.ILossFunction;
|
---|
6 | import org.nd4j.linalg.lossfunctions.impl.LossMSE;
|
---|
7 | import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
|
---|
8 | import org.nd4j.linalg.ops.transforms.Transforms;
|
---|
9 |
|
---|
10 | import java.util.LinkedList;
|
---|
11 | import java.util.List;
|
---|
12 | import java.util.Stack;
|
---|
13 |
|
---|
14 | public class NeuralRankNet {
|
---|
15 | protected List<Layer> layers;
|
---|
16 | protected double learningRate;
|
---|
17 |
|
---|
18 | public static RanknetBuilder Builder() {
|
---|
19 | return new RanknetBuilder();
|
---|
20 | }
|
---|
21 |
|
---|
22 | protected void setLayers(List<Layer> layers) {
|
---|
23 | this.layers = layers;
|
---|
24 | }
|
---|
25 |
|
---|
26 | protected void setLearningRate(double learningRate) {
|
---|
27 | this.learningRate = learningRate;
|
---|
28 | }
|
---|
29 |
|
---|
30 | public List<INDArray> feedForward(INDArray input) {
|
---|
31 | //input the
|
---|
32 | List<INDArray> activations = new LinkedList<INDArray>();
|
---|
33 | activations.add(input);
|
---|
34 |
|
---|
35 | INDArray layerInput = input;
|
---|
36 | for (int i = 0; i < layers.size(); i++) {
|
---|
37 | Layer layer = layers.get(i);
|
---|
38 | layer.activate(layerInput);
|
---|
39 | INDArray output = layer.getActivation();
|
---|
40 |
|
---|
41 | activations.add(output);
|
---|
42 | layerInput = output;
|
---|
43 | }
|
---|
44 |
|
---|
45 | return activations;
|
---|
46 | }
|
---|
47 |
|
---|
48 | public void train(INDArray inputI, INDArray inputJ, INDArray expected) {
|
---|
49 | Layer outputLayer = layers.get(layers.size() - 1);
|
---|
50 |
|
---|
51 | Pair<List<INDArray>, List<INDArray>> pairI = getActivationZPair(inputI);
|
---|
52 | Pair<List<INDArray>, List<INDArray>> pairJ = getActivationZPair(inputJ);
|
---|
53 |
|
---|
54 | List<INDArray> activationsI = pairI.fst;
|
---|
55 | List<INDArray> zI = pairI.snd;
|
---|
56 |
|
---|
57 | List<INDArray> activationsJ = pairJ.fst;
|
---|
58 | List<INDArray> zJ = pairJ.snd;
|
---|
59 |
|
---|
60 | INDArray oI = activationsI.get(activationsI.size() - 1);
|
---|
61 | INDArray oJ = activationsJ.get(activationsJ.size() - 1);
|
---|
62 | INDArray oIJ = oI.sub(oJ);
|
---|
63 | INDArray pIJ = getProbability(oIJ);
|
---|
64 | // System.out.println(pIJ);
|
---|
65 |
|
---|
66 | INDArray deltaC = getOutputCostDerivative(pIJ, expected);
|
---|
67 | INDArray outputErrorGradientI = deltaC.mul(outputLayer.calculateActivationDerivative(oI));
|
---|
68 | INDArray outputErrorGradientJ = deltaC.mul(outputLayer.calculateActivationDerivative(oJ));
|
---|
69 |
|
---|
70 | List<INDArray> errorGradientsI = backpropagateError(outputErrorGradientI, activationsI);
|
---|
71 | List<INDArray> errorGradientsJ = backpropagateError(outputErrorGradientJ, activationsJ);
|
---|
72 |
|
---|
73 | updateParams(
|
---|
74 | new Pair<List<INDArray>, List<INDArray>>(errorGradientsI, errorGradientsJ),
|
---|
75 | new Pair<List<INDArray>, List<INDArray>>(activationsI, activationsJ));
|
---|
76 | }
|
---|
77 |
|
---|
78 | private List<INDArray> backpropagateError(INDArray outputErrorGradient, List<INDArray> activations) {
|
---|
79 | Stack<INDArray> reverseGradients = new Stack<INDArray>();
|
---|
80 | INDArray errorGradient = outputErrorGradient;
|
---|
81 | reverseGradients.add(errorGradient);
|
---|
82 |
|
---|
83 | for (int i = layers.size() - 2; i >= 0; i--) {
|
---|
84 | Layer layer = layers.get(i);
|
---|
85 | Layer forwardLayer = layers.get(i + 1);
|
---|
86 | INDArray activationDerivative = layer.calculateActivationDerivative(activations.get(i + 1));
|
---|
87 | errorGradient = forwardLayer.getErrorGradient(errorGradient, activationDerivative);
|
---|
88 | reverseGradients.add(errorGradient);
|
---|
89 | }
|
---|
90 |
|
---|
91 | List<INDArray> errorGradients = new LinkedList<INDArray>();
|
---|
92 | while (!reverseGradients.isEmpty()) {
|
---|
93 | errorGradients.add(reverseGradients.pop());
|
---|
94 | }
|
---|
95 |
|
---|
96 | return errorGradients;
|
---|
97 | }
|
---|
98 |
|
---|
99 | private void updateParams(Pair<List<INDArray>, List<INDArray>> errorGradients, Pair<List<INDArray>, List<INDArray>> activations) {
|
---|
100 | List<INDArray> errorGradientsI = errorGradients.fst;
|
---|
101 | List<INDArray> errorGradientsJ = errorGradients.snd;
|
---|
102 |
|
---|
103 | List<INDArray> activationsI = activations.fst;
|
---|
104 | List<INDArray> activationsJ = activations.snd;
|
---|
105 |
|
---|
106 | for (int i = layers.size() - 1; i >= 0; i--) {
|
---|
107 | Layer layer = layers.get(i);
|
---|
108 | INDArray errorGradientI = errorGradientsI.get(i);
|
---|
109 | INDArray errorGradientJ = errorGradientsJ.get(i);
|
---|
110 |
|
---|
111 | INDArray activationI = activationsI.get(i);
|
---|
112 | INDArray activationJ = activationsJ.get(i);
|
---|
113 |
|
---|
114 | INDArray weightDeltaI = errorGradientI.transpose().mmul(activationI).mul(learningRate);
|
---|
115 | INDArray weightDeltaJ = errorGradientJ.transpose().mmul(activationJ).mul(learningRate);
|
---|
116 |
|
---|
117 | layer.updateWeights(weightDeltaI.sub(weightDeltaJ));
|
---|
118 | layer.updateBiases(errorGradientI.sub(errorGradientJ));
|
---|
119 | }
|
---|
120 | }
|
---|
121 |
|
---|
122 | private static INDArray getProbability(INDArray data) {
|
---|
123 | return Transforms.sigmoid(data);
|
---|
124 | }
|
---|
125 |
|
---|
126 | private static INDArray getOutputCostDerivative(INDArray output, INDArray expected) {
|
---|
127 | // Naive MSE.
|
---|
128 | // return expected.sub(output);
|
---|
129 | ILossFunction lf = new LossBinaryXENT();
|
---|
130 | // XE
|
---|
131 | // ILossFunction lf = new LossMSE();
|
---|
132 | return lf.computeGradient(expected, output, new ActivationIdentity(), null);
|
---|
133 | }
|
---|
134 |
|
---|
135 | public Pair<List<INDArray>, List<INDArray>> getActivationZPair(INDArray input) {
|
---|
136 | List<INDArray> activations = feedForward(input);
|
---|
137 | List<INDArray> zs = new LinkedList<INDArray>();
|
---|
138 | zs.add(input);
|
---|
139 | for (Layer layer : layers) {
|
---|
140 | zs.add(layer.getZ());
|
---|
141 | }
|
---|
142 |
|
---|
143 | return new Pair<List<INDArray>, List<INDArray>>(activations, zs);
|
---|
144 | }
|
---|
145 | } |
---|