Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample should allow interventions on constant variables #588

Closed
sabinala opened this issue Jul 9, 2024 · 3 comments · Fixed by #597
Closed

Sample should allow interventions on constant variables #588

sabinala opened this issue Jul 9, 2024 · 3 comments · Fixed by #597

Comments

@sabinala
Copy link
Contributor

sabinala commented Jul 9, 2024

This intervention came from a policy optmization:static_parameter_interventions(optimization_result["policy"]) = {27.3639: {'NPI_mult': tensor(0.5000)}}, but when I go to sample from the model with the intervention applied, I get the following error if NPI_mult is a constant parameter (but not when it is defined with uncertainty).

opt_intervention_result = pyciemss.sample(
    model,
    end_time,
    logging_step_size,
    num_samples,
    start_time=start_time,
    inferred_parameters=parameter_estimates,
    static_parameter_interventions=static_parameter_interventions(optimization_result["policy"]),
)

# Check risk estimate used in constraints
print("Risk associated with QoI:", opt_intervention_result["risk"][observed_params[0]]["risk"])

# Plot results for observables
schema = plots.trajectories(opt_intervention_result["data"], keep=".*observable_state", qlow=0.0, qhigh=1.0)
plots.ipy_display(schema, dpi=150)

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
Cell In[59], line 1
----> 1 opt_intervention_result = pyciemss.sample(
      2     model,
      3     end_time,
      4     logging_step_size,
      5     num_samples,
      6     start_time=start_time,
      7     inferred_parameters=parameter_estimates,
      8     static_parameter_interventions=static_parameter_interventions(optimization_result["policy"]),
      9 )
     11 # Check risk estimate used in constraints
     12 print("Risk associated with QoI:", opt_intervention_result["risk"][observed_params[0]]["risk"])

File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:29, in pyciemss_logging_wrapper.<locals>.wrapped(*args, **kwargs)
     17 log_message = """
     18     ###############################
     19 
   (...)
     26     ################################
     27 """
     28 logging.exception(log_message, function.__name__, function.__doc__)
---> 29 raise e

File ~/Projects/pyciemss/pyciemss/integration_utils/custom_decorators.py:10, in pyciemss_logging_wrapper.<locals>.wrapped(*args, **kwargs)
      8 try:
      9     start_time = time.perf_counter()
---> 10     result = function(*args, **kwargs)
     11     end_time = time.perf_counter()
     12     logging.info(
     13         "Elapsed time for %s: %f", function.__name__, end_time - start_time
     14     )

File ~/Projects/pyciemss/pyciemss/interfaces.py:542, in sample(model_path_or_json, end_time, logging_step_size, num_samples, noise_model, noise_model_kwargs, solver_method, solver_options, start_time, time_unit, inferred_parameters, static_state_interventions, static_parameter_interventions, dynamic_state_interventions, dynamic_parameter_interventions, alpha)
    538         sq_est = alpha_superquantile(qoi_sample, alpha=alpha)
    539         risk_results.update({k: {"risk": [sq_est], "qoi": qoi_sample}})
    541 return {
--> 542     **prepare_interchange_dictionary(
    543         samples, timepoints=logging_times, time_unit=time_unit
    544     ),
    545     "risk": risk_results,
    546 }

File ~/Projects/pyciemss/pyciemss/integration_utils/result_processing.py:49, in prepare_interchange_dictionary(samples, time_unit, timepoints, visual_options, ensemble_quantiles, alpha_qs, stacking_order)
     38 def prepare_interchange_dictionary(
     39     samples: Dict[str, torch.Tensor],
     40     time_unit: Optional[str] = None,
   (...)
     45     stacking_order: str = "timepoints",
     46 ) -> Dict[str, Any]:
     47     samples = {k: (v.squeeze() if len(v.shape) > 2 else v) for k, v in samples.items()}
---> 49     processed_samples, quantile_results = convert_to_output_format(
     50         samples,
     51         time_unit=time_unit,
     52         timepoints=timepoints,
     53         ensemble_quantiles=ensemble_quantiles,
     54         alpha_qs=alpha_qs,
     55         stacking_order=stacking_order,
     56     )
     58     result = {"data": processed_samples, "unprocessed_result": samples}
     59     if ensemble_quantiles:

File ~/Projects/pyciemss/pyciemss/integration_utils/result_processing.py:159, in convert_to_output_format(samples, time_unit, timepoints, ensemble_quantiles, alpha_qs, stacking_order)
    155 # Set values to reflect interventions in time-order
    156 for name, values in sorted(
    157     intervention_data.items(), key=lambda item: int(item[0].split("_")[-1])
    158 ):
--> 159     result = set_intervention_values(result, name, values, intervention_times)
    161 if ensemble_quantiles:
    162     result_quantiles = make_quantiles(
    163         pyciemss_results,
    164         alpha_qs=alpha_qs,
   (...)
    167         stacking_order=stacking_order,
    168     )

File ~/Projects/pyciemss/pyciemss/integration_utils/result_processing.py:374, in set_intervention_values(df, intervention, intervention_values, intervention_times)
    371 target_col = f"persistent_{target_var}_param"
    373 if target_col not in df.columns:
--> 374     raise KeyError(f"Could not find target column for '{target_var}'")
    376 time_col = [
    377     c for c in df.columns if c.startswith("timepoint_") and c != "timepoint_id"
    378 ][0]
    380 def rework(group):

KeyError: "Could not find target column for 'NPI_mult'"
@SamWitty
Copy link
Contributor

@sabinala , can you create a draft PR with a test that fails because of this issue?

I suspect this is actually an issue in the output processing when the parameter is deterministic, and that the intervention is being correctly applied in the simulation.

@sabinala
Copy link
Contributor Author

@SamWitty I think you're correct about this being an issue with output processing

@sabinala sabinala linked a pull request Jul 30, 2024 that will close this issue
@sabinala
Copy link
Contributor Author

@SamWitty see here: #594

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
2 participants