From 9df8a3917f58611f01009c1a2fd754a62e5d26d3 Mon Sep 17 00:00:00 2001 From: nforsg Date: Fri, 18 Aug 2023 09:04:31 +0200 Subject: [PATCH 1/2] probability function is wierd --- .../src/csle_common/dao/training/mixed_ppo_policy.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/simulation-system/libs/csle-common/src/csle_common/dao/training/mixed_ppo_policy.py b/simulation-system/libs/csle-common/src/csle_common/dao/training/mixed_ppo_policy.py index 5d8510538..98beb4a12 100644 --- a/simulation-system/libs/csle-common/src/csle_common/dao/training/mixed_ppo_policy.py +++ b/simulation-system/libs/csle-common/src/csle_common/dao/training/mixed_ppo_policy.py @@ -1,5 +1,6 @@ from typing import List, Dict, Union, Any import numpy as np +from numpy.typing import NDArray from csle_common.dao.training.policy import Policy from csle_common.dao.training.agent_type import AgentType from csle_common.dao.training.player_type import PlayerType @@ -36,7 +37,7 @@ def __init__(self, simulation_name: str, player_type: PlayerType, states: List[S self.avg_R = avg_R self.policy_type = PolicyType.MIXED_PPO_POLICY - def action(self, o: List[float]) -> Union[int, List[int], np.ndarray[Any, Any]]: + def action(self, o: List[float]) -> NDArray[Any]: """ Multi-threshold stopping policy @@ -120,7 +121,7 @@ def copy(self) -> "MixedPPOPolicy": """ return self.from_dict(self.to_dict()) - def stage_policy(self, o: Union[List[Union[int, float]], int, float]) -> List[List[float]]: + def stage_policy(self, o: Union[List[int], List[float]]) -> List[List[float]]: """ Returns the stage policy for a given observation From 305c0d74cd25bcc43e369cd4d6a06d956be4bc2f Mon Sep 17 00:00:00 2001 From: nforsg Date: Fri, 18 Aug 2023 10:32:13 +0200 Subject: [PATCH 2/2] fixed --- .../src/csle_common/dao/training/mixed_ppo_policy.py | 6 +++--- .../libs/csle-common/src/csle_common/dao/training/policy.py | 2 +- .../csle-common/src/csle_common/dao/training/ppo_policy.py | 5 +++-- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/simulation-system/libs/csle-common/src/csle_common/dao/training/mixed_ppo_policy.py b/simulation-system/libs/csle-common/src/csle_common/dao/training/mixed_ppo_policy.py index 98beb4a12..e359447bb 100644 --- a/simulation-system/libs/csle-common/src/csle_common/dao/training/mixed_ppo_policy.py +++ b/simulation-system/libs/csle-common/src/csle_common/dao/training/mixed_ppo_policy.py @@ -37,7 +37,7 @@ def __init__(self, simulation_name: str, player_type: PlayerType, states: List[S self.avg_R = avg_R self.policy_type = PolicyType.MIXED_PPO_POLICY - def action(self, o: List[float]) -> NDArray[Any]: + def action(self, o: List[float]) -> int: """ Multi-threshold stopping policy @@ -48,7 +48,7 @@ def action(self, o: List[float]) -> NDArray[Any]: a = policy.action(o=o) return a - def probability(self, o: List[float], a: int) -> int: + def probability(self, o: List[float], a: int) -> float: """ Probability of a given action @@ -56,7 +56,7 @@ def probability(self, o: List[float], a: int) -> int: :param a: a given action :return: the probability of a """ - return self.action(o=o) == a + return float(self.action(o=o) == a) def to_dict(self) -> Dict[str, Any]: """ diff --git a/simulation-system/libs/csle-common/src/csle_common/dao/training/policy.py b/simulation-system/libs/csle-common/src/csle_common/dao/training/policy.py index 561287092..40d40fc6e 100644 --- a/simulation-system/libs/csle-common/src/csle_common/dao/training/policy.py +++ b/simulation-system/libs/csle-common/src/csle_common/dao/training/policy.py @@ -22,7 +22,7 @@ def __init__(self, agent_type: AgentType, player_type: PlayerType) -> None: self.player_type = player_type @abstractmethod - def action(self, o: Any) -> NDArray[Any]: + def action(self, o: Any) -> Union[int, NDArray[Any]]: """ Calculates the next action diff --git a/simulation-system/libs/csle-common/src/csle_common/dao/training/ppo_policy.py b/simulation-system/libs/csle-common/src/csle_common/dao/training/ppo_policy.py index 575f844ce..b9d0c928c 100644 --- a/simulation-system/libs/csle-common/src/csle_common/dao/training/ppo_policy.py +++ b/simulation-system/libs/csle-common/src/csle_common/dao/training/ppo_policy.py @@ -51,7 +51,7 @@ def __init__(self, model: Optional[PPO], simulation_name: str, save_path: str, p self.avg_R = avg_R self.policy_type = PolicyType.PPO - def action(self, o: Union[List[float], List[int]]) -> NDArray[Any]: + def action(self, o: Union[List[float], List[int]]) -> int: """ Multi-threshold stopping policy @@ -129,7 +129,8 @@ def stage_policy(self, o: Union[List[int], List[float]]) -> List[List[float]]: stage_strategy[i][j] = self.probability(o=o, a=j) stage_strategy[i] = iteround.saferound(stage_strategy[i], 2) assert round(sum(stage_strategy[i]), 3) == 1 - return stage_strategy.tolist() + stage_strategy.tolist() + return list(stage_strategy.tolist()) def _get_attacker_dist(self, obs) -> List[float]: """