diff --git a/mart/attack/composer/modular.py b/mart/attack/composer/modular.py index d070436b..c614a3c5 100644 --- a/mart/attack/composer/modular.py +++ b/mart/attack/composer/modular.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Iterable import torch +from torchvision.transforms.functional import to_pil_image from mart.nn import SequentialDict @@ -19,7 +20,7 @@ class Composer(torch.nn.Module): - def __init__(self, perturber: Perturber, modules, sequence) -> None: + def __init__(self, perturber: Perturber, modules, sequence, visualize: bool = False) -> None: """_summary_ Args: @@ -34,6 +35,7 @@ def __init__(self, perturber: Perturber, modules, sequence) -> None: if isinstance(sequence, dict): sequence = [sequence[key] for key in sorted(sequence)] self.functions = SequentialDict(modules, {"composer": sequence}) + self.visualize = visualize def configure_perturbation(self, input: torch.Tensor | Iterable[torch.Tensor]): return self.perturber.configure_perturbation(input) @@ -76,6 +78,12 @@ def _compose( input=input, target=target, perturbation=perturbation, step="composer" ) + # Visualize intermediate images. + if self.visualize: + for key, value in output.items(): + if isinstance(value, torch.Tensor): + to_pil_image(value / 255).save(f"{key}.png") + # SequentialDict returns a dictionary DotDict, # but we only need the return value of the most recently executed module. last_added_key = next(reversed(output))