Skip to content

Commit

Permalink
Add parallel sampling when there are no interventions (#446)
Browse files Browse the repository at this point in the history
* add parallel sampling with no interventions

* lint
  • Loading branch information
SamWitty authored Jan 4, 2024
1 parent 550902f commit 671e405
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
9 changes: 7 additions & 2 deletions pyciemss/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ def wrapped_model():
)

samples = pyro.infer.Predictive(
wrapped_model, guide=inferred_parameters, num_samples=num_samples
wrapped_model, guide=inferred_parameters, num_samples=num_samples, parallel=True
)()

return prepare_interchange_dictionary(samples)
Expand Down Expand Up @@ -292,8 +292,13 @@ def wrapped_model():
# Adding noise to the model so that we can access the noisy trajectory in the Predictive object.
compiled_noise_model(full_trajectory)

parallel = False if len(intervention_handlers) > 0 else True

samples = pyro.infer.Predictive(
wrapped_model, guide=inferred_parameters, num_samples=num_samples
wrapped_model,
guide=inferred_parameters,
num_samples=num_samples,
parallel=parallel,
)()

return prepare_interchange_dictionary(samples)
Expand Down
4 changes: 2 additions & 2 deletions pyciemss/mira_integration/compiled_dynamics.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def _eval_deriv_mira(
dX: State[torch.Tensor] = dict()
for i, var in enumerate(src.variables.values()):
k = get_name(var)
dX[k] = numeric_deriv[i]
dX[k] = numeric_deriv[..., i]
return dX


Expand All @@ -146,7 +146,7 @@ def _eval_initial_state_mira(
X: State[torch.Tensor] = dict()
for i, var in enumerate(src.variables.values()):
k = get_name(var)
X[k] = numeric_initial_state[i]
X[k] = numeric_initial_state[..., i]
return X


Expand Down

0 comments on commit 671e405

Please sign in to comment.