From 8fd1bfb691b4263a6058946a0485f6692c8c2a55 Mon Sep 17 00:00:00 2001 From: Five Grant <5@fivegrant.com> Date: Thu, 7 Mar 2024 11:17:00 -0600 Subject: [PATCH] Use new `optimize` interventions --- pyproject.toml | 2 +- service/models/base.py | 5 --- service/models/converters.py | 7 ---- service/models/operations/optimize.py | 45 +++++++++++++++++----- tests/examples/optimize/input/request.json | 16 ++++---- 5 files changed, 45 insertions(+), 30 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 9dc1e6e..c62ab0e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -40,7 +40,7 @@ httpx = "^0.24.1" [tool.poe.tasks] -install-pyciemss = "pip install --no-cache-dir pyro-ppl==1.8.6 git+https://github.com/ciemss/pyciemss.git@f624385e0cba5236d93ccc83375b7670836d17bd --use-pep517" +install-pyciemss = "pip install --no-cache-dir pyro-ppl==1.8.6 git+https://github.com/ciemss/pyciemss.git@d6838e72bdc145b2f87ab9e33e220eb84fd87e87 --use-pep517" [tool.pytest.ini_options] diff --git a/service/models/base.py b/service/models/base.py index ff6f0c3..3be2fb2 100644 --- a/service/models/base.py +++ b/service/models/base.py @@ -37,11 +37,6 @@ class InterventionObject(BaseModel): value: float -class OptimizeInterventionObject(BaseModel): - timestep: float - name: str - - class InterventionSelection(BaseModel): timestep: float name: str diff --git a/service/models/converters.py b/service/models/converters.py index 4eba42c..772354d 100644 --- a/service/models/converters.py +++ b/service/models/converters.py @@ -11,13 +11,6 @@ def convert_to_static_interventions(interventions): return static_interventions -def convert_optimize_to_static_interventions(interventions): - static_interventions = defaultdict(dict) - for i in interventions: - static_interventions[torch.tensor(i.timestep)] = i.name - return static_interventions - - def convert_to_solution_mapping(config): individual_to_ensemble = { individual_state: ensemble_state diff --git a/service/models/operations/optimize.py b/service/models/operations/optimize.py index 9486a2c..c564c83 100644 --- a/service/models/operations/optimize.py +++ b/service/models/operations/optimize.py @@ -1,13 +1,15 @@ from __future__ import annotations -# from enum import Enum from typing import ClassVar, Dict, List, Optional import numpy as np import torch from pydantic import BaseModel, Field, Extra -from models.base import OperationRequest, Timespan, OptimizeInterventionObject -from models.converters import convert_optimize_to_static_interventions +from models.base import OperationRequest, Timespan +from pyciemss.integration_utils.intervention_builder import ( + param_value_objective, + start_time_objective, +) from utils.tds import fetch_model, fetch_inferred_parameters @@ -40,7 +42,15 @@ def objfun(x, is_minimized): return -np.sum(np.abs(x)) -# qoi_implementations = {QOIMethod.obs_nday_average.value: obs_nday_average_qoi} +class InterventionObjective(BaseModel): + selection: str = Field( + "param_value", + description="The intervention objective to use", + example="param_value", + ) + param_names: list[str] + param_values: Optional[list[Optional[float]]] = None + start_time: Optional[list[float]] = None class OptimizeExtra(BaseModel): @@ -53,7 +63,7 @@ class OptimizeExtra(BaseModel): ) inferred_parameters: Optional[str] = Field( None, - description="id from a previous calibration", + description="ID from a previous calibration", example=None, ) maxiter: int = 5 @@ -65,9 +75,7 @@ class Optimize(OperationRequest): pyciemss_lib_function: ClassVar[str] = "optimize" model_config_id: str = Field(..., example="ba8da8d4-047d-11ee-be56") timespan: Timespan = Timespan(start=0, end=90) - interventions: List[OptimizeInterventionObject] = Field( - default_factory=list, example=[{"timestep": 1, "name": "beta"}] - ) + interventions: InterventionObjective step_size: float = 1.0 qoi: List[str] # QOIMethod risk_bound: float @@ -82,7 +90,26 @@ def gen_pyciemss_args(self, job_id): # Get model from TDS amr_path = fetch_model(self.model_config_id, job_id) - interventions = convert_optimize_to_static_interventions(self.interventions) + intervention_type = self.interventions.selection + if intervention_type == "param_value": + assert self.interventions.start_time is not None + start_time = [torch.tensor(time) for time in self.interventions.start_time] + param_value = [None] * len(self.interventions.param_names) + + interventions = param_value_objective( + start_time=start_time, + param_name=self.interventions.param_names, + param_value=param_value, + ) + else: + assert self.interventions.param_values is not None + param_value = [ + torch.tensor(value) for value in self.interventions.param_values + ] + interventions = start_time_objective( + param_name=self.interventions.param_names, + param_value=param_value, + ) extra_options = self.extra.dict() inferred_parameters = fetch_inferred_parameters( diff --git a/tests/examples/optimize/input/request.json b/tests/examples/optimize/input/request.json index 88dee95..50f71eb 100644 --- a/tests/examples/optimize/input/request.json +++ b/tests/examples/optimize/input/request.json @@ -2,20 +2,20 @@ "engine": "ciemss", "user_id": "not_provided", "model_config_id": "sidarthe", - "interventions": [ - { - "timestep": 1.0, - "name": "beta" - } - ], + "interventions": { + "selection": "param_value", + "start_time": [2], + "param_names": ["beta"], + "param_values": [0.02] + }, "timespan": { "start": 0, "end": 90 }, "qoi": ["Infected"], "risk_bound": 10.0, - "initial_guess_interventions": [1.0], - "bounds_interventions": [[0.0], [3.0]], + "initial_guess_interventions": [1.0], + "bounds_interventions": [[0.0], [3.0]], "extra": { "num_samples": 4, "is_minimized": true