1 | package geniusweb.blingbling.Ranknet4j;
|
---|
2 |
|
---|
3 | import org.neuroph.core.learning.error.ErrorFunction;
|
---|
4 | import org.neuroph.core.transfer.Sigmoid;
|
---|
5 |
|
---|
6 | import java.io.Serializable;
|
---|
7 |
|
---|
8 | /**
|
---|
9 | * Special error function which is recommended to be used in classification models
|
---|
10 | */
|
---|
11 | public class CrossEntropyError implements ErrorFunction, Serializable {
|
---|
12 |
|
---|
13 | // private double[] errorDerivative;
|
---|
14 | private transient double totalError;
|
---|
15 | private transient double n;
|
---|
16 | Sigmoid sigfunc = new Sigmoid();
|
---|
17 |
|
---|
18 | @Override
|
---|
19 | public double getTotalError() {
|
---|
20 | return -totalError / n ;
|
---|
21 | }
|
---|
22 |
|
---|
23 | @Override
|
---|
24 | public void reset() {
|
---|
25 | totalError = 0;
|
---|
26 | n = 0;
|
---|
27 | }
|
---|
28 |
|
---|
29 | @Override
|
---|
30 | public double[] addPatternError(double[] predictedOutput, double[] targetOutput) {
|
---|
31 | double[] errorDerivative = new double[targetOutput.length];
|
---|
32 |
|
---|
33 | if (predictedOutput.length != targetOutput.length)
|
---|
34 | throw new IllegalArgumentException("Output array length and desired output array length must be the same size!");
|
---|
35 |
|
---|
36 | for (int i = 0; i < predictedOutput.length; i++) {
|
---|
37 | Double sigm = sigfunc.getOutput(predictedOutput[i]);
|
---|
38 | errorDerivative[i] = -1.0/sigm * sigm * (1-sigm);
|
---|
39 | totalError += targetOutput[i] * Math.log(sigm);
|
---|
40 |
|
---|
41 | }
|
---|
42 | n++;
|
---|
43 |
|
---|
44 | return errorDerivative;
|
---|
45 | }
|
---|
46 |
|
---|
47 | public double[] addPatternError2(double[] predictedOutput, double[] targetOutput) {
|
---|
48 | double[] errorDerivative = new double[targetOutput.length];
|
---|
49 |
|
---|
50 | if (predictedOutput.length != targetOutput.length)
|
---|
51 | throw new IllegalArgumentException("Output array length and desired output array length must be the same size!");
|
---|
52 |
|
---|
53 | for (int i = 0; i < predictedOutput.length; i++) {
|
---|
54 | Double sigm = sigfunc.getOutput(predictedOutput[i]);
|
---|
55 | errorDerivative[i] = 1.0/sigm * sigm * (1-sigm);
|
---|
56 | // totalError += targetOutput[i] * Math.log(sigm);
|
---|
57 |
|
---|
58 | }
|
---|
59 | // n++;
|
---|
60 |
|
---|
61 | return errorDerivative;
|
---|
62 | }
|
---|
63 |
|
---|
64 | } |
---|