Line | |
---|
1 | package geniusweb.blingbling.Ranknet;
|
---|
2 |
|
---|
3 |
|
---|
4 | import org.nd4j.linalg.api.ndarray.INDArray;
|
---|
5 | import org.nd4j.linalg.factory.Nd4j;
|
---|
6 |
|
---|
7 | public class LayerBuilder {
|
---|
8 | private int inCount;
|
---|
9 | private int outCount;
|
---|
10 | private IWeightInit weightInit;
|
---|
11 | private IActivationFunction activationFunction = IdentityActivationFunction.INSTANCE;
|
---|
12 |
|
---|
13 | public LayerBuilder setInCount(int count) {
|
---|
14 | this.inCount = count;
|
---|
15 |
|
---|
16 | return this;
|
---|
17 | }
|
---|
18 |
|
---|
19 | public LayerBuilder setOutCount(int count) {
|
---|
20 | this.outCount = count;
|
---|
21 |
|
---|
22 | return this;
|
---|
23 | }
|
---|
24 |
|
---|
25 | public LayerBuilder setWeightInit(IWeightInit weightInit) {
|
---|
26 | this.weightInit = weightInit;
|
---|
27 |
|
---|
28 | return this;
|
---|
29 | }
|
---|
30 |
|
---|
31 | public LayerBuilder setActivationFunction(IActivationFunction activationFunction) {
|
---|
32 | this.activationFunction = activationFunction;
|
---|
33 |
|
---|
34 | return this;
|
---|
35 | }
|
---|
36 |
|
---|
37 | public Layer build() {
|
---|
38 | INDArray weights = Nd4j.rand(outCount, inCount).div(inCount);
|
---|
39 | // INDArray weights = Nd4j.rand(outCount, inCount);
|
---|
40 | // INDArray biases = Nd4j.create(1, outCount);
|
---|
41 | INDArray biases = Nd4j.zeros(1, outCount).div(inCount);
|
---|
42 |
|
---|
43 | Layer layer = new Layer(weights, biases, activationFunction);
|
---|
44 |
|
---|
45 | return layer;
|
---|
46 | }
|
---|
47 | } |
---|
Note:
See
TracBrowser
for help on using the repository browser.