Skip to content

Commit

Permalink
Fix text dataset stats and benchmark materialize return (#380)
Browse files Browse the repository at this point in the history
Fix text dataset stats and benchmark dataset materialize return
  • Loading branch information
zechengz authored Mar 16, 2024
1 parent 5345eea commit 2acfc8c
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 12 deletions.
3 changes: 2 additions & 1 deletion torch_frame/datasets/data_frame_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -797,9 +797,10 @@ def __repr__(self) -> str:
f' cls={self.cls_str}\n'
f')')

def materialize(self, *args, **kwargs):
def materialize(self, *args, **kwargs) -> torch_frame.data.Dataset:
super().materialize(*args, **kwargs)
if self.task_type != self._task_type:
raise RuntimeError(f"task type does not match. It should be "
f"{self.task_type.value} but specified as "
f"{self._task_type.value}.")
return self
13 changes: 7 additions & 6 deletions torch_frame/datasets/data_frame_text_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ class DataFrameTextBenchmark(torch_frame.data.Dataset):
- 1
- 6
- MultimodalTextBenchmark(name='data_scientist_salary')
- 0.0%
- 12.3%
* - multiclass_classification
- small
- 3
Expand All @@ -130,7 +130,7 @@ class DataFrameTextBenchmark(torch_frame.data.Dataset):
- 3
- 10
- MultimodalTextBenchmark(name='melbourne_airbnb')
- 0.0%
- 9.6%
* - multiclass_classification
- medium
- 0
Expand Down Expand Up @@ -240,7 +240,7 @@ class DataFrameTextBenchmark(torch_frame.data.Dataset):
- 3
- 1
- MultimodalTextBenchmark(name='ae_price_prediction')
- 0.0%
- 6.1%
* - regression
- small
- 7
Expand All @@ -251,7 +251,7 @@ class DataFrameTextBenchmark(torch_frame.data.Dataset):
- 11
- 1
- MultimodalTextBenchmark(name='california_house_price')
- 0.0%
- 13.8%
* - regression
- medium
- 0
Expand All @@ -262,7 +262,7 @@ class DataFrameTextBenchmark(torch_frame.data.Dataset):
- 1
- 1
- MultimodalTextBenchmark(name='mercari_price_suggestion100K')
- 0.0%
- 3.4%
* - regression
- large
- 0
Expand Down Expand Up @@ -495,9 +495,10 @@ def __repr__(self) -> str:
f' cls={self.cls_str}\n'
f')')

def materialize(self, *args, **kwargs):
def materialize(self, *args, **kwargs) -> torch_frame.data.Dataset:
super().materialize(*args, **kwargs)
if self.task_type != self._task_type:
raise RuntimeError(f"task type does not match. It should be "
f"{self.task_type.value} but specified as "
f"{self._task_type.value}.")
return self
10 changes: 5 additions & 5 deletions torch_frame/datasets/multimodal_text_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ class MultimodalTextBenchmark(torch_frame.data.Dataset):
- 1
- 6
- multiclass_classification
- 0.0%
- 12.3%
* - melbourne_airbnb
- 22,895
- 26
Expand All @@ -90,7 +90,7 @@ class MultimodalTextBenchmark(torch_frame.data.Dataset):
- 3
- 10
- multiclass_classification
- 0.0%
- 9.6%
* - imdb_genre_prediction
- 1,000
- 7
Expand Down Expand Up @@ -180,7 +180,7 @@ class MultimodalTextBenchmark(torch_frame.data.Dataset):
- 3
- 1
- regression
- 0.0%
- 6.1%
* - california_house_price
- 47,439
- 18
Expand All @@ -189,7 +189,7 @@ class MultimodalTextBenchmark(torch_frame.data.Dataset):
- 11
- 1
- regression
- 0.0%
- 13.8%
* - mercari_price_suggestion100K
- 125,000
- 0
Expand All @@ -198,7 +198,7 @@ class MultimodalTextBenchmark(torch_frame.data.Dataset):
- 1
- 1
- regression
- 0.0%
- 3.4%
"""
base_url = 'https://automl-mm-bench.s3.amazonaws.com'

Expand Down

0 comments on commit 2acfc8c

Please sign in to comment.