[1] | 1 | package geniusweb.blingbling.Ranknet;
|
---|
| 2 |
|
---|
| 3 | import org.nd4j.linalg.activations.impl.ActivationIdentity;
|
---|
| 4 | import org.nd4j.linalg.api.ndarray.INDArray;
|
---|
| 5 | import org.nd4j.linalg.lossfunctions.ILossFunction;
|
---|
| 6 | import org.nd4j.linalg.lossfunctions.impl.LossMSE;
|
---|
| 7 | import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
|
---|
| 8 | import org.nd4j.linalg.ops.transforms.Transforms;
|
---|
| 9 |
|
---|
| 10 | import java.util.LinkedList;
|
---|
| 11 | import java.util.List;
|
---|
| 12 | import java.util.Stack;
|
---|
| 13 |
|
---|
| 14 | public class NeuralRankNet {
|
---|
| 15 | protected List<Layer> layers;
|
---|
| 16 | protected double learningRate;
|
---|
| 17 |
|
---|
| 18 | public static RanknetBuilder Builder() {
|
---|
| 19 | return new RanknetBuilder();
|
---|
| 20 | }
|
---|
| 21 |
|
---|
| 22 | protected void setLayers(List<Layer> layers) {
|
---|
| 23 | this.layers = layers;
|
---|
| 24 | }
|
---|
| 25 |
|
---|
| 26 | protected void setLearningRate(double learningRate) {
|
---|
| 27 | this.learningRate = learningRate;
|
---|
| 28 | }
|
---|
| 29 |
|
---|
| 30 | public List<INDArray> feedForward(INDArray input) {
|
---|
| 31 | //input the
|
---|
| 32 | List<INDArray> activations = new LinkedList<INDArray>();
|
---|
| 33 | activations.add(input);
|
---|
| 34 |
|
---|
| 35 | INDArray layerInput = input;
|
---|
| 36 | for (int i = 0; i < layers.size(); i++) {
|
---|
| 37 | Layer layer = layers.get(i);
|
---|
| 38 | layer.activate(layerInput);
|
---|
| 39 | INDArray output = layer.getActivation();
|
---|
| 40 |
|
---|
| 41 | activations.add(output);
|
---|
| 42 | layerInput = output;
|
---|
| 43 | }
|
---|
| 44 |
|
---|
| 45 | return activations;
|
---|
| 46 | }
|
---|
| 47 |
|
---|
| 48 | public void train(INDArray inputI, INDArray inputJ, INDArray expected) {
|
---|
| 49 | Layer outputLayer = layers.get(layers.size() - 1);
|
---|
| 50 |
|
---|
| 51 | Pair<List<INDArray>, List<INDArray>> pairI = getActivationZPair(inputI);
|
---|
| 52 | Pair<List<INDArray>, List<INDArray>> pairJ = getActivationZPair(inputJ);
|
---|
| 53 |
|
---|
| 54 | List<INDArray> activationsI = pairI.fst;
|
---|
| 55 | List<INDArray> zI = pairI.snd;
|
---|
| 56 |
|
---|
| 57 | List<INDArray> activationsJ = pairJ.fst;
|
---|
| 58 | List<INDArray> zJ = pairJ.snd;
|
---|
| 59 |
|
---|
| 60 | INDArray oI = activationsI.get(activationsI.size() - 1);
|
---|
| 61 | INDArray oJ = activationsJ.get(activationsJ.size() - 1);
|
---|
| 62 | INDArray oIJ = oI.sub(oJ);
|
---|
| 63 | INDArray pIJ = getProbability(oIJ);
|
---|
| 64 | // System.out.println(pIJ);
|
---|
| 65 |
|
---|
| 66 | INDArray deltaC = getOutputCostDerivative(pIJ, expected);
|
---|
| 67 | INDArray outputErrorGradientI = deltaC.mul(outputLayer.calculateActivationDerivative(oI));
|
---|
| 68 | INDArray outputErrorGradientJ = deltaC.mul(outputLayer.calculateActivationDerivative(oJ));
|
---|
| 69 |
|
---|
| 70 | List<INDArray> errorGradientsI = backpropagateError(outputErrorGradientI, activationsI);
|
---|
| 71 | List<INDArray> errorGradientsJ = backpropagateError(outputErrorGradientJ, activationsJ);
|
---|
| 72 |
|
---|
| 73 | updateParams(
|
---|
| 74 | new Pair<List<INDArray>, List<INDArray>>(errorGradientsI, errorGradientsJ),
|
---|
| 75 | new Pair<List<INDArray>, List<INDArray>>(activationsI, activationsJ));
|
---|
| 76 | }
|
---|
| 77 |
|
---|
| 78 | private List<INDArray> backpropagateError(INDArray outputErrorGradient, List<INDArray> activations) {
|
---|
| 79 | Stack<INDArray> reverseGradients = new Stack<INDArray>();
|
---|
| 80 | INDArray errorGradient = outputErrorGradient;
|
---|
| 81 | reverseGradients.add(errorGradient);
|
---|
| 82 |
|
---|
| 83 | for (int i = layers.size() - 2; i >= 0; i--) {
|
---|
| 84 | Layer layer = layers.get(i);
|
---|
| 85 | Layer forwardLayer = layers.get(i + 1);
|
---|
| 86 | INDArray activationDerivative = layer.calculateActivationDerivative(activations.get(i + 1));
|
---|
| 87 | errorGradient = forwardLayer.getErrorGradient(errorGradient, activationDerivative);
|
---|
| 88 | reverseGradients.add(errorGradient);
|
---|
| 89 | }
|
---|
| 90 |
|
---|
| 91 | List<INDArray> errorGradients = new LinkedList<INDArray>();
|
---|
| 92 | while (!reverseGradients.isEmpty()) {
|
---|
| 93 | errorGradients.add(reverseGradients.pop());
|
---|
| 94 | }
|
---|
| 95 |
|
---|
| 96 | return errorGradients;
|
---|
| 97 | }
|
---|
| 98 |
|
---|
| 99 | private void updateParams(Pair<List<INDArray>, List<INDArray>> errorGradients, Pair<List<INDArray>, List<INDArray>> activations) {
|
---|
| 100 | List<INDArray> errorGradientsI = errorGradients.fst;
|
---|
| 101 | List<INDArray> errorGradientsJ = errorGradients.snd;
|
---|
| 102 |
|
---|
| 103 | List<INDArray> activationsI = activations.fst;
|
---|
| 104 | List<INDArray> activationsJ = activations.snd;
|
---|
| 105 |
|
---|
| 106 | for (int i = layers.size() - 1; i >= 0; i--) {
|
---|
| 107 | Layer layer = layers.get(i);
|
---|
| 108 | INDArray errorGradientI = errorGradientsI.get(i);
|
---|
| 109 | INDArray errorGradientJ = errorGradientsJ.get(i);
|
---|
| 110 |
|
---|
| 111 | INDArray activationI = activationsI.get(i);
|
---|
| 112 | INDArray activationJ = activationsJ.get(i);
|
---|
| 113 |
|
---|
| 114 | INDArray weightDeltaI = errorGradientI.transpose().mmul(activationI).mul(learningRate);
|
---|
| 115 | INDArray weightDeltaJ = errorGradientJ.transpose().mmul(activationJ).mul(learningRate);
|
---|
| 116 |
|
---|
| 117 | layer.updateWeights(weightDeltaI.sub(weightDeltaJ));
|
---|
| 118 | layer.updateBiases(errorGradientI.sub(errorGradientJ));
|
---|
| 119 | }
|
---|
| 120 | }
|
---|
| 121 |
|
---|
| 122 | private static INDArray getProbability(INDArray data) {
|
---|
| 123 | return Transforms.sigmoid(data);
|
---|
| 124 | }
|
---|
| 125 |
|
---|
| 126 | private static INDArray getOutputCostDerivative(INDArray output, INDArray expected) {
|
---|
| 127 | // Naive MSE.
|
---|
| 128 | // return expected.sub(output);
|
---|
| 129 | ILossFunction lf = new LossBinaryXENT();
|
---|
| 130 | // XE
|
---|
| 131 | // ILossFunction lf = new LossMSE();
|
---|
| 132 | return lf.computeGradient(expected, output, new ActivationIdentity(), null);
|
---|
| 133 | }
|
---|
| 134 |
|
---|
| 135 | public Pair<List<INDArray>, List<INDArray>> getActivationZPair(INDArray input) {
|
---|
| 136 | List<INDArray> activations = feedForward(input);
|
---|
| 137 | List<INDArray> zs = new LinkedList<INDArray>();
|
---|
| 138 | zs.add(input);
|
---|
| 139 | for (Layer layer : layers) {
|
---|
| 140 | zs.add(layer.getZ());
|
---|
| 141 | }
|
---|
| 142 |
|
---|
| 143 | return new Pair<List<INDArray>, List<INDArray>>(activations, zs);
|
---|
| 144 | }
|
---|
| 145 | } |
---|