Skip to content

Commit

Permalink
experiment runner
Browse files Browse the repository at this point in the history
  • Loading branch information
djpasseyjr committed Feb 29, 2024
1 parent 0b3e31a commit 611a08e
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 75 deletions.
60 changes: 35 additions & 25 deletions experiments/exp1/runner/exp_tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Tools for saving progress and picking up where the last experiment ended."""

import os
from pathlib import Path
import pickle as pkl
from typing import Any, Dict, List
Expand All @@ -10,6 +11,22 @@
PARAM_DIR = Path(__file__).parents[1] / "parameters"
DYN_PARAM_FNAME = "dynamic_models"
METH_PARAM_FNAME = "inference_methods"
FILE_PREFIX = "exp1"
SAVE_DIR = Path(os.getcwd())

def save_file_path(idx):
return SAVE_DIR / Path(f"{FILE_PREFIX}_output{idx}.pkl")

def is_incomplete(idx: str):
i = int(idx)
dyn_ps, _ = load_parameters()
results = load_results(i, dyn_ps[i])
print(not results.get("complete", False))


def print_num_jobs():
dyn_ps, _ = load_parameters()
print(len(dyn_ps))


def load_parameters(test=False):
Expand All @@ -27,20 +44,19 @@ def load_parameters(test=False):
return dyn_ps, meth_ps


def save_results_dict(results_dict, outfile: str,):
def save_results_dict(results_dict, outfile_idx: str,):
# Path to results file.
p = Path(outfile)
p = save_file_path(outfile_idx)
# Load the output pickle file
with open(p, "wb") as f:
output = pkl.dump(results_dict, f)
pkl.dump(results_dict, f)

def check_for_existing_results(outfile: str):
def check_for_existing_results(outfile_idx: str):
"""Loads a pickle object or returns None if the file doesn't exists."""

# Check for experiment output file
p = Path(outfile)
p = save_file_path(outfile_idx)
if not p.exists():
print(f"No output file found at: {p}")
return None

# Load the output pickle file
Expand All @@ -49,10 +65,10 @@ def check_for_existing_results(outfile: str):
return output


def load_results(outfile: str, dynamic_model_args: dict):
def load_results(outfile_idx: str, dynamic_model_args: dict):
"""Loads existing results or makes a new template dict"""

results = check_for_existing_results(outfile)
results = check_for_existing_results(outfile_idx)

if results is None:
model_name = dynamic_model_args["model_type"].__name__
Expand All @@ -63,7 +79,7 @@ def load_results(outfile: str, dynamic_model_args: dict):
**dynamic_model_args
}
# Save empty results dict.
save_results_dict(results, outfile)
save_results_dict(results, outfile_idx)

return results

Expand All @@ -89,13 +105,13 @@ def load_dynamic_sim(results_dict):
return Xs, X_dos, t


def run_dynamics(dyn_args: dict, results_dict: dict, outfile: str):
def run_dynamics(dyn_args: dict, results_dict: dict, outfile_idx: str):
"""Checks for existing sim data, and if none exists, runs the simulation."""

if not results_dict["dynamic_sim_complete"]:
Xs, X_dos, t = interfere.generate_counterfactual_forecasts(**dyn_args)
store_dynamic_model_outputs(Xs, X_dos, t, results_dict)
save_results_dict(results_dict, outfile)
save_results_dict(results_dict, outfile_idx)
else:
Xs, X_dos, t = load_dynamic_sim(results_dict)

Expand All @@ -113,12 +129,11 @@ def run_forecasts(
intervention: interfere.interventions.ExogIntervention,
method_args: dict,
results_dict: dict,
outfile: str,
outfile_idx: str,
opt_all: bool = True
):
"""Accepts the output of `run_dynamics` along with `method_args` a
dictionary of inference method arguments, the results_dict, and outfile
path.
dictionary of inference method arguments, the results_dict, and outfile_idx.
The `opt_all` argument controls whether hyperparameter optimization
happens once, for the first simulation, and those parameters are used for
Expand Down Expand Up @@ -163,11 +178,11 @@ def run_forecasts(
# Store progress and save every five time series.
if i % 5 == 0:
method_progress["best_params"] = best_params
save_results_dict(results_dict, outfile)
save_results_dict(results_dict, outfile_idx)

# Mark complete and save
method_progress["complete"] = True
save_results_dict(results_dict, outfile)
save_results_dict(results_dict, outfile_idx)


def load_method_progress(method_name: str, results_dict: dict):
Expand All @@ -192,10 +207,9 @@ def method_progress_templ(method_name):
}


def check_consistency(
outfile: str, dyn_args: dict, exp_idx: int, opt_all: bool):
def check_consistency(dyn_args: dict, exp_idx: int, opt_all: bool):
"""Checks outfile to make sure everything is correct."""
results = load_results(outfile, dyn_args)
results = load_results(exp_idx, dyn_args)
dyn_params_match(results, dyn_args, exp_idx)

# If output contains no dynamics, ensure that there are no methods either.
Expand All @@ -221,10 +235,6 @@ def check_consistency(
"set to false. Should be only one set."
f" Command line arg: {exp_idx}")






def args_are_equal(arg1, arg2):
if type(arg1) != type(arg2):
Expand Down Expand Up @@ -253,7 +263,7 @@ def dyn_params_match(result_dict, dyn_args, exp_idx):
)


def finish(results, outfile):
def finish(results, outfile_idx):
"""Wrap up."""
results["complete"] = True
save_results_dict(results, outfile)
save_results_dict(results, outfile_idx)
12 changes: 5 additions & 7 deletions experiments/exp1/runner/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,11 @@
import exp_tools

DIR_PATH = os.path.dirname(os.path.realpath(__file__))
FILE_PREFIX = "exp1"

# Parse command line argument that designates the index of the hyper parameters.
PARAM_IDX = int(sys.argv[1])

# Save file name.
SAVE_FILE = f"{FILE_PREFIX}_output{PARAM_IDX}.pkl"

# Toggles the amount of hyper parameter optimization. See exp_tools.run_forecast
OPT_ALL = False
Expand All @@ -28,19 +26,19 @@
dyn_args = dyn_args_list[PARAM_IDX]

# Check that the previous save file (if any) is consistent with experiment.
exp_tools.check_consistency(SAVE_FILE, dyn_args, PARAM_IDX, OPT_ALL)
exp_tools.check_consistency(dyn_args, PARAM_IDX, OPT_ALL)

# Load result file from previous runs, or make an empty one.
results = exp_tools.load_results(SAVE_FILE, dyn_args)
results = exp_tools.load_results(PARAM_IDX, dyn_args)

# Run the dyanamic simulations.
dyn_sim_output = exp_tools.run_dynamics(dyn_args, results, SAVE_FILE)
dyn_sim_output = exp_tools.run_dynamics(dyn_args, results, PARAM_IDX)

# Loop over each infernce method.
for margs in method_arg_list:
# Tune hyper parameters, run forecasts and store results.
exp_tools.run_forecasts(
*dyn_sim_output, margs, results, SAVE_FILE, opt_all=OPT_ALL)
*dyn_sim_output, margs, results, PARAM_IDX, opt_all=OPT_ALL)

# Final save.
exp_tools.finish(results, SAVE_FILE)
exp_tools.finish(results, PARAM_IDX)
Loading

0 comments on commit 611a08e

Please sign in to comment.