1 | from time import sleep, time
|
---|
2 | from typing import cast
|
---|
3 | import unittest
|
---|
4 | from unittest.mock import Mock
|
---|
5 |
|
---|
6 | from tudelft_utilities_logging.Reporter import Reporter
|
---|
7 | from uri.uri import URI
|
---|
8 |
|
---|
9 | from geniusweb.actions.EndNegotiation import EndNegotiation
|
---|
10 | from geniusweb.actions.LearningDone import LearningDone
|
---|
11 | from geniusweb.actions.PartyId import PartyId
|
---|
12 | from geniusweb.deadline.DeadlineTime import DeadlineTime
|
---|
13 | from geniusweb.inform.Finished import Finished
|
---|
14 | from geniusweb.protocol.partyconnection.ProtocolToPartyConn import ProtocolToPartyConn
|
---|
15 | from geniusweb.protocol.partyconnection.ProtocolToPartyConnFactory import ProtocolToPartyConnFactory
|
---|
16 | from geniusweb.protocol.session.TeamInfo import TeamInfo
|
---|
17 | from geniusweb.protocol.session.learn.Learn import Learn
|
---|
18 | from geniusweb.protocol.session.learn.LearnSettings import LearnSettings
|
---|
19 | from geniusweb.protocol.session.learn.LearnState import LearnState
|
---|
20 | from geniusweb.references.Parameters import Parameters
|
---|
21 | from geniusweb.references.PartyRef import PartyRef
|
---|
22 | from geniusweb.references.PartyWithParameters import PartyWithParameters
|
---|
23 | from geniusweb.references.PartyWithProfile import PartyWithProfile
|
---|
24 | from geniusweb.references.ProfileRef import ProfileRef
|
---|
25 |
|
---|
26 |
|
---|
27 | class LearnTest(unittest.TestCase):
|
---|
28 |
|
---|
29 | reporter = Mock(Reporter)
|
---|
30 | params = Parameters()
|
---|
31 | deadline = DeadlineTime(1800)
|
---|
32 | party1id = PartyId("party1")
|
---|
33 | party2id = PartyId("party2")
|
---|
34 |
|
---|
35 | def setUp(self):
|
---|
36 | self.conn1 = Mock(ProtocolToPartyConn)
|
---|
37 | self.conn2 = Mock(ProtocolToPartyConn)
|
---|
38 | self.params = self.params.With("persistentstate",
|
---|
39 | "6bb5f909-0079-43ac-a8ac-a31794391074")
|
---|
40 | self.params = self.params.With("negotiationdata",
|
---|
41 | ["12b5f909-0079-43ac-a8ac-a31794391012"])
|
---|
42 |
|
---|
43 | def testsmokeTest(self):
|
---|
44 | Learn(Mock(LearnState), Mock(Reporter))
|
---|
45 |
|
---|
46 | def testgetDescrTest(self):
|
---|
47 | l = Learn(Mock(LearnState), Mock(Reporter))
|
---|
48 |
|
---|
49 | self.assertNotEqual(None, l.getDescription())
|
---|
50 |
|
---|
51 | def testgetStateTest(self) :
|
---|
52 | state = Mock(LearnState)
|
---|
53 | l = Learn(state, Mock(Reporter))
|
---|
54 | self.assertEqual(state, l.getState())
|
---|
55 |
|
---|
56 | def testgetRefTest(self):
|
---|
57 | l = Learn(Mock(LearnState), Mock(Reporter))
|
---|
58 | self.assertEqual("Learn", str(l.getRef().getURI().getPath()))
|
---|
59 |
|
---|
60 | def testfinalStateNotificationTest(self):
|
---|
61 | # check that listeners get notified when session ends.
|
---|
62 | pass
|
---|
63 |
|
---|
64 | def testStartStopBasic(self) :
|
---|
65 | learn = self.createBasicLearn()
|
---|
66 | factory = self.createFactory()
|
---|
67 | learn.start(factory)
|
---|
68 | #extra, check if start worked ok
|
---|
69 | self.assertEqual(None, learn.getState().getError())
|
---|
70 | self.assertFalse(learn.getState().isFinal(time()*1000))
|
---|
71 | sleep(2.000)
|
---|
72 | self.assertTrue(learn.getState().isFinal(time()*1000))
|
---|
73 |
|
---|
74 | def testStartStopLearn(self):
|
---|
75 | learn = self.createBasicLearn();
|
---|
76 | factory = self.createFactory()
|
---|
77 | learn.start(factory)
|
---|
78 | self.assertFalse(learn.getState().isFinal(time()*1000))
|
---|
79 |
|
---|
80 | # Instead of mocking connectin we call actionRequest directly
|
---|
81 | learn._actionRequest(self.conn1, LearningDone(self.party1id))
|
---|
82 | self.assertFalse(learn.getState().isFinal(time()*1000))
|
---|
83 | learn._actionRequest(self.conn2, LearningDone(self.party2id))
|
---|
84 | self.assertTrue(learn.getState().isFinal(time()*1000))
|
---|
85 |
|
---|
86 | def testAddParty(self):
|
---|
87 | l = Learn(Mock(LearnState), Mock(Reporter))
|
---|
88 | self.assertRaises(ValueError, lambda:l.addParticipant(Mock(PartyWithProfile)))
|
---|
89 |
|
---|
90 |
|
---|
91 | def testisFinishSentNormally(self):
|
---|
92 | learn = self.createBasicLearn()
|
---|
93 | factory = self.createFactory()
|
---|
94 | learn.start(factory)
|
---|
95 | learn._actionRequest(self.conn1, LearningDone(self.party1id))
|
---|
96 | learn._actionRequest(self.conn2, LearningDone(self.party2id))
|
---|
97 |
|
---|
98 | self.assertEqual(1, len([call for call in self.conn1.send.call_args_list \
|
---|
99 | if isinstance(call[0][0],Finished)]))
|
---|
100 | self.assertEqual(1, len([call for call in self.conn2.send.call_args_list \
|
---|
101 | if isinstance(call[0][0],Finished)]))
|
---|
102 |
|
---|
103 | def testisFinishSentInError(self):
|
---|
104 | learn = self.createBasicLearn()
|
---|
105 | factory = self.createFactory()
|
---|
106 | learn.start(factory)
|
---|
107 | learn._actionRequest(self.conn1, EndNegotiation(self.party1id))
|
---|
108 | self.assertEqual(1, len([call for call in self.conn1.send.call_args_list \
|
---|
109 | if isinstance(call[0][0],Finished)]))
|
---|
110 | self.assertEqual(1, len([call for call in self.conn2.send.call_args_list \
|
---|
111 | if isinstance(call[0][0],Finished)]))
|
---|
112 |
|
---|
113 | def createBasicLearn(self)->Learn:
|
---|
114 | team1 = self.createTeam(1)
|
---|
115 | team2 = self.createTeam(2)
|
---|
116 | settings = LearnSettings([team1, team2],self.deadline)
|
---|
117 | return cast(Learn, settings.getProtocol(self.reporter))
|
---|
118 |
|
---|
119 | def createTeam( self, nr:int) -> TeamInfo :
|
---|
120 | party1ref = PartyRef(URI("party" + str(nr)))
|
---|
121 | party1 = PartyWithParameters(party1ref, self.params)
|
---|
122 | profile1 = ProfileRef(URI("prof" + str(nr)))
|
---|
123 | partywithp1 = PartyWithProfile(party1, profile1)
|
---|
124 | team1pp = [partywithp1]
|
---|
125 | team = TeamInfo(team1pp)
|
---|
126 | return team
|
---|
127 |
|
---|
128 | def createFactory(self )->ProtocolToPartyConnFactory :
|
---|
129 | self.conn1.getParty=Mock(return_value=self.party1id)
|
---|
130 | self.conn2.getParty=Mock(return_value=self.party2id)
|
---|
131 |
|
---|
132 | factory = Mock( ProtocolToPartyConnFactory)
|
---|
133 | # connections = mock(List.class);
|
---|
134 | connections = [self.conn1, self.conn2]
|
---|
135 | factory.connectAll=Mock(return_value=connections)
|
---|
136 | return factory;
|
---|
137 |
|
---|