Skip to content

Commit

Permalink
atol and rtol moved to inside solver_options dict, updated tests
Browse files Browse the repository at this point in the history
  • Loading branch information
augeorge committed Sep 18, 2024
1 parent d076a2e commit e8929fa
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 56 deletions.
58 changes: 21 additions & 37 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,6 @@ def ensemble_sample(
time_unit: Optional[str] = None,
alpha_qs: Optional[List[float]] = DEFAULT_ALPHA_QS,
stacking_order: str = "timepoints",
rtol: float = 1e-7,
atol: float = 1e-9,
) -> Dict[str, Any]:
"""
Load a collection of models from files, compile them into an ensemble probabilistic program,
Expand Down Expand Up @@ -95,7 +93,7 @@ def ensemble_sample(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand All @@ -109,10 +107,6 @@ def ensemble_sample(
stacking_order: Optional[str]
- The stacking order requested for the ensemble quantiles to keep the selected quantity together for each state.
- Options: "timepoints" or "quantiles"
rtol: float
- The relative tolerance for the solver.
atol: float
- The absolute tolerance for the solver.
Returns:
result: Dict[str, Any]
Expand All @@ -127,6 +121,10 @@ def ensemble_sample(
"""
check_solver(solver_method, solver_options)

# Get tolerances for solver
rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7
atol = solver_options.pop('atol', 1e-9) # default = 1e-9

with torch.no_grad():
if dirichlet_alpha is None:
dirichlet_alpha = torch.ones(len(model_paths_or_jsons))
Expand Down Expand Up @@ -203,8 +201,6 @@ def ensemble_calibrate(
num_particles: int = 1,
deterministic_learnable_parameters: List[str] = [],
progress_hook: Callable = lambda i, loss: None,
rtol: float = 1e-7,
atol: float = 1e-9,
) -> Dict[str, Any]:
"""
Infer parameters for an ensemble of DynamicalSystem models conditional on data.
Expand Down Expand Up @@ -243,7 +239,7 @@ def ensemble_calibrate(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand All @@ -264,10 +260,6 @@ def ensemble_calibrate(
- This is called at the beginning of each iteration.
- By default, this is a no-op.
- This can be used to implement custom progress bars.
rtol: float
- The relative tolerance for the solver.
atol: float
- The absolute tolerance for the solver.
Returns:
result: Dict[str, Any]
Expand All @@ -293,6 +285,10 @@ def ensemble_calibrate(
# Check that num_iterations is a positive integer
if not (isinstance(num_iterations, int) and num_iterations > 0):
raise ValueError("num_iterations must be a positive integer")

# Get tolerances for solver
rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7
atol = solver_options.pop('atol', 1e-9) # default = 1e-9

def autoguide(model):
guide = pyro.infer.autoguide.AutoGuideList(model)
Expand Down Expand Up @@ -382,8 +378,6 @@ def sample(
Dict[str, Intervention],
] = {},
alpha: float = 0.95,
rtol: float = 1e-7,
atol: float = 1e-9,
) -> Dict[str, Any]:
r"""
Load a model from a file, compile it into a probabilistic program, and sample from it.
Expand All @@ -402,7 +396,7 @@ def sample(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down Expand Up @@ -445,10 +439,6 @@ def sample(
:func:`~chirho.interventional.ops.intervene`, including functions.
alpha: float
- Risk level for alpha-superquantile outputs in the results dictionary.
rtol: float
- The relative tolerance for the solver.
atol: float
- The absolute tolerance for the solver.
Returns:
result: Dict[str, Any]
Expand All @@ -471,6 +461,10 @@ def sample(

check_solver(solver_method, solver_options)

# Get tolerances for solver
rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7
atol = solver_options.pop('atol', 1e-9) # default = 1e-9

with torch.no_grad():
model = CompiledDynamics.load(model_path_or_json)

Expand Down Expand Up @@ -600,8 +594,6 @@ def calibrate(
num_particles: int = 1,
deterministic_learnable_parameters: List[str] = [],
progress_hook: Callable = lambda i, loss: None,
rtol: float = 1e-7,
atol: float = 1e-9,
) -> Dict[str, Any]:
"""
Infer parameters for a DynamicalSystem model conditional on data.
Expand All @@ -628,7 +620,7 @@ def calibrate(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
- solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details.
- start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down Expand Up @@ -681,10 +673,6 @@ def calibrate(
- This is called at the beginning of each iteration.
- By default, this is a no-op.
- This can be used to implement custom progress bars.
- rtol: float
- The relative tolerance for the solver.
- atol: float
- The absolute tolerance for the solver.
Returns:
result: Dict[str, Any]
Expand All @@ -698,6 +686,10 @@ def calibrate(

check_solver(solver_method, solver_options)

# Get tolerances for solver
rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7
atol = solver_options.pop('atol', 1e-9) # default = 1e-9

pyro.clear_param_store()

model = CompiledDynamics.load(model_path_or_json)
Expand Down Expand Up @@ -825,8 +817,6 @@ def optimize(
verbose: bool = False,
roundup_decimal: int = 4,
progress_hook: Callable[[torch.Tensor], None] = lambda x: None,
rtol: float = 1e-7,
atol: float = 1e-9,
) -> Dict[str, Any]:
r"""
Load a model from a file, compile it into a probabilistic program, and optimize under uncertainty with risk-based
Expand Down Expand Up @@ -868,7 +858,7 @@ def optimize(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver. See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down Expand Up @@ -897,10 +887,6 @@ def optimize(
- A callback function that takes in the current parameter vector as a tensor.
If the function returns StopIteration, the minimization will terminate.
- This can be used to implement custom progress bars and/or early stopping criteria.
rtol: float
- The relative tolerance for the solver.
atol: float
- The absolute tolerance for the solver.
Returns:
result: Dict[str, Any]
Expand Down Expand Up @@ -944,8 +930,6 @@ def optimize(
solver_options=solver_options,
u_bounds=bounds_np,
risk_bound=risk_bound,
rtol=rtol,
atol=atol,
)

# Run one sample to estimate model evaluation time
Expand Down
6 changes: 2 additions & 4 deletions pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,6 @@ def __init__(
solver_options: Dict[str, Any] = {},
u_bounds: np.ndarray = np.atleast_2d([[0], [1]]),
risk_bound: List[float] = [0.0],
rtol: float = 1e-7,
atol: float = 1e-9
):
self.model = model
self.interventions = interventions
Expand All @@ -99,8 +97,8 @@ def __init__(
self.u_bounds = u_bounds
self.risk_bound = risk_bound # used for defining penalty
warnings.simplefilter("always", UserWarning)
self.rtol = rtol
self.atol = atol
self.rtol = self.solver_options.pop('rtol', 1e-7) # default = 1e-7
self.atol = self.solver_options.pop('atol', 1e-9) # default = 1e-9

def __call__(self, x):
if np.any(x - self.u_bounds[0, :] < 0.0) or np.any(
Expand Down
21 changes: 6 additions & 15 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,7 @@ def test_sample_no_interventions(
logging_step_size,
num_samples,
start_time=start_time,
rtol=rtol,
atol=atol,
solver_options = {'rtol':rtol, 'atol':atol}
)["unprocessed_result"]
with pyro.poutine.seed(rng_seed=0):
result2 = sample_method(
Expand All @@ -128,8 +127,7 @@ def test_sample_no_interventions(
logging_step_size,
num_samples,
start_time=start_time,
rtol=rtol,
atol=atol,
solver_options = {'rtol':rtol, 'atol':atol}
)["unprocessed_result"]

result3 = sample_method(
Expand All @@ -138,8 +136,7 @@ def test_sample_no_interventions(
logging_step_size,
num_samples,
start_time=start_time,
rtol=rtol,
atol=atol,
solver_options = {'rtol':rtol, 'atol':atol}
)["unprocessed_result"]

for result in [result1, result2, result3]:
Expand Down Expand Up @@ -413,8 +410,7 @@ def test_calibrate_deterministic(
"data_mapping": model_fixture.data_mapping,
"start_time": start_time,
"deterministic_learnable_parameters": deterministic_learnable_parameters,
"rtol": rtol,
"atol": atol,
"solver_options": {'rtol':rtol, 'atol':atol},
**CALIBRATE_KWARGS,
}

Expand All @@ -437,8 +433,7 @@ def test_calibrate_deterministic(
*sample_args,
**sample_kwargs,
inferred_parameters=inferred_parameters,
rtol=rtol,
atol=atol,
solver_options={'rtol':rtol, 'atol':atol}
)["unprocessed_result"]

check_result_sizes(result, start_time, end_time, logging_step_size, 1)
Expand Down Expand Up @@ -621,14 +616,12 @@ def __call__(self, x):
optimize_kwargs = {
**model_fixture.optimize_kwargs,
"solver_method": "euler",
"solver_options": {"step_size": 0.1},
"solver_options": {"step_size": 0.1, "rtol": rtol, "atol": atol},
"start_time": start_time,
"n_samples_ouu": int(2),
"maxiter": 1,
"maxfeval": 2,
"progress_hook": progress_hook,
"rtol": rtol,
"atol": atol,
}
bounds_interventions = optimize_kwargs["bounds_interventions"]
opt_result = optimize(
Expand Down Expand Up @@ -667,8 +660,6 @@ def __call__(self, x):
static_parameter_interventions=opt_intervention,
solver_method=optimize_kwargs["solver_method"],
solver_options=optimize_kwargs["solver_options"],
rtol=rtol,
atol=atol,
)["unprocessed_result"]

intervened_result_subset = {
Expand Down

0 comments on commit e8929fa

Please sign in to comment.