source: anac2020/BlingBling/src/main/java/geniusweb/blingbling/Ranknet/LayerBuilder.java

Last change on this file was 1, checked in by wouter, 4 years ago

#1910 added anac2020 parties

File size: 1.2 KB
Line 
1package geniusweb.blingbling.Ranknet;
2
3
4import org.nd4j.linalg.api.ndarray.INDArray;
5import org.nd4j.linalg.factory.Nd4j;
6
7public class LayerBuilder {
8 private int inCount;
9 private int outCount;
10 private IWeightInit weightInit;
11 private IActivationFunction activationFunction = IdentityActivationFunction.INSTANCE;
12
13 public LayerBuilder setInCount(int count) {
14 this.inCount = count;
15
16 return this;
17 }
18
19 public LayerBuilder setOutCount(int count) {
20 this.outCount = count;
21
22 return this;
23 }
24
25 public LayerBuilder setWeightInit(IWeightInit weightInit) {
26 this.weightInit = weightInit;
27
28 return this;
29 }
30
31 public LayerBuilder setActivationFunction(IActivationFunction activationFunction) {
32 this.activationFunction = activationFunction;
33
34 return this;
35 }
36
37 public Layer build() {
38 INDArray weights = Nd4j.rand(outCount, inCount).div(inCount);
39// INDArray weights = Nd4j.rand(outCount, inCount);
40// INDArray biases = Nd4j.create(1, outCount);
41 INDArray biases = Nd4j.zeros(1, outCount).div(inCount);
42
43 Layer layer = new Layer(weights, biases, activationFunction);
44
45 return layer;
46 }
47}
Note: See TracBrowser for help on using the repository browser.