Skip to content

Commit

Permalink
Merge pull request #56 from deel-ai/fix/homogenize_preprocess
Browse files Browse the repository at this point in the history
Homogenize preprocess_fn between torch and tf
  • Loading branch information
y-prudent authored Jul 19, 2023
2 parents 204f1db + 29f6787 commit 48a7f1e
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 20 deletions.
4 changes: 2 additions & 2 deletions docs/notebooks/torch/demo_dknn_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@
"\n",
"\n",
"# 3- prepare data (preprocess, shuffle, batch) => torch dataloaders\n",
"def preprocess_fn(inputs):\n",
"def preprocess_fn(*inputs):\n",
" \"\"\"Simple preprocessing function to normalize images in [0, 1].\"\"\"\n",
" x = inputs[0] / 255.0\n",
" return tuple([x] + list(inputs[1:]))\n",
Expand Down Expand Up @@ -339,7 +339,7 @@
"# 2- prepare data (preprocess, shuffle, batch) => torch dataloaders\n",
"\n",
"\n",
"def preprocess_fn(inputs):\n",
"def preprocess_fn(*inputs):\n",
" \"\"\"Preprocessing function from\n",
" https://github.com/chenyaofo/pytorch-cifar-models\n",
" \"\"\"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/torch/demo_energy_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
"oods_in, oods_out = mnist_test.assign_ood_labels_by_class(in_labels=in_labels)\n",
"\n",
"# 3- prepare data (preprocess, shuffle, batch) => torch dataloaders\n",
"def preprocess_fn(inputs):\n",
"def preprocess_fn(*inputs):\n",
" \"\"\"Simple preprocessing function to normalize images in [0, 1].\n",
" \"\"\"\n",
" x = inputs[0] / 255.0\n",
Expand Down Expand Up @@ -310,7 +310,7 @@
"\n",
"# 2- prepare data (preprocess, shuffle, batch) => torch dataloaders\n",
"\n",
"def preprocess_fn(inputs):\n",
"def preprocess_fn(*inputs):\n",
" \"\"\"Preprocessing function from\n",
" https://github.com/chenyaofo/pytorch-cifar-models\n",
" \"\"\"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/torch/demo_entropy_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
"oods_in, oods_out = mnist_test.assign_ood_labels_by_class(in_labels=in_labels)\n",
"\n",
"# 3- prepare data (preprocess, shuffle, batch) => torch dataloaders\n",
"def preprocess_fn(inputs):\n",
"def preprocess_fn(*inputs):\n",
" \"\"\"Simple preprocessing function to normalize images in [0, 1].\n",
" \"\"\"\n",
" x = inputs[0] / 255.0\n",
Expand Down Expand Up @@ -310,7 +310,7 @@
"\n",
"# 2- prepare data (preprocess, shuffle, batch) => torch dataloaders\n",
"\n",
"def preprocess_fn(inputs):\n",
"def preprocess_fn(*inputs):\n",
" \"\"\"Preprocessing function from\n",
" https://github.com/chenyaofo/pytorch-cifar-models\n",
" \"\"\"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/torch/demo_mahalanobis_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@
"oods_in, oods_out = mnist_test.assign_ood_labels_by_class(in_labels=in_labels)\n",
"\n",
"# 3- prepare data (preprocess, shuffle, batch) => torch dataloaders\n",
"def preprocess_fn(inputs):\n",
"def preprocess_fn(*inputs):\n",
" \"\"\"Simple preprocessing function to normalize images in [0, 1].\n",
" \"\"\"\n",
" x = inputs[0] / 255.0\n",
Expand Down Expand Up @@ -311,7 +311,7 @@
"\n",
"# 2- prepare data (preprocess, shuffle, batch) => torch dataloaders\n",
"\n",
"def preprocess_fn(inputs):\n",
"def preprocess_fn(*inputs):\n",
" \"\"\"Preprocessing function from\n",
" https://github.com/chenyaofo/pytorch-cifar-models\n",
" \"\"\"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/torch/demo_mls_msp_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
"oods_in, oods_out = mnist_test.assign_ood_labels_by_class(in_labels=in_labels)\n",
"\n",
"# 3- prepare data (preprocess, shuffle, batch) => torch dataloaders\n",
"def preprocess_fn(inputs):\n",
"def preprocess_fn(*inputs):\n",
" \"\"\"Simple preprocessing function to normalize images in [0, 1].\n",
" \"\"\"\n",
" x = inputs[0] / 255.0\n",
Expand Down Expand Up @@ -386,7 +386,7 @@
"\n",
"# 2- prepare data (preprocess, shuffle, batch) => torch dataloaders\n",
"\n",
"def preprocess_fn(inputs):\n",
"def preprocess_fn(*inputs):\n",
" \"\"\"Preprocessing function from\n",
" https://github.com/chenyaofo/pytorch-cifar-models\n",
" \"\"\"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/torch/demo_odin_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@
"oods_in, oods_out = mnist_test.assign_ood_labels_by_class(in_labels=in_labels)\n",
"\n",
"# 3- prepare data (preprocess, shuffle, batch) => torch dataloaders\n",
"def preprocess_fn(inputs):\n",
"def preprocess_fn(*inputs):\n",
" \"\"\"Simple preprocessing to normalize images in [0, 1].\n",
" \"\"\"\n",
" x = inputs[0] / 255.0\n",
Expand Down Expand Up @@ -316,7 +316,7 @@
"\n",
"# 2- prepare data (preprocess, shuffle, batch) => torch dataloaders\n",
"\n",
"def preprocess_fn(inputs):\n",
"def preprocess_fn(*inputs):\n",
" \"\"\"Preprocessing function from\n",
" https://github.com/chenyaofo/pytorch-cifar-models\n",
" \"\"\"\n",
Expand Down
4 changes: 2 additions & 2 deletions docs/notebooks/torch/demo_vim_torch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@
"oods_in, oods_out = mnist_test.assign_ood_labels_by_class(in_labels=in_labels)\n",
"\n",
"# 3- prepare data (preprocess, shuffle, batch) => torch dataloaders\n",
"def preprocess_fn(inputs):\n",
"def preprocess_fn(*inputs):\n",
" \"\"\"Simple preprocessing function to normalize images in [0, 1].\n",
" \"\"\"\n",
" x = inputs[0] / 255.0\n",
Expand Down Expand Up @@ -318,7 +318,7 @@
"\n",
"# 2- prepare data (preprocess, shuffle, batch) => torch dataloaders\n",
"\n",
"def preprocess_fn(inputs):\n",
"def preprocess_fn(*inputs):\n",
" \"\"\"Preprocessing function from\n",
" https://github.com/chenyaofo/pytorch-cifar-models\n",
" \"\"\"\n",
Expand Down
12 changes: 8 additions & 4 deletions oodeel/datasets/torch_data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,22 +601,26 @@ def prepare_for_training(
Returns:
DataLoader: dataloader
"""
preprocess_fn = preprocess_fn or (lambda x: x)
augment_fn = augment_fn or (lambda x: x)
output_keys = output_keys or cls.get_ds_feature_keys(dataset)

def collate_fn(batch: List[dict]):
if dict_based_fns:
# preprocess + DA: List[dict] -> List[dict]
batch = [augment_fn(preprocess_fn(d)) for d in batch]
preprocess_func = preprocess_fn or (lambda x: x)
augment_func = augment_fn or (lambda x: x)
batch = [augment_func(preprocess_func(d)) for d in batch]
# to tuple of batchs
return tuple(
default_collate([d[key] for d in batch]) for key in output_keys
)
else:
# preprocess + DA: List[dict] -> List[tuple]
preprocess_func = preprocess_fn or (lambda *x: x)
augment_func = augment_fn or (lambda *x: x)
batch = [
augment_fn(preprocess_fn(tuple(d[key] for key in output_keys)))
augment_func(
*preprocess_func(*tuple(d[key] for key in output_keys))
)
for d in batch
]
# to tuple of batchs
Expand Down
4 changes: 2 additions & 2 deletions tests/tests_torch/datasets/test_torch_ooddataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,11 +204,11 @@ def test_prepare(shuffle, with_labels, with_ood_labels, expected_output):
backend="torch",
)

def preprocess_fn(inputs):
def preprocess_fn(*inputs):
x = inputs[0] / 255
return tuple([x] + list(inputs[1:]))

def augment_fn_(inputs):
def augment_fn_(*inputs):
x = torchvision.transforms.RandomHorizontalFlip()(inputs[0])
return tuple([x] + list(inputs[1:]))

Expand Down

0 comments on commit 48a7f1e

Please sign in to comment.