source: anac2020/BlingBling/src/main/java/geniusweb/blingbling/Ranknet4j/Test1.java@ 31

Last change on this file since 31 was 1, checked in by wouter, 4 years ago

#1910 added anac2020 parties

File size: 2.1 KB
Line 
1package geniusweb.blingbling.Ranknet4j;
2
3import java.util.ArrayList;
4import java.util.List;
5import java.util.Random;
6
7import org.neuroph.core.Layer;
8import org.neuroph.core.NeuralNetwork;
9import org.neuroph.core.data.DataSet;
10import org.neuroph.core.data.DataSetRow;
11import org.neuroph.core.learning.LearningRule;
12import org.neuroph.util.TransferFunctionType;
13
14public class Test1 {
15
16 public static void main(String[] args) {
17 // TODO Auto-generated method stub
18 int inputsCount = 10;
19 int samplenum = 10;
20 NeuralNetwork ann = new Ranknet(TransferFunctionType.SIGMOID, inputsCount, 10, 1);
21
22 LearningRule lr = ann.getLearningRule();
23
24 System.out.print(lr.toString());
25 List<double[]> datalist = getData(inputsCount, samplenum);
26 for (double[] input: datalist) {
27 ann.setInput(input);
28 ann.calculate();
29 System.out.println(ann.getOutput()[0]);
30 }
31
32 DataSet dataset = getDataset(datalist, inputsCount, samplenum);
33 dataset.shuffle();
34 ann.learn(dataset);
35 System.out.println("-----------------------------");
36 for (double[] input: datalist) {
37 ann.setInput(input);
38 ann.calculate();
39// System.out.print(input);
40 System.out.println(ann.getOutput()[0]);
41 }
42
43 }
44
45 public static List<double[]> getData(int inputsize, int sample) {
46
47 Random rand = new Random();
48
49 List<double[]> l = new ArrayList<double[]>();
50 for (int num = 0; num<sample; num++) {
51 double[] d = new double[inputsize];
52 for (int i=0; i<inputsize; i++) {
53 if (i==num) {
54 d[i]=1.0;
55 }else {
56 d[i]=0.0;
57 }
58// d[i] = rand.nextDouble();
59 }
60 l.add(d);
61 }
62
63
64 return l;
65 }
66
67 public static DataSet getDataset(List<double[]> l,int inputsize, int sample) {
68 DataSet ds = new DataSet(inputsize*2, 1);
69 double[] output = new double[1];
70 output[0] = 1.0;
71 for (int i =0; i <sample; i++) {
72 for (int j=i+1; j<sample; j++) {
73 double[] data = new double[inputsize*2];
74 for (int ind=0; ind<inputsize*2; ind++) {
75 if (ind<inputsize) {
76 data[ind] = l.get(i)[ind];
77 }else {
78 data[ind] = l.get(j)[ind-inputsize];
79 }
80 }
81 ds.add(data, output);
82 }
83 }
84 ds.shuffle();
85 return ds.split(10)[0];
86// return ds;
87
88 }
89}
Note: See TracBrowser for help on using the repository browser.