Skip to content

Commit

Permalink
formatting and linting passing
Browse files Browse the repository at this point in the history
  • Loading branch information
augeorge committed Sep 18, 2024
1 parent e8929fa commit 710acda
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 20 deletions.
29 changes: 16 additions & 13 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,9 +122,9 @@ def ensemble_sample(
check_solver(solver_method, solver_options)

# Get tolerances for solver
rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7
atol = solver_options.pop('atol', 1e-9) # default = 1e-9
rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7
atol = solver_options.pop("atol", 1e-9) # default = 1e-9

with torch.no_grad():
if dirichlet_alpha is None:
dirichlet_alpha = torch.ones(len(model_paths_or_jsons))
Expand Down Expand Up @@ -285,10 +285,10 @@ def ensemble_calibrate(
# Check that num_iterations is a positive integer
if not (isinstance(num_iterations, int) and num_iterations > 0):
raise ValueError("num_iterations must be a positive integer")

# Get tolerances for solver
rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7
atol = solver_options.pop('atol', 1e-9) # default = 1e-9
rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7
atol = solver_options.pop("atol", 1e-9) # default = 1e-9

def autoguide(model):
guide = pyro.infer.autoguide.AutoGuideList(model)
Expand Down Expand Up @@ -396,7 +396,8 @@ def sample(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol).
See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down Expand Up @@ -462,8 +463,8 @@ def sample(
check_solver(solver_method, solver_options)

# Get tolerances for solver
rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7
atol = solver_options.pop('atol', 1e-9) # default = 1e-9
rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7
atol = solver_options.pop("atol", 1e-9) # default = 1e-9

with torch.no_grad():
model = CompiledDynamics.load(model_path_or_json)
Expand Down Expand Up @@ -620,7 +621,8 @@ def calibrate(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
- solver_options: Dict[str, Any]
- Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol).
See torchdiffeq' `odeint` method for more details.
- start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down Expand Up @@ -687,8 +689,8 @@ def calibrate(
check_solver(solver_method, solver_options)

# Get tolerances for solver
rtol = solver_options.pop('rtol', 1e-7) # default = 1e-7
atol = solver_options.pop('atol', 1e-9) # default = 1e-9
rtol = solver_options.pop("rtol", 1e-7) # default = 1e-7
atol = solver_options.pop("atol", 1e-9) # default = 1e-9

pyro.clear_param_store()

Expand Down Expand Up @@ -858,7 +860,8 @@ def optimize(
- If performance is incredibly slow, we suggest using `euler` to debug.
If using `euler` results in faster simulation, the issue is likely that the model is stiff.
solver_options: Dict[str, Any]
- Options to pass to the solver (including atol and rtol). See torchdiffeq' `odeint` method for more details.
- Options to pass to the solver (including atol and rtol).
See torchdiffeq' `odeint` method for more details.
start_time: float
- The start time of the model. This is used to align the `start_state` from the
AMR model with the simulation timepoints.
Expand Down
4 changes: 2 additions & 2 deletions pyciemss/ouu/ouu.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def __init__(
self.u_bounds = u_bounds
self.risk_bound = risk_bound # used for defining penalty
warnings.simplefilter("always", UserWarning)
self.rtol = self.solver_options.pop('rtol', 1e-7) # default = 1e-7
self.atol = self.solver_options.pop('atol', 1e-9) # default = 1e-9
self.rtol = self.solver_options.pop("rtol", 1e-7) # default = 1e-7
self.atol = self.solver_options.pop("atol", 1e-9) # default = 1e-9

def __call__(self, x):
if np.any(x - self.u_bounds[0, :] < 0.0) or np.any(
Expand Down
10 changes: 5 additions & 5 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_sample_no_interventions(
logging_step_size,
num_samples,
start_time=start_time,
solver_options = {'rtol':rtol, 'atol':atol}
solver_options={"rtol": rtol, "atol": atol},
)["unprocessed_result"]
with pyro.poutine.seed(rng_seed=0):
result2 = sample_method(
Expand All @@ -127,7 +127,7 @@ def test_sample_no_interventions(
logging_step_size,
num_samples,
start_time=start_time,
solver_options = {'rtol':rtol, 'atol':atol}
solver_options={"rtol": rtol, "atol": atol},
)["unprocessed_result"]

result3 = sample_method(
Expand All @@ -136,7 +136,7 @@ def test_sample_no_interventions(
logging_step_size,
num_samples,
start_time=start_time,
solver_options = {'rtol':rtol, 'atol':atol}
solver_options={"rtol": rtol, "atol": atol},
)["unprocessed_result"]

for result in [result1, result2, result3]:
Expand Down Expand Up @@ -410,7 +410,7 @@ def test_calibrate_deterministic(
"data_mapping": model_fixture.data_mapping,
"start_time": start_time,
"deterministic_learnable_parameters": deterministic_learnable_parameters,
"solver_options": {'rtol':rtol, 'atol':atol},
"solver_options": {"rtol": rtol, "atol": atol},
**CALIBRATE_KWARGS,
}

Expand All @@ -433,7 +433,7 @@ def test_calibrate_deterministic(
*sample_args,
**sample_kwargs,
inferred_parameters=inferred_parameters,
solver_options={'rtol':rtol, 'atol':atol}
solver_options={"rtol": rtol, "atol": atol},
)["unprocessed_result"]

check_result_sizes(result, start_time, end_time, logging_step_size, 1)
Expand Down

0 comments on commit 710acda

Please sign in to comment.