Skip to content

Commit

Permalink
resolving conflicts on interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
anirban-chaudhuri committed Jul 13, 2023
1 parent 5643749 commit e87fb9c
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 100 deletions.
107 changes: 67 additions & 40 deletions src/pyciemss/Ensemble/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,13 @@
import pandas as pd

from pyro.infer import Predictive
from pyro import poutine

from pyciemss.interfaces import (
setup_model,
reset_model,
intervene,
sample,
calibrate,
optimize,
DynamicalSystem,
)

Expand All @@ -26,13 +24,13 @@

from typing import Iterable, Optional, Tuple, Callable, Union
import copy
from pyciemss.visuals import plots

# TODO: probably refactor this out later.
from pyciemss.PetriNetODE.events import (
StartEvent,
ObservationEvent,
LoggingEvent,
StaticParameterInterventionEvent,
)

EnsembleSolution = Iterable[dict[str, torch.Tensor]]
Expand All @@ -54,8 +52,8 @@ def load_and_sample_petri_ensemble(
dirichlet_concentration: float = 1.0,
start_time: float = -1e-10,
method="dopri5",
alpha_qs: Optional[Iterable[float]] = [0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99],
stacking_order: Optional[str] = "timepoints",
time_unit: Optional[str] = None,
visual_options: Union[None, bool, dict[str, any]] = None,
) -> pd.DataFrame:
"""
Load a petri net from a file, compile it into a probabilistic program, and sample from it.
Expand All @@ -70,12 +68,14 @@ def load_and_sample_petri_ensemble(
- By convention these weights should sum to 1.0.
solution_mappings: Iterable[Callable]
- A list of functions that map the output of the model to the output of the shared state space.
- Each element of the iterable is a function that takes in a model output and returns a dict of the form {variable_name: value}.
- Each element of the iterable is a function that takes in a model output
and returns a dict of the form {variable_name: value}.
- The order of the functions should match the order of the models.
num_samples: int
- The number of samples to draw from the model.
timepoints: [Iterable[float]]
- The timepoints to simulate the model from. Backcasting and/or forecasting is reflected in the choice of timepoints.
- The timepoints to simulate the model from. Backcasting and/or forecasting is reflected
in the choice of timepoints.
start_states: Optional[Iterable[dict[str, float]]]
- Each element of the iterable is the initial state of the component model.
- If None, the initial state is taken from each of the mira models.
Expand All @@ -90,19 +90,23 @@ def load_and_sample_petri_ensemble(
- Larger values of dirichlet_concentration correspond to more certainty about the weights.
start_time: float
- The start time of the model. This is used to align the `start_state` with the `timepoints`.
- By default we set the `start_time` to be a small negative number to avoid numerical issues w/ collision with the `timepoints` which typically start at 0.
- By default we set the `start_time` to be a small negative number to avoid numerical
issues w/ collision with the `timepoints` which typically start at 0.
method: str
- The method to use for solving the ODE. See torchdiffeq's `odeint` method for more details.
- If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff.
alpha_qs: Optional[Iterable[float]]
- The quantiles required for estimating weighted interval score to test ensemble forecasting accuracy.
stacking_order: Optional[str]
- The stacking order requested for the ensemble quantiles to keep the selected quantity together for each state.
- Options: "timepoints" or "quantiles"
- If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results
in faster simulation, the issue is likely that the model is stiff.
time_unit: str
- Time unit (used for labeling outputs)
visual_options: None, bool, dict[str, any]
- True output a visual
- False do not output a visual
- dict output a visual with the dictionary passed to the visualization as kwargs
Returns:
samples: PetriSolution
- The samples from the model as a pandas DataFrame.
samples:
- PetriSolution: The samples from the model as a pandas DataFrame. (for falsy visual_options)
- dict{data: <samples>, visual: <visual>}: The PetriSolution and a visualization (for truthy visual_options)
"""
models = [
load_petri_model(
Expand Down Expand Up @@ -137,9 +141,17 @@ def load_and_sample_petri_ensemble(
num_samples,
method=method,
)
processed_samples, q_ensemble = convert_to_output_format(samples, timepoints, ensemble_quantiles=True, alpha_qs=alpha_qs, stacking_order=stacking_order)
processed_samples = convert_to_output_format(
samples, timepoints, time_unit=time_unit
)

if visual_options:
visual_options = {} if visual_options is True else visual_options
schema = plots.trajectories(processed_samples, **visual_options)
return {"data": processed_samples, "visual": schema}
else:
return processed_samples

return processed_samples, q_ensemble

def load_and_calibrate_and_sample_ensemble_model(
petri_model_or_paths: Iterable[
Expand All @@ -163,8 +175,8 @@ def load_and_calibrate_and_sample_ensemble_model(
num_particles: int = 1,
autoguide=pyro.infer.autoguide.AutoLowRankMultivariateNormal,
method="dopri5",
alpha_qs: Optional[Iterable[float]] = [0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99],
stacking_order: Optional[str] = "timepoints",
time_unit: Optional[str] = None,
visual_options: Union[None, bool, dict[str, any]] = None,
) -> pd.DataFrame:
"""
Load a collection petri net from a file, compile them into an ensemble probabilistic program, calibrate it on data,
Expand All @@ -176,20 +188,24 @@ def load_and_calibrate_and_sample_ensemble_model(
- This path can be a URL or a local path to a mira model or AMR model.
- Alternatively, this can be a mira template model directly.
data_path: str
- The path to the data to calibrate the model to. See notebook/integration_demo/data.csv for an example of the format.
- The path to the data to calibrate the model to. See notebook/integration_demo/data.csv
for an example of the format.
- The data should be a csv with one column for "time" and remaining columns for each state variable.
- Each state variable must exactly align with the state variables in the shared ensemble representation. (See `solution_mappings` for more details.)
- Each state variable must exactly align with the state variables in the shared ensemble representation.
(See `solution_mappings` for more details.)
weights: Iterable[float]
- Weights representing prior belief about which models are more likely to be correct.
- By convention these weights should sum to 1.0.
solution_mappings: Iterable[Callable]
- A list of functions that map the output of the model to the output of the shared state space.
- Each element of the iterable is a function that takes in a model output and returns a dict of the form {variable_name: value}.
- Each element of the iterable is a function that takes in a model output and returns a dict of
the form {variable_name: value}.
- The order of the functions should match the order of the models.
num_samples: int
- The number of samples to draw from the model.
timepoints: [Iterable[float]]
- The timepoints to simulate the model from. Backcasting and/or forecasting is reflected in the choice of timepoints.
- The timepoints to simulate the model from. Backcasting and/or forecasting is reflected
in the choice of timepoints.
start_states: Optional[Iterable[dict[str, float]]]
- Each element of the iterable is the initial state of the component model.
- If None, the initial state is taken from each of the mira models.
Expand All @@ -204,35 +220,39 @@ def load_and_calibrate_and_sample_ensemble_model(
- Larger values of dirichlet_concentration correspond to more certainty about the weights.
start_time: float
- The start time of the model. This is used to align the `start_state` with the `timepoints`.
- By default we set the `start_time` to be a small negative number to avoid numerical issues w/ collision with the `timepoints` which typically start at 0.
- By default we set the `start_time` to be a small negative number to avoid numerical issues
w/ collision with the `timepoints` which typically start at 0.
num_iterations: int
- The number of iterations to run the calibration for.
lr: float
- The learning rate to use for the calibration.
verbose: bool
- Whether to print out the calibration progress. This will include summaries of the evidence lower bound (ELBO) and the parameters.
- Whether to print out the calibration progress. This will include summaries of the evidence lower
bound (ELBO) and the parameters.
verbose_every: int
- How often to print out the loss during calibration.
num_particles: int
- The number of particles to use for the calibration. Increasing this value will result in lower variance gradient estimates, but will also increase the computational cost per gradient step.
- The number of particles to use for the calibration. Increasing this value will result in lower variance
gradient estimates, but will also increase the computational cost per gradient step.
autoguide: pyro.infer.autoguide.AutoGuide
- The autoguide to use for the calibration.
method: str
- The method to use for the ODE solver. See `torchdiffeq.odeint` for more details.
- If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results in faster simulation, the issue is likely that the model is stiff.
alpha_qs: Optional[Iterable[float]]
- The quantiles required for estimating weighted interval score to test ensemble forecasting accuracy.
stacking_order: Optional[str]
- The stacking order requested for the ensemble quantiles to keep the selected quantity together for each state.
- Options: "timepoints" or "quantiles"
- If performance is incredibly slow, we suggest using `euler` to debug. If using `euler` results
in faster simulation, the issue is likely that the model is stiff.
time_unit: str
- Time unit (used for labeling outputs)
visual_options: None, bool, dict[str, any]
- True output a visual
- False do not output a visual
- dict output a visual with the dictionary passed to the visualization as kwargs
Returns:
samples: pd.DataFrame
- A dataframe containing the samples from the calibrated model.
samples:
- PetriSolution: The samples from the model as a pandas DataFrame. (for falsy visual_options)
- dict{data: <samples>, visual: <visual>}: The PetriSolution and a visualization (for truthy visual_options)
"""


data = csv_to_list(data_path)

models = [
Expand Down Expand Up @@ -282,10 +302,17 @@ def load_and_calibrate_and_sample_ensemble_model(
method=method,
)

processed_samples, q_ensemble = convert_to_output_format(samples, timepoints, ensemble_quantiles=True, alpha_qs=alpha_qs, stacking_order=stacking_order)
processed_samples = convert_to_output_format(
samples, timepoints, time_unit=time_unit
)

if visual_options:
visual_options = {} if visual_options is True else visual_options
schema = plots.trajectories(processed_samples, **visual_options)
return {"data": processed_samples, "visual": schema}
else:
return processed_samples

return processed_samples, q_ensemble


##############################################################################
# Internal Interfaces Below - TA4 above
Expand Down
102 changes: 42 additions & 60 deletions src/pyciemss/utils/interface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@ def convert_to_output_format(
samples: Dict[str, torch.Tensor],
timepoints: Iterable[float],
interventions: Optional[Dict[str, torch.Tensor]] = None,
ensemble_quantiles: Optional[bool] = False,
alpha_qs: Optional[Iterable[float]] = [0.01, 0.025, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55, 0.6, 0.65, 0.7, 0.75, 0.8, 0.85, 0.9, 0.95, 0.975, 0.99],
num_ensemble_quantiles: Optional[int] = 0,
stacking_order: Optional[str] = "timepoints"
*,
time_unit: Optional[str] = "(unknown)",
) -> pd.DataFrame:
"""
Convert the samples from the Pyro model to a DataFrame in the TA4 requested format.
time_unit -- Label timepoints in a semantically relevant way `timepoint_<time_unit>`.
If None, a `timepoint_<time_unit>` field is not provided.
"""

pyciemss_results = {"parameters": {}, "states": {}}
Expand All @@ -34,11 +35,7 @@ def convert_to_output_format(
n_models = sample.shape[1]
for i in range(n_models):
pyciemss_results["parameters"][f"model_{i}_weight"] = (
sample[:, i]
.data.detach()
.cpu()
.numpy()
.astype(np.float64)
sample[:, i].data.detach().cpu().numpy().astype(np.float64)
)
else:
pyciemss_results["states"][name] = (
Expand All @@ -63,7 +60,9 @@ def convert_to_output_format(
else:
d = {
**d,
**assign_interventions_to_timepoints(interventions, timepoints, pyciemss_results["parameters"])
**assign_interventions_to_timepoints(
interventions, timepoints, pyciemss_results["parameters"]
),
}

# Solution (state variables)
Expand All @@ -75,39 +74,12 @@ def convert_to_output_format(
},
}

if ensemble_quantiles:
key_list = ["timepoint_id", "target", "type", "quantile", "value"]
q = {k: [] for k in key_list}
if alpha_qs is None:
alpha_qs = np.linspace(0, 1, num_ensemble_quantiles)
alpha_qs[0] = 0.01
alpha_qs[-1] = 0.99
else:
num_ensemble_quantiles = len(alpha_qs)

# Solution (state variables)
for k, v in pyciemss_results["states"].items():
q_vals = np.quantile(v, alpha_qs, axis=0)
k = k.replace("_sol","")
if stacking_order == "timepoints":
# Keeping timepoints together
q["timepoint_id"].extend(list(np.repeat(np.array(range(num_timepoints)), num_ensemble_quantiles)))
q["target"].extend([k]*num_timepoints*num_ensemble_quantiles)
q["type"].extend(["quantile"]*num_timepoints*num_ensemble_quantiles)
q["quantile"].extend(list(np.tile(alpha_qs, num_timepoints)))
q["value"].extend(list(np.squeeze(q_vals.T.reshape((num_timepoints * num_ensemble_quantiles, 1)))))
elif stacking_order == "quantiles":
# Keeping quantiles together
q["timepoint_id"].extend(list(np.tile(np.array(range(num_timepoints)), num_ensemble_quantiles)))
q["target"].extend([k]*num_timepoints*num_ensemble_quantiles)
q["type"].extend(["quantile"]*num_timepoints*num_ensemble_quantiles)
q["quantile"].extend(list(np.repeat(alpha_qs, num_timepoints)))
q["value"].extend(list(np.squeeze(q_vals.reshape((num_timepoints * num_ensemble_quantiles, 1)))))
else:
raise Exception("Incorrect input for stacking_order.")
return pd.DataFrame(d), pd.DataFrame(q)
else:
return pd.DataFrame(d)
result = pd.DataFrame(d)
if time_unit is not None:
all_timepoints = result["timepoint_id"].map(lambda v: timepoints[v])
result = result.assign(**{f"timepoint_{time_unit}": all_timepoints})

return result


def csv_to_list(filename):
Expand All @@ -124,35 +96,43 @@ def csv_to_list(filename):
return result


def interventions_and_sampled_params_to_interval(interventions: dict, sampled_params: dict) -> dict:
def interventions_and_sampled_params_to_interval(
interventions: dict, sampled_params: dict
) -> dict:
"""Convert interventions and sampled parameters to dict of intervals.
:param interventions: dict keyed by parameter name where each value is a tuple (intervention_time, value)
:param sampled_params: dict keyed by param where each value is an array of sampled parameter values
:return: dict keyed by param where the values lists of intervals and values sorted by start time
"""
# assign each sampled parameter to an infinite interval
param_dict = {param: [dict(start=-np.inf, end=np.inf, param_values=value)]
for param, value in sampled_params.items()}

param_dict = {
param: [dict(start=-np.inf, end=np.inf, param_values=value)]
for param, value in sampled_params.items()
}

# sort the interventions by start time
for start, param, intervention_value in sorted(interventions):

# update the end time of the previous interval
param_dict[f"{param}_param"][-1]['end'] = start
param_dict[f"{param}_param"][-1]["end"] = start

# add new interval and broadcast the intevention value to the size of the sampled parameters
param_dict[f"{param}_param"].append(
dict(start=start, end=np.inf, param_values=[intervention_value]*len(sampled_params[f"{param}_param"])))

dict(
start=start,
end=np.inf,
param_values=[intervention_value]
* len(sampled_params[f"{param}_param"]),
)
)

# sort intervals by start time
return {
k: sorted(v, key=lambda x: x['start'])
for k, v in param_dict.items()
}
return {k: sorted(v, key=lambda x: x["start"]) for k, v in param_dict.items()}


def assign_interventions_to_timepoints(interventions: dict, timepoints: Iterable[float], sampled_params: dict) -> dict:
def assign_interventions_to_timepoints(
interventions: dict, timepoints: Iterable[float], sampled_params: dict
) -> dict:
"""Assign the value of each parameter to every timepoint, taking into account interventions.
:param interventions: dict keyed by parameter name where each value is a tuple (intervention_time, value)
Expand All @@ -161,11 +141,13 @@ def assign_interventions_to_timepoints(interventions: dict, timepoints: Iterable
:return: dict keyed by param where the values are sorted by sample then timepoint
"""
# transform interventions and sampled parameters into intervals
param_interval_dict = interventions_and_sampled_params_to_interval(interventions, sampled_params)
param_interval_dict = interventions_and_sampled_params_to_interval(
interventions, sampled_params
)
result = {}
for param, interval_dict in param_interval_dict.items():
intervals = [(d['start'], d['end']) for d in interval_dict]
param_values = [d['param_values'] for d in interval_dict]
intervals = [(d["start"], d["end"]) for d in interval_dict]
param_values = [d["param_values"] for d in interval_dict]

# generate list of parameter values at each timepoint
result[param] = []
Expand Down

0 comments on commit e87fb9c

Please sign in to comment.