Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

capture ReadError in download #86

Merged
merged 7 commits into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pygmtools/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,12 @@
os.remove(filename)
return self.download(url, name, retries - 1)

file_names = tar.getnames()
try:
file_names = tar.getnames()
except tarfile.ReadError as err:
print('Warning: Content error. Retrying...\n', err)
os.remove(filename)
return self.download(url, name, retries - 1)

Check warning on line 248 in pygmtools/dataset.py

View check run for this annotation

Codecov / codecov/patch

pygmtools/dataset.py#L245-L248

Added lines #L245 - L248 were not covered by tests
print('Unzipping files...')
sleep(0.5)
for file_name in tqdm(file_names):
Expand Down
26 changes: 17 additions & 9 deletions tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from random import choice
import os

import platform
os_name = platform.system()

# Test dataset download and preprocess, and data fetch and evaluation
def _test_benchmark(name, sets, problem, filter, **ds_dict):
benchmark = pygm.benchmark.Benchmark(name=name, sets=sets, problem=problem, filter=filter, **ds_dict)
Expand Down Expand Up @@ -57,19 +60,14 @@ def _test_get_data(benchmark, num):

# Entry function
def test_dataset_and_benchmark():
dataset_name_list = ['PascalVOC', 'WillowObject', 'SPair71k', 'IMC_PT_SparseGM', 'CUB2011']
dataset_name_list = ['WillowObject', 'PascalVOC', 'SPair71k', 'IMC_PT_SparseGM', 'CUB2011']

if os_name == 'Darwin':
dataset_name_list = ['WillowObject', 'SPair71k', 'IMC_PT_SparseGM', 'CUB2011']
problem_type_list = ['2GM', 'MGM']
set_list = ['train', 'test']
filter_list = ['intersection', 'inclusion', 'unfiltered']
dict_list = []
voc_cfg_dict = dict()
voc_cfg_dict['KPT_ANNO_DIR'] = dataset_cfg.PascalVOC.KPT_ANNO_DIR
voc_cfg_dict['ROOT_DIR'] = dataset_cfg.PascalVOC.ROOT_DIR
voc_cfg_dict['SET_SPLIT'] = dataset_cfg.PascalVOC.SET_SPLIT
voc_cfg_dict['CLASSES'] = dataset_cfg.PascalVOC.CLASSES
voc_cfg_dict['CACHE_PATH'] = dataset_cfg.CACHE_PATH
voc_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?id=1TLBN4dnf_THmN3kNINMO0DpHhcBDqzcI&export=download'
dict_list.append(voc_cfg_dict)

willow_cfg_dict = dict()
willow_cfg_dict['CLASSES'] = dataset_cfg.WillowObject.CLASSES
Expand All @@ -82,6 +80,16 @@ def test_dataset_and_benchmark():
willow_cfg_dict['URL'] = 'https://drive.google.com/u/0/uc?export=download&confirm=Z-AR&id=18AvGwkuhnih5bFDjfJK5NYM16LvDfwW_'
dict_list.append(willow_cfg_dict)

if os_name != 'Darwin':
voc_cfg_dict = dict()
voc_cfg_dict['KPT_ANNO_DIR'] = dataset_cfg.PascalVOC.KPT_ANNO_DIR
voc_cfg_dict['ROOT_DIR'] = dataset_cfg.PascalVOC.ROOT_DIR
voc_cfg_dict['SET_SPLIT'] = dataset_cfg.PascalVOC.SET_SPLIT
voc_cfg_dict['CLASSES'] = dataset_cfg.PascalVOC.CLASSES
voc_cfg_dict['CACHE_PATH'] = dataset_cfg.CACHE_PATH
voc_cfg_dict['URL'] = 'https://huggingface.co/datasets/ziaoguo/small_VOC/resolve/main/small_voc.tar?download=true'
dict_list.append(voc_cfg_dict)

spair_cfg_dict = dict()
spair_cfg_dict['TRAIN_DIFF_PARAMS'] = {'mirror': 0}
spair_cfg_dict['EVAL_DIFF_PARAMS'] = dataset_cfg.SPair.EVAL_DIFF_PARAMS
Expand Down
Loading