Skip to content

Commit

Permalink
adding testing check_data function to load_data (#458)
Browse files Browse the repository at this point in the history
* adding testing check_data function to load_data

* reformatting

* fixing data checker

* adding test for load_data

* Update observation.py

* Update observation.py

* Update observation.py

Cleaning up data report
  • Loading branch information
sabinala authored Mar 6, 2024
1 parent e47b4ed commit d6838e7
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 5 deletions.
54 changes: 49 additions & 5 deletions pyciemss/integration_utils/observation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Tuple
from typing import Dict, Tuple, Union

import pandas as pd
import torch
Expand All @@ -9,19 +9,63 @@


def load_data(
path: str, data_mapping: Dict[str, str] = {}
path: Union[str, pd.DataFrame], data_mapping: Dict[str, str] = {}
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Load data from a CSV file.
Load data from a CSV file, or directly from a DataFrame.
- path: path to the CSV file
- path: path to the CSV file, or DataFrame
- data_mapping: A mapping from column names in the data file to state variable names in the model.
- keys: str name of column in dataset
- values: str name of state/observable in model
- If not provided, we will assume that the column names in the data file match the state variable names.
"""

df = pd.read_csv(path)
def check_data(data_path: Union[str, pd.DataFrame]):
# This function checks a dataset for formatting errors, and returns a DataFrame

# Read the data
if isinstance(data_path, str):
# If data_path is a string, assume it's a file path and read as a csv
data_df = pd.read_csv(data_path)
elif isinstance(data_path, pd.DataFrame):
# If data_path is a DataFrame, use it directly
data_df = data_path
else:
# If data_path is neither a string nor a DataFrame, raise an error
raise ValueError("data_path must be either a file path or a DataFrame.")

# Check that the first column name is "Timestamp"
if data_df.columns[0] != "Timestamp":
raise ValueError(
"The first column must be named 'Timestamp' and contain the time corresponding to each row of data."
)

# Check that there are no NaN values or empty entries
if data_df.isna().any().any():
raise ValueError("Dataset cannot contain NaN or empty entries.")

# Check that there is no missing data in the form of None type or char values
if not data_df.applymap(lambda x: isinstance(x, (int, float))).all().all():
raise ValueError(
"Dataset cannot contain None type or char values. All entries must be of type `int` or `float`."
)

return data_df

def print_data_report(data_df):
# Prints a short report about the data

print(
f"Data printout: This dataset contains {len(data_df) - 1} rows of data. "
f"The first column, {data_df.columns[0]}, begins at {data_df.iloc[0, 0]} "
f"and ends at {data_df.iloc[-1, 0]}. "
f"The subsequent columns are named: "
f"{', '.join(data_df.columns[1:])}"
)

df = check_data(path)
print_data_report(df)

data_timepoints = torch.tensor(df["Timestamp"].values, dtype=torch.float32)
data = {}
Expand Down
35 changes: 35 additions & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, Optional, TypeVar

import numpy as np
import pandas as pd
import torch

from pyciemss.integration_utils.intervention_builder import (
Expand Down Expand Up @@ -139,6 +140,40 @@ def __init__(
torch.tensor(3),
] # bad candidates for num_samples/num_iterations

bad_data1 = {
"Timestamp": {0: 1.1, 1: 2.2, 2: 3.3},
"case": {0: 15.0, 1: "", 2: 20.0},
"hosp": {0: 0.1, 1: 1.0, 2: 2.2},
}
bad_data2 = {
"Timestamp": {0: 1.1, 1: 2.2, 2: 3.3},
"case": {0: 15.0, 1: "apple", 2: 20.0},
"hosp": {0: 0.1, 1: 1.0, 2: 2.2},
}
bad_data3 = {
"Timestamp": {0: 1.1, 1: 2.2, 2: 3.3},
"case": {0: 15.0, 1: " ", 2: 20.0},
"hosp": {0: 0.1, 1: 1.0, 2: 2.2},
}
bad_data4 = {
"Timestamp": {0: 1.1, 1: 2.2, 2: 3.3},
"case": {0: 15.0, 1: None, 2: 20.0},
"hosp": {0: 0.1, 1: 1.0, 2: 2.2},
}
bad_data5 = {
"Timepoints": {0: 1.1, 1: 2.2, 2: 3.3},
"case": {0: 15.0, 1: 18.0, 2: 20.0},
"hosp": {0: 0.1, 1: 1.0, 2: 2.2},
}
BADLY_FORMATTED_DATAFRAMES = [
pd.DataFrame(bad_data1),
pd.DataFrame(bad_data2),
pd.DataFrame(bad_data3),
pd.DataFrame(bad_data4),
pd.DataFrame(bad_data5),
] # improperly formatted datasets
MAPPING_FOR_DATA_TESTS = {"case": "I", "hosp": "H"}


def check_keys_match(obj1: Dict[str, T], obj2: Dict[str, T]):
assert set(obj1.keys()) == set(obj2.keys()), "Objects have different variables."
Expand Down
13 changes: 13 additions & 0 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from pyciemss.interfaces import calibrate, ensemble_sample, optimize, sample

from .fixtures import (
BADLY_FORMATTED_DATAFRAMES,
END_TIMES,
LOGGING_STEP_SIZES,
MAPPING_FOR_DATA_TESTS,
MODEL_URLS,
MODELS,
NON_POS_INTS,
Expand Down Expand Up @@ -611,3 +613,14 @@ def test_non_pos_int_sample(
logging_step_size,
num_samples=bad_num_samples,
)


@pytest.mark.parametrize("bad_data", BADLY_FORMATTED_DATAFRAMES)
@pytest.mark.parametrize("data_mapping", MAPPING_FOR_DATA_TESTS)
def test_load_data(bad_data, data_mapping):
# Assert that a ValueError is raised for improperly formatted data
with pytest.raises(ValueError):
load_data(
bad_data,
data_mapping,
)

0 comments on commit d6838e7

Please sign in to comment.