diff --git a/pyciemss/interfaces.py b/pyciemss/interfaces.py index 7ade53a6..9ffe13f2 100644 --- a/pyciemss/interfaces.py +++ b/pyciemss/interfaces.py @@ -93,7 +93,7 @@ def ensemble_sample( - If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff. solver_options: Dict[str, Any] - - Options to pass to the solver. See torchdiffeq' `odeint` method for more details. + - Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details. start_time: float - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. @@ -121,6 +121,10 @@ def ensemble_sample( """ check_solver(solver_method, solver_options) + # Get tolerances for solver + rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7 + atol = solver_options.pop("atol", 1e-9) # default = 1e-9 + with torch.no_grad(): if dirichlet_alpha is None: dirichlet_alpha = torch.ones(len(model_paths_or_jsons)) @@ -138,7 +142,9 @@ def ensemble_sample( raise ValueError("num_samples must be a positive integer") def wrapped_model(): - with TorchDiffEq(method=solver_method, options=solver_options): + with TorchDiffEq( + rtol=rtol, atol=atol, method=solver_method, options=solver_options + ): solution = model( torch.as_tensor(start_time), torch.as_tensor(end_time), @@ -233,7 +239,7 @@ def ensemble_calibrate( - If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff. solver_options: Dict[str, Any] - - Options to pass to the solver. See torchdiffeq' `odeint` method for more details. + - Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details. start_time: float - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. @@ -280,6 +286,10 @@ def ensemble_calibrate( if not (isinstance(num_iterations, int) and num_iterations > 0): raise ValueError("num_iterations must be a positive integer") + # Get tolerances for solver + rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7 + atol = solver_options.pop("atol", 1e-9) # default = 1e-9 + def autoguide(model): guide = pyro.infer.autoguide.AutoGuideList(model) guide.append( @@ -314,7 +324,9 @@ def autoguide(model): def wrapped_model(): obs = condition(data=_data)(_noise_model) - with TorchDiffEq(method=solver_method, options=solver_options): + with TorchDiffEq( + rtol=rtol, atol=atol, method=solver_method, options=solver_options + ): solution = model( torch.as_tensor(start_time), torch.as_tensor(data_timepoints[-1]), @@ -384,7 +396,8 @@ def sample( - If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff. solver_options: Dict[str, Any] - - Options to pass to the solver. See torchdiffeq' `odeint` method for more details. + - Options to pass to the solver (including atol and rtol). + See torchdiffeq' `odeint` method for more details. start_time: float - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. @@ -449,6 +462,10 @@ def sample( check_solver(solver_method, solver_options) + # Get tolerances for solver + rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7 + atol = solver_options.pop("atol", 1e-9) # default = 1e-9 + with torch.no_grad(): model = CompiledDynamics.load(model_path_or_json) @@ -492,7 +509,9 @@ def sample( def wrapped_model(): with ParameterInterventionTracer(): - with TorchDiffEq(method=solver_method, options=solver_options): + with TorchDiffEq( + rtol=rtol, atol=atol, method=solver_method, options=solver_options + ): with contextlib.ExitStack() as stack: for handler in intervention_handlers: stack.enter_context(handler) @@ -602,7 +621,8 @@ def calibrate( - If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff. - solver_options: Dict[str, Any] - - Options to pass to the solver. See torchdiffeq' `odeint` method for more details. + - Options to pass to the solver (including atol and rtol). + See torchdiffeq' `odeint` method for more details. - start_time: float - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. @@ -668,6 +688,10 @@ def calibrate( check_solver(solver_method, solver_options) + # Get tolerances for solver + rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7 + atol = solver_options.pop("atol", 1e-9) # default = 1e-9 + pyro.clear_param_store() model = CompiledDynamics.load(model_path_or_json) @@ -740,7 +764,9 @@ def wrapped_model(): obs = condition(data=_data)(_noise_model) with StaticBatchObservation(data_timepoints, observation=obs): - with TorchDiffEq(method=solver_method, options=solver_options): + with TorchDiffEq( + rtol=rtol, atol=atol, method=solver_method, options=solver_options + ): with contextlib.ExitStack() as stack: for handler in intervention_handlers: stack.enter_context(handler) @@ -834,7 +860,8 @@ def optimize( - If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff. solver_options: Dict[str, Any] - - Options to pass to the solver. See torchdiffeq' `odeint` method for more details. + - Options to pass to the solver (including atol and rtol). + See torchdiffeq' `odeint` method for more details. start_time: float - The start time of the model. This is used to align the `start_state` from the AMR model with the simulation timepoints. diff --git a/pyciemss/ouu/ouu.py b/pyciemss/ouu/ouu.py index 09f1cd50..28629d4f 100644 --- a/pyciemss/ouu/ouu.py +++ b/pyciemss/ouu/ouu.py @@ -97,6 +97,8 @@ def __init__( self.u_bounds = u_bounds self.risk_bound = risk_bound # used for defining penalty warnings.simplefilter("always", UserWarning) + self.rtol = self.solver_options.pop("rtol", 1e-7) # default = 1e-7 + self.atol = self.solver_options.pop("atol", 1e-9) # default = 1e-9 def __call__(self, x): if np.any(x - self.u_bounds[0, :] < 0.0) or np.any( @@ -144,7 +146,10 @@ def propagate_uncertainty(self, x): def wrapped_model(): with ParameterInterventionTracer(): with TorchDiffEq( - method=self.solver_method, options=self.solver_options + rtol=self.rtol, + atol=self.atol, + method=self.solver_method, + options=self.solver_options, ): with contextlib.ExitStack() as stack: for handler in static_parameter_intervention_handlers: diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index 1f653b06..e4563b8b 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -87,6 +87,9 @@ def setup_calibrate(model_fixture, start_time, end_time, logging_step_size): "num_iterations": 2, } +RTOL = [1e-6, 1e-4] +ATOL = [1e-8, 1e-6] + @pytest.mark.parametrize("sample_method", SAMPLE_METHODS) @pytest.mark.parametrize("model", MODELS) @@ -94,22 +97,46 @@ def setup_calibrate(model_fixture, start_time, end_time, logging_step_size): @pytest.mark.parametrize("end_time", END_TIMES) @pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) @pytest.mark.parametrize("num_samples", NUM_SAMPLES) +@pytest.mark.parametrize("rtol", RTOL) +@pytest.mark.parametrize("atol", ATOL) def test_sample_no_interventions( - sample_method, model, start_time, end_time, logging_step_size, num_samples + sample_method, + model, + start_time, + end_time, + logging_step_size, + num_samples, + rtol, + atol, ): model_url = model.url with pyro.poutine.seed(rng_seed=0): result1 = sample_method( - model_url, end_time, logging_step_size, num_samples, start_time=start_time + model_url, + end_time, + logging_step_size, + num_samples, + start_time=start_time, + solver_options={"rtol": rtol, "atol": atol}, )["unprocessed_result"] with pyro.poutine.seed(rng_seed=0): result2 = sample_method( - model_url, end_time, logging_step_size, num_samples, start_time=start_time + model_url, + end_time, + logging_step_size, + num_samples, + start_time=start_time, + solver_options={"rtol": rtol, "atol": atol}, )["unprocessed_result"] result3 = sample_method( - model_url, end_time, logging_step_size, num_samples, start_time=start_time + model_url, + end_time, + logging_step_size, + num_samples, + start_time=start_time, + solver_options={"rtol": rtol, "atol": atol}, )["unprocessed_result"] for result in [result1, result2, result3]: @@ -364,8 +391,10 @@ def test_calibrate_no_kwargs( @pytest.mark.parametrize("start_time", START_TIMES) @pytest.mark.parametrize("end_time", END_TIMES) @pytest.mark.parametrize("logging_step_size", LOGGING_STEP_SIZES) +@pytest.mark.parametrize("rtol", RTOL) +@pytest.mark.parametrize("atol", ATOL) def test_calibrate_deterministic( - model_fixture, start_time, end_time, logging_step_size + model_fixture, start_time, end_time, logging_step_size, rtol, atol ): model_url = model_fixture.url ( @@ -381,6 +410,7 @@ def test_calibrate_deterministic( "data_mapping": model_fixture.data_mapping, "start_time": start_time, "deterministic_learnable_parameters": deterministic_learnable_parameters, + "solver_options": {"rtol": rtol, "atol": atol}, **CALIBRATE_KWARGS, } @@ -400,7 +430,10 @@ def test_calibrate_deterministic( assert torch.allclose(param_value, param_sample_2[param_name]) result = sample( - *sample_args, **sample_kwargs, inferred_parameters=inferred_parameters + *sample_args, + **sample_kwargs, + inferred_parameters=inferred_parameters, + solver_options={"rtol": rtol, "atol": atol}, )["unprocessed_result"] check_result_sizes(result, start_time, end_time, logging_step_size, 1) @@ -563,7 +596,9 @@ def test_output_format( @pytest.mark.parametrize("start_time", START_TIMES) @pytest.mark.parametrize("end_time", END_TIMES) @pytest.mark.parametrize("num_samples", NUM_SAMPLES) -def test_optimize(model_fixture, start_time, end_time, num_samples): +@pytest.mark.parametrize("rtol", RTOL) +@pytest.mark.parametrize("atol", ATOL) +def test_optimize(model_fixture, start_time, end_time, num_samples, rtol, atol): logging_step_size = 1.0 model_url = model_fixture.url @@ -581,7 +616,7 @@ def __call__(self, x): optimize_kwargs = { **model_fixture.optimize_kwargs, "solver_method": "euler", - "solver_options": {"step_size": 0.1}, + "solver_options": {"step_size": 0.1, "rtol": rtol, "atol": atol}, "start_time": start_time, "n_samples_ouu": int(2), "maxiter": 1,