From 032fcf0710973a76fcbf1476659b1ba36449b1a5 Mon Sep 17 00:00:00 2001 From: Sam Greenbury Date: Fri, 13 Sep 2024 17:58:58 +0100 Subject: [PATCH] Fix use of config, rename CLI args --- scripts/1_prep_synthpop.py | 16 ++++++++-------- scripts/2_match_households_and_individuals.py | 15 ++++++++++----- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/scripts/1_prep_synthpop.py b/scripts/1_prep_synthpop.py index 0a1a9c9..7a14203 100644 --- a/scripts/1_prep_synthpop.py +++ b/scripts/1_prep_synthpop.py @@ -1,21 +1,21 @@ import click import numpy as np -import tomlkit from uatk_spc.builder import Builder import acbm +from acbm.utils import get_config @click.command() # TODO: add override for case when seed provided from CLI # @click.option("--seed", default=1, help="Seed for random state", type=int) -@click.option("--config", prompt="Filepath relative to repo root of config", type=str) -def main(config): - # Read config - with open(acbm.root_path / config, "rb") as f: - config_dict = tomlkit.load(f) - seed = config_dict["parameters"]["seed"] - region = config_dict["parameters"]["region"] +@click.option( + "--config_file", prompt="Filepath relative to repo root of config", type=str +) +def main(config_file): + config = get_config(config_file) + seed = config["parameters"]["seed"] + region = config["parameters"]["region"] # Seed RNG np.random.seed(seed) diff --git a/scripts/2_match_households_and_individuals.py b/scripts/2_match_households_and_individuals.py index 4598db8..cf3ac0b 100644 --- a/scripts/2_match_households_and_individuals.py +++ b/scripts/2_match_households_and_individuals.py @@ -23,12 +23,14 @@ @click.command() -@click.option("--config", prompt="Filepath relative to repo root of config", type=str) -def main(config): - config_loaded = get_config(config) +@click.option( + "--config_file", prompt="Filepath relative to repo root of config", type=str +) +def main(config_file): + config = get_config(config_file) # Seed RNG - SEED = config_loaded["seed"] + SEED = config["parameters"]["seed"] np.random.seed(SEED) pd.set_option("display.max_columns", None) @@ -89,7 +91,10 @@ def get_interim_path( unique_households = spc["household"].unique() # Sample a subset of households sampled_households = pd.Series(unique_households).sample( - n=min(config_loaded["number_of_households"], unique_households), + n=min( + config["parameters"]["number_of_households"], + unique_households.shape[0], + ), random_state=SEED, ) # Filter the original DataFrame based on the sampled households