Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding testing check_data function to load_data #458

Merged
merged 12 commits into from
Mar 6, 2024
54 changes: 49 additions & 5 deletions pyciemss/integration_utils/observation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, Tuple
from typing import Dict, Tuple, Union

import pandas as pd
import torch
Expand All @@ -9,19 +9,63 @@


def load_data(
path: str, data_mapping: Dict[str, str] = {}
path: Union[str, pd.DataFrame], data_mapping: Dict[str, str] = {}
) -> Tuple[torch.Tensor, Dict[str, torch.Tensor]]:
"""
Load data from a CSV file.
Load data from a CSV file, or directly from a DataFrame.

- path: path to the CSV file
- path: path to the CSV file, or DataFrame
- data_mapping: A mapping from column names in the data file to state variable names in the model.
- keys: str name of column in dataset
- values: str name of state/observable in model
- If not provided, we will assume that the column names in the data file match the state variable names.
"""

df = pd.read_csv(path)
def check_data(data_path: Union[str, pd.DataFrame]):
# This function checks a dataset for formatting errors, and returns a DataFrame

# Read the data
if isinstance(data_path, str):
# If data_path is a string, assume it's a file path and read as a csv
data_df = pd.read_csv(data_path)
elif isinstance(data_path, pd.DataFrame):
# If data_path is a DataFrame, use it directly
data_df = data_path
else:
# If data_path is neither a string nor a DataFrame, raise an error
raise ValueError("data_path must be either a file path or a DataFrame.")

# Check that the first column name is "Timestamp"
if data_df.columns[0] != "Timestamp":
raise ValueError(
"The first column must be named 'Timestamp' and contain the time corresponding to each row of data."
)

# Check that there are no NaN values or empty entries
if data_df.isna().any().any():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this a constraint that we want to impose? It might be better if we could handle ragged data, yes?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@djinnome It would be better, but I'm not sure where to go with that. How would that propagate to calibrate? I think it's probably best to throw an error message for now, and create a new issue to handle ragged data in the future.

raise ValueError("Dataset cannot contain NaN or empty entries.")

# Check that there is no missing data in the form of None type or char values
if not data_df.applymap(lambda x: isinstance(x, (int, float))).all().all():
raise ValueError(
"Dataset cannot contain None type or char values. All entries must be of type `int` or `float`."
)

return data_df

def print_data_report(data_df):
# Prints a short report about the data

print(
f"Data printout: This dataset contains {len(data_df) - 1} rows of data. "
f"The first column, {data_df.columns[0]}, begins at {data_df.iloc[0, 0]} "
f"and ends at {data_df.iloc[-1, 0]}. "
f"The subsequent columns are named: "
f"{', '.join(data_df.columns[1:])}"
)

df = check_data(path)
print_data_report(df)

data_timepoints = torch.tensor(df["Timestamp"].values, dtype=torch.float32)
data = {}
Expand Down
35 changes: 35 additions & 0 deletions tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from typing import Any, Dict, Optional, TypeVar

import numpy as np
import pandas as pd
import torch

from pyciemss.ouu.qoi import obs_nday_average_qoi
Expand Down Expand Up @@ -107,6 +108,40 @@ def __init__(
torch.tensor(3),
] # bad candidates for num_samples/num_iterations

bad_data1 = {
"Timestamp": {0: 1.1, 1: 2.2, 2: 3.3},
"case": {0: 15.0, 1: "", 2: 20.0},
"hosp": {0: 0.1, 1: 1.0, 2: 2.2},
}
bad_data2 = {
"Timestamp": {0: 1.1, 1: 2.2, 2: 3.3},
"case": {0: 15.0, 1: "apple", 2: 20.0},
"hosp": {0: 0.1, 1: 1.0, 2: 2.2},
}
bad_data3 = {
"Timestamp": {0: 1.1, 1: 2.2, 2: 3.3},
"case": {0: 15.0, 1: " ", 2: 20.0},
"hosp": {0: 0.1, 1: 1.0, 2: 2.2},
}
bad_data4 = {
"Timestamp": {0: 1.1, 1: 2.2, 2: 3.3},
"case": {0: 15.0, 1: None, 2: 20.0},
"hosp": {0: 0.1, 1: 1.0, 2: 2.2},
}
bad_data5 = {
"Timepoints": {0: 1.1, 1: 2.2, 2: 3.3},
"case": {0: 15.0, 1: 18.0, 2: 20.0},
"hosp": {0: 0.1, 1: 1.0, 2: 2.2},
}
BADLY_FORMATTED_DATAFRAMES = [
pd.DataFrame(bad_data1),
pd.DataFrame(bad_data2),
pd.DataFrame(bad_data3),
pd.DataFrame(bad_data4),
pd.DataFrame(bad_data5),
] # improperly formatted datasets
MAPPING_FOR_DATA_TESTS = {"case": "I", "hosp": "H"}


def check_keys_match(obj1: Dict[str, T], obj2: Dict[str, T]):
assert set(obj1.keys()) == set(obj2.keys()), "Objects have different variables."
Expand Down
13 changes: 13 additions & 0 deletions tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,10 @@
from pyciemss.interfaces import calibrate, ensemble_sample, optimize, sample

from .fixtures import (
BADLY_FORMATTED_DATAFRAMES,
END_TIMES,
LOGGING_STEP_SIZES,
MAPPING_FOR_DATA_TESTS,
MODEL_URLS,
MODELS,
NON_POS_INTS,
Expand Down Expand Up @@ -516,3 +518,14 @@ def test_non_pos_int_sample(
logging_step_size,
num_samples=bad_num_samples,
)


@pytest.mark.parametrize("bad_data", BADLY_FORMATTED_DATAFRAMES)
@pytest.mark.parametrize("data_mapping", MAPPING_FOR_DATA_TESTS)
def test_load_data(bad_data, data_mapping):
# Assert that a ValueError is raised for improperly formatted data
with pytest.raises(ValueError):
load_data(
bad_data,
data_mapping,
)
Loading