Skip to content

Commit

Permalink
Merge pull request #209 from Limmen/dev5
Browse files Browse the repository at this point in the history
observation_function_config.py
  • Loading branch information
Limmen authored Aug 16, 2023
2 parents a55e0c0 + 6200790 commit f4f224d
Showing 1 changed file with 4 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from typing import List, Dict, Any
import numpy as np
from numpy.typing import NDArray
from csle_base.json_serializable import JSONSerializable


Expand All @@ -8,7 +9,7 @@ class ObservationFunctionConfig(JSONSerializable):
DTO representing the configuration of the observation function of a simulation
"""

def __init__(self, observation_tensor: List, component_observation_tensors: Dict[str, List]):
def __init__(self, observation_tensor: NDArray[Any], component_observation_tensors: Dict[str, NDArray[Any]]):
"""
Initializes the DTO
:param observation_tensor: the observation tensor
Expand All @@ -34,8 +35,8 @@ def to_dict(self) -> Dict[str, Any]:
:return: a dict representation of the object
"""
d = {}
if isinstance(self.observation_tensor, np.ndarray):
d: Dict[str, Any] = {}
if isinstance(self.observation_tensor, type(NDArray[Any])):
tensor = self.observation_tensor.tolist()
else:
tensor = self.observation_tensor
Expand Down

0 comments on commit f4f224d

Please sign in to comment.