Skip to content

Commit

Permalink
fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
nforsg committed Aug 18, 2023
1 parent 9df8a39 commit 305c0d7
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(self, simulation_name: str, player_type: PlayerType, states: List[S
self.avg_R = avg_R
self.policy_type = PolicyType.MIXED_PPO_POLICY

def action(self, o: List[float]) -> NDArray[Any]:
def action(self, o: List[float]) -> int:
"""
Multi-threshold stopping policy
Expand All @@ -48,15 +48,15 @@ def action(self, o: List[float]) -> NDArray[Any]:
a = policy.action(o=o)
return a

def probability(self, o: List[float], a: int) -> int:
def probability(self, o: List[float], a: int) -> float:
"""
Probability of a given action
:param o: the current observation
:param a: a given action
:return: the probability of a
"""
return self.action(o=o) == a
return float(self.action(o=o) == a)

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, agent_type: AgentType, player_type: PlayerType) -> None:
self.player_type = player_type

@abstractmethod
def action(self, o: Any) -> NDArray[Any]:
def action(self, o: Any) -> Union[int, NDArray[Any]]:
"""
Calculates the next action
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, model: Optional[PPO], simulation_name: str, save_path: str, p
self.avg_R = avg_R
self.policy_type = PolicyType.PPO

def action(self, o: Union[List[float], List[int]]) -> NDArray[Any]:
def action(self, o: Union[List[float], List[int]]) -> int:
"""
Multi-threshold stopping policy
Expand Down Expand Up @@ -129,7 +129,8 @@ def stage_policy(self, o: Union[List[int], List[float]]) -> List[List[float]]:
stage_strategy[i][j] = self.probability(o=o, a=j)
stage_strategy[i] = iteround.saferound(stage_strategy[i], 2)
assert round(sum(stage_strategy[i]), 3) == 1
return stage_strategy.tolist()
stage_strategy.tolist()
return list(stage_strategy.tolist())

def _get_attacker_dist(self, obs) -> List[float]:
"""
Expand Down

0 comments on commit 305c0d7

Please sign in to comment.