Skip to content

Commit

Permalink
Optimize: add step size explicity. (#110)
Browse files Browse the repository at this point in the history
* add step size explicity.

* Good call Dan
  • Loading branch information
Tom-Szendrey authored Jul 30, 2024
1 parent 1cb85a1 commit 5893dce
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions service/models/operations/optimize.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from typing import Any, ClassVar, Dict, List, Optional
from typing import ClassVar, List, Optional
from enum import Enum

import numpy as np
Expand Down Expand Up @@ -77,7 +77,12 @@ class OptimizeExtra(BaseModel):
maxfeval: int = 25
alpha: float = 0.95
solver_method: str = "dopri5"
solver_options: Dict[str, Any] = {}
# https://github.com/ciemss/pyciemss/blob/main/pyciemss/integration_utils/interface_checks.py
solver_step_size: float = Field(
None,
description="Step size required if solver method is euler.",
example=1.0,
)


class Optimize(OperationRequest):
Expand All @@ -88,7 +93,7 @@ class Optimize(OperationRequest):
fixed_static_parameter_interventions: list[HMIIntervention] = Field(
None
) # Theses are interventions provided that will not be optimized
step_size: float = 1.0
logging_step_size: float = 1.0
qoi: QOI
risk_bound: float
bounds_interventions: List[List[float]]
Expand Down Expand Up @@ -137,10 +142,17 @@ def gen_pyciemss_args(self, job_id):
extra_options.pop("inferred_parameters"), job_id
)
n_samples_ouu = extra_options.pop("num_samples")
solver_options = {}
step_size = extra_options.pop(
"solver_step_size"
) # Need to pop this out of extra.
solver_method = extra_options.pop("solver_method")
if step_size is not None and solver_method == "euler":
solver_options["step_size"] = step_size

return {
"model_path_or_json": amr_path,
"logging_step_size": self.step_size,
"logging_step_size": self.logging_step_size,
"start_time": self.timespan.start,
"end_time": self.timespan.end,
"objfun": lambda x: objfun(
Expand All @@ -156,6 +168,8 @@ def gen_pyciemss_args(self, job_id):
"fixed_static_parameter_interventions": fixed_static_parameter_interventions,
"inferred_parameters": inferred_parameters,
"n_samples_ouu": n_samples_ouu,
"solver_method": solver_method,
"solver_options": solver_options,
**extra_options,
}

Expand Down

0 comments on commit 5893dce

Please sign in to comment.