Skip to content

Commit

Permalink
Merge branch 'main' into v1.0.0
Browse files Browse the repository at this point in the history
  • Loading branch information
andrewilyas committed Mar 1, 2023
2 parents 25c9700 + e32280b commit 5085773
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 44 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ jobs:
mv docs docs_src
cd docs_src
pip install -U sphinx karma-sphinx-theme
pip install -U numpy==1.20 numba tqdm
pip install -U numpy numba tqdm
pip install --upgrade -U pygments
make html
cp -r _build/html ../docs
Expand Down
2 changes: 1 addition & 1 deletion ffcv/fields/rgb_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Alloca
min_height = heights.min()
min_width = widths.min()
if min_width != max_width or max_height != min_height:
msg = """SimpleRGBImageDecoder ony supports constant image,
msg = """SimpleRGBImageDecoder only supports constant image,
consider RandomResizedCropRGBImageDecoder or CenterCropRGBImageDecoder
instead."""
raise TypeError(msg)
Expand Down
2 changes: 1 addition & 1 deletion ffcv/loader/epoch_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def run(self):
event = ch.cuda.Event()
event.record(self.current_stream)
events[just_finished_slot] = event
b_ix += 1
b_ix += 1

except StopIteration:
self.output_queue.put(None)
Expand Down
9 changes: 6 additions & 3 deletions ffcv/transforms/random_resized_crop.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from ..pipeline.compiler import Compiler

class RandomResizedCrop(Operation):
"""Crop a random portion of image with random aspect ratio and resize it to a given size.
"""Crop a random portion of image with random aspect ratio and resize it to
a given size. Chances are you do not want to use this augmentation and
instead want to include RRC as part of the decoder, by using the
:cla:`~ffcv.fields.rgb_image.ResizedCropRGBImageDecoder` class.
Parameters
----------
Expand Down Expand Up @@ -49,7 +52,7 @@ def random_resized_crop(images, dst):
return random_resized_crop

def declare_state_and_memory(self, previous_state: State) -> Tuple[State, Optional[AllocationQuery]]:
assert previous_state.jit_mode
return replace(previous_state, shape=(self.size, self.size, 3)), AllocationQuery((self.size, self.size, 3), dtype=np.dtype('uint8'))
return replace(previous_state, jit_mode=True, shape=(self.size, self.size, 3)), \
AllocationQuery((self.size, self.size, 3), dtype=previous_state.dtype)


1 change: 1 addition & 0 deletions ffcv/transforms/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def translate(images, dst):
dst[:] = fill
dst[:, pad:pad+h, pad:pad+w] = images
for i in my_range(n):
dst[i] = 0
y_coord = randint(low=0, high=2 * pad + 1)
x_coord = randint(low=0, high=2 * pad + 1)
images[i] = dst[i, y_coord:y_coord+h, x_coord:x_coord+w]
Expand Down
56 changes: 18 additions & 38 deletions tests/test_augmentations.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ToTorchImage()
]

def run_test(length, pipeline, compile=False):
def run_test(length, pipeline, should_compile=False, aug_name=''):
my_dataset = Subset(CIFAR10(root='/tmp', train=True, download=True), range(length))

with NamedTemporaryFile() as handle:
Expand All @@ -42,7 +42,7 @@ def run_test(length, pipeline, compile=False):

writer.from_indexed_dataset(my_dataset, chunksize=10)

Compiler.set_enabled(compile)
Compiler.set_enabled(should_compile)

loader = Loader(name, batch_size=7, num_workers=2, pipelines={
'image': pipeline,
Expand All @@ -57,18 +57,16 @@ def run_test(length, pipeline, compile=False):

tot_indices = 0
tot_images = 0
for (images, labels), (original_images, original_labels) in zip(loader, unaugmented_loader):
print(images.shape, original_images.shape)
for it_num, ((images, labels), (original_images, original_labels)) in enumerate(zip(loader, unaugmented_loader)):
tot_indices += labels.shape[0]
tot_images += images.shape[0]

for label, original_label in zip(labels, original_labels):
assert_that(label).is_equal_to(original_label)

if SAVE_IMAGES:
if SAVE_IMAGES and it_num == 0:
save_image(make_grid(ch.concat([images, original_images])/255., images.shape[0]),
os.path.join(IMAGES_TMP_PATH, str(uuid.uuid4()) + '.jpeg')
)
os.path.join(IMAGES_TMP_PATH, aug_name + '-' + str(uuid.uuid4()) + '.jpeg'))

assert_that(tot_indices).is_equal_to(len(my_dataset))
assert_that(tot_images).is_equal_to(len(my_dataset))
Expand All @@ -80,7 +78,7 @@ def test_cutout():
Cutout(8),
ToTensor(),
ToTorchImage()
], comp)
], comp, 'cutout')


def test_flip():
Expand All @@ -90,7 +88,7 @@ def test_flip():
RandomHorizontalFlip(1.0),
ToTensor(),
ToTorchImage()
], comp)
], comp, 'flip')


def test_module_wrapper():
Expand All @@ -100,7 +98,7 @@ def test_module_wrapper():
ToTensor(),
ToTorchImage(),
ModuleWrapper(tvt.Grayscale(3)),
], comp)
], comp, 'module')


def test_mixup():
Expand All @@ -110,7 +108,7 @@ def test_mixup():
ImageMixup(.5, False),
ToTensor(),
ToTorchImage()
], comp)
], comp, 'mixup')


def test_poison():
Expand All @@ -125,8 +123,7 @@ def test_poison():
Poison(mask, alpha, list(range(100))),
ToTensor(),
ToTorchImage()
], comp)

], comp, 'poison')

def test_random_resized_crop():
for comp in [True, False]:
Expand All @@ -137,7 +134,7 @@ def test_random_resized_crop():
size=32),
ToTensor(),
ToTorchImage()
], comp)
], comp, 'rrc')


def test_translate():
Expand All @@ -147,7 +144,7 @@ def test_translate():
RandomTranslate(padding=10),
ToTensor(),
ToTorchImage()
], comp)
], comp, 'translate')


## Torchvision Transforms
Expand All @@ -157,7 +154,7 @@ def test_torchvision_greyscale():
ToTensor(),
ToTorchImage(),
tvt.Grayscale(3),
])
], aug_name='tv_grey')

def test_torchvision_centercrop_pad():
run_test(100, [
Expand All @@ -166,15 +163,15 @@ def test_torchvision_centercrop_pad():
ToTorchImage(),
tvt.CenterCrop(10),
tvt.Pad(11)
])
], aug_name='tv_crop_pad')

def test_torchvision_random_affine():
run_test(100, [
SimpleRGBImageDecoder(),
ToTensor(),
ToTorchImage(),
tvt.RandomAffine(25),
])
], aug_name='tv_random_affine')

def test_torchvision_random_crop():
run_test(100, [
Expand All @@ -183,29 +180,12 @@ def test_torchvision_random_crop():
ToTorchImage(),
tvt.Pad(10),
tvt.RandomCrop(size=32),
])
], aug_name='tv_randcrop')

def test_torchvision_color_jitter():
run_test(100, [
SimpleRGBImageDecoder(),
ToTensor(),
ToTorchImage(),
tvt.ColorJitter(.5, .5, .5, .5),
])


if __name__ == '__main__':
test_cutout()
test_flip()
test_module_wrapper()
test_mixup()
test_poison()
test_random_resized_crop()
test_translate()

## Torchvision Transforms
test_torchvision_greyscale()
test_torchvision_centercrop_pad()
test_torchvision_random_affine()
test_torchvision_random_crop()
test_torchvision_color_jitter()
tvt.ColorJitter(.5, .5, .5, .5)
], aug_name='tv_colorjitter')

0 comments on commit 5085773

Please sign in to comment.