1 | package geniusweb.blingbling.Ranknet4j;
|
---|
2 |
|
---|
3 | import java.util.ArrayList;
|
---|
4 | import java.util.List;
|
---|
5 | import java.util.Random;
|
---|
6 |
|
---|
7 | import org.neuroph.core.Layer;
|
---|
8 | import org.neuroph.core.NeuralNetwork;
|
---|
9 | import org.neuroph.core.data.DataSet;
|
---|
10 | import org.neuroph.core.data.DataSetRow;
|
---|
11 | import org.neuroph.core.learning.LearningRule;
|
---|
12 | import org.neuroph.util.TransferFunctionType;
|
---|
13 |
|
---|
14 | public class Test1 {
|
---|
15 |
|
---|
16 | public static void main(String[] args) {
|
---|
17 | // TODO Auto-generated method stub
|
---|
18 | int inputsCount = 10;
|
---|
19 | int samplenum = 10;
|
---|
20 | NeuralNetwork ann = new Ranknet(TransferFunctionType.SIGMOID, inputsCount, 10, 1);
|
---|
21 |
|
---|
22 | LearningRule lr = ann.getLearningRule();
|
---|
23 |
|
---|
24 | System.out.print(lr.toString());
|
---|
25 | List<double[]> datalist = getData(inputsCount, samplenum);
|
---|
26 | for (double[] input: datalist) {
|
---|
27 | ann.setInput(input);
|
---|
28 | ann.calculate();
|
---|
29 | System.out.println(ann.getOutput()[0]);
|
---|
30 | }
|
---|
31 |
|
---|
32 | DataSet dataset = getDataset(datalist, inputsCount, samplenum);
|
---|
33 | dataset.shuffle();
|
---|
34 | ann.learn(dataset);
|
---|
35 | System.out.println("-----------------------------");
|
---|
36 | for (double[] input: datalist) {
|
---|
37 | ann.setInput(input);
|
---|
38 | ann.calculate();
|
---|
39 | // System.out.print(input);
|
---|
40 | System.out.println(ann.getOutput()[0]);
|
---|
41 | }
|
---|
42 |
|
---|
43 | }
|
---|
44 |
|
---|
45 | public static List<double[]> getData(int inputsize, int sample) {
|
---|
46 |
|
---|
47 | Random rand = new Random();
|
---|
48 |
|
---|
49 | List<double[]> l = new ArrayList<double[]>();
|
---|
50 | for (int num = 0; num<sample; num++) {
|
---|
51 | double[] d = new double[inputsize];
|
---|
52 | for (int i=0; i<inputsize; i++) {
|
---|
53 | if (i==num) {
|
---|
54 | d[i]=1.0;
|
---|
55 | }else {
|
---|
56 | d[i]=0.0;
|
---|
57 | }
|
---|
58 | // d[i] = rand.nextDouble();
|
---|
59 | }
|
---|
60 | l.add(d);
|
---|
61 | }
|
---|
62 |
|
---|
63 |
|
---|
64 | return l;
|
---|
65 | }
|
---|
66 |
|
---|
67 | public static DataSet getDataset(List<double[]> l,int inputsize, int sample) {
|
---|
68 | DataSet ds = new DataSet(inputsize*2, 1);
|
---|
69 | double[] output = new double[1];
|
---|
70 | output[0] = 1.0;
|
---|
71 | for (int i =0; i <sample; i++) {
|
---|
72 | for (int j=i+1; j<sample; j++) {
|
---|
73 | double[] data = new double[inputsize*2];
|
---|
74 | for (int ind=0; ind<inputsize*2; ind++) {
|
---|
75 | if (ind<inputsize) {
|
---|
76 | data[ind] = l.get(i)[ind];
|
---|
77 | }else {
|
---|
78 | data[ind] = l.get(j)[ind-inputsize];
|
---|
79 | }
|
---|
80 | }
|
---|
81 | ds.add(data, output);
|
---|
82 | }
|
---|
83 | }
|
---|
84 | ds.shuffle();
|
---|
85 | return ds.split(10)[0];
|
---|
86 | // return ds;
|
---|
87 |
|
---|
88 | }
|
---|
89 | }
|
---|