Skip to content

Commit

Permalink
Fix minor bugs in time series forecasting (#600)
Browse files Browse the repository at this point in the history
* Fix potential bugs in seasonal error calculation

* Fix dtype test
  • Loading branch information
shchur authored Oct 23, 2023
1 parent 6872b63 commit 53b0eea
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 10 deletions.
4 changes: 3 additions & 1 deletion amlb/datasets/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ def __init__(self, path, fold, target, features, cache_dir, config):
self.id_column = config['id_column']
self.timestamp_column = config['timestamp_column']

# Ensure that id_column is parsed as string to avoid incorrect sorting
full_data[self.id_column] = full_data[self.id_column].astype(str)
full_data[self.timestamp_column] = pd.to_datetime(full_data[self.timestamp_column])
if config['name'] is not None:
file_name = config['name']
Expand All @@ -353,7 +355,7 @@ def __init__(self, path, fold, target, features, cache_dir, config):

self._train = CsvDatasplit(self, train_path, timestamp_column=self.timestamp_column)
self._test = CsvDatasplit(self, test_path, timestamp_column=self.timestamp_column)
self._dtypes = None
self._dtypes = full_data.dtypes

# Store repeated item_id & in-sample seasonal error for each time step in the forecast horizon - needed later for metrics like MASE.
# We need to store this information here because Result object has no access to past time series values.
Expand Down
22 changes: 14 additions & 8 deletions frameworks/AutoGluon/exec_ts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,24 @@
from joblib.externals.loky import get_reusable_executor

from frameworks.shared.callee import call_run, result, output_subdir
from frameworks.shared.utils import Timer, zip_path
from frameworks.shared.utils import Timer, zip_path, load_timeseries_dataset

log = logging.getLogger(__name__)


def run(dataset, config):
log.info(f"\n**** AutoGluon TimeSeries [v{__version__}] ****\n")
prediction_length = dataset.forecast_horizon_in_steps
train_df, test_df = load_timeseries_dataset(dataset)

train_data = TimeSeriesDataFrame.from_path(
dataset.train_path,
train_data = TimeSeriesDataFrame.from_data_frame(
train_df,
id_column=dataset.id_column,
timestamp_column=dataset.timestamp_column,
)

test_data = TimeSeriesDataFrame.from_data_frame(
test_df,
id_column=dataset.id_column,
timestamp_column=dataset.timestamp_column,
)
Expand All @@ -45,14 +52,14 @@ def run(dataset, config):
predictor.fit(
train_data=train_data,
time_limit=config.max_runtime_seconds,
random_seed=config.seed,
**{k: v for k, v in config.framework_params.items() if not k.startswith('_')},
)

with Timer() as predict:
predictions = pd.DataFrame(predictor.predict(train_data))

# Add columns necessary for the metric computation + quantile forecast to `optional_columns`
test_data_future = pd.read_csv(dataset.test_path, parse_dates=[dataset.timestamp_column])
optional_columns = dict(
repeated_item_id=np.load(dataset.repeated_item_id),
repeated_abs_seasonal_error=np.load(dataset.repeated_abs_seasonal_error),
Expand All @@ -61,13 +68,12 @@ def run(dataset, config):
optional_columns[str(q)] = predictions[str(q)].values

predictions_only = get_point_forecast(predictions, config.metric)
truth_only = test_data_future[dataset.target].values
truth_only = test_df[dataset.target].values

# Sanity check - make sure predictions are ordered correctly
future_index = pd.MultiIndex.from_frame(test_data_future[[dataset.id_column, dataset.timestamp_column]])
assert predictions.index.equals(future_index), "Predictions and test data index do not match"
assert predictions.index.equals(test_data.index), "Predictions and test data index do not match"

test_data_full = pd.concat([train_data, test_data_future.set_index([dataset.id_column, dataset.timestamp_column])])
test_data_full = pd.concat([train_data, test_data])
leaderboard = predictor.leaderboard(test_data_full, silent=True)

with pd.option_context('display.max_rows', None, 'display.max_columns', None, 'display.width', 1000):
Expand Down
8 changes: 8 additions & 0 deletions frameworks/shared/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import importlib.util
import logging
import os
import pandas as pd
import sys


Expand Down Expand Up @@ -42,6 +43,13 @@ def load_amlb_module(mod, amlb_path=None):
return import_module(mod)


def load_timeseries_dataset(dataset):
# Ensure that id_column is loaded as string to avoid incorrect sorting
train_data = pd.read_csv(dataset.train_path, dtype={dataset.id_column: str}, parse_dates=[dataset.timestamp_column])
test_data = pd.read_csv(dataset.test_path, dtype={dataset.id_column: str}, parse_dates=[dataset.timestamp_column])
return train_data, test_data


utils = load_amlb_module("amlb.utils")
# unorthodox for it's only now that we can safely import those functions
from amlb.utils import *
Expand Down
4 changes: 4 additions & 0 deletions requirements.in
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,7 @@ scikit-learn>=1.0,<2.0

pyarrow>=11.0
# tables>=3.6

# Allow loading datasets from S3
fsspec
s3fs
7 changes: 7 additions & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ boto3==1.26.98
botocore==1.29.98
# via
# boto3
# s3fs
# s3transfer
certifi==2022.12.7
# via
Expand All @@ -18,6 +19,10 @@ charset-normalizer==3.1.0
# via requests
filelock==3.12.0
# via -r requirements.in
fsspec==2023.6.0
# via
# -r requirements.in
# s3fs
idna==3.4
# via requests
jmespath==1.0.1
Expand Down Expand Up @@ -65,6 +70,8 @@ ruamel-yaml==0.17.21
# via -r requirements.in
ruamel-yaml-clib==0.2.7
# via ruamel-yaml
s3fs==0.4.2
# via -r requirements.in
s3transfer==0.6.0
# via boto3
scikit-learn==1.2.2
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/amlb/datasets/file/test_file_dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def test_load_timeseries_task_csv(file_loader):
assert len(ds.repeated_abs_seasonal_error) == len(ds.test.data)
assert len(ds.repeated_item_id) == len(ds.test.data)

assert pat.is_categorical_dtype(ds._dtypes[ds.id_column])
assert pat.is_string_dtype(ds._dtypes[ds.id_column])
assert pat.is_datetime64_dtype(ds._dtypes[ds.timestamp_column])
assert pat.is_float_dtype(ds._dtypes[ds.target.name])

Expand Down

0 comments on commit 53b0eea

Please sign in to comment.