diff --git a/CHANGELOG.md b/CHANGELOG.md index db9bac5595..8bae42d46d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276)) +- Changed the transforms in `SemanticSegmentationData` to use albumentations instead of Kornia ([#1313](https://github.com/PyTorchLightning/lightning-flash/pull/1313)) + ### Deprecated ### Removed diff --git a/docs/source/api/data.rst b/docs/source/api/data.rst index f65fc75f4f..397c9f842f 100644 --- a/docs/source/api/data.rst +++ b/docs/source/api/data.rst @@ -119,13 +119,6 @@ __________________________ :template: classtemplate.rst ~flash.core.data.transforms.ApplyToKeys - ~flash.core.data.transforms.KorniaParallelTransforms - -.. autosummary:: - :toctree: generated/ - :nosignatures: - - ~flash.core.data.transforms.kornia_collate flash.core.data.utils _____________________ diff --git a/flash/core/data/transforms.py b/flash/core/data/transforms.py index 459960b471..2f77731cbf 100644 --- a/flash/core/data/transforms.py +++ b/flash/core/data/transforms.py @@ -11,13 +11,51 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Any, Dict, Mapping, Sequence, Union +from typing import Any, Mapping, Sequence, Union -import torch +import numpy as np from torch import nn -from flash.core.data.utilities.collate import default_collate +from flash.core.data.io.input import DataKeys from flash.core.data.utils import convert_to_modules +from flash.core.utilities.imports import _ALBUMENTATIONS_AVAILABLE, requires + +if _ALBUMENTATIONS_AVAILABLE: + from albumentations import BasicTransform, Compose +else: + BasicTransform, Compose = object, object + + +class AlbumentationsAdapter(nn.Module): + # mapping from albumentations to Flash + TRANSFORM_INPUT_MAPPING = {"image": DataKeys.INPUT, "mask": DataKeys.TARGET} + + @requires("albumentations") + def __init__( + self, + transform: Union[BasicTransform, Sequence[BasicTransform]], + mapping: dict = None, + ): + super().__init__() + if not isinstance(transform, (list, tuple)): + transform = [transform] + self.transform = Compose(list(transform)) + if not mapping: + mapping = self.TRANSFORM_INPUT_MAPPING + self._mapping_rev = mapping + self._mapping = {v: k for k, v in mapping.items()} + + def forward(self, x: Any) -> Any: + if isinstance(x, dict): + x_ = {self._mapping.get(key, key): np.array(value) for key, value in x.items() if key in self._mapping} + else: + x_ = {"image": x} + x_ = self.transform(**x_) + if isinstance(x, dict): + x.update({self._mapping_rev.get(k, k): x_[k] for k in self._mapping_rev if k in x_}) + else: + x = x_["image"] + return x class ApplyToKeys(nn.Sequential): @@ -49,9 +87,7 @@ def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]: try: outputs = super().forward(inputs) except TypeError as e: - raise Exception( - "Failed to apply transforms to multiple keys at the same time, try using KorniaParallelTransforms." - ) from e + raise Exception("Failed to apply transforms to multiple keys at the same time.") from e for i, key in enumerate(keys): result[key] = outputs[i] @@ -66,54 +102,3 @@ def __repr__(self): transform = transform[0] if len(transform) == 1 else transform return f"{self.__class__.__name__}(keys={repr(keys)}, transform={repr(transform)})" - - -class KorniaParallelTransforms(nn.Sequential): - """The ``KorniaParallelTransforms`` class is an ``nn.Sequential`` which will apply the given transforms to each - input (to ``.forward``) in parallel, whilst sharing the random state (``._params``). This should be used when - multiple elements need to be augmented in the same way (e.g. an image and corresponding segmentation mask). - - Args: - args: The transforms, passed to the ``nn.Sequential`` super constructor. - """ - - def __init__(self, *args): - super().__init__(*(convert_to_modules(arg) for arg in args)) - - def forward(self, inputs: Any): - result = list(inputs) if isinstance(inputs, Sequence) else [inputs] - for transform in self.children(): - inputs = result - - # we enforce the first time to sample random params - result[0] = transform(inputs[0]) - - if hasattr(transform, "_params") and bool(transform._params): - params = transform._params - else: - params = None - - # apply transforms from (1, n) - for i, input in enumerate(inputs[1:]): - if params is not None: - result[i + 1] = transform(input, params) - else: # case for non-random transforms - result[i + 1] = transform(input) - if hasattr(transform, "_params") and bool(transform._params): - transform._params = None - return result - - -def kornia_collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Any]: - """Kornia transforms add batch dimension which need to be removed. - - This function removes that dimension and then - applies ``torch.utils.data._utils.collate.default_collate``. - """ - if len(samples) == 1 and isinstance(samples[0], list): - samples = samples[0] - for sample in samples: - for key in sample.keys(): - if torch.is_tensor(sample[key]) and sample[key].ndim == 4: - sample[key] = sample[key].squeeze(0) - return default_collate(samples) diff --git a/flash/core/data/utilities/collate.py b/flash/core/data/utilities/collate.py index a5bed216b3..02c1075167 100644 --- a/flash/core/data/utilities/collate.py +++ b/flash/core/data/utilities/collate.py @@ -20,6 +20,10 @@ def _wrap_collate(collate: Callable, batch: List[Any]) -> Any: + # Needed for learn2learn integration + if len(batch) == 1 and isinstance(batch[0], list): + batch = batch[0] + metadata = [sample.pop(DataKeys.METADATA, None) if isinstance(sample, Mapping) else None for sample in batch] metadata = metadata if any(m is not None for m in metadata) else None diff --git a/flash/core/utilities/imports.py b/flash/core/utilities/imports.py index bc22549655..adbc9e6264 100644 --- a/flash/core/utilities/imports.py +++ b/flash/core/utilities/imports.py @@ -156,7 +156,7 @@ class Image: _TORCHVISION_AVAILABLE, _TIMM_AVAILABLE, _PIL_AVAILABLE, - _KORNIA_AVAILABLE, + _ALBUMENTATIONS_AVAILABLE, _PYSTICHE_AVAILABLE, _SEGMENTATION_MODELS_AVAILABLE, ] diff --git a/flash/image/classification/input_transform.py b/flash/image/classification/input_transform.py index 179d60e9cc..050944b64b 100644 --- a/flash/image/classification/input_transform.py +++ b/flash/image/classification/input_transform.py @@ -12,14 +12,14 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Callable, Tuple, Union +from typing import Tuple, Union import torch from torch import nn from flash.core.data.io.input import DataKeys from flash.core.data.io.input_transform import InputTransform -from flash.core.data.transforms import ApplyToKeys, kornia_collate +from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _ALBUMENTATIONS_AVAILABLE, _TORCHVISION_AVAILABLE, requires if _TORCHVISION_AVAILABLE: @@ -76,7 +76,3 @@ def train_per_sample_transform(self): ApplyToKeys(DataKeys.TARGET, torch.as_tensor), ] ) - - def collate(self) -> Callable: - # TODO: Remove kornia collate for default_collate - return kornia_collate diff --git a/flash/image/classification/integrations/learn2learn.py b/flash/image/classification/integrations/learn2learn.py index 255da82506..bdeb9f3096 100644 --- a/flash/image/classification/integrations/learn2learn.py +++ b/flash/image/classification/integrations/learn2learn.py @@ -21,9 +21,9 @@ import pytorch_lightning as pl from torch.utils.data import IterableDataset -from torch.utils.data._utils.collate import default_collate from torch.utils.data._utils.worker import get_worker_info +from flash.core.data.utilities.collate import default_collate from flash.core.utilities.imports import requires @@ -109,7 +109,6 @@ def __init__( self.epoch_length = epoch_length self.seed = seed self.iteration = 0 - self.iteration = 0 self.requires_divisible = requires_divisible self.counter = 0 diff --git a/flash/image/instance_segmentation/data.py b/flash/image/instance_segmentation/data.py index 23674436fe..0c7c80ae4a 100644 --- a/flash/image/instance_segmentation/data.py +++ b/flash/image/instance_segmentation/data.py @@ -22,7 +22,12 @@ from flash.core.data.utilities.sort import sorted_alphanumeric from flash.core.integrations.icevision.data import IceVisionInput from flash.core.integrations.icevision.transforms import IceVisionInputTransform -from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_EXTRAS_TESTING, _KORNIA_AVAILABLE +from flash.core.utilities.imports import ( + _ICEVISION_AVAILABLE, + _IMAGE_EXTRAS_TESTING, + _TORCHVISION_AVAILABLE, + _TORCHVISION_GREATER_EQUAL_0_9, +) from flash.core.utilities.stages import RunningStage from flash.core.utilities.types import INPUT_TRANSFORM_TYPE @@ -34,8 +39,15 @@ VOCMaskParser = object Parser = object -if _KORNIA_AVAILABLE: - import kornia as K +if _TORCHVISION_AVAILABLE: + from torchvision import transforms as T + + if _TORCHVISION_GREATER_EQUAL_0_9: + from torchvision.transforms import InterpolationMode + else: + + class InterpolationMode: + NEAREST = "nearest" # Skip doctests if requirements aren't available @@ -45,7 +57,7 @@ class InstanceSegmentationOutputTransform(OutputTransform): def per_sample_transform(self, sample: Any) -> Any: - resize = K.geometry.Resize(sample[DataKeys.METADATA]["size"], interpolation="nearest") + resize = T.Resize(sample[DataKeys.METADATA]["size"], interpolation=InterpolationMode.NEAREST) sample[DataKeys.PREDS]["masks"] = [resize(tensor(mask)) for mask in sample[DataKeys.PREDS]["masks"]] return sample[DataKeys.PREDS] diff --git a/flash/image/segmentation/input.py b/flash/image/segmentation/input.py index 84fedff495..7d8455ae63 100644 --- a/flash/image/segmentation/input.py +++ b/flash/image/segmentation/input.py @@ -94,7 +94,7 @@ def load_data( def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]: if DataKeys.TARGET in sample: - sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET]))[:, :, 0] + sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET])).transpose((2, 0, 1))[:, :, 0] return super().load_sample(sample) diff --git a/flash/image/segmentation/input_transform.py b/flash/image/segmentation/input_transform.py index 17cfa186d9..0cb86db53c 100644 --- a/flash/image/segmentation/input_transform.py +++ b/flash/image/segmentation/input_transform.py @@ -12,17 +12,17 @@ # See the License for the specific language governing permissions and # limitations under the License. from dataclasses import dataclass -from typing import Any, Callable, Dict, Tuple, Union - -import torch +from typing import Any, Callable, Dict, Tuple from flash.core.data.io.input import DataKeys from flash.core.data.io.input_transform import InputTransform -from flash.core.data.transforms import ApplyToKeys, kornia_collate, KorniaParallelTransforms -from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TORCHVISION_AVAILABLE, requires +from flash.core.data.transforms import AlbumentationsAdapter, ApplyToKeys +from flash.core.utilities.imports import _ALBUMENTATIONS_AVAILABLE, _TORCHVISION_AVAILABLE, requires -if _KORNIA_AVAILABLE: - import kornia as K +if _ALBUMENTATIONS_AVAILABLE: + import albumentations as alb +else: + alb = None if _TORCHVISION_AVAILABLE: from torchvision import transforms as T @@ -31,16 +31,16 @@ def prepare_target(batch: Dict[str, Any]) -> Dict[str, Any]: """Convert the target mask to long and remove the channel dimension.""" if DataKeys.TARGET in batch: - batch[DataKeys.TARGET] = batch[DataKeys.TARGET].long().squeeze(1) + batch[DataKeys.TARGET] = batch[DataKeys.TARGET].squeeze().long() return batch -def target_as_tensor(sample: Dict[str, Any]) -> Dict[str, Any]: +def permute_target(sample: Dict[str, Any]) -> Dict[str, Any]: if DataKeys.TARGET in sample: target = sample[DataKeys.TARGET] if target.ndim == 2: - target = target[:, :, None] - sample[DataKeys.TARGET] = torch.from_numpy(target.transpose((2, 0, 1))).contiguous().squeeze().float() + target = target[None, :, :] + sample[DataKeys.TARGET] = target.transpose((1, 2, 0)) return sample @@ -53,28 +53,28 @@ def remove_extra_dimensions(batch: Dict[str, Any]): @dataclass class SemanticSegmentationInputTransform(InputTransform): + # https://albumentations.ai/docs/examples/pytorch_semantic_segmentation image_size: Tuple[int, int] = (128, 128) - mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406) - std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225) + mean: Tuple[float, float, float] = (0.485, 0.456, 0.406) + std: Tuple[float, float, float] = (0.229, 0.224, 0.225) @requires("image") def train_per_sample_transform(self) -> Callable: return T.Compose( [ + permute_target, + AlbumentationsAdapter( + [ + alb.Resize(*self.image_size), + alb.HorizontalFlip(p=0.5), + alb.Normalize(mean=self.mean, std=self.std), + ] + ), ApplyToKeys( DataKeys.INPUT, T.ToTensor(), ), - target_as_tensor, - ApplyToKeys( - [DataKeys.INPUT, DataKeys.TARGET], - KorniaParallelTransforms( - K.geometry.Resize(self.image_size, interpolation="nearest"), - K.augmentation.RandomHorizontalFlip(p=0.5), - ), - ), - ApplyToKeys([DataKeys.INPUT], K.augmentation.Normalize(mean=self.mean, std=self.std)), ] ) @@ -82,33 +82,19 @@ def train_per_sample_transform(self) -> Callable: def per_sample_transform(self) -> Callable: return T.Compose( [ + permute_target, + AlbumentationsAdapter( + [ + alb.Resize(*self.image_size), + alb.Normalize(mean=self.mean, std=self.std), + ] + ), ApplyToKeys( DataKeys.INPUT, T.ToTensor(), ), - target_as_tensor, - ApplyToKeys( - [DataKeys.INPUT, DataKeys.TARGET], - KorniaParallelTransforms(K.geometry.Resize(self.image_size, interpolation="nearest")), - ), - ApplyToKeys([DataKeys.INPUT], K.augmentation.Normalize(mean=self.mean, std=self.std)), ] ) - @requires("image") - def predict_per_sample_transform(self) -> Callable: - return ApplyToKeys( - DataKeys.INPUT, - T.ToTensor(), - K.geometry.Resize( - self.image_size, - interpolation="nearest", - ), - K.augmentation.Normalize(mean=self.mean, std=self.std), - ) - - def collate(self) -> Callable: - return kornia_collate - def per_batch_transform(self) -> Callable: return T.Compose([prepare_target, remove_extra_dimensions]) diff --git a/flash/image/segmentation/model.py b/flash/image/segmentation/model.py index c738a96cfe..ccd1ce3b16 100644 --- a/flash/image/segmentation/model.py +++ b/flash/image/segmentation/model.py @@ -23,7 +23,12 @@ from flash.core.model import Task from flash.core.registry import FlashRegistry from flash.core.serve import Composition -from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TM_GREATER_EQUAL_0_7_0, requires +from flash.core.utilities.imports import ( + _TM_GREATER_EQUAL_0_7_0, + _TORCHVISION_AVAILABLE, + _TORCHVISION_GREATER_EQUAL_0_9, + requires, +) from flash.core.utilities.isinstance import _isinstance from flash.core.utilities.types import ( INPUT_TRANSFORM_TYPE, @@ -39,8 +44,16 @@ from flash.image.segmentation.input_transform import SemanticSegmentationInputTransform from flash.image.segmentation.output import SEMANTIC_SEGMENTATION_OUTPUTS -if _KORNIA_AVAILABLE: - import kornia as K +if _TORCHVISION_AVAILABLE: + from torchvision import transforms as T + + if _TORCHVISION_GREATER_EQUAL_0_9: + from torchvision.transforms import InterpolationMode + else: + + class InterpolationMode: + NEAREST = "nearest" + if _TM_GREATER_EQUAL_0_7_0: from torchmetrics import JaccardIndex @@ -50,7 +63,7 @@ class SemanticSegmentationOutputTransform(OutputTransform): def per_sample_transform(self, sample: Any) -> Any: - resize = K.geometry.Resize(sample[DataKeys.METADATA]["size"], interpolation="bilinear") + resize = T.Resize(sample[DataKeys.METADATA]["size"], interpolation=InterpolationMode.NEAREST) sample[DataKeys.PREDS] = resize(sample[DataKeys.PREDS]) sample[DataKeys.INPUT] = resize(sample[DataKeys.INPUT]) return super().per_sample_transform(sample) diff --git a/flash/image/segmentation/output.py b/flash/image/segmentation/output.py index 4239543714..cecd2c796c 100644 --- a/flash/image/segmentation/output.py +++ b/flash/image/segmentation/output.py @@ -24,8 +24,8 @@ from flash.core.registry import FlashRegistry from flash.core.utilities.imports import ( _FIFTYONE_AVAILABLE, - _KORNIA_AVAILABLE, _MATPLOTLIB_AVAILABLE, + _TORCHVISION_AVAILABLE, lazy_import, requires, ) @@ -43,10 +43,10 @@ else: plt = None -if _KORNIA_AVAILABLE: - import kornia as K +if _TORCHVISION_AVAILABLE: + from torchvision import transforms as T else: - K = None + T = None SEMANTIC_SEGMENTATION_OUTPUTS = FlashRegistry("outputs") @@ -91,7 +91,7 @@ def create_random_labels_map(num_classes: int) -> Dict[int, Tuple[int, int, int] @requires("matplotlib") def _visualize(self, labels): labels_vis = self.labels_to_image(labels, self.labels_map) - labels_vis = K.utils.tensor_to_image(labels_vis) + labels_vis = T.ToPILImage(labels_vis) plt.imshow(labels_vis) plt.show() diff --git a/flash/image/segmentation/viz.py b/flash/image/segmentation/viz.py index f72e9eb486..bf51a4ead1 100644 --- a/flash/image/segmentation/viz.py +++ b/flash/image/segmentation/viz.py @@ -43,7 +43,9 @@ def __init__(self, labels_map: Dict[int, Tuple[int, int, int]]): @requires("image") def _to_numpy(img: Union[Tensor, Image.Image]) -> np.ndarray: out: np.ndarray - if isinstance(img, Image.Image): + if isinstance(img, np.ndarray): + out = img + elif isinstance(img, Image.Image): out = np.array(img) elif isinstance(img, Tensor): out = img.squeeze(0).permute(1, 2, 0).cpu().numpy() diff --git a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py index b506dff891..f2e132a572 100644 --- a/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py +++ b/flash_examples/integrations/learn2learn/image_classification_imagenette_mini.py @@ -37,7 +37,7 @@ import flash from flash.core.data.io.input import DataKeys from flash.core.data.io.input_transform import InputTransform -from flash.core.data.transforms import ApplyToKeys, kornia_collate +from flash.core.data.transforms import ApplyToKeys from flash.image import ImageClassificationData, ImageClassifier warnings.simplefilter("ignore") @@ -109,9 +109,6 @@ def per_batch_transform_on_device(self): Ka.RandomHorizontalFlip(p=0.25), ) - def collate(self): - return kornia_collate - # construct datamodule diff --git a/requirements/datatype_image.txt b/requirements/datatype_image.txt index 4588c1b327..8e083abf28 100644 --- a/requirements/datatype_image.txt +++ b/requirements/datatype_image.txt @@ -2,6 +2,6 @@ torchvision timm>=0.4.5 lightning-bolts>=0.3.3 Pillow>=7.2 -kornia>=0.5.1 +albumentations>=1.0 pystiche==1.* segmentation-models-pytorch>=0.2.0 diff --git a/requirements/datatype_image_extras.txt b/requirements/datatype_image_extras.txt index 34fed43bee..71ecbd4375 100644 --- a/requirements/datatype_image_extras.txt +++ b/requirements/datatype_image_extras.txt @@ -5,7 +5,7 @@ vissl>=0.1.5 icevision>=0.8 icedata effdet -albumentations +kornia>=0.5.1 learn2learn fastface fairscale diff --git a/tests/core/data/test_transforms.py b/tests/core/data/test_transforms.py index b55bfdee3a..e99025f615 100644 --- a/tests/core/data/test_transforms.py +++ b/tests/core/data/test_transforms.py @@ -15,10 +15,9 @@ import pytest import torch -from torch import nn from flash.core.data.io.input import DataKeys -from flash.core.data.transforms import ApplyToKeys, kornia_collate, KorniaParallelTransforms +from flash.core.data.transforms import ApplyToKeys from flash.core.utilities.imports import _CORE_TESTING @@ -69,46 +68,3 @@ def test_forward(self, sample, keys, expected): ) def test_repr(self, transform, expected): assert repr(transform) == expected - - -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") -@pytest.mark.parametrize("with_params", [True, False]) -def test_kornia_parallel_transforms(with_params): - samples = [torch.rand(1, 3, 10, 10), torch.rand(1, 3, 10, 10)] - transformed_sample = torch.rand(1, 3, 10, 10) - - transform_a = Mock(spec=nn.Module, return_value=transformed_sample) - transform_b = Mock(spec=nn.Module) - - if with_params: - transform_a._params = "test" # initialize params with some value - - parallel_transforms = KorniaParallelTransforms(transform_a, transform_b) - parallel_transforms(samples) - - assert transform_a.call_count == 2 - assert transform_b.call_count == 2 - - if with_params: - assert transform_a.call_args_list[1][0][1] == "test" - # check that after the forward `_params` is set to None - assert transform_a._params == transform_a._params is None - - assert torch.allclose(transform_a.call_args_list[0][0][0], samples[0]) - assert torch.allclose(transform_a.call_args_list[1][0][0], samples[1]) - assert torch.allclose(transform_b.call_args_list[0][0][0], transformed_sample) - assert torch.allclose(transform_b.call_args_list[1][0][0], transformed_sample) - - -@pytest.mark.skipif(not _CORE_TESTING, reason="Not testing core.") -def test_kornia_collate(): - samples = [ - {DataKeys.INPUT: torch.zeros(1, 3, 10, 10), DataKeys.TARGET: 1}, - {DataKeys.INPUT: torch.zeros(1, 3, 10, 10), DataKeys.TARGET: 2}, - {DataKeys.INPUT: torch.zeros(1, 3, 10, 10), DataKeys.TARGET: 3}, - ] - - result = kornia_collate(samples) - assert torch.all(result[DataKeys.TARGET] == torch.tensor([1, 2, 3])) - assert list(result[DataKeys.INPUT].shape) == [3, 3, 10, 10] - assert torch.allclose(result[DataKeys.INPUT], torch.zeros(1)) diff --git a/tests/image/instance_segmentation/test_data.py b/tests/image/instance_segmentation/test_data.py index 81dae6ff7e..f6491ec1aa 100644 --- a/tests/image/instance_segmentation/test_data.py +++ b/tests/image/instance_segmentation/test_data.py @@ -58,8 +58,8 @@ def test_instance_segmentation_output_transform(): ], "labels": [0, 1], "masks": [ - np.random.randint(2, size=(128, 128), dtype=np.uint8), - np.random.randint(2, size=(128, 128), dtype=np.uint8), + np.random.randint(2, size=(1, 128, 128), dtype=np.uint8), + np.random.randint(2, size=(1, 128, 128), dtype=np.uint8), ], "scores": [0.5, 0.5], }, @@ -69,5 +69,5 @@ def test_instance_segmentation_output_transform(): output_transform_cls = InstanceSegmentationOutputTransform() data = output_transform_cls.per_sample_transform(sample) - assert data["masks"][0].size() == (224, 224) - assert data["masks"][1].size() == (224, 224) + assert data["masks"][0].size() == (1, 224, 224) + assert data["masks"][1].size() == (1, 224, 224)