[1] | 1 | package geniusweb.blingbling.Ranknet4j;
|
---|
| 2 |
|
---|
| 3 | import java.util.ArrayList;
|
---|
| 4 | import java.util.List;
|
---|
| 5 | import java.util.Random;
|
---|
| 6 |
|
---|
| 7 | import org.neuroph.core.Layer;
|
---|
| 8 | import org.neuroph.core.NeuralNetwork;
|
---|
| 9 | import org.neuroph.core.data.DataSet;
|
---|
| 10 | import org.neuroph.core.data.DataSetRow;
|
---|
| 11 | import org.neuroph.core.learning.LearningRule;
|
---|
| 12 | import org.neuroph.util.TransferFunctionType;
|
---|
| 13 |
|
---|
| 14 | public class Test1 {
|
---|
| 15 |
|
---|
| 16 | public static void main(String[] args) {
|
---|
| 17 | // TODO Auto-generated method stub
|
---|
| 18 | int inputsCount = 10;
|
---|
| 19 | int samplenum = 10;
|
---|
| 20 | NeuralNetwork ann = new Ranknet(TransferFunctionType.SIGMOID, inputsCount, 10, 1);
|
---|
| 21 |
|
---|
| 22 | LearningRule lr = ann.getLearningRule();
|
---|
| 23 |
|
---|
| 24 | System.out.print(lr.toString());
|
---|
| 25 | List<double[]> datalist = getData(inputsCount, samplenum);
|
---|
| 26 | for (double[] input: datalist) {
|
---|
| 27 | ann.setInput(input);
|
---|
| 28 | ann.calculate();
|
---|
| 29 | System.out.println(ann.getOutput()[0]);
|
---|
| 30 | }
|
---|
| 31 |
|
---|
| 32 | DataSet dataset = getDataset(datalist, inputsCount, samplenum);
|
---|
| 33 | dataset.shuffle();
|
---|
| 34 | ann.learn(dataset);
|
---|
| 35 | System.out.println("-----------------------------");
|
---|
| 36 | for (double[] input: datalist) {
|
---|
| 37 | ann.setInput(input);
|
---|
| 38 | ann.calculate();
|
---|
| 39 | // System.out.print(input);
|
---|
| 40 | System.out.println(ann.getOutput()[0]);
|
---|
| 41 | }
|
---|
| 42 |
|
---|
| 43 | }
|
---|
| 44 |
|
---|
| 45 | public static List<double[]> getData(int inputsize, int sample) {
|
---|
| 46 |
|
---|
| 47 | Random rand = new Random();
|
---|
| 48 |
|
---|
| 49 | List<double[]> l = new ArrayList<double[]>();
|
---|
| 50 | for (int num = 0; num<sample; num++) {
|
---|
| 51 | double[] d = new double[inputsize];
|
---|
| 52 | for (int i=0; i<inputsize; i++) {
|
---|
| 53 | if (i==num) {
|
---|
| 54 | d[i]=1.0;
|
---|
| 55 | }else {
|
---|
| 56 | d[i]=0.0;
|
---|
| 57 | }
|
---|
| 58 | // d[i] = rand.nextDouble();
|
---|
| 59 | }
|
---|
| 60 | l.add(d);
|
---|
| 61 | }
|
---|
| 62 |
|
---|
| 63 |
|
---|
| 64 | return l;
|
---|
| 65 | }
|
---|
| 66 |
|
---|
| 67 | public static DataSet getDataset(List<double[]> l,int inputsize, int sample) {
|
---|
| 68 | DataSet ds = new DataSet(inputsize*2, 1);
|
---|
| 69 | double[] output = new double[1];
|
---|
| 70 | output[0] = 1.0;
|
---|
| 71 | for (int i =0; i <sample; i++) {
|
---|
| 72 | for (int j=i+1; j<sample; j++) {
|
---|
| 73 | double[] data = new double[inputsize*2];
|
---|
| 74 | for (int ind=0; ind<inputsize*2; ind++) {
|
---|
| 75 | if (ind<inputsize) {
|
---|
| 76 | data[ind] = l.get(i)[ind];
|
---|
| 77 | }else {
|
---|
| 78 | data[ind] = l.get(j)[ind-inputsize];
|
---|
| 79 | }
|
---|
| 80 | }
|
---|
| 81 | ds.add(data, output);
|
---|
| 82 | }
|
---|
| 83 | }
|
---|
| 84 | ds.shuffle();
|
---|
| 85 | return ds.split(10)[0];
|
---|
| 86 | // return ds;
|
---|
| 87 |
|
---|
| 88 | }
|
---|
| 89 | }
|
---|