Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ensemble functionality #18

Merged
merged 11 commits into from
Jul 18, 2023
69 changes: 59 additions & 10 deletions api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,11 +82,47 @@ class Config:
)


class EnsembleSimulateExtra(BaseModel):
class Config:
extra = ExtraEnum.allow

num_samples: int = Field(
100, description="number of samples for a CIEMSS simulation", example=100
)


class EnsembleCalibrateExtra(BaseModel):
class Config:
extra = ExtraEnum.allow

num_samples: int = Field(
100, description="number of samples for a CIEMSS simulation", example=100
)

total_population: int = Field(
1000, description="total population", example=1000
)

num_iterations: int = Field(
350, description="number of iterations", example=1000
)

time_unit: int = Field(
"days", description="units in numbers of days", example="days"
)


class ModelConfig(BaseModel):
id: str = Field(..., example="cd339570-047d-11ee-be55")
solution_mappings: dict[str, str] = Field(..., example={"Infected": "Cases", "Hospitalizations": "hospitalized_population"})
weight: float = Field(..., example="cd339570-047d-11ee-be55")


class Dataset(BaseModel):
id: str = Field(None, example="cd339570-047d-11ee-be55")
filename: str = Field(None, example="dataset.csv")
mappings: Optional[Dict[str, str]] = Field(
None,
mappings: Dict[str, str] = Field(
default_factory=dict,
description="Mappings from the dataset column names to the model names they should be replaced with.",
example={'postive_tests': 'infected'},
)
Expand Down Expand Up @@ -120,18 +156,31 @@ class CalibratePostRequest(BaseModel):
)


class EnsemblePostRequest(BaseModel):
class EnsembleSimulatePostRequest(BaseModel):
engine: Engine = Field(..., example="ciemss")
model_configuration_ids: Optional[List[str]] = Field(
username: str = Field("not_provided", example="not_provided")
model_configs: List[ModelConfig] = Field(
[],
example=[],
)
timespan: Timespan

extra: EnsembleSimulateExtra = Field(
None,
example=[
"ba8da8d4-047d-11ee-be56",
"c1cd941a-047d-11ee-be56",
"c4b9f88a-047d-11ee-be56",
],
description="optional extra system specific arguments for advanced use cases",
)


class EnsembleCalibratePostRequest(BaseModel):
engine: Engine = Field(..., example="ciemss")
username: str = Field("not_provided", example="not_provided")
model_configs: List[ModelConfig] = Field(
[],
example=[],
)
timespan: Timespan
extra: Optional[Dict[str, Any]] = Field(
dataset: Dataset
extra: EnsembleCalibrateExtra = Field(
None,
description="optional extra system specific arguments for advanced use cases",
)
Expand Down
75 changes: 67 additions & 8 deletions api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
Status,
JobResponse,
CalibratePostRequest,
EnsemblePostRequest,
SimulatePostRequest,
EnsembleSimulatePostRequest,
EnsembleCalibratePostRequest,
StatusSimulationIdGetResponse,
)

Expand Down Expand Up @@ -138,15 +139,73 @@ def calibrate_model(body: CalibratePostRequest) -> JobResponse:
return response


@app.post("/ensemble", response_model=JobResponse)
def create_ensemble(body: EnsemblePostRequest) -> JobResponse:
@app.post("/ensemble-simulate", response_model=JobResponse)
def create_simulate_ensemble(body: EnsembleSimulatePostRequest) -> JobResponse:
"""
Perform an ensemble simulation
Perform ensemble simulate
"""
return Response(
status_code=status.HTTP_501_NOT_IMPLEMENTED,
content="Ensemble is not yet implemented",
)
from utils import create_job

# Parse request body
engine = str(body.engine).lower()
model_configs = [config.dict() for config in body.model_configs]
start = body.timespan.start
end = body.timespan.end
username = body.username
extra = body.extra.dict()


operation_name = "operations.ensemble_simulate"
options = {
"engine": engine,
"model_configs": model_configs,
"start": start,
"end": end,
"username": username,
"extra": extra,
"visual_options": True
}

resp = create_job(operation_name=operation_name, options=options)

response = {"simulation_id": resp["id"]}

return response


@app.post("/ensemble-calibrate", response_model=JobResponse)
def create_calibrate_ensemble(body: EnsembleCalibratePostRequest) -> JobResponse:
"""
Perform ensemble simulate
"""
from utils import create_job

# Parse request body
engine = str(body.engine).lower()
username = body.username
dataset = body.dataset.dict()
model_configs = [config.dict() for config in body.model_configs]
start = body.timespan.start
end = body.timespan.end
extra = body.extra.dict()


operation_name = "operations.ensemble_calibrate"
options = {
"engine": engine,
"model_configs": model_configs,
"dataset": dataset,
"start": start,
"end": end,
"username": username,
"extra": extra,
"visual_options": True
}

resp = create_job(operation_name=operation_name, options=options)

response = {"simulation_id": resp["id"]}

return response


16 changes: 10 additions & 6 deletions api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,17 @@ def create_job(operation_name: str, options: Optional[Dict[Any, Any]] = None):

if STANDALONE:
logging.info(f"OPTIONS: {options}")
# TODO: Allow extras on payload and simply put full object here
ex_payload = {
"engine": "ciemss",
"model_config_id": options.get("model_config_id"),
"model_config_id": options.get("model_config_id", "not_provided"),
"timespan": {
"start": options.get("start"),
"end": options.get("end"),
"start": options.get("start", 0),
"end": options.get("end", 1),
},
"extra": options.get("extra"),
"extra": options.get("extra", None),
}
post_url = TDS_API + TDS_SIMULATIONS + job_id
post_url = TDS_API + TDS_SIMULATIONS #+ job_id
payload = {
"id": job_id,
"execution_payload": ex_payload,
Expand All @@ -75,7 +76,10 @@ def create_job(operation_name: str, options: Optional[Dict[Any, Any]] = None):
}
logging.info(payload)
sys.stdout.flush()
logging.info(requests.put(post_url, json=json.loads(json.dumps(payload))).content)
response = requests.post(post_url, json=payload)
if response.status_code >= 300:
raise Exception(f"Failed to create simulation on TDS (status: {response.status_code}): {json.dumps(payload)}")
logging.info(response.content)

if job and force_restart:
job.cleanup(ttl=0) # Cleanup/remove data immediately
Expand Down
85 changes: 85 additions & 0 deletions workers/operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
load_and_sample_petri_model,
)

from pyciemss.Ensemble.interfaces import (
load_and_sample_petri_ensemble, load_and_calibrate_and_sample_ensemble_model
)


TDS_CONFIGURATIONS = "/model_configurations/"
TDS_SIMULATIONS = "/simulations/"
OUTPUT_FILENAME = os.getenv("PYCIEMSS_OUTPUT_FILEPATH")
Expand Down Expand Up @@ -89,3 +94,83 @@ def calibrate_then_simulate(*args, **kwargs):


return True


@catch_job_status
def ensemble_simulate(*args, **kwargs):
model_configs = kwargs.pop("model_configs")
start = kwargs.pop("start")
end = kwargs.pop("end")
num_samples = kwargs.pop("num_samples")
username = kwargs.pop("username")
job_id = kwargs.pop("job_id")

sim_results_url = TDS_API + TDS_SIMULATIONS + job_id

update_tds_status(sim_results_url, status="running", start=True)

weights = [config["weight"] for config in model_configs]
solution_mappings = [config["solution_mappings"] for config in model_configs]
amr_paths = [fetch_model(config["id"], TDS_API, TDS_CONFIGURATIONS) for config in model_configs]

# Generate timepoints
time_count = end - start
timepoints=[x for x in range(1,time_count+1)]

output = load_and_sample_petri_ensemble(
amr_paths,
weights,
solution_mappings,
num_samples,
timepoints,
**kwargs
)
samples = output.get('data')
schema = output.get('visual')
with open("visualization.json", "w") as f:
json.dump(schema, f, indent=2)
samples.to_csv(OUTPUT_FILENAME, index=False)
attach_files({OUTPUT_FILENAME: "simulation.csv", "visualization.json": "visualization.json"}, TDS_API, TDS_SIMULATIONS, job_id)
return True


@catch_job_status
def ensemble_calibrate(*args, **kwargs):
model_configs = kwargs.pop("model_configs")
start = kwargs.pop("start")
end = kwargs.pop("end")
num_samples = kwargs.pop("num_samples")
dataset = kwargs.pop("dataset")
username = kwargs.pop("username")
job_id = kwargs.pop("job_id")

sim_results_url = TDS_API + TDS_SIMULATIONS + job_id

update_tds_status(sim_results_url, status="running", start=True)

weights = [config["weight"] for config in model_configs]
solution_mappings = [config["solution_mappings"] for config in model_configs]
amr_paths = [fetch_model(config["id"], TDS_API, TDS_CONFIGURATIONS) for config in model_configs]

dataset_path = fetch_dataset(dataset, TDS_API)

# Generate timepoints
time_count = end - start
timepoints=[x for x in range(1,time_count+1)]

output = load_and_calibrate_and_sample_ensemble_model(
amr_paths,
weights,
dataset_path,
solution_mappings,
num_samples,
timepoints,
**kwargs
)
samples = output.get('data')
schema = output.get('visual')
with open("visualization.json", "w") as f:
json.dump(schema, f, indent=2)
samples.to_csv(OUTPUT_FILENAME, index=False)
attach_files({OUTPUT_FILENAME: "simulation.csv", "visualization.json": "visualization.json"}, TDS_API, TDS_SIMULATIONS, job_id)
return True
9 changes: 5 additions & 4 deletions workers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def fetch_model(model_config_id, tds_api, config_endpoint):
for component in url_components:
model_url = urllib.parse.urljoin(model_url, component)
model_response = requests.get(model_url)
amr_path = os.path.abspath("./amr.json")
amr_path = os.path.abspath(f"./{model_config_id}.json")
with open(amr_path, "w") as file:
json.dump(model_response.json()["configuration"], file)
return amr_path
Expand All @@ -146,7 +146,8 @@ def attach_files(files: dict, tds_api, simulation_endpoint, job_id, status='comp
presigned_upload_url = upload_response.json()["url"]
with open(location, "rb") as f:
upload_response = requests.put(presigned_upload_url, f)

else:
logging.info(f"{job_id} ran into error")

# Update simulation object with status and filepaths.
update_tds_status(
Expand All @@ -164,7 +165,7 @@ def wrapped(*args, **kwargs):
result = function(*args, **kwargs)
end_time = time.perf_counter()
logging.info(
f"Elapsed time for {function.__name__} for {kwargs["username"]}:",
f"Elapsed time for {function.__name__} for {kwargs['username']}:",
end_time - start_time
)
return result
Expand All @@ -183,7 +184,7 @@ def wrapped(*args, **kwargs):

Error occured in function: {function.__name__}

Username: {kwargs["username"]}
Username: {kwargs['username']}

################################
"""
Expand Down