source: anac2020/BlingBling/src/main/java/geniusweb/blingbling/Ranknet4j/CrossEntropyError.java

Last change on this file was 1, checked in by wouter, 4 years ago

#1910 added anac2020 parties

File size: 2.0 KB
Line 
1package geniusweb.blingbling.Ranknet4j;
2
3import org.neuroph.core.learning.error.ErrorFunction;
4import org.neuroph.core.transfer.Sigmoid;
5
6import java.io.Serializable;
7
8/**
9 * Special error function which is recommended to be used in classification models
10 */
11public class CrossEntropyError implements ErrorFunction, Serializable {
12
13// private double[] errorDerivative;
14 private transient double totalError;
15 private transient double n;
16 Sigmoid sigfunc = new Sigmoid();
17
18 @Override
19 public double getTotalError() {
20 return -totalError / n ;
21 }
22
23 @Override
24 public void reset() {
25 totalError = 0;
26 n = 0;
27 }
28
29 @Override
30 public double[] addPatternError(double[] predictedOutput, double[] targetOutput) {
31 double[] errorDerivative = new double[targetOutput.length];
32
33 if (predictedOutput.length != targetOutput.length)
34 throw new IllegalArgumentException("Output array length and desired output array length must be the same size!");
35
36 for (int i = 0; i < predictedOutput.length; i++) {
37 Double sigm = sigfunc.getOutput(predictedOutput[i]);
38 errorDerivative[i] = -1.0/sigm * sigm * (1-sigm);
39 totalError += targetOutput[i] * Math.log(sigm);
40
41 }
42 n++;
43
44 return errorDerivative;
45 }
46
47 public double[] addPatternError2(double[] predictedOutput, double[] targetOutput) {
48 double[] errorDerivative = new double[targetOutput.length];
49
50 if (predictedOutput.length != targetOutput.length)
51 throw new IllegalArgumentException("Output array length and desired output array length must be the same size!");
52
53 for (int i = 0; i < predictedOutput.length; i++) {
54 Double sigm = sigfunc.getOutput(predictedOutput[i]);
55 errorDerivative[i] = 1.0/sigm * sigm * (1-sigm);
56// totalError += targetOutput[i] * Math.log(sigm);
57
58 }
59// n++;
60
61 return errorDerivative;
62 }
63
64}
Note: See TracBrowser for help on using the repository browser.