diff --git a/mart/attack/adversary.py b/mart/attack/adversary.py index 7f4cbefb..92de0198 100644 --- a/mart/attack/adversary.py +++ b/mart/attack/adversary.py @@ -149,7 +149,7 @@ def configure_gradient_clipping( if self.gradient_modifier: for group in optimizer.param_groups: - self.gradient_modifier(group["params"]) + self.gradient_modifier(group) @silent() def fit(self, input, target, *, model: Callable): diff --git a/mart/attack/composer.py b/mart/attack/composer.py index 6b40950a..afe36e3e 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -11,6 +11,8 @@ import torch +from ..utils.modality_dispatch import DEFAULT_MODALITY, modality_dispatch + class Composer(abc.ABC): def __call__( @@ -78,3 +80,34 @@ def compose(self, perturbation, *, input, target): masked_perturbation = perturbation * mask return input + masked_perturbation + + +class Modality(Composer): + def __init__(self, **modality_method): + self.modality_method = modality_method + + def __call__( + self, + perturbation: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], + *, + input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], + **kwargs, + ) -> torch.Tensor | Iterable[torch.Tensor]: + return modality_dispatch( + input, + data=perturbation, + target=target, + modality_func=self.compose, + modality=DEFAULT_MODALITY, + ) + + def compose( + self, + perturbation: torch.Tensor, + *, + input: torch.Tensor, + target: torch.Tensor | dict[str, Any], + modality: str, + ) -> torch.Tensor: + return self.modality_method[modality](perturbation, input=input, target=target) diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index c0f57ca6..7d4b27fd 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -11,6 +11,8 @@ import torch +from ..utils.modality_dispatch import modality_dispatch + __all__ = ["Enforcer"] @@ -96,36 +98,41 @@ def verify(self, input_adv, *, input, target): class Enforcer: - def __init__(self, constraints: dict[str, Constraint]) -> None: - self.constraints = list(constraints.values()) # intentionally ignore keys + def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None: + self.modality_constraints = {} + + for modality, constraints in modality_constraints.items(): + # Intentionally ignore keys after modality. + # The keys are there for combining constraints easily in Hydra. + self.modality_constraints[modality] = constraints.values() @torch.no_grad() def __call__( self, - input_adv: torch.Tensor | Iterable[torch.Tensor], + input_adv: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], *, - input: torch.Tensor | Iterable[torch.Tensor], + input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], **kwargs, ): - if isinstance(input_adv, torch.Tensor) and isinstance(input, torch.Tensor): - self.enforce(input_adv, input=input, target=target) - - elif ( - isinstance(input_adv, Iterable) - and isinstance(input, Iterable) # noqa: W503 - and isinstance(target, Iterable) # noqa: W503 - ): - for input_adv_i, input_i, target_i in zip(input_adv, input, target): - self.enforce(input_adv_i, input=input_i, target=target_i) + # The default modality is set to "constraints", so that it is backward compatible with existing configs. + modality_dispatch( + input, + data=input_adv, + target=target, + modality_func=self.enforce, + modality="constraints", + ) @torch.no_grad() def enforce( self, - input_adv: torch.Tensor, + input_adv: torch.Tensor | Iterable[torch.Tensor], *, - input: torch.Tensor, - target: torch.Tensor | dict[str, Any], + input: torch.Tensor | Iterable[torch.Tensor], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], + modality: str, ): - for constraint in self.constraints: + + for constraint in self.modality_constraints[modality]: constraint(input_adv, input=input, target=target) diff --git a/mart/attack/gradient_modifier.py b/mart/attack/gradient_modifier.py index b2882574..a5fd68c6 100644 --- a/mart/attack/gradient_modifier.py +++ b/mart/attack/gradient_modifier.py @@ -6,17 +6,20 @@ from __future__ import annotations -from typing import Iterable +from typing import Any import torch +from ..utils.modality_dispatch import DEFAULT_MODALITY + __all__ = ["GradientModifier"] class GradientModifier: """Gradient modifier base class.""" - def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: + def __call__(self, param_group: dict[str, Any]) -> None: + parameters = param_group["params"] if isinstance(parameters, torch.Tensor): parameters = [parameters] @@ -43,3 +46,15 @@ def __init__(self, p: int | float): def modify_(self, parameter: torch.Tensor) -> None: p_norm = torch.norm(parameter.grad.detach(), p=self.p) parameter.grad.detach().div_(p_norm) + + +class Modality(GradientModifier): + def __init__(self, **modality_method): + if len(modality_method) == 0: + modality_method = {DEFAULT_MODALITY: self.modify_} + + self.modality_method_ = modality_method + + def __call__(self, param_group: dict[str, Any]) -> None: + modality = param_group["modality"] if "modality" in param_group else DEFAULT_MODALITY + self.modality_method_[modality](param_group) diff --git a/mart/attack/initializer.py b/mart/attack/initializer.py index 9b38e6a1..e197000a 100644 --- a/mart/attack/initializer.py +++ b/mart/attack/initializer.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Iterable +from typing import Any import torch import torchvision @@ -21,11 +21,15 @@ class Initializer: """Initializer base class.""" @torch.no_grad() - def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None: - if isinstance(parameters, torch.Tensor): - parameters = [parameters] - - [self.initialize_(parameter) for parameter in parameters] + def __call__( + self, + parameter: torch.Tensor, + *, + input: torch.Tensor | None = None, + target: torch.Tensor | dict[str, Any] | None = None, + ) -> None: + # Accept input and target from modality_dispatch(). + self.initialize_(parameter) @torch.no_grad() def initialize_(self, parameter: torch.Tensor) -> None: diff --git a/mart/attack/perturber.py b/mart/attack/perturber.py index 75c565ad..f3ee09d7 100644 --- a/mart/attack/perturber.py +++ b/mart/attack/perturber.py @@ -11,6 +11,11 @@ import torch from lightning.pytorch.utilities.exceptions import MisconfigurationException +from ..utils.modality_dispatch import ( + DEFAULT_MODALITY, + ModalityParameterDict, + modality_dispatch, +) from .projector import Projector if TYPE_CHECKING: @@ -23,8 +28,8 @@ class Perturber(torch.nn.Module): def __init__( self, *, - initializer: Initializer, - projector: Projector | None = None, + initializer: Initializer | dict[str, Initializer], + projector: Projector | dict[str, Projector] | None = None, ): """_summary_ @@ -34,8 +39,17 @@ def __init__( """ super().__init__() + projector = projector or Projector() + + # Modality-specific objects. + # Backward compatibility, in case modality is unknown, and not given in input. + if not isinstance(initializer, dict): + initializer = {DEFAULT_MODALITY: initializer} + if not isinstance(projector, dict): + projector = {DEFAULT_MODALITY: projector} + self.initializer_ = initializer - self.projector_ = projector or Projector() + self.projector_ = projector self.perturbation = None @@ -65,6 +79,10 @@ def create_from_tensor(tensor): return torch.nn.Parameter( torch.empty_like(tensor, dtype=torch.float, requires_grad=True) ) + elif isinstance(tensor, dict): + return ModalityParameterDict( + {modality: create_from_tensor(t) for modality, t in tensor.items()} + ) elif isinstance(tensor, Iterable): return torch.nn.ParameterList([create_from_tensor(t) for t in tensor]) else: @@ -76,7 +94,13 @@ def create_from_tensor(tensor): self.perturbation = create_from_tensor(input) # Always (re)initialize perturbation. - self.initializer_(self.perturbation) + modality_dispatch( + input, + data=self.perturbation, + target=None, + modality_func=self.initializer_, + modality=DEFAULT_MODALITY, + ) def named_parameters(self, *args, **kwargs): if self.perturbation is None: @@ -90,12 +114,18 @@ def parameters(self, *args, **kwargs): return super().parameters(*args, **kwargs) - def forward(self, **batch): + def forward(self, *, input, target, **batch): if self.perturbation is None: raise MisconfigurationException( "You need to call the configure_perturbation before forward." ) - self.projector_(self.perturbation, **batch) + modality_dispatch( + input, + data=self.perturbation, + target=target, + modality_func=self.projector_, + modality=DEFAULT_MODALITY, + ) return self.perturbation diff --git a/mart/attack/projector.py b/mart/attack/projector.py index f9887354..58af6a7f 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -6,7 +6,7 @@ from __future__ import annotations -from typing import Any, Iterable +from typing import Any import torch @@ -17,33 +17,21 @@ class Projector: @torch.no_grad() def __call__( self, - perturbation: torch.Tensor | Iterable[torch.Tensor], + perturbation: torch.Tensor, *, - input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], + input: torch.Tensor, + target: torch.Tensor | dict[str, Any], **kwargs, ) -> None: - if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor): - self.project_(perturbation, input=input, target=target) - - elif ( - isinstance(perturbation, Iterable) - and isinstance(input, Iterable) # noqa: W503 - and isinstance(target, Iterable) # noqa: W503 - ): - for perturbation_i, input_i, target_i in zip(perturbation, input, target): - self.project_(perturbation_i, input=input_i, target=target_i) - - else: - raise NotImplementedError + self.project_(perturbation, input=input, target=target) @torch.no_grad() def project_( self, - perturbation: torch.Tensor | Iterable[torch.Tensor], + perturbation: torch.Tensor, *, - input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], + input: torch.Tensor, + target: torch.Tensor | dict[str, Any], ) -> None: pass @@ -57,10 +45,10 @@ def __init__(self, projectors: list[Projector]): @torch.no_grad() def __call__( self, - perturbation: torch.Tensor | Iterable[torch.Tensor], + perturbation: torch.Tensor, *, - input: torch.Tensor | Iterable[torch.Tensor], - target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], + input: torch.Tensor, + target: torch.Tensor | dict[str, Any], **kwargs, ) -> None: for projector in self.projectors: diff --git a/mart/configs/attack/composer/modality.yaml b/mart/configs/attack/composer/modality.yaml new file mode 100644 index 00000000..34955313 --- /dev/null +++ b/mart/configs/attack/composer/modality.yaml @@ -0,0 +1 @@ +_target_: mart.attack.composer.Modality diff --git a/mart/configs/attack/enforcer/default.yaml b/mart/configs/attack/enforcer/default.yaml index 46fc0bb1..a59d8d3e 100644 --- a/mart/configs/attack/enforcer/default.yaml +++ b/mart/configs/attack/enforcer/default.yaml @@ -1,2 +1,3 @@ _target_: mart.attack.Enforcer -constraints: ??? +# FIXME: Hydra does not detect modality-aware constraints defined as sub-components. +# constraints: ??? diff --git a/mart/configs/attack/gradient_modifier/modality.yaml b/mart/configs/attack/gradient_modifier/modality.yaml new file mode 100644 index 00000000..a5596dfd --- /dev/null +++ b/mart/configs/attack/gradient_modifier/modality.yaml @@ -0,0 +1 @@ +_target_: mart.attack.gradient_modifier.Modality diff --git a/mart/configs/attack/object_detection_rgb_mask_adversary.yaml b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml new file mode 100644 index 00000000..30634925 --- /dev/null +++ b/mart/configs/attack/object_detection_rgb_mask_adversary.yaml @@ -0,0 +1,28 @@ +defaults: + - adversary + - perturber: default + - perturber/initializer@perturber.initializer.rgb: constant + - perturber/projector@perturber.projector.rgb: mask_range + - /optimizer@optimizer: sgd + - gain: rcnn_training_loss + - objective: zero_ap + - gradient_modifier: modality + - gradient_modifier@gradient_modifier.rgb: sign + - composer: modality + - composer@composer.rgb: overlay + - enforcer: default + - enforcer/constraints@enforcer.rgb: [mask, pixel_range] + +# Make a 5-step attack for the demonstration purpose. +optimizer: + # Though we only use modality-aware hyper-params, the config requires a value for optimizer.lr. + lr: 0 + rgb: + lr: 55 + +max_iters: 5 + +perturber: + initializer: + rgb: + constant: 127 diff --git a/mart/optim/optimizer.py b/mart/optim/optimizer.py index 3cc57131..0fb3b694 100644 --- a/mart/optim/optimizer.py +++ b/mart/optim/optimizer.py @@ -5,6 +5,7 @@ # import logging +from collections import defaultdict logger = logging.getLogger(__name__) @@ -25,6 +26,9 @@ def __init__(self, optimizer, **kwargs): self.bias_decay = kwargs.pop("bias_decay", weight_decay) self.norm_decay = kwargs.pop("norm_decay", weight_decay) self.optimizer = optimizer + + # Separate modality-wise params from kwargs, because optimizers do not recognize them. + self.modality_wise_params = kwargs.pop("modality_wise", {}) self.kwargs = kwargs def __call__(self, module): @@ -32,6 +36,7 @@ def __call__(self, module): bias_params = [] norm_params = [] weight_params = [] + modality_params = defaultdict(list) for param_name, param in module.named_parameters(): if not param.requires_grad: @@ -42,7 +47,11 @@ def __call__(self, module): _, param_module = next(filter(lambda nm: nm[0] == module_name, module.named_modules())) module_kind = param_module.__class__.__name__ - if "Norm" in module_kind: + if module_kind == "ModalityParameterDict": + # Identify modality-aware parameters for adversary. + modality = param_name.split(".")[-1] + modality_params[modality].append(param) + elif "Norm" in module_kind: assert len(param.shape) == 1 norm_params.append(param) elif isinstance(param, torch.nn.UninitializedParameter): @@ -53,8 +62,20 @@ def __call__(self, module): else: # Assume weights weight_params.append(param) - # Set decay for bias and norm parameters params = [] + + # Set modality-aware weight params. + if len(modality_params) > 0: + for modality, param in modality_params.items(): + # Take notes of modality for gradient modifier later. + # Add modality-specific optim params. + if modality in self.modality_wise_params: + modality_params = self.modality_wise_params[modality] + else: + modality_params = {} + params.append({"params": param, "modality": modality} | modality_params) + + # Set decay for bias and norm parameters if len(weight_params) > 0: params.append({"params": weight_params}) # use default weight decay if len(bias_params) > 0: diff --git a/mart/utils/modality_dispatch.py b/mart/utils/modality_dispatch.py new file mode 100644 index 00000000..4f0757e8 --- /dev/null +++ b/mart/utils/modality_dispatch.py @@ -0,0 +1,106 @@ +# +# Copyright (C) 2022 Intel Corporation +# +# SPDX-License-Identifier: BSD-3-Clause +# + +from __future__ import annotations + +import functools +from itertools import cycle +from typing import Any, Callable, Iterable + +import torch + +DEFAULT_MODALITY = "default" + + +@functools.singledispatch +def modality_dispatch( + input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], + *, + data: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], + target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]] | None, + modality_func: Callable | dict[str, Callable], + modality: str = DEFAULT_MODALITY, +): + """Recursively dispatch data and input/target to functions of the same modality. + + The function returns an object that is homomorphic to input. We make input the first non- + keyword argument for singledispatch to work. + """ + + raise ValueError(f"Unsupported data type of input: type(input)={type(input)}.") + + +@modality_dispatch.register +def _(input: torch.Tensor, *, data, target, modality_func, modality=DEFAULT_MODALITY): + # Take action when input is a tensor. + if isinstance(modality_func, dict): + # A dictionary of Callable indexed by modality. + return modality_func[modality](data, input=input, target=target) + elif isinstance(modality_func, Callable): + # A Callable with modality=? as a keyword argument. + return modality_func(data, input=input, target=target, modality=modality) + + +@modality_dispatch.register +def _(input: dict, *, data, target, modality_func, modality=DEFAULT_MODALITY): + # The dict input has modalities specified in keys, passing them recursively. + output = {} + for modality in input.keys(): + output[modality] = modality_dispatch( + input[modality], + data=data[modality], + target=target, + modality_func=modality_func, + modality=modality, + ) + return output + + +@modality_dispatch.register +def _(input: list, *, data, target, modality_func, modality=DEFAULT_MODALITY): + # The list input implies a collection of sub-input and sub-target. + if not isinstance(target, Iterable): + # Make target zip well with input. + target = cycle([target]) + if not isinstance(data, Iterable): + # Make data zip well with input. + # Besides list and tuple, data could be ParameterList too. + # Data is shared for all input, e.g. universal perturbation. + data = cycle([data]) + + output = [] + for data_i, input_i, target_i in zip(data, input, target): + output_i = modality_dispatch( + input_i, + data=data_i, + target=target_i, + modality_func=modality_func, + modality=modality, + ) + output.append(output_i) + + return output + + +@modality_dispatch.register +def _(input: tuple, *, data, target, modality_func, modality=DEFAULT_MODALITY): + # The tuple input is similar with the list input. + output = modality_dispatch( + list(input), + data=data, + target=target, + modality_func=modality_func, + modality=modality, + ) + # Make the output a tuple, the same as input. + output = tuple(output) + return output + + +class ModalityParameterDict(torch.nn.ParameterDict): + """Get a new name so we know when parameters are associated with modality.""" + + pass diff --git a/tests/test_adversary.py b/tests/test_adversary.py index 2113d3f6..b599cdf3 100644 --- a/tests/test_adversary.py +++ b/tests/test_adversary.py @@ -15,6 +15,7 @@ import mart from mart.attack import Adversary, Composer, Perturber from mart.attack.gradient_modifier import Sign +from mart.attack.initializer import Constant def test_with_model(input_data, target_data, perturbation): @@ -188,11 +189,8 @@ def gain(logits): ) # Perturbation initialized as zero. - def initializer(x): - torch.nn.init.constant_(x, 0) - perturber = Perturber( - initializer=initializer, + initializer=Constant(0), projector=None, ) diff --git a/tests/test_gradient.py b/tests/test_gradient.py index a4ad49ee..fe366342 100644 --- a/tests/test_gradient.py +++ b/tests/test_gradient.py @@ -14,9 +14,10 @@ def test_gradient_sign(input_data): # Don't share input_data with other tests, because the gradient would be changed. input_data = torch.tensor([1.0, 2.0, 3.0]) input_data.grad = torch.tensor([-1.0, 3.0, 0.0]) + param_group = {"params": input_data} grad_modifier = Sign() - grad_modifier(input_data) + grad_modifier(param_group) expected_grad = torch.tensor([-1.0, 1.0, 0.0]) torch.testing.assert_close(input_data.grad, expected_grad) @@ -25,9 +26,10 @@ def test_gradient_lp_normalizer(): # Don't share input_data with other tests, because the gradient would be changed. input_data = torch.tensor([1.0, 2.0, 3.0]) input_data.grad = torch.tensor([-1.0, 3.0, 0.0]) + param_group = {"params": input_data} p = 1 grad_modifier = LpNormalizer(p) - grad_modifier(input_data) + grad_modifier(param_group) expected_grad = torch.tensor([-0.25, 0.75, 0.0]) torch.testing.assert_close(input_data.grad, expected_grad)