Skip to content
This repository has been archived by the owner on Oct 9, 2023. It is now read-only.

Commit

Permalink
refactoring Img Segm augmentation with albumentations (#1313)
Browse files Browse the repository at this point in the history
Co-authored-by: Kushashwa Ravi Shrimali <kushashwaravishrimali@gmail.com>
Co-authored-by: Ethan Harris <ethanwharris@gmail.com>
  • Loading branch information
3 people authored Sep 5, 2022
1 parent c05a3ea commit 6742323
Show file tree
Hide file tree
Showing 18 changed files with 131 additions and 186 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Changed the `ImageEmbedder` dependency on VISSL to optional ([#1276](https://github.com/PyTorchLightning/lightning-flash/pull/1276))

- Changed the transforms in `SemanticSegmentationData` to use albumentations instead of Kornia ([#1313](https://github.com/PyTorchLightning/lightning-flash/pull/1313))

### Deprecated

### Removed
Expand Down
7 changes: 0 additions & 7 deletions docs/source/api/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,6 @@ __________________________
:template: classtemplate.rst

~flash.core.data.transforms.ApplyToKeys
~flash.core.data.transforms.KorniaParallelTransforms

.. autosummary::
:toctree: generated/
:nosignatures:

~flash.core.data.transforms.kornia_collate

flash.core.data.utils
_____________________
Expand Down
99 changes: 42 additions & 57 deletions flash/core/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,51 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Mapping, Sequence, Union
from typing import Any, Mapping, Sequence, Union

import torch
import numpy as np
from torch import nn

from flash.core.data.utilities.collate import default_collate
from flash.core.data.io.input import DataKeys
from flash.core.data.utils import convert_to_modules
from flash.core.utilities.imports import _ALBUMENTATIONS_AVAILABLE, requires

if _ALBUMENTATIONS_AVAILABLE:
from albumentations import BasicTransform, Compose
else:
BasicTransform, Compose = object, object


class AlbumentationsAdapter(nn.Module):
# mapping from albumentations to Flash
TRANSFORM_INPUT_MAPPING = {"image": DataKeys.INPUT, "mask": DataKeys.TARGET}

@requires("albumentations")
def __init__(
self,
transform: Union[BasicTransform, Sequence[BasicTransform]],
mapping: dict = None,
):
super().__init__()
if not isinstance(transform, (list, tuple)):
transform = [transform]
self.transform = Compose(list(transform))
if not mapping:
mapping = self.TRANSFORM_INPUT_MAPPING
self._mapping_rev = mapping
self._mapping = {v: k for k, v in mapping.items()}

def forward(self, x: Any) -> Any:
if isinstance(x, dict):
x_ = {self._mapping.get(key, key): np.array(value) for key, value in x.items() if key in self._mapping}
else:
x_ = {"image": x}
x_ = self.transform(**x_)
if isinstance(x, dict):
x.update({self._mapping_rev.get(k, k): x_[k] for k in self._mapping_rev if k in x_})
else:
x = x_["image"]
return x


class ApplyToKeys(nn.Sequential):
Expand Down Expand Up @@ -49,9 +87,7 @@ def forward(self, x: Mapping[str, Any]) -> Mapping[str, Any]:
try:
outputs = super().forward(inputs)
except TypeError as e:
raise Exception(
"Failed to apply transforms to multiple keys at the same time, try using KorniaParallelTransforms."
) from e
raise Exception("Failed to apply transforms to multiple keys at the same time.") from e

for i, key in enumerate(keys):
result[key] = outputs[i]
Expand All @@ -66,54 +102,3 @@ def __repr__(self):
transform = transform[0] if len(transform) == 1 else transform

return f"{self.__class__.__name__}(keys={repr(keys)}, transform={repr(transform)})"


class KorniaParallelTransforms(nn.Sequential):
"""The ``KorniaParallelTransforms`` class is an ``nn.Sequential`` which will apply the given transforms to each
input (to ``.forward``) in parallel, whilst sharing the random state (``._params``). This should be used when
multiple elements need to be augmented in the same way (e.g. an image and corresponding segmentation mask).
Args:
args: The transforms, passed to the ``nn.Sequential`` super constructor.
"""

def __init__(self, *args):
super().__init__(*(convert_to_modules(arg) for arg in args))

def forward(self, inputs: Any):
result = list(inputs) if isinstance(inputs, Sequence) else [inputs]
for transform in self.children():
inputs = result

# we enforce the first time to sample random params
result[0] = transform(inputs[0])

if hasattr(transform, "_params") and bool(transform._params):
params = transform._params
else:
params = None

# apply transforms from (1, n)
for i, input in enumerate(inputs[1:]):
if params is not None:
result[i + 1] = transform(input, params)
else: # case for non-random transforms
result[i + 1] = transform(input)
if hasattr(transform, "_params") and bool(transform._params):
transform._params = None
return result


def kornia_collate(samples: Sequence[Dict[str, Any]]) -> Dict[str, Any]:
"""Kornia transforms add batch dimension which need to be removed.
This function removes that dimension and then
applies ``torch.utils.data._utils.collate.default_collate``.
"""
if len(samples) == 1 and isinstance(samples[0], list):
samples = samples[0]
for sample in samples:
for key in sample.keys():
if torch.is_tensor(sample[key]) and sample[key].ndim == 4:
sample[key] = sample[key].squeeze(0)
return default_collate(samples)
4 changes: 4 additions & 0 deletions flash/core/data/utilities/collate.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@


def _wrap_collate(collate: Callable, batch: List[Any]) -> Any:
# Needed for learn2learn integration
if len(batch) == 1 and isinstance(batch[0], list):
batch = batch[0]

metadata = [sample.pop(DataKeys.METADATA, None) if isinstance(sample, Mapping) else None for sample in batch]
metadata = metadata if any(m is not None for m in metadata) else None

Expand Down
2 changes: 1 addition & 1 deletion flash/core/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ class Image:
_TORCHVISION_AVAILABLE,
_TIMM_AVAILABLE,
_PIL_AVAILABLE,
_KORNIA_AVAILABLE,
_ALBUMENTATIONS_AVAILABLE,
_PYSTICHE_AVAILABLE,
_SEGMENTATION_MODELS_AVAILABLE,
]
Expand Down
8 changes: 2 additions & 6 deletions flash/image/classification/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Callable, Tuple, Union
from typing import Tuple, Union

import torch
from torch import nn

from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys, kornia_collate
from flash.core.data.transforms import ApplyToKeys
from flash.core.utilities.imports import _ALBUMENTATIONS_AVAILABLE, _TORCHVISION_AVAILABLE, requires

if _TORCHVISION_AVAILABLE:
Expand Down Expand Up @@ -76,7 +76,3 @@ def train_per_sample_transform(self):
ApplyToKeys(DataKeys.TARGET, torch.as_tensor),
]
)

def collate(self) -> Callable:
# TODO: Remove kornia collate for default_collate
return kornia_collate
3 changes: 1 addition & 2 deletions flash/image/classification/integrations/learn2learn.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@

import pytorch_lightning as pl
from torch.utils.data import IterableDataset
from torch.utils.data._utils.collate import default_collate
from torch.utils.data._utils.worker import get_worker_info

from flash.core.data.utilities.collate import default_collate
from flash.core.utilities.imports import requires


Expand Down Expand Up @@ -109,7 +109,6 @@ def __init__(
self.epoch_length = epoch_length
self.seed = seed
self.iteration = 0
self.iteration = 0
self.requires_divisible = requires_divisible
self.counter = 0

Expand Down
20 changes: 16 additions & 4 deletions flash/image/instance_segmentation/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,12 @@
from flash.core.data.utilities.sort import sorted_alphanumeric
from flash.core.integrations.icevision.data import IceVisionInput
from flash.core.integrations.icevision.transforms import IceVisionInputTransform
from flash.core.utilities.imports import _ICEVISION_AVAILABLE, _IMAGE_EXTRAS_TESTING, _KORNIA_AVAILABLE
from flash.core.utilities.imports import (
_ICEVISION_AVAILABLE,
_IMAGE_EXTRAS_TESTING,
_TORCHVISION_AVAILABLE,
_TORCHVISION_GREATER_EQUAL_0_9,
)
from flash.core.utilities.stages import RunningStage
from flash.core.utilities.types import INPUT_TRANSFORM_TYPE

Expand All @@ -34,8 +39,15 @@
VOCMaskParser = object
Parser = object

if _KORNIA_AVAILABLE:
import kornia as K
if _TORCHVISION_AVAILABLE:
from torchvision import transforms as T

if _TORCHVISION_GREATER_EQUAL_0_9:
from torchvision.transforms import InterpolationMode
else:

class InterpolationMode:
NEAREST = "nearest"


# Skip doctests if requirements aren't available
Expand All @@ -45,7 +57,7 @@

class InstanceSegmentationOutputTransform(OutputTransform):
def per_sample_transform(self, sample: Any) -> Any:
resize = K.geometry.Resize(sample[DataKeys.METADATA]["size"], interpolation="nearest")
resize = T.Resize(sample[DataKeys.METADATA]["size"], interpolation=InterpolationMode.NEAREST)
sample[DataKeys.PREDS]["masks"] = [resize(tensor(mask)) for mask in sample[DataKeys.PREDS]["masks"]]
return sample[DataKeys.PREDS]

Expand Down
2 changes: 1 addition & 1 deletion flash/image/segmentation/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def load_data(

def load_sample(self, sample: Dict[str, Any]) -> Dict[str, Any]:
if DataKeys.TARGET in sample:
sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET]))[:, :, 0]
sample[DataKeys.TARGET] = np.array(load_image(sample[DataKeys.TARGET])).transpose((2, 0, 1))[:, :, 0]
return super().load_sample(sample)


Expand Down
72 changes: 29 additions & 43 deletions flash/image/segmentation/input_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,17 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from typing import Any, Callable, Dict, Tuple, Union

import torch
from typing import Any, Callable, Dict, Tuple

from flash.core.data.io.input import DataKeys
from flash.core.data.io.input_transform import InputTransform
from flash.core.data.transforms import ApplyToKeys, kornia_collate, KorniaParallelTransforms
from flash.core.utilities.imports import _KORNIA_AVAILABLE, _TORCHVISION_AVAILABLE, requires
from flash.core.data.transforms import AlbumentationsAdapter, ApplyToKeys
from flash.core.utilities.imports import _ALBUMENTATIONS_AVAILABLE, _TORCHVISION_AVAILABLE, requires

if _KORNIA_AVAILABLE:
import kornia as K
if _ALBUMENTATIONS_AVAILABLE:
import albumentations as alb
else:
alb = None

if _TORCHVISION_AVAILABLE:
from torchvision import transforms as T
Expand All @@ -31,16 +31,16 @@
def prepare_target(batch: Dict[str, Any]) -> Dict[str, Any]:
"""Convert the target mask to long and remove the channel dimension."""
if DataKeys.TARGET in batch:
batch[DataKeys.TARGET] = batch[DataKeys.TARGET].long().squeeze(1)
batch[DataKeys.TARGET] = batch[DataKeys.TARGET].squeeze().long()
return batch


def target_as_tensor(sample: Dict[str, Any]) -> Dict[str, Any]:
def permute_target(sample: Dict[str, Any]) -> Dict[str, Any]:
if DataKeys.TARGET in sample:
target = sample[DataKeys.TARGET]
if target.ndim == 2:
target = target[:, :, None]
sample[DataKeys.TARGET] = torch.from_numpy(target.transpose((2, 0, 1))).contiguous().squeeze().float()
target = target[None, :, :]
sample[DataKeys.TARGET] = target.transpose((1, 2, 0))
return sample


Expand All @@ -53,62 +53,48 @@ def remove_extra_dimensions(batch: Dict[str, Any]):

@dataclass
class SemanticSegmentationInputTransform(InputTransform):
# https://albumentations.ai/docs/examples/pytorch_semantic_segmentation

image_size: Tuple[int, int] = (128, 128)
mean: Union[float, Tuple[float, float, float]] = (0.485, 0.456, 0.406)
std: Union[float, Tuple[float, float, float]] = (0.229, 0.224, 0.225)
mean: Tuple[float, float, float] = (0.485, 0.456, 0.406)
std: Tuple[float, float, float] = (0.229, 0.224, 0.225)

@requires("image")
def train_per_sample_transform(self) -> Callable:
return T.Compose(
[
permute_target,
AlbumentationsAdapter(
[
alb.Resize(*self.image_size),
alb.HorizontalFlip(p=0.5),
alb.Normalize(mean=self.mean, std=self.std),
]
),
ApplyToKeys(
DataKeys.INPUT,
T.ToTensor(),
),
target_as_tensor,
ApplyToKeys(
[DataKeys.INPUT, DataKeys.TARGET],
KorniaParallelTransforms(
K.geometry.Resize(self.image_size, interpolation="nearest"),
K.augmentation.RandomHorizontalFlip(p=0.5),
),
),
ApplyToKeys([DataKeys.INPUT], K.augmentation.Normalize(mean=self.mean, std=self.std)),
]
)

@requires("image")
def per_sample_transform(self) -> Callable:
return T.Compose(
[
permute_target,
AlbumentationsAdapter(
[
alb.Resize(*self.image_size),
alb.Normalize(mean=self.mean, std=self.std),
]
),
ApplyToKeys(
DataKeys.INPUT,
T.ToTensor(),
),
target_as_tensor,
ApplyToKeys(
[DataKeys.INPUT, DataKeys.TARGET],
KorniaParallelTransforms(K.geometry.Resize(self.image_size, interpolation="nearest")),
),
ApplyToKeys([DataKeys.INPUT], K.augmentation.Normalize(mean=self.mean, std=self.std)),
]
)

@requires("image")
def predict_per_sample_transform(self) -> Callable:
return ApplyToKeys(
DataKeys.INPUT,
T.ToTensor(),
K.geometry.Resize(
self.image_size,
interpolation="nearest",
),
K.augmentation.Normalize(mean=self.mean, std=self.std),
)

def collate(self) -> Callable:
return kornia_collate

def per_batch_transform(self) -> Callable:
return T.Compose([prepare_target, remove_extra_dimensions])
Loading

0 comments on commit 6742323

Please sign in to comment.