[74] | 1 | import os
|
---|
| 2 |
|
---|
| 3 | import plotly.graph_objects as go
|
---|
| 4 | import numpy as np
|
---|
| 5 | from collections import defaultdict
|
---|
| 6 |
|
---|
| 7 | class PlotTournament():
|
---|
| 8 |
|
---|
| 9 | def __init__(self, results_summaries, my_agent):
|
---|
| 10 | self.utilities = defaultdict(list)
|
---|
| 11 | self.opponent_utilities = defaultdict(list)
|
---|
| 12 | self.nash_products = defaultdict(list)
|
---|
| 13 | self.social_welfares = defaultdict(list)
|
---|
| 14 | self.results_summaries = results_summaries
|
---|
| 15 | self.my_agent = my_agent
|
---|
| 16 |
|
---|
| 17 | def update_tournament_results(self):
|
---|
| 18 | for match in self.results_summaries:
|
---|
| 19 | # only interested in the matches where our agent appears.
|
---|
| 20 | if self.my_agent in match.values():
|
---|
| 21 | agent1 = None
|
---|
| 22 | util1 = None
|
---|
| 23 | agent2 = None
|
---|
| 24 | util2 = None
|
---|
| 25 | for key in match.keys():
|
---|
| 26 | if key.startswith("agent_"):
|
---|
| 27 | if agent1 == None:
|
---|
| 28 | agent1 = match[key]
|
---|
| 29 | else:
|
---|
| 30 | agent2 = match[key]
|
---|
| 31 | if key.startswith("utility_"):
|
---|
| 32 | if util1 == None:
|
---|
| 33 | util1 = match[key]
|
---|
| 34 | else:
|
---|
| 35 | util2 = match[key]
|
---|
| 36 |
|
---|
| 37 | if agent1 != self.my_agent:
|
---|
| 38 | self.utilities[agent1].append(util2)
|
---|
| 39 | self.nash_products[agent1].append(match["nash_product"])
|
---|
| 40 | self.social_welfares[agent1].append(match["social_welfare"])
|
---|
| 41 |
|
---|
| 42 | if agent1 == self.my_agent:
|
---|
| 43 | self.opponent_utilities[agent2].append(util2)
|
---|
| 44 |
|
---|
| 45 | if agent2 != self.my_agent:
|
---|
| 46 | self.utilities[agent2].append(util1)
|
---|
| 47 | self.nash_products[agent2].append(match["nash_product"])
|
---|
| 48 | self.social_welfares[agent2].append(match["social_welfare"])
|
---|
| 49 |
|
---|
| 50 | if agent2 == self.my_agent:
|
---|
| 51 | self.opponent_utilities[agent1].append(util1)
|
---|
| 52 |
|
---|
| 53 |
|
---|
| 54 | def plot_tournament_utilities(self, plot_file):
|
---|
| 55 | self.update_tournament_results()
|
---|
| 56 |
|
---|
| 57 | x_data = list(self.utilities.keys())
|
---|
| 58 |
|
---|
| 59 | trace1 = go.Bar(
|
---|
| 60 | x = x_data,
|
---|
| 61 | y = [np.mean(value) for value in self.utilities.values()],
|
---|
| 62 | name = self.my_agent + " Utility"
|
---|
| 63 | )
|
---|
| 64 |
|
---|
| 65 | trace2 = go.Bar(
|
---|
| 66 | x = x_data,
|
---|
| 67 | y = [np.mean(value) for value in self.nash_products.values()],
|
---|
| 68 | name = "Nash Product"
|
---|
| 69 | )
|
---|
| 70 |
|
---|
| 71 | trace3 = go.Bar(
|
---|
| 72 | x = x_data,
|
---|
| 73 | y = [np.mean(value) for value in self.social_welfares.values()],
|
---|
| 74 | name = "Social Welfare"
|
---|
| 75 | )
|
---|
| 76 |
|
---|
| 77 | trace4 = go.Bar(
|
---|
| 78 | x = [agent for agent in self.opponent_utilities.keys()],
|
---|
| 79 | y = [np.mean(value) for value in self.opponent_utilities.values()],
|
---|
| 80 | name = "Opponent Utility"
|
---|
| 81 | )
|
---|
| 82 |
|
---|
| 83 | data = [trace1, trace4, trace2, trace3]
|
---|
| 84 |
|
---|
| 85 | layout = go.Layout(barmode = 'group')
|
---|
| 86 | fig = go.Figure(data = data, layout = layout)
|
---|
| 87 |
|
---|
| 88 | title = "Average performance of " + self.my_agent + " against " \
|
---|
| 89 | "other agents"
|
---|
| 90 | fig.update_layout(title_text=title, title_x=0.5)
|
---|
| 91 | fig.update_yaxes(title_text="Average Score", ticks="outside")
|
---|
| 92 |
|
---|
| 93 | fig.write_html(f"{os.path.splitext(plot_file)[0]}.html") |
---|