From d6838e72bdc145b2f87ab9e33e220eb84fd87e87 Mon Sep 17 00:00:00 2001 From: sabinala <130604122+sabinala@users.noreply.github.com> Date: Wed, 6 Mar 2024 15:53:25 -0800 Subject: [PATCH] adding testing check_data function to load_data (#458) * adding testing check_data function to load_data * reformatting * fixing data checker * adding test for load_data * Update observation.py * Update observation.py * Update observation.py Cleaning up data report --- pyciemss/integration_utils/observation.py | 54 ++++++++++++++++++++--- tests/fixtures.py | 35 +++++++++++++++ tests/test_interfaces.py | 13 ++++++ 3 files changed, 97 insertions(+), 5 deletions(-) diff --git a/pyciemss/integration_utils/observation.py b/pyciemss/integration_utils/observation.py index 7e24b6d37..f48061f28 100644 --- a/pyciemss/integration_utils/observation.py +++ b/pyciemss/integration_utils/observation.py @@ -1,4 +1,4 @@ -from typing import Dict, Tuple +from typing import Dict, Tuple, Union import pandas as pd import torch @@ -9,19 +9,63 @@ def load_data( - path: str, data_mapping: Dict[str, str] = {} + path: Union[str, pd.DataFrame], data_mapping: Dict[str, str] = {} ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]: """ - Load data from a CSV file. + Load data from a CSV file, or directly from a DataFrame. - - path: path to the CSV file + - path: path to the CSV file, or DataFrame - data_mapping: A mapping from column names in the data file to state variable names in the model. - keys: str name of column in dataset - values: str name of state/observable in model - If not provided, we will assume that the column names in the data file match the state variable names. """ - df = pd.read_csv(path) + def check_data(data_path: Union[str, pd.DataFrame]): + # This function checks a dataset for formatting errors, and returns a DataFrame + + # Read the data + if isinstance(data_path, str): + # If data_path is a string, assume it's a file path and read as a csv + data_df = pd.read_csv(data_path) + elif isinstance(data_path, pd.DataFrame): + # If data_path is a DataFrame, use it directly + data_df = data_path + else: + # If data_path is neither a string nor a DataFrame, raise an error + raise ValueError("data_path must be either a file path or a DataFrame.") + + # Check that the first column name is "Timestamp" + if data_df.columns[0] != "Timestamp": + raise ValueError( + "The first column must be named 'Timestamp' and contain the time corresponding to each row of data." + ) + + # Check that there are no NaN values or empty entries + if data_df.isna().any().any(): + raise ValueError("Dataset cannot contain NaN or empty entries.") + + # Check that there is no missing data in the form of None type or char values + if not data_df.applymap(lambda x: isinstance(x, (int, float))).all().all(): + raise ValueError( + "Dataset cannot contain None type or char values. All entries must be of type `int` or `float`." + ) + + return data_df + + def print_data_report(data_df): + # Prints a short report about the data + + print( + f"Data printout: This dataset contains {len(data_df) - 1} rows of data. " + f"The first column, {data_df.columns[0]}, begins at {data_df.iloc[0, 0]} " + f"and ends at {data_df.iloc[-1, 0]}. " + f"The subsequent columns are named: " + f"{', '.join(data_df.columns[1:])}" + ) + + df = check_data(path) + print_data_report(df) data_timepoints = torch.tensor(df["Timestamp"].values, dtype=torch.float32) data = {} diff --git a/tests/fixtures.py b/tests/fixtures.py index dcfe789ba..50ab66f0d 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, TypeVar import numpy as np +import pandas as pd import torch from pyciemss.integration_utils.intervention_builder import ( @@ -139,6 +140,40 @@ def __init__( torch.tensor(3), ] # bad candidates for num_samples/num_iterations +bad_data1 = { + "Timestamp": {0: 1.1, 1: 2.2, 2: 3.3}, + "case": {0: 15.0, 1: "", 2: 20.0}, + "hosp": {0: 0.1, 1: 1.0, 2: 2.2}, +} +bad_data2 = { + "Timestamp": {0: 1.1, 1: 2.2, 2: 3.3}, + "case": {0: 15.0, 1: "apple", 2: 20.0}, + "hosp": {0: 0.1, 1: 1.0, 2: 2.2}, +} +bad_data3 = { + "Timestamp": {0: 1.1, 1: 2.2, 2: 3.3}, + "case": {0: 15.0, 1: " ", 2: 20.0}, + "hosp": {0: 0.1, 1: 1.0, 2: 2.2}, +} +bad_data4 = { + "Timestamp": {0: 1.1, 1: 2.2, 2: 3.3}, + "case": {0: 15.0, 1: None, 2: 20.0}, + "hosp": {0: 0.1, 1: 1.0, 2: 2.2}, +} +bad_data5 = { + "Timepoints": {0: 1.1, 1: 2.2, 2: 3.3}, + "case": {0: 15.0, 1: 18.0, 2: 20.0}, + "hosp": {0: 0.1, 1: 1.0, 2: 2.2}, +} +BADLY_FORMATTED_DATAFRAMES = [ + pd.DataFrame(bad_data1), + pd.DataFrame(bad_data2), + pd.DataFrame(bad_data3), + pd.DataFrame(bad_data4), + pd.DataFrame(bad_data5), +] # improperly formatted datasets +MAPPING_FOR_DATA_TESTS = {"case": "I", "hosp": "H"} + def check_keys_match(obj1: Dict[str, T], obj2: Dict[str, T]): assert set(obj1.keys()) == set(obj2.keys()), "Objects have different variables." diff --git a/tests/test_interfaces.py b/tests/test_interfaces.py index f825ae33e..dc70beb65 100644 --- a/tests/test_interfaces.py +++ b/tests/test_interfaces.py @@ -9,8 +9,10 @@ from pyciemss.interfaces import calibrate, ensemble_sample, optimize, sample from .fixtures import ( + BADLY_FORMATTED_DATAFRAMES, END_TIMES, LOGGING_STEP_SIZES, + MAPPING_FOR_DATA_TESTS, MODEL_URLS, MODELS, NON_POS_INTS, @@ -611,3 +613,14 @@ def test_non_pos_int_sample( logging_step_size, num_samples=bad_num_samples, ) + + +@pytest.mark.parametrize("bad_data", BADLY_FORMATTED_DATAFRAMES) +@pytest.mark.parametrize("data_mapping", MAPPING_FOR_DATA_TESTS) +def test_load_data(bad_data, data_mapping): + # Assert that a ValueError is raised for improperly formatted data + with pytest.raises(ValueError): + load_data( + bad_data, + data_mapping, + )