Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
gsheni committed Jun 28, 2023
1 parent 1d6ca22 commit b126d65
Show file tree
Hide file tree
Showing 13 changed files with 195 additions and 368 deletions.
20 changes: 0 additions & 20 deletions scripts/walmart.py

This file was deleted.

83 changes: 0 additions & 83 deletions scripts/yelp.py

This file was deleted.

67 changes: 8 additions & 59 deletions tests/integration_tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@

import pandas as pd
import pytest
from woodwork.column_schema import ColumnSchema
from woodwork.logical_types import (
Categorical,
Datetime,
Double,
Integer,
)

import trane
from trane.datasets.load_functions import (
load_bike_metadata,
load_covid_metadata,
load_youtube_metadata,
)

from .utils import generate_and_verify_prediction_problem

Expand All @@ -34,21 +32,7 @@ def df_youtube(current_dir):

@pytest.fixture
def meta_youtube(current_dir):
table_meta = {
"trending_date": ColumnSchema(logical_type=Datetime),
"channel_title": ColumnSchema(
logical_type=Categorical,
semantic_tags={"index"},
),
"category_id": ColumnSchema(
logical_type=Categorical,
semantic_tags={"category", "index"},
),
"views": ColumnSchema(logical_type=Integer, semantic_tags={"numeric"}),
"likes": ColumnSchema(logical_type=Integer, semantic_tags={"numeric"}),
"dislikes": ColumnSchema(logical_type=Integer, semantic_tags={"numeric"}),
"comment_count": ColumnSchema(logical_type=Integer, semantic_tags={"numeric"}),
}
table_meta = load_youtube_metadata()
return table_meta


Expand All @@ -67,22 +51,7 @@ def df_covid(current_dir):

@pytest.fixture
def meta_covid(current_dir):
table_meta = {
"Province/State": ColumnSchema(
logical_type=Categorical,
semantic_tags={"category"},
),
"Country/Region": ColumnSchema(
logical_type=Categorical,
semantic_tags={"category", "index"},
),
"Lat": ColumnSchema(logical_type=Double, semantic_tags={"numeric"}),
"Long": ColumnSchema(logical_type=Double, semantic_tags={"numeric"}),
"Date": ColumnSchema(logical_type=Datetime),
"Confirmed": ColumnSchema(logical_type=Integer, semantic_tags={"numeric"}),
"Deaths": ColumnSchema(logical_type=Integer, semantic_tags={"numeric"}),
"Recovered": ColumnSchema(logical_type=Integer, semantic_tags={"numeric"}),
}
table_meta = load_covid_metadata()
return table_meta


Expand All @@ -99,27 +68,7 @@ def df_chicago(current_dir):

@pytest.fixture
def meta_chicago(current_dir):
table_meta = {
"date": ColumnSchema(logical_type=Datetime),
"hour": ColumnSchema(logical_type=Categorical, semantic_tags={"category"}),
"usertype": ColumnSchema(logical_type=Categorical, semantic_tags={"category"}),
"gender": ColumnSchema(logical_type=Categorical, semantic_tags={"category"}),
"tripduration": ColumnSchema(logical_type=Double, semantic_tags={"numeric"}),
"temperature": ColumnSchema(logical_type=Double, semantic_tags={"numeric"}),
"from_station_id": ColumnSchema(
logical_type=Categorical,
semantic_tags={"index"},
),
"dpcapacity_start": ColumnSchema(
logical_type=Integer,
semantic_tags={"numeric"},
),
"to_station_id": ColumnSchema(
logical_type=Categorical,
semantic_tags={"index"},
),
"dpcapacity_end": ColumnSchema(logical_type=Integer, semantic_tags={"numeric"}),
}
table_meta = load_bike_metadata()
return table_meta


Expand Down
93 changes: 57 additions & 36 deletions tests/test_datasets.py → tests/test_load_functions.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,22 @@
from woodwork.column_schema import ColumnSchema
from woodwork.logical_types import (
Datetime,
)

from trane.datasets.load_functions import (
load_bike,
load_bike_metadata,
load_covid,
load_covid_metadata,
load_youtube,
load_youtube_metadata,
)


def test_load_covid():
df = load_covid()
for col in [
metadata = load_covid_metadata()
expected_columns = [
"Province/State",
"Country/Region",
"Lat",
Expand All @@ -16,10 +25,55 @@ def test_load_covid():
"Confirmed",
"Deaths",
"Recovered",
]:
assert col in df.columns
]
check_column_schema(expected_columns, df, metadata)
assert len(df) >= 17136
assert df["Date"].dtype == "datetime64[ns]"
assert metadata["Date"] == ColumnSchema(logical_type=Datetime)


def test_load_bike():
df = load_bike()
metadata = load_bike_metadata()
expected_columns = [
"date",
"hour",
"usertype",
"gender",
"tripduration",
"temperature",
"from_station_id",
"dpcapacity_start",
"to_station_id",
"dpcapacity_end",
]
check_column_schema(expected_columns, df, metadata)
assert df["date"].dtype == "datetime64[ns]"
assert metadata["date"] == ColumnSchema(logical_type=Datetime)


def test_load_youtube():
df = load_youtube()
metadata = load_youtube_metadata()
expected_columns = [
"trending_date",
"channel_title",
"category_id",
"views",
"likes",
"dislikes",
"comment_count",
]
check_column_schema(expected_columns, df, metadata)
assert df["trending_date"].dtype == "datetime64[ns]"
assert metadata["trending_date"] == ColumnSchema(logical_type=Datetime)


def check_column_schema(columns, df, metadata):
for col in columns:
assert col in df.columns
assert col in metadata.keys()
assert isinstance(metadata[col], ColumnSchema)


# def test_load_flight():
Expand Down Expand Up @@ -60,36 +114,3 @@ def test_load_covid():
# assert col in flights_df.columns

# assert flights_df["DATE"].dtype == "datetime64[ns]"


def test_load_bike():
df = load_bike()
for col in [
"date",
"hour",
"usertype",
"gender",
"tripduration",
"temperature",
"from_station_id",
"dpcapacity_start",
"to_station_id",
"dpcapacity_end",
]:
assert col in df.columns
assert df["date"].dtype == "datetime64[ns]"


def test_load_youtube():
df = load_youtube()
for col in [
"trending_date",
"channel_title",
"category_id",
"views",
"likes",
"dislikes",
"comment_count",
]:
assert col in df.columns
assert df["trending_date"].dtype == "datetime64[ns]"
2 changes: 1 addition & 1 deletion trane/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from trane.core import * # noqa
from trane.datasets import (
load_covid,
load_covid_tablemeta,
load_covid_metadata,
load_bike,
load_youtube,
load_youtube_metadata,
Expand Down
Loading

0 comments on commit b126d65

Please sign in to comment.