Skip to content

Commit

Permalink
possible modification of typing
Browse files Browse the repository at this point in the history
  • Loading branch information
nforsg committed Aug 18, 2023
1 parent c562e4c commit fa90933
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def __init__(self, emulation_env_name: str, emulation_statistic_id: int,
self.emulation_env_name = emulation_env_name
self.emulation_statistic_id = emulation_statistic_id
self.id = -1
self.conditionals_kl_divergences: Dict[str, Dict[str, Dict[str, Union[str, float]]]] = {}
self.conditionals_kl_divergences: Dict[str, Dict[str, Dict[str, float]]] = {}
self.compute_kl_divergences()

def compute_kl_divergences(self) -> None:
Expand All @@ -58,7 +58,7 @@ def compute_kl_divergences(self) -> None:
metric_distributions_condition_2[0].conditional_name][metric_dist.metric_name]):
self.conditionals_kl_divergences[metric_distributions_condition_1[0].conditional_name][
metric_distributions_condition_2[0].conditional_name][
metric_dist.metric_name] = "inf"
metric_dist.metric_name] = math.inf

@staticmethod
def from_dict(d: Dict[str, Any]) -> "GPSystemModel":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ class HParam(JSONSerializable):
DTO class representing a hyperparameter
"""

def __init__(self, value: Union[int, float, str, List], name: str, descr: str):
def __init__(self, value: Union[int, float, str, List[Any]], name: str, descr: str):
"""
Initializes the DTO
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
A FNN model with Softmax output defined in PyTorch
"""
import torch

from typing import Union, Any, List

class FNNwithSoftmax(torch.nn.Module):
"""
Expand All @@ -11,8 +11,8 @@ class FNNwithSoftmax(torch.nn.Module):
Sub-classing the torch.nn.Module to be able to use high-level API for creating the custom network
"""

def __init__(self, input_dim: int, output_dim: int, hidden_dim: int, num_hidden_layers: int = 2,
hidden_activation: str = "ReLU"):
def __init__(self, input_dim: int, output_dim: int, hidden_dim: Union[int, float, str, List[Any]],
num_hidden_layers: int = 2, hidden_activation: str = "ReLU"):
"""
Builds the model
Expand Down

0 comments on commit fa90933

Please sign in to comment.