From eaac60c3c4e6fe401c90efea6e6bbfe7a0dc9831 Mon Sep 17 00:00:00 2001 From: Weilin Xu Date: Sun, 4 Feb 2024 08:47:10 -0800 Subject: [PATCH] Visualize intermediate images of Composer (#244) * Create a folder for attack.composer. * Add composer modules for unbounded patch adversary. * Add config of Adam optimizer. * Add LoadCoords for patch adversary. * Add a config of unbounded patch adversary. * Add a datamodule config for carla patch adversary. * Fix the simple Linf projection. * Add composer module PertImageBase for Lp bounded patch adversary. * Add config of lp-bounded patch adversary. * Add a fake renderer composer module. * Teardown a test dataset gracefully for the rendering-in-loop adversary. * Add configs of simulation-in-loop adversary. * Add a datamodule config for CARLA patch rendering. * Update CarlaDataset config. * Add a composer.visualize switch to see intermediate images. * Revert "Teardown a test dataset gracefully for the rendering-in-loop adversary." This reverts commit a5ffef3b3bb812f25e2f44e224c051d8dcb1b617. * Revert "Add a composer.visualize switch to see intermediate images." This reverts commit a17e224d425a637d233ca7415d986b8d8ca8b2d7. * Add a composer.visualize switch to see intermediate images. --- mart/attack/composer/modular.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/mart/attack/composer/modular.py b/mart/attack/composer/modular.py index d070436b..c614a3c5 100644 --- a/mart/attack/composer/modular.py +++ b/mart/attack/composer/modular.py @@ -9,6 +9,7 @@ from typing import TYPE_CHECKING, Any, Iterable import torch +from torchvision.transforms.functional import to_pil_image from mart.nn import SequentialDict @@ -19,7 +20,7 @@ class Composer(torch.nn.Module): - def __init__(self, perturber: Perturber, modules, sequence) -> None: + def __init__(self, perturber: Perturber, modules, sequence, visualize: bool = False) -> None: """_summary_ Args: @@ -34,6 +35,7 @@ def __init__(self, perturber: Perturber, modules, sequence) -> None: if isinstance(sequence, dict): sequence = [sequence[key] for key in sorted(sequence)] self.functions = SequentialDict(modules, {"composer": sequence}) + self.visualize = visualize def configure_perturbation(self, input: torch.Tensor | Iterable[torch.Tensor]): return self.perturber.configure_perturbation(input) @@ -76,6 +78,12 @@ def _compose( input=input, target=target, perturbation=perturbation, step="composer" ) + # Visualize intermediate images. + if self.visualize: + for key, value in output.items(): + if isinstance(value, torch.Tensor): + to_pil_image(value / 255).save(f"{key}.png") + # SequentialDict returns a dictionary DotDict, # but we only need the return value of the most recently executed module. last_added_key = next(reversed(output))