package geniusweb.blingbling.Ranknet; import org.nd4j.linalg.activations.impl.ActivationIdentity; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.lossfunctions.ILossFunction; import org.nd4j.linalg.lossfunctions.impl.LossMSE; import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT; import org.nd4j.linalg.ops.transforms.Transforms; import java.util.LinkedList; import java.util.List; import java.util.Stack; public class NeuralRankNet { protected List layers; protected double learningRate; public static RanknetBuilder Builder() { return new RanknetBuilder(); } protected void setLayers(List layers) { this.layers = layers; } protected void setLearningRate(double learningRate) { this.learningRate = learningRate; } public List feedForward(INDArray input) { //input the List activations = new LinkedList(); activations.add(input); INDArray layerInput = input; for (int i = 0; i < layers.size(); i++) { Layer layer = layers.get(i); layer.activate(layerInput); INDArray output = layer.getActivation(); activations.add(output); layerInput = output; } return activations; } public void train(INDArray inputI, INDArray inputJ, INDArray expected) { Layer outputLayer = layers.get(layers.size() - 1); Pair, List> pairI = getActivationZPair(inputI); Pair, List> pairJ = getActivationZPair(inputJ); List activationsI = pairI.fst; List zI = pairI.snd; List activationsJ = pairJ.fst; List zJ = pairJ.snd; INDArray oI = activationsI.get(activationsI.size() - 1); INDArray oJ = activationsJ.get(activationsJ.size() - 1); INDArray oIJ = oI.sub(oJ); INDArray pIJ = getProbability(oIJ); // System.out.println(pIJ); INDArray deltaC = getOutputCostDerivative(pIJ, expected); INDArray outputErrorGradientI = deltaC.mul(outputLayer.calculateActivationDerivative(oI)); INDArray outputErrorGradientJ = deltaC.mul(outputLayer.calculateActivationDerivative(oJ)); List errorGradientsI = backpropagateError(outputErrorGradientI, activationsI); List errorGradientsJ = backpropagateError(outputErrorGradientJ, activationsJ); updateParams( new Pair, List>(errorGradientsI, errorGradientsJ), new Pair, List>(activationsI, activationsJ)); } private List backpropagateError(INDArray outputErrorGradient, List activations) { Stack reverseGradients = new Stack(); INDArray errorGradient = outputErrorGradient; reverseGradients.add(errorGradient); for (int i = layers.size() - 2; i >= 0; i--) { Layer layer = layers.get(i); Layer forwardLayer = layers.get(i + 1); INDArray activationDerivative = layer.calculateActivationDerivative(activations.get(i + 1)); errorGradient = forwardLayer.getErrorGradient(errorGradient, activationDerivative); reverseGradients.add(errorGradient); } List errorGradients = new LinkedList(); while (!reverseGradients.isEmpty()) { errorGradients.add(reverseGradients.pop()); } return errorGradients; } private void updateParams(Pair, List> errorGradients, Pair, List> activations) { List errorGradientsI = errorGradients.fst; List errorGradientsJ = errorGradients.snd; List activationsI = activations.fst; List activationsJ = activations.snd; for (int i = layers.size() - 1; i >= 0; i--) { Layer layer = layers.get(i); INDArray errorGradientI = errorGradientsI.get(i); INDArray errorGradientJ = errorGradientsJ.get(i); INDArray activationI = activationsI.get(i); INDArray activationJ = activationsJ.get(i); INDArray weightDeltaI = errorGradientI.transpose().mmul(activationI).mul(learningRate); INDArray weightDeltaJ = errorGradientJ.transpose().mmul(activationJ).mul(learningRate); layer.updateWeights(weightDeltaI.sub(weightDeltaJ)); layer.updateBiases(errorGradientI.sub(errorGradientJ)); } } private static INDArray getProbability(INDArray data) { return Transforms.sigmoid(data); } private static INDArray getOutputCostDerivative(INDArray output, INDArray expected) { // Naive MSE. // return expected.sub(output); ILossFunction lf = new LossBinaryXENT(); // XE // ILossFunction lf = new LossMSE(); return lf.computeGradient(expected, output, new ActivationIdentity(), null); } public Pair, List> getActivationZPair(INDArray input) { List activations = feedForward(input); List zs = new LinkedList(); zs.add(input); for (Layer layer : layers) { zs.add(layer.getZ()); } return new Pair, List>(activations, zs); } }