source: anac2020/BlingBling/src/main/java/geniusweb/blingbling/Ranknet/Layer.java@ 43

Last change on this file since 43 was 1, checked in by wouter, 4 years ago

#1910 added anac2020 parties

File size: 3.0 KB
Line 
1package geniusweb.blingbling.Ranknet;
2
3//import neuralnet.activationfunction.IActivationFunction;
4import org.nd4j.linalg.api.ndarray.INDArray;
5import org.nd4j.linalg.factory.Nd4j;
6//import org.nd4j.shade.guava.collect.Lists;
7
8import java.util.LinkedList;
9import java.util.List;
10
11public class Layer {
12 private INDArray weights;
13 private INDArray biases;
14 private IActivationFunction activationFunction;
15
16 // Iteration instance variables.
17 private INDArray input;
18 private INDArray z;
19 private INDArray activation;
20 private INDArray activationDerivative;
21
22 public static LayerBuilder Builder() {
23 return new LayerBuilder();
24 }
25
26 protected Layer(INDArray weights, INDArray biases, IActivationFunction activationFunction) {
27 this.weights = weights;
28 this.biases = biases;
29 this.activationFunction = activationFunction;
30 }
31
32 public void activate(INDArray input) {
33 this.input = input;
34 z = calculateZ(input);
35 activation = calculateActivation(z);
36 activationDerivative = calculateActivationDerivative(z);
37 }
38
39 public INDArray getInput() {
40 return input;
41 }
42
43 public INDArray getZ() {
44 return z;
45 }
46
47 public INDArray getActivation() {
48 return activation;
49 }
50
51 public INDArray getActivationDerivative() {
52 return activationDerivative;
53 }
54
55 public INDArray getErrorGradient(INDArray error, INDArray prevActivationDerivative) {
56 return error.mmul(weights).mul(prevActivationDerivative);
57 }
58
59 public void updateWeights(INDArray gradients) {
60 List<INDArray> rows = new LinkedList<INDArray>();
61 int stride = activation.rows() * weights.rows();
62
63 for (int row = 0; row < gradients.rows(); row = row + stride) {
64 int[] extractRows = new int[stride];
65 for (int i = 0; i < stride; i++) {
66 extractRows[i] = (row * stride) + i;
67 }
68 rows.add(gradients.getRows(extractRows));
69 }
70 weights = weights.add(Nd4j.averageAndPropagate(rows));
71 }
72
73 public void updateBiases(INDArray gradients) {
74 List<INDArray> rows = new LinkedList<INDArray>();
75 int stride = activation.rows();
76
77 for (int row = 0; row < gradients.rows(); row = row + stride) {
78 int[] extractRows = new int[stride];
79 for (int i = 0; i < stride; i++) {
80 extractRows[i] = (row * stride) + i;
81 }
82 rows.add(gradients.getRows(extractRows));
83 }
84 biases = biases.add(Nd4j.averageAndPropagate(rows));
85 }
86
87 public INDArray calculateZ(INDArray input) {
88 INDArray biasMatrix = biases;
89 for (int i = 1; i < input.rows(); i++) {
90 biasMatrix = Nd4j.hstack(biasMatrix, biases);
91 }
92 return input.mmul(weights.transpose()).add(biasMatrix);
93 }
94
95 public INDArray calculateActivation(INDArray z) {
96 return activationFunction.output(z);
97 }
98
99 public INDArray calculateActivationDerivative(INDArray z) {
100 return activationFunction.derivative(z);
101 }
102}
Note: See TracBrowser for help on using the repository browser.