1 | import logging
|
---|
2 | import time
|
---|
3 | import numpy as np
|
---|
4 | from decimal import Decimal
|
---|
5 | from typing import cast, Dict
|
---|
6 |
|
---|
7 | from geniusweb.actions.Accept import Accept
|
---|
8 | from geniusweb.actions.Action import Action
|
---|
9 | from geniusweb.actions.Offer import Offer
|
---|
10 | from geniusweb.bidspace.AllBidsList import AllBidsList
|
---|
11 | from geniusweb.bidspace.BidsWithUtility import BidsWithUtility
|
---|
12 | from geniusweb.bidspace.Interval import Interval
|
---|
13 | from geniusweb.inform.ActionDone import ActionDone
|
---|
14 | from geniusweb.inform.Finished import Finished
|
---|
15 | from geniusweb.inform.Inform import Inform
|
---|
16 | from geniusweb.inform.Settings import Settings
|
---|
17 | from geniusweb.inform.YourTurn import YourTurn
|
---|
18 | from geniusweb.issuevalue.Bid import Bid
|
---|
19 | from geniusweb.party.Capabilities import Capabilities
|
---|
20 | from geniusweb.party.DefaultParty import DefaultParty
|
---|
21 | from geniusweb.profile.utilityspace.LinearAdditive import LinearAdditive
|
---|
22 | from geniusweb.profileconnection.ProfileConnectionFactory import (
|
---|
23 | ProfileConnectionFactory,
|
---|
24 | )
|
---|
25 | from geniusweb.progress.ProgressRounds import ProgressRounds
|
---|
26 | from tudelft_utilities_logging.Reporter import Reporter
|
---|
27 |
|
---|
28 | """Author:
|
---|
29 | Aleksander Buszydlik
|
---|
30 | Karol Dobiczek
|
---|
31 | Eva Noritsyna
|
---|
32 | Andra Sav
|
---|
33 | """
|
---|
34 |
|
---|
35 |
|
---|
36 | class Agent3(DefaultParty):
|
---|
37 | def __init__(self, reporter: Reporter = None):
|
---|
38 | super().__init__(reporter)
|
---|
39 | self.getReporter().log(logging.INFO, "party is initialized")
|
---|
40 | self._profile = None
|
---|
41 | # Last bid sent by this agent
|
---|
42 | self._my_last_bid: Bid = None
|
---|
43 | # Last bid received from the opponent
|
---|
44 | self._last_received_bid: Bid = None
|
---|
45 | # Utility of the last bid received from the opponent
|
---|
46 | self._last_received_utility = -1
|
---|
47 | # Bid received from the opponent two rounds ago (short term memory)
|
---|
48 | self._previous_to_last_bid: Bid = None
|
---|
49 | # Current statistics of opponent bids
|
---|
50 | self._stat_dict = None
|
---|
51 | # Statistics of opponent bids before this round
|
---|
52 | self._last_stat_dict = None
|
---|
53 | # Bids which should be taken into consideration
|
---|
54 | self._possible_bids = None
|
---|
55 | # Index of the current bid in the stored list of bids
|
---|
56 | self._last_index = 0
|
---|
57 | # Prediction for opponent's weights of issues
|
---|
58 | self._opponent_weights = None
|
---|
59 | # Prediction for opponent's preferences for issue values
|
---|
60 | self._opponent_value_weights = None
|
---|
61 | self._sorted_issue_values = dict()
|
---|
62 | # Previous to last bid
|
---|
63 | self._last_bid_to_process = None
|
---|
64 | # Best welfare of opponent's bid seen so far
|
---|
65 | self._best_bid_welfare = -1
|
---|
66 | # Best utility of opponent's bid seen so far
|
---|
67 | self._best_bid_utility = -1
|
---|
68 | # Best bid seen so far
|
---|
69 | self._best_received_bid: Bid = None
|
---|
70 | # Progress when the bids were reranked last time
|
---|
71 | self._last_calculation_progress = 0
|
---|
72 | # Willingness to make big concessions (rerank bids)
|
---|
73 | self._big_concessions_index = 0
|
---|
74 | # Willingness to make small concessions
|
---|
75 | self._small_concessions_index = 0
|
---|
76 |
|
---|
77 | # With small probability be the first to make a concession
|
---|
78 | self._random_concessions_coefficient = 0.015
|
---|
79 | # Steers the length of time when bids are not accepted
|
---|
80 | self._exploration_coefficient = 0.9
|
---|
81 | # Steers the length of time when bids are not reranked
|
---|
82 | self._progress_coefficient = 0.1
|
---|
83 | # Steers willingness to prioritize welfare over own utility
|
---|
84 | self._selfishness_coefficient = 0.8
|
---|
85 |
|
---|
86 | def notifyChange(self, info: Inform):
|
---|
87 | """This is the entry point of all interaction with your agent after is has been initialised.
|
---|
88 | Args:
|
---|
89 | info (Inform): Contains either a request for action or information.
|
---|
90 | """
|
---|
91 |
|
---|
92 | # Settings message is the first message that will be send to the
|
---|
93 | # agent containing all the information about the negotiation session.
|
---|
94 | if isinstance(info, Settings):
|
---|
95 | self._settings: Settings = cast(Settings, info)
|
---|
96 | self._me = self._settings.getID()
|
---|
97 |
|
---|
98 | # Progress towards the deadline has to be tracked manually through the use of the Progress object
|
---|
99 | self._progress: ProgressRounds = self._settings.getProgress()
|
---|
100 |
|
---|
101 | # Profile contains the preferences of the agent over the domain
|
---|
102 | self._profile = ProfileConnectionFactory.create(
|
---|
103 | info.getProfile().getURI(), self.getReporter()
|
---|
104 | )
|
---|
105 |
|
---|
106 | # Store reservation utility if it exists
|
---|
107 | profile = self._profile.getProfile()
|
---|
108 | if profile.getReservationBid():
|
---|
109 | self._reservation_utility = profile.getUtility(profile.getReservationBid())
|
---|
110 | else:
|
---|
111 | self._reservation_utility = 0
|
---|
112 |
|
---|
113 | # Prepare data structures for recording opponent bids
|
---|
114 | self._stat_dict = self._prepare_stat_dict()
|
---|
115 | self._last_stat_dict = self._stat_dict
|
---|
116 |
|
---|
117 | self._prepare_bid_data()
|
---|
118 | self._create_possible_bids()
|
---|
119 |
|
---|
120 | # ActionDone is an action send by an opponent (an offer or an accept)
|
---|
121 | elif isinstance(info, ActionDone):
|
---|
122 | action: Action = cast(ActionDone, info).getAction()
|
---|
123 | actor = action.getActor()
|
---|
124 |
|
---|
125 | # Ignore action if it is our action
|
---|
126 | if actor != self._me:
|
---|
127 | # If it is an offer, set the last received bid
|
---|
128 | if isinstance(action, Offer):
|
---|
129 | self._last_received_bid = cast(Offer, action).getBid()
|
---|
130 |
|
---|
131 | # Execute the move
|
---|
132 | elif isinstance(info, YourTurn):
|
---|
133 | action = self._myTurn()
|
---|
134 | if isinstance(self._progress, ProgressRounds):
|
---|
135 | self._progress = self._progress.advance()
|
---|
136 | self.getConnection().send(action)
|
---|
137 |
|
---|
138 | # Finish the negotiation on agreement or deadline
|
---|
139 | elif isinstance(info, Finished):
|
---|
140 | self.terminate()
|
---|
141 |
|
---|
142 | else:
|
---|
143 | self.getReporter().log(logging.WARNING, "Ignoring unknown info " + str(info))
|
---|
144 |
|
---|
145 | # Lets the geniusweb system know what settings this agent can handle
|
---|
146 | def getCapabilities(self) -> Capabilities:
|
---|
147 | return Capabilities(
|
---|
148 | {"SAOP"},
|
---|
149 | {"geniusweb.profile.utilityspace.LinearAdditive"}
|
---|
150 | )
|
---|
151 |
|
---|
152 | # Terminates the agent and its connections
|
---|
153 | # leave it as it is for this competition
|
---|
154 | def terminate(self):
|
---|
155 | self.getReporter().log(logging.INFO, "party is terminating:")
|
---|
156 | super().terminate()
|
---|
157 | if self._profile is not None:
|
---|
158 | self._profile.close()
|
---|
159 | self._profile = None
|
---|
160 |
|
---|
161 |
|
---|
162 |
|
---|
163 | def getDescription(self) -> str:
|
---|
164 | return """Agent which employs frequency modelling to optimize for welfare of bids.
|
---|
165 | At first bids are returned based on highest individual utility, then based on welfare.
|
---|
166 | It concedes after the opponent concedes enough times or on its own with small probability.
|
---|
167 | Acceptance is based on long exploration and then at the end choosing a bid that is at least
|
---|
168 | as good as what has been previously seen. Always agrees in the last round."""
|
---|
169 |
|
---|
170 | # Execute a turn
|
---|
171 | def _myTurn(self):
|
---|
172 | self._collect_opponent_bid_data()
|
---|
173 |
|
---|
174 | # Check if the last received offer if the opponent is good enough
|
---|
175 | if self._isGood(self._last_received_bid):
|
---|
176 | # If so, accept the offer
|
---|
177 | action = Accept(self._me, self._last_received_bid)
|
---|
178 |
|
---|
179 | # If not, find a bid to propose as counter offer
|
---|
180 | else:
|
---|
181 | bid = self._findBid()
|
---|
182 | self._my_last_bid = bid
|
---|
183 | action = Offer(self._me, bid)
|
---|
184 |
|
---|
185 | # Send the action
|
---|
186 | return action
|
---|
187 |
|
---|
188 | # method that checks if we would agree with an offer
|
---|
189 | def _isGood(self, bid: Bid) -> bool:
|
---|
190 | """Evaluates the opponent's bid based on its utility and welfare.
|
---|
191 |
|
---|
192 | Args:
|
---|
193 | bid: Set of values for every issue suggested by the opponent.
|
---|
194 |
|
---|
195 | Returns:
|
---|
196 | bool: Confirmation whether the current bid is acceptable.
|
---|
197 | """
|
---|
198 |
|
---|
199 | # If no bid was received then it is definitely bad
|
---|
200 | if bid is None:
|
---|
201 | return False
|
---|
202 |
|
---|
203 | # If we have never stored a bid previously, store it
|
---|
204 | if self._best_received_bid is None:
|
---|
205 | self._best_received_bid = bid
|
---|
206 |
|
---|
207 | profile = self._profile.getProfile()
|
---|
208 | progress = self._progress.get(time.time() * 1000)
|
---|
209 |
|
---|
210 | # Find our utility for the opponent's bid
|
---|
211 | current_utility = profile.getUtility(bid)
|
---|
212 | # Find welfare for the opponent's bid
|
---|
213 | new_bid_welfare = self._calculate_welfare(bid)
|
---|
214 | # Welfare of the best stored bid may change in time so we have to recalculate it
|
---|
215 | self._best_bid_welfare = self._calculate_welfare(self._best_received_bid)
|
---|
216 |
|
---|
217 | # Small concession index corresponds to the willingness to take the next best bet
|
---|
218 | # Big concession index corresponds to the willingness to rerank bets
|
---|
219 | if 0 < self._last_received_utility < current_utility:
|
---|
220 | self._big_concessions_index += 1
|
---|
221 | self._small_concessions_index += 1
|
---|
222 |
|
---|
223 | # Always store the best bid seen from the opponent
|
---|
224 | if new_bid_welfare >= self._best_bid_welfare:
|
---|
225 | self._best_received_bid = bid
|
---|
226 | self._best_bid_welfare = new_bid_welfare
|
---|
227 |
|
---|
228 | # Also update the best bid utility if applicable
|
---|
229 | if current_utility > self._best_bid_utility:
|
---|
230 | self._best_bid_utility = current_utility
|
---|
231 |
|
---|
232 | self._last_received_utility = current_utility
|
---|
233 |
|
---|
234 | # Spend 90% of time looking for the best option your opponent can send
|
---|
235 | if progress <= self._exploration_coefficient:
|
---|
236 | return False
|
---|
237 |
|
---|
238 | # If it is the end of negotiation and we're at least meeting the reservation utility, concede.
|
---|
239 | # It is always better to have an agreement than not.
|
---|
240 | if progress >= 0.99 and current_utility > self._reservation_utility:
|
---|
241 | return True
|
---|
242 |
|
---|
243 | # Accept a bid if it is at least as good as the best bid seen so far
|
---|
244 | # and has a utility higher than our reservation value.
|
---|
245 | if new_bid_welfare >= self._best_bid_welfare \
|
---|
246 | and current_utility >= self._reservation_utility:
|
---|
247 | return True
|
---|
248 |
|
---|
249 | # If none of the conditions hold, it is not a good bid.
|
---|
250 | return False
|
---|
251 |
|
---|
252 | def _findBid(self) -> Bid:
|
---|
253 | """Searches for a bid that can be suggested to the opponent.
|
---|
254 | This uses a model of the opponent that is being created online.
|
---|
255 |
|
---|
256 | Returns:
|
---|
257 | Bid: Set of values for every issue.
|
---|
258 | """
|
---|
259 |
|
---|
260 | # If the negotiation is finishing resend the best received bid if it at least
|
---|
261 | # fulfills reservation utility expectation
|
---|
262 | if self._progress.get(time.time() * 1000) >= 0.99 and self._best_bid_utility >= self._reservation_utility:
|
---|
263 | return self._best_received_bid
|
---|
264 |
|
---|
265 | # If it is time to run the welfare calculation the order of bids will change.
|
---|
266 | # As we learn more about the opponent's bids, we can model their behaviour better.
|
---|
267 | if self._run_welfare_calculation() and self._big_concessions_index >= 20:
|
---|
268 | self._rerank_bids()
|
---|
269 | self._big_concessions_index = 0
|
---|
270 | self._last_index = 0
|
---|
271 |
|
---|
272 | # Choose the next bid from our list of available bids
|
---|
273 | num_bids = len(self._possible_bids)
|
---|
274 | bid = self._possible_bids[max(0, min(self._last_index, num_bids - 1))][0]
|
---|
275 |
|
---|
276 | if self._small_concessions_index == 1 \
|
---|
277 | or np.random.rand() < self._random_concessions_coefficient:
|
---|
278 | self._small_concessions_index = 0
|
---|
279 | self._last_index += 1
|
---|
280 |
|
---|
281 | return bid
|
---|
282 |
|
---|
283 | def _run_welfare_calculation(self, step=0.1) -> bool:
|
---|
284 | """Used to assess based on the progress of the negotiation whether the available bids
|
---|
285 | should be reranked with the current prediction of social welfare.
|
---|
286 |
|
---|
287 | Args:
|
---|
288 | step (float, optional): Informs how often the reranking should happen, defaults to 0.1.
|
---|
289 |
|
---|
290 | Returns:
|
---|
291 | bool: True if a new ordering of bids should be generated.
|
---|
292 | """
|
---|
293 | current = (np.floor(self._progress.get(time.time() * 1000) / step)) * step
|
---|
294 | result = current != self._last_calculation_progress
|
---|
295 | self._last_calculation_progress = current
|
---|
296 |
|
---|
297 | return result and self._progress_coefficient < self._progress.get(time.time() * 1000)
|
---|
298 |
|
---|
299 | def _prepare_stat_dict(self) -> Dict:
|
---|
300 | """Before the negotiation starts, generate a dictionary storing the frequency
|
---|
301 | of opponent's bids for every value of every issue.
|
---|
302 |
|
---|
303 | Returns:
|
---|
304 | Dict: Statistics of the opponent's bids initialized to 0
|
---|
305 | """
|
---|
306 | stats = dict()
|
---|
307 | domain = self._profile.getProfile().getDomain()
|
---|
308 |
|
---|
309 | # Create a dictionary for every issue in the domain
|
---|
310 | for issue in domain.getIssues():
|
---|
311 | stats[issue] = dict()
|
---|
312 | # Create a key for every possible value of this issue
|
---|
313 | for value in domain.getValues(issue):
|
---|
314 | stats[issue][value] = 0
|
---|
315 |
|
---|
316 | return stats
|
---|
317 |
|
---|
318 | def _prepare_bid_data(self):
|
---|
319 | """Before the negotiation starts, generate dictionaries storing the opponent's
|
---|
320 | decisions and the model of their utility function
|
---|
321 | """
|
---|
322 | utilities = self._profile.getProfile().getUtilities()
|
---|
323 | self._last_bid_to_process = dict()
|
---|
324 | self._opponent_weights = dict() # Predictions for the weights of every issue
|
---|
325 | self._opponent_value_weights = dict() # Predictions for the weights of every value
|
---|
326 |
|
---|
327 | for utility in utilities:
|
---|
328 | self._opponent_value_weights[utility] = np.zeros(len(self._stat_dict[utility]))
|
---|
329 | self._opponent_weights[utility] = 1 / len(utilities) # At the beginning all weights are equal
|
---|
330 | self._last_bid_to_process[utility] = 0
|
---|
331 |
|
---|
332 | def _collect_opponent_bid_data(self):
|
---|
333 | """Process the opponent's bid to update our model.
|
---|
334 | Works based on the heuristics that if an opponent sends the same value for an issue frequently,
|
---|
335 | then it is most likely very important for that opponent.
|
---|
336 | Also, if an opponent changes their mind about an issue frequently,
|
---|
337 | then the issue probably doesn't matter for the opponent too much.
|
---|
338 | """
|
---|
339 | self._last_stat_dict = self._stat_dict
|
---|
340 | last_bid = self._last_received_bid
|
---|
341 | if last_bid is None:
|
---|
342 | return
|
---|
343 | bid_data = last_bid.getIssueValues()
|
---|
344 | alpha = 0.03 # Serves as the "learning rate" for the issue weigths
|
---|
345 |
|
---|
346 | for i, issue in enumerate(bid_data):
|
---|
347 | # Record the use of certain value
|
---|
348 | self._stat_dict[issue][bid_data[issue]] += 1
|
---|
349 |
|
---|
350 | if bid_data[issue] == self._last_bid_to_process[issue]:
|
---|
351 | # Logarithm is used as some issues have less values than others which often makes
|
---|
352 | # opponents unwilling to change even if the issue weight is relatively low.
|
---|
353 | self._opponent_weights[issue] += alpha * np.log(len(self._stat_dict[issue])) * 0.3
|
---|
354 |
|
---|
355 | # Overwrite last used value
|
---|
356 | self._last_bid_to_process[issue] = bid_data[issue]
|
---|
357 | self._opponent_value_weights[issue] = calculate_weights(self._stat_dict[issue].copy(), method="normalize")
|
---|
358 |
|
---|
359 | weights = list(self._opponent_weights.values())
|
---|
360 | weights = weights / np.sum(weights)
|
---|
361 | for i, key in enumerate(self._opponent_weights):
|
---|
362 | self._opponent_weights[key] = weights[i]
|
---|
363 |
|
---|
364 | def _create_possible_bids(self):
|
---|
365 | """Generates a list of bids that may be acceptable for this agent.
|
---|
366 | They are sorted based on decreasing utility first, and later based on welfare.
|
---|
367 | """
|
---|
368 |
|
---|
369 | bids = BidsWithUtility.create(cast(LinearAdditive, self._profile.getProfile()))
|
---|
370 | range = bids.getRange()
|
---|
371 |
|
---|
372 | domain_spread = range.getMax() - range.getMin()
|
---|
373 | domain = self._profile.getProfile().getDomain()
|
---|
374 | all_bids = AllBidsList(domain)
|
---|
375 | domain_size = all_bids.size()
|
---|
376 | possible_bids = []
|
---|
377 |
|
---|
378 | # On small domains just save all bids
|
---|
379 | if domain_size <= 50000:
|
---|
380 | interval = Interval(Decimal(self._reservation_utility), Decimal(1.0))
|
---|
381 | for bid in bids.getBids(interval):
|
---|
382 | # Calculate bid utility
|
---|
383 | utility = self._profile.getProfile().getUtility(bid)
|
---|
384 | # Save along with bid and the opponent's utility (to be calculated later)
|
---|
385 | possible_bids.append([bid, utility, 0])
|
---|
386 |
|
---|
387 | # Sort by utility in descending order
|
---|
388 | possible_bids.sort(key=lambda x: x[1], reverse=True)
|
---|
389 | self._possible_bids = possible_bids
|
---|
390 | return
|
---|
391 |
|
---|
392 | # On large domains we need to limit the number of bids taken into consideration
|
---|
393 | else:
|
---|
394 | max_bid = bids.getExtremeBid(isMax=True)
|
---|
395 | # Strong assumption: utilities are uniformly distributed in range of domain
|
---|
396 | # Take
|
---|
397 | min_utility = range.getMax() - (domain_spread * 50000) / domain_size
|
---|
398 | interval = Interval(Decimal(min_utility), range.getMax())
|
---|
399 | for bid in bids.getBids(interval):
|
---|
400 | bid_utility = self._profile.getProfile().getUtility(bid)
|
---|
401 | if bid != max_bid and bid_utility > self._reservation_utility:
|
---|
402 | possible_bids.append([bid, bid_utility, 0])
|
---|
403 |
|
---|
404 | count = 0
|
---|
405 | while count <= 40000:
|
---|
406 | bid = all_bids.get(np.random.randint(0, domain_size - 1))
|
---|
407 | if self._profile.getProfile().getUtility(bid) > self._reservation_utility:
|
---|
408 | possible_bids.append([bid, self._profile.getProfile().getUtility(bid), 0])
|
---|
409 | count += 1
|
---|
410 |
|
---|
411 | # We always want at least one bid
|
---|
412 | possible_bids.append([max_bid, self._profile.getProfile().getUtility(max_bid), 0])
|
---|
413 | # Sort by utility in descending order
|
---|
414 | possible_bids.sort(key=lambda x: x[1], reverse=True)
|
---|
415 | self._possible_bids = possible_bids
|
---|
416 |
|
---|
417 | def _rerank_bids(self):
|
---|
418 | """Sort the list of all acceptable bids based on the current estimate of their welfare
|
---|
419 | """
|
---|
420 | self._possible_bids.sort(key=lambda x: self._calculate_welfare(x[0]), reverse=True)
|
---|
421 |
|
---|
422 | def _calculate_welfare(self, bid, method="weighted_sum") -> Decimal:
|
---|
423 | """Calculate welfare which is understood as the sum of own and opponent's utilities.
|
---|
424 | Selfishness_coefficient can be used to steer preference for optimizing own utility.
|
---|
425 | This seems to give better results than optimizing for the minimal utility.
|
---|
426 |
|
---|
427 | Args:
|
---|
428 | bid (Bid): Set of values for every issue. At different stages of the negotiation,
|
---|
429 | the welfare of the same bid may differ (due to refined opponent model).
|
---|
430 |
|
---|
431 | Returns:
|
---|
432 | Decimal: Prediction of the welfare of a bid
|
---|
433 | """
|
---|
434 | own_utility = self._profile.getProfile().getUtility(bid)
|
---|
435 | opponent_utility = self._calculate_opponent_utility(bid)
|
---|
436 |
|
---|
437 | if method == "weighted_sum":
|
---|
438 | return Decimal(self._selfishness_coefficient) * Decimal(own_utility) \
|
---|
439 | + Decimal(1 - self._selfishness_coefficient) * Decimal(opponent_utility)
|
---|
440 |
|
---|
441 | else:
|
---|
442 | return min(Decimal(own_utility), Decimal(opponent_utility))
|
---|
443 |
|
---|
444 | def _calculate_opponent_utility(self, bid) -> float:
|
---|
445 | """Calculate the utility of a bid for the opponent based on the available model
|
---|
446 |
|
---|
447 | Args:
|
---|
448 | bid (Bid): Set of values for every issue. At different stages of the negotiation,
|
---|
449 | the opponent's utility of the same bid may differ (due to refined opponent model).
|
---|
450 |
|
---|
451 | Returns:
|
---|
452 | float: Prediction of the utility of a bid for the opponent
|
---|
453 | """
|
---|
454 | domain = self._profile.getProfile().getDomain()
|
---|
455 | opponent_utility = 0
|
---|
456 | for issue in domain.getIssues():
|
---|
457 | opponent_utility += self._opponent_weights[issue] \
|
---|
458 | * self._opponent_value_weights[issue][bid.getValue(issue)]
|
---|
459 | return opponent_utility
|
---|
460 |
|
---|
461 |
|
---|
462 | def calculate_weights(count_dict, method="linear") -> Dict:
|
---|
463 | """Models the predicted weights of an opponent for each value of an issue
|
---|
464 |
|
---|
465 | Args:
|
---|
466 | count_dict (Dict): Stores number of changes in opponent's bids per issue
|
---|
467 | method (str, optional): Method used to calculate weights, defaults to "linear".
|
---|
468 |
|
---|
469 | Returns:
|
---|
470 | Dict: modified dictionary with a model of opponent's weights
|
---|
471 | """
|
---|
472 | counts = list(count_dict.values())
|
---|
473 |
|
---|
474 | # Predict the weights in a linear manner based on available counts
|
---|
475 | if method == "linear":
|
---|
476 | max_pos = np.argmax(counts)
|
---|
477 | max_dist = max(max_pos + 1, len(counts) - max_pos)
|
---|
478 | step = 1
|
---|
479 |
|
---|
480 | if len(counts) > 2:
|
---|
481 | step = 1 / (max_dist - 1)
|
---|
482 |
|
---|
483 | for i in range(len(counts)):
|
---|
484 | counts[i] = step * (max_dist - abs(i - max_pos) - 1)
|
---|
485 |
|
---|
486 | # Predict the weights by normalizing counts
|
---|
487 | elif method == "normalize":
|
---|
488 | counts = counts / np.max(counts)
|
---|
489 |
|
---|
490 | for i, key in enumerate(count_dict):
|
---|
491 | count_dict[key] = counts[i]
|
---|
492 | return count_dict |
---|