Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reinclude ensemble-calibrate and use new version of PyCIEMSS #73

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ httpx = "^0.24.1"


[tool.poe.tasks]
install-pyciemss = "pip install --no-cache-dir pyro-ppl==1.8.6 git+https://github.com/ciemss/pyciemss.git@d6838e72bdc145b2f87ab9e33e220eb84fd87e87 --use-pep517"
install-pyciemss = "pip install --no-cache-dir git+https://github.com/ciemss/pyciemss.git@1fc62b0d4b0870ca992514ad7a9b7a09a175ce44 --use-pep517"


[tool.pytest.ini_options]
Expand Down
13 changes: 3 additions & 10 deletions service/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import os
from fastapi import FastAPI, Depends, HTTPException
from fastapi import FastAPI, Depends
from fastapi.middleware.cors import CORSMiddleware

from service.models import (
Expand All @@ -11,6 +11,7 @@
Calibrate,
Simulate,
EnsembleSimulate,
EnsembleCalibrate,
Optimize,
StatusSimulationIdGetResponse,
)
Expand All @@ -21,6 +22,7 @@
"simulate": Simulate,
"calibrate": Calibrate,
"ensemble-simulate": EnsembleSimulate,
"ensemble-calibrate": EnsembleCalibrate,
"optimize": Optimize,
}

Expand Down Expand Up @@ -111,12 +113,3 @@ def operate(

operate = make_operate(operation_name)
registrar(operate)


@app.get("/ensemble-calibrate", response_model=StatusSimulationIdGetResponse)
def ensemble_calibrate_not_yet_implemented():
"""
DO NOT USE. Placeholder for `ensemble-calibrate` endpoint.
This will be reimplemented in the future.
"""
raise HTTPException(status=501, detail="Not yet reimplemented")
1 change: 1 addition & 0 deletions service/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
sample,
calibrate,
ensemble_sample,
ensemble_calibrate,
optimize,
)

Expand Down
1 change: 1 addition & 0 deletions service/models/operations/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from models.operations.simulate import Simulate
from models.operations.calibrate import Calibrate
from models.operations.ensemble_simulate import EnsembleSimulate
from models.operations.ensemble_calibrate import EnsembleCalibrate
from models.operations.optimize import Optimize
80 changes: 80 additions & 0 deletions service/models/operations/ensemble_calibrate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from __future__ import annotations

from typing import ClassVar, List, Dict, Any

import socket
import logging
from pydantic import BaseModel, Field, Extra
import torch # TODO: Do not use Torch in PyCIEMSS Library interface

from pika.exceptions import AMQPConnectionError

from models.base import Dataset, OperationRequest, Timespan, ModelConfig
from models.converters import convert_to_solution_mapping
from utils.rabbitmq import gen_rabbitmq_hook
from utils.tds import fetch_dataset, fetch_model


class EnsembleCalibrateExtra(BaseModel):
noise_model: str = "normal"
noise_model_kwargs: Dict[str, Any] = {"scale": 0.1}
solver_method: str = "dopri5"
solver_options: Dict[str, Any] = {}
num_iterations: int = 1000
lr: float = 0.03
verbose: bool = False
num_particles: int = 1
deterministic_learnable_parameters: List[str] = []


class EnsembleCalibrate(OperationRequest):
pyciemss_lib_function: ClassVar[str] = "ensemble_calibrate"
model_configs: List[ModelConfig] = Field(
[],
example=[],
)
timespan: Timespan
dataset: Dataset = None

step_size: float = 1.0

extra: EnsembleCalibrateExtra = Field(
None,
description="optional extra system specific arguments for advanced use cases",
)

def gen_pyciemss_args(self, job_id):
weights = torch.tensor([config.weight for config in self.model_configs])
solution_mappings = [
convert_to_solution_mapping(config) for config in self.model_configs
]
amr_paths = [fetch_model(config.id, job_id) for config in self.model_configs]
dataset_path = fetch_dataset(self.dataset.dict(), job_id)

try:
hook = gen_rabbitmq_hook(job_id)
except (socket.gaierror, AMQPConnectionError):
logging.warning(
"%s: Failed to connect to RabbitMQ. Unable to log progress", job_id
)

def hook(progress, _loss):
progress = progress / 10 # TODO: Fix magnitude of progress upstream
if progress == int(progress):
logging.info(f"Calibration is {progress}% complete")
return None

return {
"model_paths_or_jsons": amr_paths,
"solution_mappings": solution_mappings,
"data_path": dataset_path,
"start_time": self.timespan.start,
# "end_time": self.timespan.end,
"dirichlet_alpha": weights,
"progress_hook": hook,
# "visual_options": True,
**self.extra.dict(),
}

class Config:
extra = Extra.forbid
17 changes: 14 additions & 3 deletions service/models/operations/ensemble_simulate.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
from __future__ import annotations

from typing import ClassVar, List
from typing import ClassVar, List, Optional

from pydantic import BaseModel, Field, Extra
import torch # TODO: Do not use Torch in PyCIEMSS Library interface

from models.base import OperationRequest, Timespan, ModelConfig
from models.converters import convert_to_solution_mapping
from utils.tds import fetch_model
from utils.tds import fetch_model, fetch_inferred_parameters


class EnsembleSimulateExtra(BaseModel):
num_samples: int = Field(
100, description="number of samples for a CIEMSS simulation", example=100
)
inferred_parameters: Optional[str] = Field(
None,
description="id from a previous calibration",
example=None,
)


class EnsembleSimulate(OperationRequest):
Expand All @@ -38,15 +43,21 @@ def gen_pyciemss_args(self, job_id):
]
amr_paths = [fetch_model(config.id, job_id) for config in self.model_configs]

extra_options = self.extra.dict()
inferred_parameters = fetch_inferred_parameters(
extra_options.pop("inferred_parameters"), job_id
)

return {
"model_paths_or_jsons": amr_paths,
"solution_mappings": solution_mappings,
"start_time": self.timespan.start,
"end_time": self.timespan.end,
"logging_step_size": self.step_size,
"dirichlet_alpha": weights,
"inferred_parameters": inferred_parameters,
# "visual_options": True,
**self.extra.dict(),
**extra_options,
}

class Config:
Expand Down
Loading
Loading