Skip to content

Commit

Permalink
utilize new interventions
Browse files Browse the repository at this point in the history
  • Loading branch information
Tom-Szendrey committed Jul 4, 2024
1 parent 876530f commit 1757a06
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 30 deletions.
24 changes: 20 additions & 4 deletions service/models/converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,28 @@

# TODO: Do not use Torch in PyCIEMSS Library interface
import torch
from utils.tds import fetch_interventions
import logging
from typing import Dict


def convert_to_static_interventions(interventions):
static_interventions = defaultdict(dict)
for i in interventions:
static_interventions[i.timestep][i.name] = torch.tensor(i.value)
def fetch_and_convert_static_interventions(policy_intervention_id, job_id):
logging.error("Fetching and converting started:")
static_interventions: Dict[torch.Tensor, Dict[str, any]] = defaultdict(dict)
if not (policy_intervention_id):
logging.error("No intervention id")
return static_interventions
policy_intervention = fetch_interventions(policy_intervention_id, job_id)
logging.error("interventions: ")
logging.error(str(policy_intervention))
for inter in policy_intervention["interventions"]:
for static_inter in inter["static_interventions"]:
time = torch.tensor(float(static_inter["timestep"]))
parameter_name = inter["applied_to"]
value = torch.tensor(float(static_inter["value"]))
static_interventions[time][parameter_name] = value
logging.error("static_interventions: ")
logging.error(str(static_interventions))
return static_interventions


Expand Down
17 changes: 8 additions & 9 deletions service/models/operations/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
import socket
import logging

from typing import ClassVar, Optional, List
from typing import ClassVar, Optional
from pydantic import BaseModel, Field, Extra

from pika.exceptions import AMQPConnectionError


from models.base import Dataset, OperationRequest, Timespan, InterventionObject
from models.converters import convert_to_static_interventions
from models.base import Dataset, OperationRequest, Timespan
from models.converters import fetch_and_convert_static_interventions
from utils.rabbitmq import gen_rabbitmq_hook
from utils.tds import fetch_dataset, fetch_model

Expand Down Expand Up @@ -45,10 +45,7 @@ class Calibrate(OperationRequest):
model_config_id: str = Field(..., example="c1cd941a-047d-11ee-be56")
dataset: Dataset = None
timespan: Optional[Timespan] = None
interventions: List[InterventionObject] = Field(
default_factory=list, example=[{"timestep": 1, "name": "beta", "value": 0.4}]
)

policy_intervention_id: str = Field(..., example="ba8da8d4-047d-11ee-be56")
extra: CalibrateExtra = Field(
None,
description="optional extra system specific arguments for advanced use cases",
Expand All @@ -59,7 +56,9 @@ def gen_pyciemss_args(self, job_id):

dataset_path = fetch_dataset(self.dataset.dict(), job_id)

interventions = convert_to_static_interventions(self.interventions)
static_interventions = fetch_and_convert_static_interventions(
self.policy_intervention_id
)

# TODO: Test RabbitMQ
try:
Expand All @@ -81,7 +80,7 @@ def hook(progress, _loss):
# TODO: Is this intentionally missing from `calibrate`?
# "end_time": self.timespan.end,
"data_path": dataset_path,
"static_parameter_interventions": interventions,
"static_parameter_interventions": static_interventions,
"progress_hook": hook,
# "visual_options": True,
**self.extra.dict(),
Expand Down
16 changes: 7 additions & 9 deletions service/models/operations/optimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import numpy as np
import torch
from pydantic import BaseModel, Field, Extra
from models.base import OperationRequest, Timespan, InterventionObject
from models.base import OperationRequest, Timespan
from pyciemss.integration_utils.intervention_builder import (
param_value_objective,
start_time_objective,
)

from pyciemss.ouu.qoi import obs_nday_average_qoi, obs_max_qoi
from models.converters import convert_to_static_interventions
from models.converters import fetch_and_convert_static_interventions
from utils.tds import fetch_model, fetch_inferred_parameters


Expand Down Expand Up @@ -90,17 +90,15 @@ class Optimize(OperationRequest):
None,
description="optional extra system specific arguments for advanced use cases",
)
fixed_static_parameter_interventions: List[InterventionObject] = Field(
default_factory=list,
description="The interventions provided via the model config which are not being optimized",
)
policy_intervention_id: str = Field(..., example="ba8da8d4-047d-11ee-be56")

def gen_pyciemss_args(self, job_id):
# Get model from TDS
amr_path = fetch_model(self.model_config_id, job_id)
fixed_static_parameter_interventions = convert_to_static_interventions(
self.fixed_static_parameter_interventions
static_interventions = fetch_and_convert_static_interventions(
self.policy_intervention_id
)

intervention_type = self.policy_interventions.selection
if intervention_type == "param_value":
assert self.policy_interventions.start_time is not None
Expand Down Expand Up @@ -142,7 +140,7 @@ def gen_pyciemss_args(self, job_id):
"initial_guess_interventions": self.initial_guess_interventions,
"bounds_interventions": self.bounds_interventions,
"static_parameter_interventions": policy_interventions,
"fixed_static_parameter_interventions": fixed_static_parameter_interventions,
"fixed_static_parameter_interventions": static_interventions,
"inferred_parameters": inferred_parameters,
"n_samples_ouu": n_samples_ouu,
**extra_options,
Expand Down
16 changes: 8 additions & 8 deletions service/models/operations/simulate.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from __future__ import annotations

from typing import ClassVar, List, Optional
from typing import ClassVar, Optional
from pydantic import BaseModel, Field, Extra


from models.base import OperationRequest, Timespan, InterventionObject
from models.converters import convert_to_static_interventions
from models.base import OperationRequest, Timespan
from models.converters import fetch_and_convert_static_interventions
from utils.tds import fetch_model, fetch_inferred_parameters


Expand All @@ -24,9 +24,7 @@ class Simulate(OperationRequest):
pyciemss_lib_function: ClassVar[str] = "sample"
model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56")
timespan: Timespan = Timespan(start=0, end=90)
interventions: List[InterventionObject] = Field(
default_factory=list, example=[{"timestep": 1, "name": "beta", "value": 0.4}]
)
policy_intervention_id: str = Field(..., example="ba8da8d4-047d-11ee-be56")
step_size: float = 1.0
extra: SimulateExtra = Field(
None,
Expand All @@ -37,7 +35,9 @@ def gen_pyciemss_args(self, job_id):
# Get model from TDS
amr_path = fetch_model(self.model_config_id, job_id)

interventions = convert_to_static_interventions(self.interventions)
static_interventions = fetch_and_convert_static_interventions(
self.policy_intervention_id, job_id
)

extra_options = self.extra.dict()
inferred_parameters = fetch_inferred_parameters(
Expand All @@ -49,7 +49,7 @@ def gen_pyciemss_args(self, job_id):
"logging_step_size": self.step_size,
"start_time": self.timespan.start,
"end_time": self.timespan.end,
"static_parameter_interventions": interventions,
"static_parameter_interventions": static_interventions,
"inferred_parameters": inferred_parameters,
**extra_options,
}
Expand Down
22 changes: 22 additions & 0 deletions service/utils/tds.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TDS_SIMULATIONS = "/simulations"
TDS_DATASETS = "/datasets"
TDS_CONFIGURATIONS = "/model-configurations/as-configured-model"
TDS_INTERVENTIONS = "/interventions"


#
Expand Down Expand Up @@ -318,3 +319,24 @@ def attach_files(output: dict, job_id, status="complete"):
job_id, status=status, result_files=list(files.values()), finish=True
)
logging.info("uploaded files to %s", job_id)


def fetch_interventions(policy_intervention_id: str, job_id):
job_dir = get_job_dir(job_id)
logging.error(f"Fetching interventions {policy_intervention_id}")

intervention_url = TDS_URL + TDS_INTERVENTIONS + "/" + policy_intervention_id
logging.error(intervention_url)
intervention_response = tds_session().get(intervention_url)
logging.error(intervention_response)
if intervention_response.status_code == 404:
raise HTTPException(status_code=404, detail="Intervention not found")

intervention_path = os.path.join(job_dir, f"./{policy_intervention_id}.json")
with open(intervention_path, "w") as file:
# Ensure we don't have null observables which can be problematic downstream, if so convert
# to empty list
intervention_json = intervention_response.json()
json.dump(intervention_json, file)

return intervention_response.json()

0 comments on commit 1757a06

Please sign in to comment.