package geniusweb.blingbling.Ranknet; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; public class LayerBuilder { private int inCount; private int outCount; private IWeightInit weightInit; private IActivationFunction activationFunction = IdentityActivationFunction.INSTANCE; public LayerBuilder setInCount(int count) { this.inCount = count; return this; } public LayerBuilder setOutCount(int count) { this.outCount = count; return this; } public LayerBuilder setWeightInit(IWeightInit weightInit) { this.weightInit = weightInit; return this; } public LayerBuilder setActivationFunction(IActivationFunction activationFunction) { this.activationFunction = activationFunction; return this; } public Layer build() { INDArray weights = Nd4j.rand(outCount, inCount).div(inCount); // INDArray weights = Nd4j.rand(outCount, inCount); // INDArray biases = Nd4j.create(1, outCount); INDArray biases = Nd4j.zeros(1, outCount).div(inCount); Layer layer = new Layer(weights, biases, activationFunction); return layer; } }