Skip to content

Commit

Permalink
Merge pull request #235 from Limmen/dev
Browse files Browse the repository at this point in the history
possible modification of typing
  • Loading branch information
Limmen authored Aug 18, 2023
2 parents f6e55f3 + 8e7f8da commit 05240df
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, emulation_env_name: str, emulation_statistic_id: int,
self.emulation_env_name = emulation_env_name
self.emulation_statistic_id = emulation_statistic_id
self.id = -1
self.conditionals_kl_divergences: Dict[str, Dict[str, Dict[str, Union[str, float]]]] = {}
self.conditionals_kl_divergences: Dict[str, Dict[str, Dict[str, float]]] = {}
self.compute_kl_divergences()

def compute_kl_divergences(self) -> None:
Expand All @@ -58,7 +58,7 @@ def compute_kl_divergences(self) -> None:
metric_distributions_condition_2[0].conditional_name][metric_dist.metric_name]):
self.conditionals_kl_divergences[metric_distributions_condition_1[0].conditional_name][
metric_distributions_condition_2[0].conditional_name][
metric_dist.metric_name] = "inf"
metric_dist.metric_name] = math.inf

@staticmethod
def from_dict(d: Dict[str, Any]) -> "GPSystemModel":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ def from_dict(d: Dict[str, Any]) -> "FNNWithSoftmaxPolicy":
obj.id = d["id"]
return obj

def stage_policy(self, o: Union[List[Union[int, float]], int, float]) -> List[List[float]]:
def stage_policy(self, o: Union[List[int], List[float]]) -> List[List[float]]:
"""
Gets the stage policy, i.e a |S|x|A| policy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class HParam(JSONSerializable):
DTO class representing a hyperparameter
"""

def __init__(self, value: Union[int, float, str, List], name: str, descr: str):
def __init__(self, value: Union[int, float, str, List[Any]], name: str, descr: str):
"""
Initializes the DTO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
A FNN model with Softmax output defined in PyTorch
"""
import torch

from typing import Union, Any, List

class FNNwithSoftmax(torch.nn.Module):
"""
Expand All @@ -11,8 +11,9 @@ class FNNwithSoftmax(torch.nn.Module):
Sub-classing the torch.nn.Module to be able to use high-level API for creating the custom network
"""

def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, num_hidden_layers: int = 2,
hidden_activation: str = "ReLU"):
def __init__(self, input_dim: int, output_dim: int, hidden_dim: Union[int, float, str, List[Any]],
num_hidden_layers: Union[int, float, str, List[Any]] = 2,
hidden_activation: Union[int, float, str, List[Any]] = "ReLU"):
"""
Builds the model
Expand Down

0 comments on commit 05240df

Please sign in to comment.