diff --git a/simulation-system/libs/csle-common/src/csle_common/dao/training/mixed_ppo_policy.py b/simulation-system/libs/csle-common/src/csle_common/dao/training/mixed_ppo_policy.py index 98beb4a12..e359447bb 100644 --- a/simulation-system/libs/csle-common/src/csle_common/dao/training/mixed_ppo_policy.py +++ b/simulation-system/libs/csle-common/src/csle_common/dao/training/mixed_ppo_policy.py @@ -37,7 +37,7 @@ def __init__(self, simulation_name: str, player_type: PlayerType, states: List[S self.avg_R = avg_R self.policy_type = PolicyType.MIXED_PPO_POLICY - def action(self, o: List[float]) -> NDArray[Any]: + def action(self, o: List[float]) -> int: """ Multi-threshold stopping policy @@ -48,7 +48,7 @@ def action(self, o: List[float]) -> NDArray[Any]: a = policy.action(o=o) return a - def probability(self, o: List[float], a: int) -> int: + def probability(self, o: List[float], a: int) -> float: """ Probability of a given action @@ -56,7 +56,7 @@ def probability(self, o: List[float], a: int) -> int: :param a: a given action :return: the probability of a """ - return self.action(o=o) == a + return float(self.action(o=o) == a) def to_dict(self) -> Dict[str, Any]: """ diff --git a/simulation-system/libs/csle-common/src/csle_common/dao/training/policy.py b/simulation-system/libs/csle-common/src/csle_common/dao/training/policy.py index 561287092..40d40fc6e 100644 --- a/simulation-system/libs/csle-common/src/csle_common/dao/training/policy.py +++ b/simulation-system/libs/csle-common/src/csle_common/dao/training/policy.py @@ -22,7 +22,7 @@ def __init__(self, agent_type: AgentType, player_type: PlayerType) -> None: self.player_type = player_type @abstractmethod - def action(self, o: Any) -> NDArray[Any]: + def action(self, o: Any) -> Union[int, NDArray[Any]]: """ Calculates the next action diff --git a/simulation-system/libs/csle-common/src/csle_common/dao/training/ppo_policy.py b/simulation-system/libs/csle-common/src/csle_common/dao/training/ppo_policy.py index 575f844ce..b9d0c928c 100644 --- a/simulation-system/libs/csle-common/src/csle_common/dao/training/ppo_policy.py +++ b/simulation-system/libs/csle-common/src/csle_common/dao/training/ppo_policy.py @@ -51,7 +51,7 @@ def __init__(self, model: Optional[PPO], simulation_name: str, save_path: str, p self.avg_R = avg_R self.policy_type = PolicyType.PPO - def action(self, o: Union[List[float], List[int]]) -> NDArray[Any]: + def action(self, o: Union[List[float], List[int]]) -> int: """ Multi-threshold stopping policy @@ -129,7 +129,8 @@ def stage_policy(self, o: Union[List[int], List[float]]) -> List[List[float]]: stage_strategy[i][j] = self.probability(o=o, a=j) stage_strategy[i] = iteround.saferound(stage_strategy[i], 2) assert round(sum(stage_strategy[i]), 3) == 1 - return stage_strategy.tolist() + stage_strategy.tolist() + return list(stage_strategy.tolist()) def _get_attacker_dist(self, obs) -> List[float]: """