Last change
on this file was 341, checked in by Katsuhide Fujita, 5 years ago |
Katsuhide Fujita added ANAC2018 agents.
|
File size:
1.3 KB
|
Line | |
---|
1 | package agents.anac.y2018.agentherb;
|
---|
2 |
|
---|
3 | import java.util.concurrent.ThreadLocalRandom;
|
---|
4 |
|
---|
5 | public class LogisticRegression {
|
---|
6 |
|
---|
7 | private static final double RATE = 0.5;
|
---|
8 |
|
---|
9 | private final double[] weights;
|
---|
10 |
|
---|
11 | /**
|
---|
12 | * @param sizeOfVector The size of each vector
|
---|
13 | */
|
---|
14 | public LogisticRegression(int sizeOfVector) {
|
---|
15 | this.weights = new double[sizeOfVector];
|
---|
16 | for (int i = 0; i < sizeOfVector; i++) {
|
---|
17 | this.weights[i] = ThreadLocalRandom.current().nextDouble(-1, 1);
|
---|
18 | }
|
---|
19 | }
|
---|
20 |
|
---|
21 | /**
|
---|
22 | * @param number The number to to sigmoid
|
---|
23 | * @return The sigmoid of the number
|
---|
24 | */
|
---|
25 | private static double sigmoid(double number) {
|
---|
26 | return 1.0 / (1.0 + Math.exp(-number));
|
---|
27 | }
|
---|
28 |
|
---|
29 | public void train(Vector vector, int label) {
|
---|
30 | double predicted = classify(vector);
|
---|
31 | for (int i = 0; i < this.weights.length; i++) {
|
---|
32 | this.weights[i] = this.weights[i] + RATE * (label - predicted) * vector.get(i);
|
---|
33 | }
|
---|
34 |
|
---|
35 | }
|
---|
36 |
|
---|
37 | /**
|
---|
38 | * @param vector The vector to classify
|
---|
39 | * @return The classification of the vector
|
---|
40 | */
|
---|
41 | public double classify(Vector vector) {
|
---|
42 | double sum = 0;
|
---|
43 | for (int i = 0; i < this.weights.length ; i++) {
|
---|
44 | sum += this.weights[i] * vector.get(i);
|
---|
45 | }
|
---|
46 | return sigmoid(sum);
|
---|
47 | }
|
---|
48 | }
|
---|
Note:
See
TracBrowser
for help on using the repository browser.