Skip to content

Commit

Permalink
Merge pull request #2361 from huggingface/grodino-dataset_trust_remote
Browse files Browse the repository at this point in the history
Dataset trust remote tweaks
  • Loading branch information
rwightman authored Dec 6, 2024
2 parents 9eee47d + 7573096 commit ea23107
Show file tree
Hide file tree
Showing 6 changed files with 40 additions and 21 deletions.
2 changes: 2 additions & 0 deletions timm/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ def __init__(
transform=None,
target_transform=None,
max_steps=None,
**kwargs,
):
assert reader is not None
if isinstance(reader, str):
Expand All @@ -121,6 +122,7 @@ def __init__(
input_key=input_key,
target_key=target_key,
max_steps=max_steps,
**kwargs,
)
else:
self.reader = reader
Expand Down
43 changes: 24 additions & 19 deletions timm/data/dataset_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,34 +74,37 @@ def create_dataset(
seed: int = 42,
repeats: int = 0,
input_img_mode: str = 'RGB',
trust_remote_code: bool = False,
**kwargs,
):
""" Dataset factory method
In parentheses after each arg are the type of dataset supported for each arg, one of:
* folder - default, timm folder (or tar) based ImageDataset
* torch - torchvision based datasets
* Folder - default, timm folder (or tar) based ImageDataset
* Torch - torchvision based datasets
* HFDS - Hugging Face Datasets
* HFIDS - Hugging Face Datasets Iterable (streaming mode, with IterableDataset)
* TFDS - Tensorflow-datasets wrapper in IterabeDataset interface via IterableImageDataset
* WDS - Webdataset
* all - any of the above
* All - any of the above
Args:
name: dataset name, empty is okay for folder based datasets
root: root folder of dataset (all)
split: dataset split (all)
search_split: search for split specific child fold from root so one can specify
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (folder, torch/folder)
class_map: specify class -> index mapping via text file or dict (folder)
load_bytes: load data, return images as undecoded bytes (folder)
download: download dataset if not present and supported (HFDS, TFDS, torch)
is_training: create dataset in train mode, this is different from the split.
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS)
batch_size: batch size hint for (TFDS, WDS)
seed: seed for iterable datasets (TFDS, WDS)
repeats: dataset repeats per iteration i.e. epoch (TFDS, WDS)
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS)
**kwargs: other args to pass to dataset
name: Dataset name, empty is okay for folder based datasets
root: Root folder of dataset (All)
split: Dataset split (All)
search_split: Search for split specific child fold from root so one can specify
`imagenet/` instead of `/imagenet/val`, etc on cmd line / config. (Folder, Torch)
class_map: Specify class -> index mapping via text file or dict (Folder)
load_bytes: Load data, return images as undecoded bytes (Folder)
download: Download dataset if not present and supported (HFIDS, TFDS, Torch)
is_training: Create dataset in train mode, this is different from the split.
For Iterable / TDFS it enables shuffle, ignored for other datasets. (TFDS, WDS, HFIDS)
batch_size: Batch size hint for iterable datasets (TFDS, WDS, HFIDS)
seed: Seed for iterable datasets (TFDS, WDS, HFIDS)
repeats: Dataset repeats per iteration i.e. epoch (TFDS, WDS, HFIDS)
input_img_mode: Input image color conversion mode e.g. 'RGB', 'L' (folder, TFDS, WDS, HFDS, HFIDS)
trust_remote_code: Trust remote code in Hugging Face Datasets if True (HFDS, HFIDS)
**kwargs: Other args to pass through to underlying Dataset and/or Reader classes
Returns:
Dataset object
Expand Down Expand Up @@ -162,6 +165,7 @@ def create_dataset(
split=split,
class_map=class_map,
input_img_mode=input_img_mode,
trust_remote_code=trust_remote_code,
**kwargs,
)
elif name.startswith('hfids/'):
Expand All @@ -177,7 +181,8 @@ def create_dataset(
repeats=repeats,
seed=seed,
input_img_mode=input_img_mode,
**kwargs
trust_remote_code=trust_remote_code,
**kwargs,
)
elif name.startswith('tfds/'):
ds = IterableImageDataset(
Expand Down
2 changes: 1 addition & 1 deletion timm/data/readers/reader_hfds.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def __init__(
self.dataset = datasets.load_dataset(
name, # 'name' maps to path arg in hf datasets
split=split,
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path
cache_dir=self.root, # timm doesn't expect hidden cache dir for datasets, specify a path if root set
trust_remote_code=trust_remote_code
)
# leave decode for caller, plus we want easy access to original path names...
Expand Down
7 changes: 6 additions & 1 deletion timm/data/readers/reader_hfids.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
target_img_mode: str = '',
shuffle_size: Optional[int] = None,
num_samples: Optional[int] = None,
trust_remote_code: bool = False
):
super().__init__()
self.root = root
Expand All @@ -60,7 +61,11 @@ def __init__(
self.target_key = target_key
self.target_img_mode = target_img_mode

self.builder = datasets.load_dataset_builder(name, cache_dir=root)
self.builder = datasets.load_dataset_builder(
name,
cache_dir=root,
trust_remote_code=trust_remote_code,
)
if download:
self.builder.download_and_prepare()

Expand Down
4 changes: 4 additions & 0 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@
help='Dataset key for input images.')
group.add_argument('--target-key', default=None, type=str,
help='Dataset key for target labels.')
group.add_argument('--dataset-trust-remote-code', action='store_true', default=False,
help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.')

# Model parameters
group = parser.add_argument_group('Model parameters')
Expand Down Expand Up @@ -653,6 +655,7 @@ def main():
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.train_num_samples,
trust_remote_code=args.dataset_trust_remote_code,
)

if args.val_split:
Expand All @@ -668,6 +671,7 @@ def main():
input_key=args.input_key,
target_key=args.target_key,
num_samples=args.val_num_samples,
trust_remote_code=args.dataset_trust_remote_code,
)

# setup mixup / cutmix
Expand Down
3 changes: 3 additions & 0 deletions validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
help='Dataset image conversion mode for input images.')
parser.add_argument('--target-key', default=None, type=str,
help='Dataset key for target labels.')
parser.add_argument('--dataset-trust-remote-code', action='store_true', default=False,
help='Allow huggingface dataset import to execute code downloaded from the dataset\'s repo.')

parser.add_argument('--model', '-m', metavar='NAME', default='dpn92',
help='model architecture (default: dpn92)')
Expand Down Expand Up @@ -268,6 +270,7 @@ def validate(args):
input_key=args.input_key,
input_img_mode=input_img_mode,
target_key=args.target_key,
trust_remote_code=args.dataset_trust_remote_code,
)

if args.valid_labels:
Expand Down

0 comments on commit ea23107

Please sign in to comment.