diff --git a/Examples/youtube/meta_youtube.json b/Examples/youtube/meta_youtube.json
deleted file mode 100755
index b8fc7cee..00000000
--- a/Examples/youtube/meta_youtube.json
+++ /dev/null
@@ -1,41 +0,0 @@
-{
- "tables": [
- {
- "fields": [
- {
- "name": "trending_date",
- "type": "time"
- },
- {
- "name": "channel_title",
- "type": "id"
- },
- {
- "name": "category_id",
- "subtype": "categorical",
- "type": "categorical"
- },
- {
- "name": "views",
- "subtype": "integer",
- "type": "number"
- },
- {
- "name": "likes",
- "subtype": "integer",
- "type": "number"
- },
- {
- "name": "dislikes",
- "subtype": "integer",
- "type": "number"
- },
- {
- "name": "comment_count",
- "subtype": "integer",
- "type": "number"
- }
- ]
- }
- ]
-}
diff --git a/README.md b/README.md
index 64aee096..ef5e02ca 100755
--- a/README.md
+++ b/README.md
@@ -7,7 +7,7 @@
-
+
diff --git a/docs/changelog.md b/docs/changelog.md
index bd51af72..145f37b6 100644
--- a/docs/changelog.md
+++ b/docs/changelog.md
@@ -8,6 +8,7 @@ What’s new in 0.4.0 (X, 2023)
* Add detailed walkthroughts in Examples directory [#60][#60]
* Add code coverage analysis from Codecov [#77][#77]
* Clean up input and output for operations [#87][#87]
+ * Add additional unit tests for _check_operations_valid [#93][#93]
* Fixes
* Remove TableMeta class and replace with ColumnSchema [#83][#83] [#85][#85]
@@ -16,6 +17,7 @@ What’s new in 0.4.0 (X, 2023)
[#83]:
[#85]:
[#87]:
+ [#93]:
What’s new in 0.3.0 (February 24, 2023)
=======================================
diff --git a/tests/integration_tests/meta_chicago.json b/tests/integration_tests/meta_chicago.json
deleted file mode 100755
index 0dc4f008..00000000
--- a/tests/integration_tests/meta_chicago.json
+++ /dev/null
@@ -1,55 +0,0 @@
-{
- "tables": [
- {
- "fields": [
- {
- "name": "date",
- "type": "time"
- },
- {
- "name": "hour",
- "subtype": "categorical",
- "type": "categorical"
- },
- {
- "name": "usertype",
- "subtype": "categorical",
- "type": "categorical"
- },
- {
- "name": "gender",
- "subtype": "categorical",
- "type": "categorical"
- },
- {
- "name": "tripduration",
- "subtype": "float",
- "type": "number"
- },
- {
- "name": "temperature",
- "subtype": "float",
- "type": "number"
- },
- {
- "name": "from_station_id",
- "type": "id"
- },
- {
- "name": "dpcapacity_start",
- "subtype": "integer",
- "type": "number"
- },
- {
- "name": "to_station_id",
- "type": "id"
- },
- {
- "name": "dpcapacity_end",
- "subtype": "integer",
- "type": "number"
- }
- ]
- }
- ]
-}
diff --git a/tests/integration_tests/meta_covid.json b/tests/integration_tests/meta_covid.json
deleted file mode 100644
index 3972cd16..00000000
--- a/tests/integration_tests/meta_covid.json
+++ /dev/null
@@ -1,13 +0,0 @@
-{"tables":[
- {"fields":[
- {"name": "Province/State", "type": "text"},
- {"name": "Country/Region", "type": "text"},
- {"name": "Lat", "type": "number", "subtype": "float"},
- {"name": "Long", "type": "number", "subtype": "float"},
- {"name": "Date", "type": "datetime"},
- {"name": "Confirmed", "type": "number", "subtype": "integer"},
- {"name": "Deaths", "type": "number", "subtype": "integer"},
- {"name": "Recovered", "type": "number", "subtype": "integer"}
- ]
- }
-]}
diff --git a/tests/integration_tests/meta_youtube.json b/tests/integration_tests/meta_youtube.json
deleted file mode 100755
index b8fc7cee..00000000
--- a/tests/integration_tests/meta_youtube.json
+++ /dev/null
@@ -1,41 +0,0 @@
-{
- "tables": [
- {
- "fields": [
- {
- "name": "trending_date",
- "type": "time"
- },
- {
- "name": "channel_title",
- "type": "id"
- },
- {
- "name": "category_id",
- "subtype": "categorical",
- "type": "categorical"
- },
- {
- "name": "views",
- "subtype": "integer",
- "type": "number"
- },
- {
- "name": "likes",
- "subtype": "integer",
- "type": "number"
- },
- {
- "name": "dislikes",
- "subtype": "integer",
- "type": "number"
- },
- {
- "name": "comment_count",
- "subtype": "integer",
- "type": "number"
- }
- ]
- }
- ]
-}
diff --git a/tests/utils/test_data_parser.py b/tests/test_data_parser.py
similarity index 100%
rename from tests/utils/test_data_parser.py
rename to tests/test_data_parser.py
diff --git a/tests/test_prediction_problem.py b/tests/test_prediction_problem.py
index d5756a22..38e3aac3 100644
--- a/tests/test_prediction_problem.py
+++ b/tests/test_prediction_problem.py
@@ -20,7 +20,7 @@ def make_fake_df():
datetime(2023, 1, 5),
],
"state": ["MA", "NY", "NY", "NJ", "NJ", "CT"],
- "amount": [10, 20, 30, 40, 50, 60],
+ "amount": [10.0, 20.0, 30.0, 40.0, 50.0, 60.0],
}
df = pd.DataFrame(data)
df["date"] = pd.to_datetime(df["date"])
@@ -31,7 +31,7 @@ def make_fake_df():
@pytest.fixture()
def make_fake_meta():
meta = {
- "id": ("Integer", {"index"}),
+ "id": ("Integer", {"numeric", "index"}),
"date": ("Datetime", {}),
"state": ("Categorical", {"category"}),
"amount": ("Double", {"numeric"}),
@@ -87,13 +87,9 @@ def min_column(data_slice, column, **kwargs):
return data_slice[column].min()
-def test_prediction_problem(make_fake_df, make_fake_meta):
- df = make_fake_df
- meta = make_fake_meta
- for column in df.columns:
- assert column in meta
+@pytest.fixture()
+def make_cutoff_strategy():
entity_col = "id"
- time_col = "date"
window_size = "2d"
minimum_data = "2023-01-01"
maximum_data = "2023-01-05"
@@ -103,6 +99,18 @@ def test_prediction_problem(make_fake_df, make_fake_meta):
minimum_data=minimum_data,
maximum_data=maximum_data,
)
+ return cutoff_strategy
+
+
+def test_prediction_problem(make_fake_df, make_fake_meta, make_cutoff_strategy):
+ entity_col = "id"
+ time_col = "date"
+ df = make_fake_df
+ meta = make_fake_meta
+ for column in df.columns:
+ assert column in meta
+ cutoff_strategy = make_cutoff_strategy
+
problem_generator = trane.PredictionProblemGenerator(
df=df,
table_meta=meta,
@@ -112,6 +120,10 @@ def test_prediction_problem(make_fake_df, make_fake_meta):
)
problems = problem_generator.generate(df, generate_thresholds=True)
+ verify_problems(problems, df, cutoff_strategy)
+
+
+def verify_problems(problems, df, cutoff_strategy):
problems_verified = 0
# bad integration testing
# not ideal but okay to test for now
diff --git a/tests/test_utils.py b/tests/test_utils.py
index fb8faf1c..d8c49954 100644
--- a/tests/test_utils.py
+++ b/tests/test_utils.py
@@ -2,9 +2,28 @@
import pandas as pd
-from trane.core.utils import _check_operations_valid, _parse_table_meta, clean_date
-from trane.ops.aggregation_ops import CountAggregationOp
-from trane.ops.filter_ops import AllFilterOp, EqFilterOp, GreaterFilterOp
+from trane.core.utils import (
+ _check_operations_valid,
+ _parse_table_meta,
+ clean_date,
+)
+from trane.ops.aggregation_ops import (
+ AvgAggregationOp,
+ CountAggregationOp,
+ MajorityAggregationOp,
+ MaxAggregationOp,
+ MinAggregationOp,
+ SumAggregationOp,
+)
+from trane.ops.filter_ops import (
+ AllFilterOp,
+ EqFilterOp,
+ GreaterFilterOp,
+ LessFilterOp,
+ NeqFilterOp,
+)
+from trane.typing.column_schema import ColumnSchema
+from trane.typing.logical_types import Double
def test_parse_table_simple():
@@ -22,19 +41,97 @@ def test_parse_table_simple():
def test_parse_table_numeric():
table_meta = {
"id": ("Categorical", {"index", "category"}),
- "amount": ("Double", {"numeric"}),
+ "amount": ("Integer", {"numeric"}),
}
table_meta = _parse_table_meta(table_meta)
+
+ # For each predict the number of records with equal to
+ # Technically could be a valid operation, but we don't support it yet
+ # For categorical columns it makes sense (see below)
operations = [EqFilterOp("amount"), CountAggregationOp(None)]
result, modified_meta = _check_operations_valid(operations, table_meta)
assert result is False
assert len(modified_meta) == 0
+ # For each predict the number of records with greater than
operations = [GreaterFilterOp("amount"), CountAggregationOp(None)]
result, modified_meta = _check_operations_valid(operations, table_meta)
- assert result is True
+ verify_numeric_op(modified_meta, result)
+
+ # For each predict the number of records with less than
+ operations = [LessFilterOp("amount"), CountAggregationOp(None)]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ verify_numeric_op(modified_meta, result)
+
+ # For each predict the total in all related records
+ operations = [AllFilterOp("amount"), SumAggregationOp("amount")]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ verify_numeric_op(modified_meta, result)
assert modified_meta["id"] == table_meta["id"]
+ # For each predict the total in all related records with greater than
+ operations = [GreaterFilterOp("amount"), SumAggregationOp("amount")]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ verify_numeric_op(modified_meta, result)
+
+ # For each predict the total in all related records with less than
+ operations = [LessFilterOp("amount"), SumAggregationOp("amount")]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ verify_numeric_op(modified_meta, result)
+
+ # For each predict the average in all related records
+ operations = [AllFilterOp("amount"), AvgAggregationOp("amount")]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ verify_numeric_op(modified_meta, result)
+
+ # For each predict the average in all related records with greater than
+ operations = [GreaterFilterOp("amount"), AvgAggregationOp("amount")]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ verify_numeric_op(modified_meta, result)
+
+ # For each predict the average in all related records with less than
+ operations = [LessFilterOp("amount"), AvgAggregationOp("amount")]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ verify_numeric_op(modified_meta, result)
+
+ # For each predict the maximum in all related records
+ operations = [AllFilterOp("amount"), MaxAggregationOp("amount")]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ verify_numeric_op(modified_meta, result)
+
+ # For each predict the maximum in all related records with greater than
+ operations = [GreaterFilterOp("amount"), MaxAggregationOp("amount")]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ verify_numeric_op(modified_meta, result)
+
+ # For each predict the maximum in all related records with less than
+ operations = [LessFilterOp("amount"), MaxAggregationOp("amount")]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ verify_numeric_op(modified_meta, result)
+
+ # For each predict the minimum in all related records
+ operations = [AllFilterOp("amount"), MinAggregationOp("amount")]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ verify_numeric_op(modified_meta, result)
+
+ # For each predict the minimum in all related records with greater than
+ operations = [GreaterFilterOp("amount"), MinAggregationOp("amount")]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ verify_numeric_op(modified_meta, result)
+
+ # For each predict the minimum in all related records with less than
+ operations = [LessFilterOp("amount"), MinAggregationOp("amount")]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ verify_numeric_op(modified_meta, result)
+
+
+def verify_numeric_op(modified_meta, result):
+ assert result is True
+ assert modified_meta["amount"] == ColumnSchema(
+ logical_type=Double,
+ semantic_tags={"numeric"},
+ )
+
def test_parse_table_cat():
table_meta = {
@@ -42,10 +139,26 @@ def test_parse_table_cat():
"state": ("Categorical", {"category"}),
}
table_meta = _parse_table_meta(table_meta)
+
+ # For each predict the number of records with equal to
operations = [EqFilterOp("state"), CountAggregationOp(None)]
result, modified_meta = _check_operations_valid(operations, table_meta)
assert result is True
+ # For each predict the number of records with not equal to
+ operations = [NeqFilterOp("state"), CountAggregationOp(None)]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ assert result is True
+
+ # For each predict the majority in all related records with equal to NY in next 2d days
+ operations = [EqFilterOp("state"), MajorityAggregationOp("state")]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ assert result is True
+
+ operations = [AllFilterOp(None), SumAggregationOp("state")]
+ result, modified_meta = _check_operations_valid(operations, table_meta)
+ assert result is False
+
def test_clean_date():
assert clean_date("2019-01-01") == pd.Timestamp(
diff --git a/tests/utils/__init__.py b/tests/typing/__init__.py
similarity index 100%
rename from tests/utils/__init__.py
rename to tests/typing/__init__.py
diff --git a/tests/test_column_schema.py b/tests/typing/test_column_schema.py
similarity index 100%
rename from tests/test_column_schema.py
rename to tests/typing/test_column_schema.py
diff --git a/tests/test_inference.py b/tests/typing/test_inference.py
similarity index 100%
rename from tests/test_inference.py
rename to tests/typing/test_inference.py
diff --git a/tests/test_logical_types.py b/tests/typing/test_logical_types.py
similarity index 100%
rename from tests/test_logical_types.py
rename to tests/typing/test_logical_types.py
diff --git a/trane/core/prediction_problem_evaluator.py b/trane/core/prediction_problem_evaluator.py
index c14da540..6ceb70a4 100755
--- a/trane/core/prediction_problem_evaluator.py
+++ b/trane/core/prediction_problem_evaluator.py
@@ -7,8 +7,6 @@
from sklearn.preprocessing import OneHotEncoder
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
-__all__ = ["PredictionProblemEvaluator"]
-
class PredictionProblemEvaluator(object):
"""docstring for PredictionProblemEvaluator."""
diff --git a/trane/core/prediction_problem_generator.py b/trane/core/prediction_problem_generator.py
index 2ada07c4..2e38959e 100755
--- a/trane/core/prediction_problem_generator.py
+++ b/trane/core/prediction_problem_generator.py
@@ -44,14 +44,16 @@ def __init__(self, df, entity_col, time_col, table_meta=None, cutoff_strategy=No
self.time_col = time_col
self.cutoff_strategy = cutoff_strategy
+ inferred_table_meta = False
if table_meta is None:
self.table_meta = infer_table_meta(df)
+ inferred_table_meta = True
else:
self.table_meta = _parse_table_meta(table_meta)
self.transform_data()
- self.ensure_valid_inputs()
+ self.ensure_valid_inputs(inferred_table_meta)
- def ensure_valid_inputs(self):
+ def ensure_valid_inputs(self, inferred_table_meta=False):
"""
TypeChecking for the problem generator entity_col
and label_col. Errors if types don't match up.
@@ -64,7 +66,10 @@ def ensure_valid_inputs(self):
entity_col_type = self.table_meta[self.entity_col]
assert entity_col_type.logical_type in [Integer, Categorical]
- assert "index" in entity_col_type.semantic_tags
+ if inferred_table_meta is False:
+ assert "index" in entity_col_type.semantic_tags
+ else:
+ self.table_meta[self.entity_col].semantic_tags.add("index")
time_col_type = self.table_meta[self.time_col]
assert time_col_type.logical_type == Datetime
diff --git a/trane/utils/1-1000.txt b/trane/typing/1-1000.txt
similarity index 100%
rename from trane/utils/1-1000.txt
rename to trane/typing/1-1000.txt
diff --git a/trane/typing/inference.py b/trane/typing/inference.py
index ada6e26e..81ff8c94 100644
--- a/trane/typing/inference.py
+++ b/trane/typing/inference.py
@@ -1,3 +1,5 @@
+from typing import Dict
+
import pandas as pd
from trane.typing.column_schema import ColumnSchema
@@ -20,7 +22,7 @@
)
-def _infer_series_schema(series):
+def _infer_series_schema(series: pd.Series) -> ColumnSchema:
inference_functions = {
boolean_func: ColumnSchema(logical_type=Boolean),
categorical_func: ColumnSchema(
@@ -38,7 +40,7 @@ def _infer_series_schema(series):
return ColumnSchema(logical_type=Unknown)
-def infer_table_meta(df: pd.DataFrame):
+def infer_table_meta(df: pd.DataFrame) -> Dict[str, ColumnSchema]:
table_meta = {}
for col in df.columns:
column_schema = _infer_series_schema(df[col])
diff --git a/trane/typing/inference_functions.py b/trane/typing/inference_functions.py
index 8fc2a233..5b4f65ed 100644
--- a/trane/typing/inference_functions.py
+++ b/trane/typing/inference_functions.py
@@ -13,7 +13,7 @@
COMMON_WORDS_SET = set(
word.strip().lower()
- for word in files("trane.utils").joinpath("1-1000.txt").read_text().split("\n")
+ for word in files("trane.typing").joinpath("1-1000.txt").read_text().split("\n")
if len(word) > 0
)