From 9a86f64aaa8c08550a629e25983d88da1ac09508 Mon Sep 17 00:00:00 2001 From: Gaurav Sheni Date: Sat, 8 Jul 2023 17:38:33 -0400 Subject: [PATCH] Add additional unit tests for _check_operations_valid (#93) * fix badge * cleanup files: * cleanup files: * add more tests * finish up ops tests * finish up ops tests * update changelog --- Examples/youtube/meta_youtube.json | 41 ------- README.md | 2 +- docs/changelog.md | 2 + tests/integration_tests/meta_chicago.json | 55 --------- tests/integration_tests/meta_covid.json | 13 --- tests/integration_tests/meta_youtube.json | 41 ------- tests/{utils => }/test_data_parser.py | 0 tests/test_prediction_problem.py | 28 +++-- tests/test_utils.py | 123 ++++++++++++++++++++- tests/{utils => typing}/__init__.py | 0 tests/{ => typing}/test_column_schema.py | 0 tests/{ => typing}/test_inference.py | 0 tests/{ => typing}/test_logical_types.py | 0 trane/core/prediction_problem_evaluator.py | 2 - trane/core/prediction_problem_generator.py | 11 +- trane/{utils => typing}/1-1000.txt | 0 trane/typing/inference.py | 6 +- trane/typing/inference_functions.py | 2 +- 18 files changed, 154 insertions(+), 172 deletions(-) delete mode 100755 Examples/youtube/meta_youtube.json delete mode 100755 tests/integration_tests/meta_chicago.json delete mode 100644 tests/integration_tests/meta_covid.json delete mode 100755 tests/integration_tests/meta_youtube.json rename tests/{utils => }/test_data_parser.py (100%) rename tests/{utils => typing}/__init__.py (100%) rename tests/{ => typing}/test_column_schema.py (100%) rename tests/{ => typing}/test_inference.py (100%) rename tests/{ => typing}/test_logical_types.py (100%) rename trane/{utils => typing}/1-1000.txt (100%) diff --git a/Examples/youtube/meta_youtube.json b/Examples/youtube/meta_youtube.json deleted file mode 100755 index b8fc7cee..00000000 --- a/Examples/youtube/meta_youtube.json +++ /dev/null @@ -1,41 +0,0 @@ -{ - "tables": [ - { - "fields": [ - { - "name": "trending_date", - "type": "time" - }, - { - "name": "channel_title", - "type": "id" - }, - { - "name": "category_id", - "subtype": "categorical", - "type": "categorical" - }, - { - "name": "views", - "subtype": "integer", - "type": "number" - }, - { - "name": "likes", - "subtype": "integer", - "type": "number" - }, - { - "name": "dislikes", - "subtype": "integer", - "type": "number" - }, - { - "name": "comment_count", - "subtype": "integer", - "type": "number" - } - ] - } - ] -} diff --git a/README.md b/README.md index 64aee096..ef5e02ca 100755 --- a/README.md +++ b/README.md @@ -7,7 +7,7 @@ Tests - + PyPI Version diff --git a/docs/changelog.md b/docs/changelog.md index bd51af72..145f37b6 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -8,6 +8,7 @@ What’s new in 0.4.0 (X, 2023) * Add detailed walkthroughts in Examples directory [#60][#60] * Add code coverage analysis from Codecov [#77][#77] * Clean up input and output for operations [#87][#87] + * Add additional unit tests for _check_operations_valid [#93][#93] * Fixes * Remove TableMeta class and replace with ColumnSchema [#83][#83] [#85][#85] @@ -16,6 +17,7 @@ What’s new in 0.4.0 (X, 2023) [#83]: [#85]: [#87]: + [#93]: What’s new in 0.3.0 (February 24, 2023) ======================================= diff --git a/tests/integration_tests/meta_chicago.json b/tests/integration_tests/meta_chicago.json deleted file mode 100755 index 0dc4f008..00000000 --- a/tests/integration_tests/meta_chicago.json +++ /dev/null @@ -1,55 +0,0 @@ -{ - "tables": [ - { - "fields": [ - { - "name": "date", - "type": "time" - }, - { - "name": "hour", - "subtype": "categorical", - "type": "categorical" - }, - { - "name": "usertype", - "subtype": "categorical", - "type": "categorical" - }, - { - "name": "gender", - "subtype": "categorical", - "type": "categorical" - }, - { - "name": "tripduration", - "subtype": "float", - "type": "number" - }, - { - "name": "temperature", - "subtype": "float", - "type": "number" - }, - { - "name": "from_station_id", - "type": "id" - }, - { - "name": "dpcapacity_start", - "subtype": "integer", - "type": "number" - }, - { - "name": "to_station_id", - "type": "id" - }, - { - "name": "dpcapacity_end", - "subtype": "integer", - "type": "number" - } - ] - } - ] -} diff --git a/tests/integration_tests/meta_covid.json b/tests/integration_tests/meta_covid.json deleted file mode 100644 index 3972cd16..00000000 --- a/tests/integration_tests/meta_covid.json +++ /dev/null @@ -1,13 +0,0 @@ -{"tables":[ - {"fields":[ - {"name": "Province/State", "type": "text"}, - {"name": "Country/Region", "type": "text"}, - {"name": "Lat", "type": "number", "subtype": "float"}, - {"name": "Long", "type": "number", "subtype": "float"}, - {"name": "Date", "type": "datetime"}, - {"name": "Confirmed", "type": "number", "subtype": "integer"}, - {"name": "Deaths", "type": "number", "subtype": "integer"}, - {"name": "Recovered", "type": "number", "subtype": "integer"} - ] - } -]} diff --git a/tests/integration_tests/meta_youtube.json b/tests/integration_tests/meta_youtube.json deleted file mode 100755 index b8fc7cee..00000000 --- a/tests/integration_tests/meta_youtube.json +++ /dev/null @@ -1,41 +0,0 @@ -{ - "tables": [ - { - "fields": [ - { - "name": "trending_date", - "type": "time" - }, - { - "name": "channel_title", - "type": "id" - }, - { - "name": "category_id", - "subtype": "categorical", - "type": "categorical" - }, - { - "name": "views", - "subtype": "integer", - "type": "number" - }, - { - "name": "likes", - "subtype": "integer", - "type": "number" - }, - { - "name": "dislikes", - "subtype": "integer", - "type": "number" - }, - { - "name": "comment_count", - "subtype": "integer", - "type": "number" - } - ] - } - ] -} diff --git a/tests/utils/test_data_parser.py b/tests/test_data_parser.py similarity index 100% rename from tests/utils/test_data_parser.py rename to tests/test_data_parser.py diff --git a/tests/test_prediction_problem.py b/tests/test_prediction_problem.py index d5756a22..38e3aac3 100644 --- a/tests/test_prediction_problem.py +++ b/tests/test_prediction_problem.py @@ -20,7 +20,7 @@ def make_fake_df(): datetime(2023, 1, 5), ], "state": ["MA", "NY", "NY", "NJ", "NJ", "CT"], - "amount": [10, 20, 30, 40, 50, 60], + "amount": [10.0, 20.0, 30.0, 40.0, 50.0, 60.0], } df = pd.DataFrame(data) df["date"] = pd.to_datetime(df["date"]) @@ -31,7 +31,7 @@ def make_fake_df(): @pytest.fixture() def make_fake_meta(): meta = { - "id": ("Integer", {"index"}), + "id": ("Integer", {"numeric", "index"}), "date": ("Datetime", {}), "state": ("Categorical", {"category"}), "amount": ("Double", {"numeric"}), @@ -87,13 +87,9 @@ def min_column(data_slice, column, **kwargs): return data_slice[column].min() -def test_prediction_problem(make_fake_df, make_fake_meta): - df = make_fake_df - meta = make_fake_meta - for column in df.columns: - assert column in meta +@pytest.fixture() +def make_cutoff_strategy(): entity_col = "id" - time_col = "date" window_size = "2d" minimum_data = "2023-01-01" maximum_data = "2023-01-05" @@ -103,6 +99,18 @@ def test_prediction_problem(make_fake_df, make_fake_meta): minimum_data=minimum_data, maximum_data=maximum_data, ) + return cutoff_strategy + + +def test_prediction_problem(make_fake_df, make_fake_meta, make_cutoff_strategy): + entity_col = "id" + time_col = "date" + df = make_fake_df + meta = make_fake_meta + for column in df.columns: + assert column in meta + cutoff_strategy = make_cutoff_strategy + problem_generator = trane.PredictionProblemGenerator( df=df, table_meta=meta, @@ -112,6 +120,10 @@ def test_prediction_problem(make_fake_df, make_fake_meta): ) problems = problem_generator.generate(df, generate_thresholds=True) + verify_problems(problems, df, cutoff_strategy) + + +def verify_problems(problems, df, cutoff_strategy): problems_verified = 0 # bad integration testing # not ideal but okay to test for now diff --git a/tests/test_utils.py b/tests/test_utils.py index fb8faf1c..d8c49954 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,9 +2,28 @@ import pandas as pd -from trane.core.utils import _check_operations_valid, _parse_table_meta, clean_date -from trane.ops.aggregation_ops import CountAggregationOp -from trane.ops.filter_ops import AllFilterOp, EqFilterOp, GreaterFilterOp +from trane.core.utils import ( + _check_operations_valid, + _parse_table_meta, + clean_date, +) +from trane.ops.aggregation_ops import ( + AvgAggregationOp, + CountAggregationOp, + MajorityAggregationOp, + MaxAggregationOp, + MinAggregationOp, + SumAggregationOp, +) +from trane.ops.filter_ops import ( + AllFilterOp, + EqFilterOp, + GreaterFilterOp, + LessFilterOp, + NeqFilterOp, +) +from trane.typing.column_schema import ColumnSchema +from trane.typing.logical_types import Double def test_parse_table_simple(): @@ -22,19 +41,97 @@ def test_parse_table_simple(): def test_parse_table_numeric(): table_meta = { "id": ("Categorical", {"index", "category"}), - "amount": ("Double", {"numeric"}), + "amount": ("Integer", {"numeric"}), } table_meta = _parse_table_meta(table_meta) + + # For each predict the number of records with equal to + # Technically could be a valid operation, but we don't support it yet + # For categorical columns it makes sense (see below) operations = [EqFilterOp("amount"), CountAggregationOp(None)] result, modified_meta = _check_operations_valid(operations, table_meta) assert result is False assert len(modified_meta) == 0 + # For each predict the number of records with greater than operations = [GreaterFilterOp("amount"), CountAggregationOp(None)] result, modified_meta = _check_operations_valid(operations, table_meta) - assert result is True + verify_numeric_op(modified_meta, result) + + # For each predict the number of records with less than + operations = [LessFilterOp("amount"), CountAggregationOp(None)] + result, modified_meta = _check_operations_valid(operations, table_meta) + verify_numeric_op(modified_meta, result) + + # For each predict the total in all related records + operations = [AllFilterOp("amount"), SumAggregationOp("amount")] + result, modified_meta = _check_operations_valid(operations, table_meta) + verify_numeric_op(modified_meta, result) assert modified_meta["id"] == table_meta["id"] + # For each predict the total in all related records with greater than + operations = [GreaterFilterOp("amount"), SumAggregationOp("amount")] + result, modified_meta = _check_operations_valid(operations, table_meta) + verify_numeric_op(modified_meta, result) + + # For each predict the total in all related records with less than + operations = [LessFilterOp("amount"), SumAggregationOp("amount")] + result, modified_meta = _check_operations_valid(operations, table_meta) + verify_numeric_op(modified_meta, result) + + # For each predict the average in all related records + operations = [AllFilterOp("amount"), AvgAggregationOp("amount")] + result, modified_meta = _check_operations_valid(operations, table_meta) + verify_numeric_op(modified_meta, result) + + # For each predict the average in all related records with greater than + operations = [GreaterFilterOp("amount"), AvgAggregationOp("amount")] + result, modified_meta = _check_operations_valid(operations, table_meta) + verify_numeric_op(modified_meta, result) + + # For each predict the average in all related records with less than + operations = [LessFilterOp("amount"), AvgAggregationOp("amount")] + result, modified_meta = _check_operations_valid(operations, table_meta) + verify_numeric_op(modified_meta, result) + + # For each predict the maximum in all related records + operations = [AllFilterOp("amount"), MaxAggregationOp("amount")] + result, modified_meta = _check_operations_valid(operations, table_meta) + verify_numeric_op(modified_meta, result) + + # For each predict the maximum in all related records with greater than + operations = [GreaterFilterOp("amount"), MaxAggregationOp("amount")] + result, modified_meta = _check_operations_valid(operations, table_meta) + verify_numeric_op(modified_meta, result) + + # For each predict the maximum in all related records with less than + operations = [LessFilterOp("amount"), MaxAggregationOp("amount")] + result, modified_meta = _check_operations_valid(operations, table_meta) + verify_numeric_op(modified_meta, result) + + # For each predict the minimum in all related records + operations = [AllFilterOp("amount"), MinAggregationOp("amount")] + result, modified_meta = _check_operations_valid(operations, table_meta) + verify_numeric_op(modified_meta, result) + + # For each predict the minimum in all related records with greater than + operations = [GreaterFilterOp("amount"), MinAggregationOp("amount")] + result, modified_meta = _check_operations_valid(operations, table_meta) + verify_numeric_op(modified_meta, result) + + # For each predict the minimum in all related records with less than + operations = [LessFilterOp("amount"), MinAggregationOp("amount")] + result, modified_meta = _check_operations_valid(operations, table_meta) + verify_numeric_op(modified_meta, result) + + +def verify_numeric_op(modified_meta, result): + assert result is True + assert modified_meta["amount"] == ColumnSchema( + logical_type=Double, + semantic_tags={"numeric"}, + ) + def test_parse_table_cat(): table_meta = { @@ -42,10 +139,26 @@ def test_parse_table_cat(): "state": ("Categorical", {"category"}), } table_meta = _parse_table_meta(table_meta) + + # For each predict the number of records with equal to operations = [EqFilterOp("state"), CountAggregationOp(None)] result, modified_meta = _check_operations_valid(operations, table_meta) assert result is True + # For each predict the number of records with not equal to + operations = [NeqFilterOp("state"), CountAggregationOp(None)] + result, modified_meta = _check_operations_valid(operations, table_meta) + assert result is True + + # For each predict the majority in all related records with equal to NY in next 2d days + operations = [EqFilterOp("state"), MajorityAggregationOp("state")] + result, modified_meta = _check_operations_valid(operations, table_meta) + assert result is True + + operations = [AllFilterOp(None), SumAggregationOp("state")] + result, modified_meta = _check_operations_valid(operations, table_meta) + assert result is False + def test_clean_date(): assert clean_date("2019-01-01") == pd.Timestamp( diff --git a/tests/utils/__init__.py b/tests/typing/__init__.py similarity index 100% rename from tests/utils/__init__.py rename to tests/typing/__init__.py diff --git a/tests/test_column_schema.py b/tests/typing/test_column_schema.py similarity index 100% rename from tests/test_column_schema.py rename to tests/typing/test_column_schema.py diff --git a/tests/test_inference.py b/tests/typing/test_inference.py similarity index 100% rename from tests/test_inference.py rename to tests/typing/test_inference.py diff --git a/tests/test_logical_types.py b/tests/typing/test_logical_types.py similarity index 100% rename from tests/test_logical_types.py rename to tests/typing/test_logical_types.py diff --git a/trane/core/prediction_problem_evaluator.py b/trane/core/prediction_problem_evaluator.py index c14da540..6ceb70a4 100755 --- a/trane/core/prediction_problem_evaluator.py +++ b/trane/core/prediction_problem_evaluator.py @@ -7,8 +7,6 @@ from sklearn.preprocessing import OneHotEncoder from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor -__all__ = ["PredictionProblemEvaluator"] - class PredictionProblemEvaluator(object): """docstring for PredictionProblemEvaluator.""" diff --git a/trane/core/prediction_problem_generator.py b/trane/core/prediction_problem_generator.py index 2ada07c4..2e38959e 100755 --- a/trane/core/prediction_problem_generator.py +++ b/trane/core/prediction_problem_generator.py @@ -44,14 +44,16 @@ def __init__(self, df, entity_col, time_col, table_meta=None, cutoff_strategy=No self.time_col = time_col self.cutoff_strategy = cutoff_strategy + inferred_table_meta = False if table_meta is None: self.table_meta = infer_table_meta(df) + inferred_table_meta = True else: self.table_meta = _parse_table_meta(table_meta) self.transform_data() - self.ensure_valid_inputs() + self.ensure_valid_inputs(inferred_table_meta) - def ensure_valid_inputs(self): + def ensure_valid_inputs(self, inferred_table_meta=False): """ TypeChecking for the problem generator entity_col and label_col. Errors if types don't match up. @@ -64,7 +66,10 @@ def ensure_valid_inputs(self): entity_col_type = self.table_meta[self.entity_col] assert entity_col_type.logical_type in [Integer, Categorical] - assert "index" in entity_col_type.semantic_tags + if inferred_table_meta is False: + assert "index" in entity_col_type.semantic_tags + else: + self.table_meta[self.entity_col].semantic_tags.add("index") time_col_type = self.table_meta[self.time_col] assert time_col_type.logical_type == Datetime diff --git a/trane/utils/1-1000.txt b/trane/typing/1-1000.txt similarity index 100% rename from trane/utils/1-1000.txt rename to trane/typing/1-1000.txt diff --git a/trane/typing/inference.py b/trane/typing/inference.py index ada6e26e..81ff8c94 100644 --- a/trane/typing/inference.py +++ b/trane/typing/inference.py @@ -1,3 +1,5 @@ +from typing import Dict + import pandas as pd from trane.typing.column_schema import ColumnSchema @@ -20,7 +22,7 @@ ) -def _infer_series_schema(series): +def _infer_series_schema(series: pd.Series) -> ColumnSchema: inference_functions = { boolean_func: ColumnSchema(logical_type=Boolean), categorical_func: ColumnSchema( @@ -38,7 +40,7 @@ def _infer_series_schema(series): return ColumnSchema(logical_type=Unknown) -def infer_table_meta(df: pd.DataFrame): +def infer_table_meta(df: pd.DataFrame) -> Dict[str, ColumnSchema]: table_meta = {} for col in df.columns: column_schema = _infer_series_schema(df[col]) diff --git a/trane/typing/inference_functions.py b/trane/typing/inference_functions.py index 8fc2a233..5b4f65ed 100644 --- a/trane/typing/inference_functions.py +++ b/trane/typing/inference_functions.py @@ -13,7 +13,7 @@ COMMON_WORDS_SET = set( word.strip().lower() - for word in files("trane.utils").joinpath("1-1000.txt").read_text().split("\n") + for word in files("trane.typing").joinpath("1-1000.txt").read_text().split("\n") if len(word) > 0 )