Last change
on this file was 341, checked in by Katsuhide Fujita, 5 years ago |
Katsuhide Fujita added ANAC2018 agents.
|
File size:
1.3 KB
|
Rev | Line | |
---|
[341] | 1 | package agents.anac.y2018.agentherb;
|
---|
| 2 |
|
---|
| 3 | import java.util.concurrent.ThreadLocalRandom;
|
---|
| 4 |
|
---|
| 5 | public class LogisticRegression {
|
---|
| 6 |
|
---|
| 7 | private static final double RATE = 0.5;
|
---|
| 8 |
|
---|
| 9 | private final double[] weights;
|
---|
| 10 |
|
---|
| 11 | /**
|
---|
| 12 | * @param sizeOfVector The size of each vector
|
---|
| 13 | */
|
---|
| 14 | public LogisticRegression(int sizeOfVector) {
|
---|
| 15 | this.weights = new double[sizeOfVector];
|
---|
| 16 | for (int i = 0; i < sizeOfVector; i++) {
|
---|
| 17 | this.weights[i] = ThreadLocalRandom.current().nextDouble(-1, 1);
|
---|
| 18 | }
|
---|
| 19 | }
|
---|
| 20 |
|
---|
| 21 | /**
|
---|
| 22 | * @param number The number to to sigmoid
|
---|
| 23 | * @return The sigmoid of the number
|
---|
| 24 | */
|
---|
| 25 | private static double sigmoid(double number) {
|
---|
| 26 | return 1.0 / (1.0 + Math.exp(-number));
|
---|
| 27 | }
|
---|
| 28 |
|
---|
| 29 | public void train(Vector vector, int label) {
|
---|
| 30 | double predicted = classify(vector);
|
---|
| 31 | for (int i = 0; i < this.weights.length; i++) {
|
---|
| 32 | this.weights[i] = this.weights[i] + RATE * (label - predicted) * vector.get(i);
|
---|
| 33 | }
|
---|
| 34 |
|
---|
| 35 | }
|
---|
| 36 |
|
---|
| 37 | /**
|
---|
| 38 | * @param vector The vector to classify
|
---|
| 39 | * @return The classification of the vector
|
---|
| 40 | */
|
---|
| 41 | public double classify(Vector vector) {
|
---|
| 42 | double sum = 0;
|
---|
| 43 | for (int i = 0; i < this.weights.length ; i++) {
|
---|
| 44 | sum += this.weights[i] * vector.get(i);
|
---|
| 45 | }
|
---|
| 46 | return sigmoid(sum);
|
---|
| 47 | }
|
---|
| 48 | }
|
---|
Note:
See
TracBrowser
for help on using the repository browser.