1 | import os
|
---|
2 |
|
---|
3 | import plotly.graph_objects as go
|
---|
4 | import numpy as np
|
---|
5 | from collections import defaultdict
|
---|
6 |
|
---|
7 | class PlotTournament():
|
---|
8 |
|
---|
9 | def __init__(self, results_summaries, my_agent):
|
---|
10 | self.utilities = defaultdict(list)
|
---|
11 | self.opponent_utilities = defaultdict(list)
|
---|
12 | self.nash_products = defaultdict(list)
|
---|
13 | self.social_welfares = defaultdict(list)
|
---|
14 | self.results_summaries = results_summaries
|
---|
15 | self.my_agent = my_agent
|
---|
16 |
|
---|
17 | def update_tournament_results(self):
|
---|
18 | for match in self.results_summaries:
|
---|
19 | # only interested in the matches where our agent appears.
|
---|
20 | if self.my_agent in match.values():
|
---|
21 | agent1 = None
|
---|
22 | util1 = None
|
---|
23 | agent2 = None
|
---|
24 | util2 = None
|
---|
25 | for key in match.keys():
|
---|
26 | if key.startswith("agent_"):
|
---|
27 | if agent1 == None:
|
---|
28 | agent1 = match[key]
|
---|
29 | else:
|
---|
30 | agent2 = match[key]
|
---|
31 | if key.startswith("utility_"):
|
---|
32 | if util1 == None:
|
---|
33 | util1 = match[key]
|
---|
34 | else:
|
---|
35 | util2 = match[key]
|
---|
36 |
|
---|
37 | if agent1 != self.my_agent:
|
---|
38 | self.utilities[agent1].append(util2)
|
---|
39 | self.nash_products[agent1].append(match["nash_product"])
|
---|
40 | self.social_welfares[agent1].append(match["social_welfare"])
|
---|
41 |
|
---|
42 | if agent1 == self.my_agent:
|
---|
43 | self.opponent_utilities[agent2].append(util2)
|
---|
44 |
|
---|
45 | if agent2 != self.my_agent:
|
---|
46 | self.utilities[agent2].append(util1)
|
---|
47 | self.nash_products[agent2].append(match["nash_product"])
|
---|
48 | self.social_welfares[agent2].append(match["social_welfare"])
|
---|
49 |
|
---|
50 | if agent2 == self.my_agent:
|
---|
51 | self.opponent_utilities[agent1].append(util1)
|
---|
52 |
|
---|
53 |
|
---|
54 | def plot_tournament_utilities(self, plot_file):
|
---|
55 | self.update_tournament_results()
|
---|
56 |
|
---|
57 | x_data = list(self.utilities.keys())
|
---|
58 |
|
---|
59 | trace1 = go.Bar(
|
---|
60 | x = x_data,
|
---|
61 | y = [np.mean(value) for value in self.utilities.values()],
|
---|
62 | name = self.my_agent + " Utility"
|
---|
63 | )
|
---|
64 |
|
---|
65 | trace2 = go.Bar(
|
---|
66 | x = x_data,
|
---|
67 | y = [np.mean(value) for value in self.nash_products.values()],
|
---|
68 | name = "Nash Product"
|
---|
69 | )
|
---|
70 |
|
---|
71 | trace3 = go.Bar(
|
---|
72 | x = x_data,
|
---|
73 | y = [np.mean(value) for value in self.social_welfares.values()],
|
---|
74 | name = "Social Welfare"
|
---|
75 | )
|
---|
76 |
|
---|
77 | trace4 = go.Bar(
|
---|
78 | x = [agent for agent in self.opponent_utilities.keys()],
|
---|
79 | y = [np.mean(value) for value in self.opponent_utilities.values()],
|
---|
80 | name = "Opponent Utility"
|
---|
81 | )
|
---|
82 |
|
---|
83 | data = [trace1, trace4, trace2, trace3]
|
---|
84 |
|
---|
85 | layout = go.Layout(barmode = 'group')
|
---|
86 | fig = go.Figure(data = data, layout = layout)
|
---|
87 |
|
---|
88 | title = "Average performance of " + self.my_agent + " against " \
|
---|
89 | "other agents"
|
---|
90 | fig.update_layout(title_text=title, title_x=0.5)
|
---|
91 | fig.update_yaxes(title_text="Average Score", ticks="outside")
|
---|
92 |
|
---|
93 | fig.write_html(f"{os.path.splitext(plot_file)[0]}.html") |
---|