diff --git a/pyciemss/interfaces.py b/pyciemss/interfaces.py index 4dc62527b..120c44370 100644 --- a/pyciemss/interfaces.py +++ b/pyciemss/interfaces.py @@ -138,7 +138,7 @@ def wrapped_model(): ) samples = pyro.infer.Predictive( - wrapped_model, guide=inferred_parameters, num_samples=num_samples + wrapped_model, guide=inferred_parameters, num_samples=num_samples, parallel=True )() return prepare_interchange_dictionary(samples) @@ -292,8 +292,13 @@ def wrapped_model(): # Adding noise to the model so that we can access the noisy trajectory in the Predictive object. compiled_noise_model(full_trajectory) + parallel = False if len(intervention_handlers) > 0 else True + samples = pyro.infer.Predictive( - wrapped_model, guide=inferred_parameters, num_samples=num_samples + wrapped_model, + guide=inferred_parameters, + num_samples=num_samples, + parallel=parallel, )() return prepare_interchange_dictionary(samples) diff --git a/pyciemss/mira_integration/compiled_dynamics.py b/pyciemss/mira_integration/compiled_dynamics.py index a78463c02..2cb668df6 100644 --- a/pyciemss/mira_integration/compiled_dynamics.py +++ b/pyciemss/mira_integration/compiled_dynamics.py @@ -126,7 +126,7 @@ def _eval_deriv_mira( dX: State[torch.Tensor] = dict() for i, var in enumerate(src.variables.values()): k = get_name(var) - dX[k] = numeric_deriv[i] + dX[k] = numeric_deriv[..., i] return dX @@ -146,7 +146,7 @@ def _eval_initial_state_mira( X: State[torch.Tensor] = dict() for i, var in enumerate(src.variables.values()): k = get_name(var) - X[k] = numeric_initial_state[i] + X[k] = numeric_initial_state[..., i] return X