source: src/main/java/uva/projectai/y2018/jasparon/Qlearner.java@ 67

Last change on this file since 67 was 67, checked in by Aron Hammond, 6 years ago

Added support for agents that learn via ReinforcementLearning, including an implementation of an agent that uses tabular Q-learning

File size: 5.8 KB
Line 
1package uva.projectai.y2018.jasparon;
2
3import java.io.FileInputStream;
4import java.io.File;
5import java.io.FileNotFoundException;
6import java.io.FileOutputStream;
7import java.io.ObjectInputStream;
8import java.io.ObjectOutputStream;
9import java.util.ArrayList;
10import java.util.HashMap;
11
12import genius.core.Bid;
13import genius.core.actions.Action;
14import genius.core.actions.EndNegotiation;
15import genius.core.boaframework.BOAagent;
16import genius.core.events.MultipartyNegoActionEvent;
17import genius.core.events.NegotiationEvent;
18import negotiator.boaframework.acceptanceconditions.other.AC_Next;
19import negotiator.boaframework.omstrategy.BestBid;
20import negotiator.boaframework.opponentmodel.PerfectModel;
21
22import uva.projectai.y2018.jasparon.QlearningStrategy;
23
24@SuppressWarnings("serial")
25public class Qlearner extends BOAagent implements RLBOA {
26
27 QlearningStrategy offeringStrategy;
28 String state;
29
30 @Override
31 public void agentSetup() {
32
33 System.out.print(this.instanceIdentifier());
34 HashMap<String, Double> params = new HashMap<String, Double>();
35
36 // Use perfect opponent model to decrease noise in the environment
37 opponentModel = new PerfectModel();
38 opponentModel.init(negotiationSession, params);
39
40 // Load existing qTable
41 @SuppressWarnings("unchecked")
42 HashMap<Integer, ArrayList<Double>> qTable = (HashMap<Integer, ArrayList<Double>>) readObjectFromFile("qTables/" + this.instanceIdentifier());
43
44 // Initialize offeringStrategy (is a RL-component)
45 switch (this.getStrategyParameters().getValueAsString("strategy")) {
46 case "QlearningStrategy":
47 offeringStrategy = new QlearningStrategy(negotiationSession, opponentModel);
48 break;
49 case "PriorBeliefQlearningStrategy":
50 offeringStrategy = new PriorBeliefQlearningStrategy(negotiationSession, opponentModel);
51 break;
52 case "QLambdaStrategy":
53 offeringStrategy = new QLambdaStrategy(negotiationSession, opponentModel);
54 break;
55 }
56 offeringStrategy.setHyperparameters(this.getStrategyParameters());
57 offeringStrategy.initQtable(qTable);
58
59 // Accept if the incoming offer is higher than what you would offer yourself
60 acceptConditions = new AC_Next(negotiationSession, offeringStrategy, 1, 0);
61
62 // Opponent model strategy always selects best bid it has available
63 omStrategy = new BestBid();
64 omStrategy.init(negotiationSession, opponentModel, params);
65 setDecoupledComponents(acceptConditions, offeringStrategy, opponentModel, omStrategy);
66 }
67
68 @Override
69 public String getName() {
70 return "Q-learner";
71 }
72
73 public State getStateRepresentation(MultipartyNegoActionEvent negoEvent) {
74 Bid oppLastBid = negotiationSession.getOpponentBidHistory().getLastBid();
75 Bid myLastBid = negotiationSession.getOwnBidHistory().getLastBid();
76 Bid agreement = negoEvent.getAgreement();
77 Action currentAction = negoEvent.getAction();
78
79 if (agreement != null || currentAction.getClass() == EndNegotiation.class) {
80 return State.TERMINAL;
81 }
82
83 int myBin;
84 if (myLastBid != null) {
85 double myBidUtil = this.getUtility(myLastBid);
86 myBin = this.getBinIndex(myBidUtil);
87 }
88 else {
89 myBin = Integer.MIN_VALUE;
90 }
91
92 int oppBin;
93 if (oppLastBid != null) {
94 double oppBidUtil = this.getUtility(oppLastBid);
95 oppBin = this.getBinIndex(oppBidUtil);
96 }
97 else {
98 oppBin = Integer.MIN_VALUE;
99 }
100
101 double time = negotiationSession.getTime();
102
103 State state = new State(myBin, oppBin, this.getTimeBinIndex(time));
104
105 return state;
106 }
107
108 private int getBinIndex(double util) {
109 int n_bins = offeringStrategy.getNBins();
110 return (int) Math.floor(util * n_bins);
111 }
112
113 private int getTimeBinIndex(double time) {
114 // TODO: Remove magic number
115 return (int) Math.floor(time * 4);
116 }
117
118 @Override
119 public void notifyChange(NegotiationEvent data) {
120 if (data instanceof MultipartyNegoActionEvent) {
121 // Get relevant information from negotiation event
122 MultipartyNegoActionEvent negoEvent = (MultipartyNegoActionEvent) data;
123
124 // Observe state
125 State newState = this.getStateRepresentation(negoEvent);
126
127 double reward = this.getReward(negoEvent);
128 boolean myTurn = negoEvent.getAction().getAgent() == this.getAgentID();
129
130 if (newState.isTerminalState() || !myTurn) {
131 this.observeEnvironment(reward, newState);
132 }
133 }
134 }
135
136 @Override
137 public double getReward(MultipartyNegoActionEvent negoEvent) {
138 double reward = 0.0;
139 Bid agreement = negoEvent.getAgreement();
140
141 if (agreement != null) {
142 reward = this.getUtility(agreement);
143 }
144
145 return reward;
146 }
147
148 @Override
149 public void observeEnvironment(double reward, State newState) {
150 this.offeringStrategy.observeEnvironment(reward, newState);
151
152 if (newState.isTerminalState()) {
153 this.writeObjectToFile(this.offeringStrategy.getQTable());
154 }
155 }
156
157 public void writeObjectToFile(Object serObj) {
158 String DIRECTORY = "qTables/";
159 String filepath = DIRECTORY + this.instanceIdentifier();
160
161 File directory = new File(DIRECTORY);
162 if (!directory.exists()) {
163 directory.mkdir();
164 }
165
166 try {
167 FileOutputStream fileOut = new FileOutputStream(filepath);
168 ObjectOutputStream objectOut = new ObjectOutputStream(fileOut);
169 objectOut.writeObject(serObj);
170 objectOut.close();
171 System.out.println("The Object was succesfully written to a file");
172
173 } catch (Exception ex) {
174 ex.printStackTrace();
175 }
176 }
177
178 public Object readObjectFromFile(String filepath) {
179 Object obj = null;
180
181 try {
182 FileInputStream fileIn = new FileInputStream(filepath);
183 ObjectInputStream objectIn = new ObjectInputStream(fileIn);
184 obj = objectIn.readObject();
185 objectIn.close();
186 System.out.println("Succesfully read object");
187 }
188 catch (Exception ex) {
189 if (ex instanceof FileNotFoundException ) {
190 System.out.println("qTable file does not exist. A new file will be created.");
191 }
192 else {
193 ex.printStackTrace();
194 }
195 }
196
197 return obj;
198 }
199
200 public String instanceIdentifier() {
201 return String.format("Qlearner-%s", this.getStrategyParameters().toString());
202 }
203}
Note: See TracBrowser for help on using the repository browser.