From 08bf80d65061a18d8452b506b3da157a0dc3b93e Mon Sep 17 00:00:00 2001 From: Tom Szendrey Date: Tue, 16 Jul 2024 16:28:10 -0400 Subject: [PATCH] updating converter to correctly type. (#102) --- service/models/converters.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/service/models/converters.py b/service/models/converters.py index 542d00c..0616494 100644 --- a/service/models/converters.py +++ b/service/models/converters.py @@ -11,7 +11,16 @@ def fetch_and_convert_static_interventions(policy_intervention_id, job_id): if not (policy_intervention_id): return defaultdict(dict) policy_intervention = fetch_interventions(policy_intervention_id, job_id) - interventionList = policy_intervention["interventions"] + interventionList: list[HMIIntervention] = [] + for inter in policy_intervention["interventions"]: + intervention = HMIIntervention( + name=inter["name"], + applied_to=inter["applied_to"], + type=inter["type"], + static_interventions=inter["static_interventions"], + dynamic_interventions=inter["dynamic_interventions"], + ) + interventionList.append(intervention) return convert_static_interventions(interventionList) @@ -21,10 +30,10 @@ def convert_static_interventions(interventions: list[HMIIntervention]): return defaultdict(dict) static_interventions: Dict[torch.Tensor, Dict[str, any]] = defaultdict(dict) for inter in interventions: - for static_inter in inter["static_interventions"]: - time = torch.tensor(float(static_inter["timestep"])) - parameter_name = inter["applied_to"] - value = torch.tensor(float(static_inter["value"])) + for static_inter in inter.static_interventions: + time = torch.tensor(float(static_inter.timestep)) + parameter_name = inter.applied_to + value = torch.tensor(float(static_inter.value)) static_interventions[time][parameter_name] = value return static_interventions