Skip to content

Commit

Permalink
add multi-timeseries dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
khairulislam committed Sep 7, 2023
1 parent 9f67118 commit 7b8ebc8
Show file tree
Hide file tree
Showing 18 changed files with 237 additions and 1,545 deletions.
209 changes: 0 additions & 209 deletions data_provider/base.py

This file was deleted.

28 changes: 11 additions & 17 deletions data_provider/data_factory.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
from data_provider.data_loader import Dataset_Custom, Dataset_Pred
from data_provider.data_loader import Dataset_Custom, Dataset_Pred, MultiTimeSeries
from torch.utils.data import DataLoader

data_dict = {
'custom': Dataset_Custom,
'covid': MultiTimeSeries
}


def data_provider(args, flag):
Data = data_dict[args.data]
timeenc = 0 if args.embed != 'timeF' else 1

if flag == 'test':
shuffle_flag = False
drop_last = False
batch_size = args.batch_size
freq = args.freq
elif flag == 'pred':
shuffle_flag = False
drop_last = False
batch_size = 1
freq = args.freq
if flag == 'pred':
Data = Dataset_Pred
batch_size = 1
else:
shuffle_flag = True
drop_last = True
Data = data_dict[args.data]
batch_size = args.batch_size
freq = args.freq

drop_last = flag == 'train'
shuffle_flag = flag == 'train'
freq = args.freq

data_set = Data(
root_path=args.root_path,
Expand All @@ -35,7 +28,8 @@ def data_provider(args, flag):
features=args.features,
target=args.target,
timeenc=timeenc,
freq=freq
freq=freq,
scale=not args.no_scale
)
print(flag, len(data_set))
data_loader = DataLoader(
Expand Down
Loading

0 comments on commit 7b8ebc8

Please sign in to comment.