Skip to content

Commit

Permalink
Add func to init RNG and config for all scripts
Browse files Browse the repository at this point in the history
  • Loading branch information
sgreenbury committed Sep 17, 2024
1 parent 5383a3f commit c9d3b40
Show file tree
Hide file tree
Showing 8 changed files with 37 additions and 20 deletions.
9 changes: 3 additions & 6 deletions scripts/1_prep_synthpop.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import numpy as np
from uatk_spc.builder import Builder

import acbm
from acbm.assigning.cli import acbm_cli
from acbm.utils import get_config
from acbm.utils import get_config, init_rng


@acbm_cli
def main(config_file):
config = get_config(config_file)
seed = config["parameters"]["seed"]
region = config["parameters"]["region"]
init_rng(config)

# Seed RNG
np.random.seed(seed)
region = config["parameters"]["region"]

# Pick a region with SPC output saved
path = acbm.root_path / "data/external/spc_output/raw/"
Expand Down
8 changes: 3 additions & 5 deletions scripts/2_match_households_and_individuals.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,15 @@
transform_by_group,
truncate_values,
)
from acbm.utils import get_config
from acbm.utils import get_config, init_rng


@acbm_cli
def main(config_file):
config = get_config(config_file)

# Seed RNG
SEED = config["parameters"]["seed"]
np.random.seed(SEED)
init_rng(config)

pd.set_option("display.max_columns", None)

Expand Down Expand Up @@ -86,13 +85,12 @@ def get_interim_path(

# Identify unique households
unique_households = spc["household"].unique()
# Sample a subset of households
# Sample a subset of households, RNG seeded above with `init_rng``
sampled_households = pd.Series(unique_households).sample(
n=min(
config["parameters"]["number_of_households"],
unique_households.shape[0],
),
random_state=SEED,
)
# Filter the original DataFrame based on the sampled households
spc = spc[spc["household"].isin(sampled_households)]
Expand Down
5 changes: 2 additions & 3 deletions scripts/3.1_assign_primary_feasible_zones.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import pickle as pkl

import geopandas as gpd
import numpy as np
import pandas as pd

import acbm
Expand All @@ -15,13 +14,13 @@
)
from acbm.logger_config import assigning_primary_feasible_logger as logger
from acbm.preprocessing import add_locations_to_activity_chains
from acbm.utils import get_config
from acbm.utils import get_config, init_rng


@acbm_cli
def main(config_file):
config = get_config(config_file)
np.random.seed(config["parameters"]["seed"])
init_rng(config)

#### LOAD DATA ####

Expand Down
7 changes: 6 additions & 1 deletion scripts/3.2.1_assign_primary_zone_edu.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,15 @@
)
from acbm.logger_config import assigning_primary_zones_logger as logger
from acbm.preprocessing import add_location
from acbm.utils import get_config, init_rng


@acbm_cli
def main():
def main(config_file):
config = get_config(config_file)
# TODO: consider if RNG seed needs to be distinct for different assignments
init_rng(config)

#### LOAD DATA ####

logger.info("Loading data")
Expand Down
7 changes: 5 additions & 2 deletions scripts/3.2.2_assign_primary_zone_work.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@
from acbm.assigning.utils import filter_matrix_to_boundary
from acbm.logger_config import assigning_primary_zones_logger as logger
from acbm.preprocessing import add_locations_to_activity_chains
from acbm.utils import calculate_rmse
from acbm.utils import calculate_rmse, get_config, init_rng


@acbm_cli
def main(_config_file):
def main(config_file):
config = get_config(config_file)
init_rng(config)

#### LOAD DATA ####

# --- Possible zones for each activity (calculated in 3.1_assign_possible_zones.py)
Expand Down
7 changes: 5 additions & 2 deletions scripts/3.2.3_assign_secondary_zone.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,15 @@
)
from acbm.logger_config import assigning_secondary_zones_logger as logger
from acbm.preprocessing import add_location
from acbm.utils import get_config, init_rng


@acbm_cli
def main(_config_file):
# --- Load in the data
def main(config_file):
config = get_config(config_file)
init_rng(config)

# --- Load in the data
logger.info("Loading: activity chains")

activity_chains = pd.read_parquet(
Expand Down
6 changes: 5 additions & 1 deletion scripts/3.3_assign_facility_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@
from acbm.assigning.plots import plot_desire_lines, plot_scatter_actual_reported
from acbm.assigning.select_facility import map_activity_locations, select_facility
from acbm.logger_config import assigning_facility_locations_logger as logger
from acbm.utils import get_config, init_rng


@acbm_cli
def main(_config_file):
def main(config_file):
config = get_config(config_file)
init_rng(config)

# --- Load data: activity chains
logger.info("Loading activity chains")

Expand Down
8 changes: 8 additions & 0 deletions src/acbm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,14 @@ def get_config(config: str) -> dict[Any, Any]:
return tomlkit.load(f)


def init_rng(config: dict):
try:
np.random.seed(config["parameters"]["seed"])
except Exception as err:
msg = f"config does not provide a rng seed with err: {err}"
ValueError(msg)


def prepend_datetime(s: str, delimiter: str = "_") -> str:
current_date = datetime.now().strftime("%Y-%m-%d")
return f"{current_date}{delimiter}{s}"
Expand Down

0 comments on commit c9d3b40

Please sign in to comment.