Skip to content

Commit

Permalink
work in progress
Browse files Browse the repository at this point in the history
  • Loading branch information
nforsg committed Aug 16, 2023
1 parent 6ac927c commit 829cfdd
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def stage_policy(self, o: Union[List[Union[int, float]], int, float]) -> List[Li
"""
raise NotImplementedError("Not implemented")

def _defender_action(self, o) -> Tuple[int, float]:
def _defender_action(self, o) -> Tuple[int, int]:
"""
Linear threshold stopping policy of the defender
Expand Down Expand Up @@ -124,7 +124,7 @@ def to_dict(self) -> Dict[str, List[float]]:
return d

@staticmethod
def from_dict(d: Dict) -> "LinearThresholdStoppingPolicy":
def from_dict(d: Dict[str, List[float]]) -> "LinearThresholdStoppingPolicy":
"""
Converst a dict representation of the object to an instance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def action(self, o: List[float]) -> int:
a, _ = self._attacker_action(o=o)
return a

def probability(self, o: List[float], a: int) -> int:
def probability(self, o: List[float], a: int) -> float:
"""
Probability of a given action
Expand Down Expand Up @@ -124,6 +124,8 @@ def stage_policy(self, o: Union[List[Union[int, float]], int, float]) -> List[Li
return stage_policy
else:
stage_policy = []
if self.opponent_strategy is None:
raise ValueError("The opponent strategy is None")
a1, defender_stopping_probability = self.opponent_strategy._defender_action(o=o)
if a1 == 0:
defender_stopping_probability = 1 - defender_stopping_probability
Expand Down Expand Up @@ -239,7 +241,7 @@ def to_dict(self) -> Dict[str, List[float]]:
return d

@staticmethod
def from_dict(d: Dict) -> "MultiThresholdStoppingPolicy":
def from_dict(d: Dict[str, List[float]]) -> "MultiThresholdStoppingPolicy":
"""
Converst a dict representation of the object to an instance
Expand All @@ -261,7 +263,7 @@ def thresholds(self) -> List[float]:
"""
return list(map(lambda x: round(MultiThresholdStoppingPolicy.sigmoid(x), 3), self.theta))

def stop_distributions(self) -> Dict[str, Dict[str, List[float]]]:
def stop_distributions(self) -> Dict[str, List[float]]:
"""
:return: the stop distributions and their names
"""
Expand Down

0 comments on commit 829cfdd

Please sign in to comment.