[84] | 1 | from time import sleep, time
|
---|
| 2 | from typing import cast
|
---|
| 3 | import unittest
|
---|
| 4 | from unittest.mock import Mock
|
---|
| 5 |
|
---|
| 6 | from tudelft_utilities_logging.Reporter import Reporter
|
---|
| 7 | from uri.uri import URI
|
---|
| 8 |
|
---|
| 9 | from geniusweb.actions.EndNegotiation import EndNegotiation
|
---|
| 10 | from geniusweb.actions.LearningDone import LearningDone
|
---|
| 11 | from geniusweb.actions.PartyId import PartyId
|
---|
| 12 | from geniusweb.deadline.DeadlineTime import DeadlineTime
|
---|
| 13 | from geniusweb.inform.Finished import Finished
|
---|
| 14 | from geniusweb.protocol.partyconnection.ProtocolToPartyConn import ProtocolToPartyConn
|
---|
| 15 | from geniusweb.protocol.partyconnection.ProtocolToPartyConnFactory import ProtocolToPartyConnFactory
|
---|
| 16 | from geniusweb.protocol.session.TeamInfo import TeamInfo
|
---|
| 17 | from geniusweb.protocol.session.learn.Learn import Learn
|
---|
| 18 | from geniusweb.protocol.session.learn.LearnSettings import LearnSettings
|
---|
| 19 | from geniusweb.protocol.session.learn.LearnState import LearnState
|
---|
| 20 | from geniusweb.references.Parameters import Parameters
|
---|
| 21 | from geniusweb.references.PartyRef import PartyRef
|
---|
| 22 | from geniusweb.references.PartyWithParameters import PartyWithParameters
|
---|
| 23 | from geniusweb.references.PartyWithProfile import PartyWithProfile
|
---|
| 24 | from geniusweb.references.ProfileRef import ProfileRef
|
---|
| 25 |
|
---|
| 26 |
|
---|
| 27 | class LearnTest(unittest.TestCase):
|
---|
| 28 |
|
---|
| 29 | reporter = Mock(Reporter)
|
---|
| 30 | params = Parameters()
|
---|
| 31 | deadline = DeadlineTime(1800)
|
---|
| 32 | party1id = PartyId("party1")
|
---|
| 33 | party2id = PartyId("party2")
|
---|
| 34 |
|
---|
| 35 | def setUp(self):
|
---|
| 36 | self.conn1 = Mock(ProtocolToPartyConn)
|
---|
| 37 | self.conn2 = Mock(ProtocolToPartyConn)
|
---|
| 38 | self.params = self.params.With("persistentstate",
|
---|
| 39 | "6bb5f909-0079-43ac-a8ac-a31794391074")
|
---|
| 40 | self.params = self.params.With("negotiationdata",
|
---|
| 41 | ["12b5f909-0079-43ac-a8ac-a31794391012"])
|
---|
| 42 |
|
---|
| 43 | def testsmokeTest(self):
|
---|
| 44 | Learn(Mock(LearnState), Mock(Reporter))
|
---|
| 45 |
|
---|
| 46 | def testgetDescrTest(self):
|
---|
| 47 | l = Learn(Mock(LearnState), Mock(Reporter))
|
---|
| 48 |
|
---|
| 49 | self.assertNotEqual(None, l.getDescription())
|
---|
| 50 |
|
---|
| 51 | def testgetStateTest(self) :
|
---|
| 52 | state = Mock(LearnState)
|
---|
| 53 | l = Learn(state, Mock(Reporter))
|
---|
| 54 | self.assertEqual(state, l.getState())
|
---|
| 55 |
|
---|
| 56 | def testgetRefTest(self):
|
---|
| 57 | l = Learn(Mock(LearnState), Mock(Reporter))
|
---|
| 58 | self.assertEqual("Learn", str(l.getRef().getURI().getPath()))
|
---|
| 59 |
|
---|
| 60 | def testfinalStateNotificationTest(self):
|
---|
| 61 | # check that listeners get notified when session ends.
|
---|
| 62 | pass
|
---|
| 63 |
|
---|
| 64 | def testStartStopBasic(self) :
|
---|
| 65 | learn = self.createBasicLearn()
|
---|
| 66 | factory = self.createFactory()
|
---|
| 67 | learn.start(factory)
|
---|
| 68 | #extra, check if start worked ok
|
---|
| 69 | self.assertEqual(None, learn.getState().getError())
|
---|
| 70 | self.assertFalse(learn.getState().isFinal(time()*1000))
|
---|
| 71 | sleep(2.000)
|
---|
| 72 | self.assertTrue(learn.getState().isFinal(time()*1000))
|
---|
| 73 |
|
---|
| 74 | def testStartStopLearn(self):
|
---|
| 75 | learn = self.createBasicLearn();
|
---|
| 76 | factory = self.createFactory()
|
---|
| 77 | learn.start(factory)
|
---|
| 78 | self.assertFalse(learn.getState().isFinal(time()*1000))
|
---|
| 79 |
|
---|
| 80 | # Instead of mocking connectin we call actionRequest directly
|
---|
| 81 | learn._actionRequest(self.conn1, LearningDone(self.party1id))
|
---|
| 82 | self.assertFalse(learn.getState().isFinal(time()*1000))
|
---|
| 83 | learn._actionRequest(self.conn2, LearningDone(self.party2id))
|
---|
| 84 | self.assertTrue(learn.getState().isFinal(time()*1000))
|
---|
| 85 |
|
---|
| 86 | def testAddParty(self):
|
---|
| 87 | l = Learn(Mock(LearnState), Mock(Reporter))
|
---|
| 88 | self.assertRaises(ValueError, lambda:l.addParticipant(Mock(PartyWithProfile)))
|
---|
| 89 |
|
---|
| 90 |
|
---|
| 91 | def testisFinishSentNormally(self):
|
---|
| 92 | learn = self.createBasicLearn()
|
---|
| 93 | factory = self.createFactory()
|
---|
| 94 | learn.start(factory)
|
---|
| 95 | learn._actionRequest(self.conn1, LearningDone(self.party1id))
|
---|
| 96 | learn._actionRequest(self.conn2, LearningDone(self.party2id))
|
---|
| 97 |
|
---|
| 98 | self.assertEqual(1, len([call for call in self.conn1.send.call_args_list \
|
---|
| 99 | if isinstance(call[0][0],Finished)]))
|
---|
| 100 | self.assertEqual(1, len([call for call in self.conn2.send.call_args_list \
|
---|
| 101 | if isinstance(call[0][0],Finished)]))
|
---|
| 102 |
|
---|
| 103 | def testisFinishSentInError(self):
|
---|
| 104 | learn = self.createBasicLearn()
|
---|
| 105 | factory = self.createFactory()
|
---|
| 106 | learn.start(factory)
|
---|
| 107 | learn._actionRequest(self.conn1, EndNegotiation(self.party1id))
|
---|
| 108 | self.assertEqual(1, len([call for call in self.conn1.send.call_args_list \
|
---|
| 109 | if isinstance(call[0][0],Finished)]))
|
---|
| 110 | self.assertEqual(1, len([call for call in self.conn2.send.call_args_list \
|
---|
| 111 | if isinstance(call[0][0],Finished)]))
|
---|
| 112 |
|
---|
| 113 | def createBasicLearn(self)->Learn:
|
---|
| 114 | team1 = self.createTeam(1)
|
---|
| 115 | team2 = self.createTeam(2)
|
---|
| 116 | settings = LearnSettings([team1, team2],self.deadline)
|
---|
| 117 | return cast(Learn, settings.getProtocol(self.reporter))
|
---|
| 118 |
|
---|
| 119 | def createTeam( self, nr:int) -> TeamInfo :
|
---|
| 120 | party1ref = PartyRef(URI("party" + str(nr)))
|
---|
| 121 | party1 = PartyWithParameters(party1ref, self.params)
|
---|
| 122 | profile1 = ProfileRef(URI("prof" + str(nr)))
|
---|
| 123 | partywithp1 = PartyWithProfile(party1, profile1)
|
---|
| 124 | team1pp = [partywithp1]
|
---|
| 125 | team = TeamInfo(team1pp)
|
---|
| 126 | return team
|
---|
| 127 |
|
---|
| 128 | def createFactory(self )->ProtocolToPartyConnFactory :
|
---|
| 129 | self.conn1.getParty=Mock(return_value=self.party1id)
|
---|
| 130 | self.conn2.getParty=Mock(return_value=self.party2id)
|
---|
| 131 |
|
---|
| 132 | factory = Mock( ProtocolToPartyConnFactory)
|
---|
| 133 | # connections = mock(List.class);
|
---|
| 134 | connections = [self.conn1, self.conn2]
|
---|
| 135 | factory.connectAll=Mock(return_value=connections)
|
---|
| 136 | return factory;
|
---|
| 137 |
|
---|