From 081d0970de86fbb368d3249f7362087394f4c75e Mon Sep 17 00:00:00 2001 From: nforsg Date: Wed, 16 Aug 2023 16:23:06 +0200 Subject: [PATCH] reward_function_config.py, transistion_operator_config.py --- .../dao/simulation_config/reward_function_config.py | 7 ++++--- .../dao/simulation_config/transition_operator_config.py | 7 ++++--- 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/simulation-system/libs/csle-common/src/csle_common/dao/simulation_config/reward_function_config.py b/simulation-system/libs/csle-common/src/csle_common/dao/simulation_config/reward_function_config.py index b1c5f374f..e75a459d3 100644 --- a/simulation-system/libs/csle-common/src/csle_common/dao/simulation_config/reward_function_config.py +++ b/simulation-system/libs/csle-common/src/csle_common/dao/simulation_config/reward_function_config.py @@ -1,6 +1,7 @@ from typing import Dict, Any, List import numpy as np +from numpy.typing import NDArray from csle_base.json_serializable import JSONSerializable @@ -10,7 +11,7 @@ class RewardFunctionConfig(JSONSerializable): DTO containing the reward tensor of a simulation """ - def __init__(self, reward_tensor: List): + def __init__(self, reward_tensor: NDArray[Any]): """ Initalizes the DTO @@ -33,8 +34,8 @@ def to_dict(self) -> Dict[str, Any]: """ :return: a dict representation of the object """ - d = {} - if isinstance(self.reward_tensor, np.ndarray): + d: Dict[str, Any] = {} + if isinstance(self.reward_tensor, type(NDArray[Any])): tensor = self.reward_tensor.tolist() else: tensor = self.reward_tensor diff --git a/simulation-system/libs/csle-common/src/csle_common/dao/simulation_config/transition_operator_config.py b/simulation-system/libs/csle-common/src/csle_common/dao/simulation_config/transition_operator_config.py index f00c525e7..46d7fdf22 100644 --- a/simulation-system/libs/csle-common/src/csle_common/dao/simulation_config/transition_operator_config.py +++ b/simulation-system/libs/csle-common/src/csle_common/dao/simulation_config/transition_operator_config.py @@ -1,5 +1,6 @@ from typing import List, Dict, Any import numpy as np +from numpy.typing import NDArray from csle_base.json_serializable import JSONSerializable @@ -8,7 +9,7 @@ class TransitionOperatorConfig(JSONSerializable): DTO representing the transition operator definition of a simulation """ - def __init__(self, transition_tensor: List): + def __init__(self, transition_tensor: NDArray[Any]): """ Initializes the DTO @@ -33,8 +34,8 @@ def to_dict(self) -> Dict[str, Any]: :return: a dict representation of the object """ - d = {} - if isinstance(self.transition_tensor, np.ndarray): + d: Dict[str, Any] = {} + if isinstance(self.transition_tensor, type(NDArray[Any])): tensor = self.transition_tensor.tolist() else: tensor = self.transition_tensor