Rev | Line | |
---|
[1] | 1 | package geniusweb.blingbling.Ranknet;
|
---|
| 2 |
|
---|
| 3 |
|
---|
| 4 | import org.nd4j.linalg.api.ndarray.INDArray;
|
---|
| 5 | import org.nd4j.linalg.factory.Nd4j;
|
---|
| 6 |
|
---|
| 7 | public class LayerBuilder {
|
---|
| 8 | private int inCount;
|
---|
| 9 | private int outCount;
|
---|
| 10 | private IWeightInit weightInit;
|
---|
| 11 | private IActivationFunction activationFunction = IdentityActivationFunction.INSTANCE;
|
---|
| 12 |
|
---|
| 13 | public LayerBuilder setInCount(int count) {
|
---|
| 14 | this.inCount = count;
|
---|
| 15 |
|
---|
| 16 | return this;
|
---|
| 17 | }
|
---|
| 18 |
|
---|
| 19 | public LayerBuilder setOutCount(int count) {
|
---|
| 20 | this.outCount = count;
|
---|
| 21 |
|
---|
| 22 | return this;
|
---|
| 23 | }
|
---|
| 24 |
|
---|
| 25 | public LayerBuilder setWeightInit(IWeightInit weightInit) {
|
---|
| 26 | this.weightInit = weightInit;
|
---|
| 27 |
|
---|
| 28 | return this;
|
---|
| 29 | }
|
---|
| 30 |
|
---|
| 31 | public LayerBuilder setActivationFunction(IActivationFunction activationFunction) {
|
---|
| 32 | this.activationFunction = activationFunction;
|
---|
| 33 |
|
---|
| 34 | return this;
|
---|
| 35 | }
|
---|
| 36 |
|
---|
| 37 | public Layer build() {
|
---|
| 38 | INDArray weights = Nd4j.rand(outCount, inCount).div(inCount);
|
---|
| 39 | // INDArray weights = Nd4j.rand(outCount, inCount);
|
---|
| 40 | // INDArray biases = Nd4j.create(1, outCount);
|
---|
| 41 | INDArray biases = Nd4j.zeros(1, outCount).div(inCount);
|
---|
| 42 |
|
---|
| 43 | Layer layer = new Layer(weights, biases, activationFunction);
|
---|
| 44 |
|
---|
| 45 | return layer;
|
---|
| 46 | }
|
---|
| 47 | } |
---|
Note:
See
TracBrowser
for help on using the repository browser.