Skip to content

Commit

Permalink
refactor: code refactor and release v0.1;
Browse files Browse the repository at this point in the history
  • Loading branch information
WenjieDu committed Aug 5, 2023
1 parent 1e30262 commit 256133a
Show file tree
Hide file tree
Showing 6 changed files with 18 additions and 12 deletions.
1 change: 1 addition & 0 deletions MANIFEST.in
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
prune tests
9 changes: 6 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,17 @@
"Download": "https://github.com/WenjieDu/TSDB/archive/main.zip",
},
keywords=[
"time-series analysis",
"data mining",
"time series",
"time-series analysis",
"time-series database",
"time-series datasets",
"datasets",
"database",
"datasets",
"dataset downloading",
"data mining",
"imputation",
"classification",
"forecasting",
],
packages=find_packages(exclude=["tests"]),
include_package_data=True,
Expand Down
File renamed without changes.
File renamed without changes.
2 changes: 1 addition & 1 deletion tsdb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
#
# Dev branch marker is: 'X.Y.dev' or 'X.Y.devN' where N is an integer.
# 'X.Y.dev0' is the canonical version of 'X.Y.dev'
__version__ = "0.0.9"
__version__ = "0.1"


try:
Expand Down
18 changes: 10 additions & 8 deletions tsdb/data_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def window_truncate(feature_vectors, seq_len):
start_indices = numpy.asarray(range(feature_vectors.shape[0] // seq_len)) * seq_len
sample_collector = []
for idx in start_indices:
sample_collector.append(feature_vectors[idx: idx + seq_len])
sample_collector.append(feature_vectors[idx : idx + seq_len])

return numpy.asarray(sample_collector).astype("float32")

Expand Down Expand Up @@ -105,7 +105,7 @@ def _download_and_extract(url, saving_path):
logger.info(f"Successfully downloaded data to {raw_data_saving_path}.")

if (
suffix in supported_compression_format
suffix in supported_compression_format
): # if the file is compressed, then unpack it
try:
os.makedirs(saving_path, exist_ok=True)
Expand Down Expand Up @@ -177,7 +177,7 @@ def delete_cached_data(dataset_name=None):
# if CACHED_DATASET_DIR exists, then purge
if dataset_name is not None:
assert (
dataset_name in AVAILABLE_DATASETS
dataset_name in AVAILABLE_DATASETS
), f"{dataset_name} is not available in TSDB, so it has no cache. Please check your dataset name."
dir_to_delete = os.path.join(CACHED_DATASET_DIR, dataset_name)
if not os.path.exists(dir_to_delete):
Expand Down Expand Up @@ -290,13 +290,15 @@ def load_dataset(dataset_name, use_cache=True):
)

profile_dir = dataset_name if "ucr_uea_" not in dataset_name else "ucr_uea_datasets"
logger.info(f"You're using dataset {dataset_name}, please cite it properly in your work.\n"
f"You can find its reference information at the below link: \n"
f"https://github.com/WenjieDu/TSDB/tree/main/dataset_profiles/{profile_dir}")
logger.info(
f"You're using dataset {dataset_name}, please cite it properly in your work. "
f"You can find its reference information at the below link: \n"
f"https://github.com/WenjieDu/TSDB/tree/main/dataset_profiles/{profile_dir}"
)

dataset_saving_path = os.path.join(CACHED_DATASET_DIR, dataset_name)
if not os.path.exists(
dataset_saving_path
dataset_saving_path
): # if the dataset is not cached, then download it
download_and_extract(dataset_name, dataset_saving_path)
else:
Expand All @@ -320,7 +322,7 @@ def load_dataset(dataset_name, use_cache=True):
try:
if dataset_name == "physionet_2012":
result = load_physionet2012(dataset_saving_path)
if dataset_name == "physionet_2019":
elif dataset_name == "physionet_2019":
result = load_physionet2019(dataset_saving_path)
elif dataset_name == "electricity_load_diagrams":
result = load_electricity(dataset_saving_path)
Expand Down

0 comments on commit 256133a

Please sign in to comment.