Skip to content

Commit

Permalink
Merge pull request #233 from Limmen/dev
Browse files Browse the repository at this point in the history
probability function is wierd
  • Loading branch information
Limmen authored Aug 18, 2023
2 parents 085d131 + 305c0d7 commit c562e4c
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 7 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Dict, Union, Any
import numpy as np
from numpy.typing import NDArray
from csle_common.dao.training.policy import Policy
from csle_common.dao.training.agent_type import AgentType
from csle_common.dao.training.player_type import PlayerType
Expand Down Expand Up @@ -36,7 +37,7 @@ def __init__(self, simulation_name: str, player_type: PlayerType, states: List[S
self.avg_R = avg_R
self.policy_type = PolicyType.MIXED_PPO_POLICY

def action(self, o: List[float]) -> Union[int, List[int], np.ndarray[Any, Any]]:
def action(self, o: List[float]) -> int:
"""
Multi-threshold stopping policy
Expand All @@ -47,15 +48,15 @@ def action(self, o: List[float]) -> Union[int, List[int], np.ndarray[Any, Any]]:
a = policy.action(o=o)
return a

def probability(self, o: List[float], a: int) -> int:
def probability(self, o: List[float], a: int) -> float:
"""
Probability of a given action
:param o: the current observation
:param a: a given action
:return: the probability of a
"""
return self.action(o=o) == a
return float(self.action(o=o) == a)

def to_dict(self) -> Dict[str, Any]:
"""
Expand Down Expand Up @@ -120,7 +121,7 @@ def copy(self) -> "MixedPPOPolicy":
"""
return self.from_dict(self.to_dict())

def stage_policy(self, o: Union[List[Union[int, float]], int, float]) -> List[List[float]]:
def stage_policy(self, o: Union[List[int], List[float]]) -> List[List[float]]:
"""
Returns the stage policy for a given observation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def __init__(self, agent_type: AgentType, player_type: PlayerType) -> None:
self.player_type = player_type

@abstractmethod
def action(self, o: Any) -> NDArray[Any]:
def action(self, o: Any) -> Union[int, NDArray[Any]]:
"""
Calculates the next action
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __init__(self, model: Optional[PPO], simulation_name: str, save_path: str, p
self.avg_R = avg_R
self.policy_type = PolicyType.PPO

def action(self, o: Union[List[float], List[int]]) -> NDArray[Any]:
def action(self, o: Union[List[float], List[int]]) -> int:
"""
Multi-threshold stopping policy
Expand Down Expand Up @@ -129,7 +129,8 @@ def stage_policy(self, o: Union[List[int], List[float]]) -> List[List[float]]:
stage_strategy[i][j] = self.probability(o=o, a=j)
stage_strategy[i] = iteround.saferound(stage_strategy[i], 2)
assert round(sum(stage_strategy[i]), 3) == 1
return stage_strategy.tolist()
stage_strategy.tolist()
return list(stage_strategy.tolist())

def _get_attacker_dist(self, obs) -> List[float]:
"""
Expand Down

0 comments on commit c562e4c

Please sign in to comment.