Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modality dispatch #136

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
56014e6
Replace tuple with Iterable[torch.Tensor]
dxoigmn Apr 14, 2023
1c47cc0
Fix tests
dxoigmn Apr 14, 2023
70cc36a
Cleanup
dxoigmn Apr 14, 2023
53ee7f4
Make GradientModifier accept Iterable[torch.Tensor]
dxoigmn Apr 14, 2023
3f399fa
Pull the modality_dispatch code from PR 115.
mzweilin Apr 18, 2023
fe87864
Add a constant DEFAULT_MODALITY.
mzweilin Apr 18, 2023
ad1f372
Add modality aware enforcer.
mzweilin Apr 18, 2023
5acc632
Type annotation.
mzweilin Apr 18, 2023
8861436
Type annotation.
mzweilin Apr 18, 2023
9c03718
Merge branch 'main' into modality_dispatch
mzweilin Jun 12, 2023
e46151d
Make a single-level if-else in modality_dispatch().
mzweilin Jun 12, 2023
7bb3321
Remove unused keys early.
mzweilin Jun 12, 2023
ca92767
Merge branch 'main' into modality_dispatch
mzweilin Jun 12, 2023
20ffada
Make it fancy with singledispatch.
mzweilin Jun 20, 2023
e77236f
Rename back to Enforcer.enforce().
mzweilin Jun 20, 2023
3737c98
Comment.
mzweilin Jun 20, 2023
3833c34
Comment.
mzweilin Jun 20, 2023
f08510b
Loosen data type requirement.
mzweilin Jun 21, 2023
4038610
Modality-aware adversary.
mzweilin Jun 21, 2023
7f47ab6
Backward compatible with exisiting non-modality configs of adversary.
mzweilin Jun 21, 2023
a121c80
Fix test.
mzweilin Jun 21, 2023
c508266
Type annotation for modality-aware componenets.
mzweilin Jun 21, 2023
2676954
Make a new name ModalityParameterDict for modality-aware parameters.
mzweilin Jun 21, 2023
e386276
Fix function arguments and type annotations.
mzweilin Jun 21, 2023
7f100d9
Make modality an optional keyword argument.
mzweilin Jun 21, 2023
c523907
Fix type annotation.
mzweilin Jun 21, 2023
8866c53
Fix type annotation.
mzweilin Jun 21, 2023
5611b3a
Simplify composer, initializerr and projector with modality_dispatch.
mzweilin Jun 21, 2023
3fbd048
Simplify type annotaiton with modality_dispatch().
mzweilin Jun 21, 2023
b675d3f
Update type annotation.
mzweilin Jun 21, 2023
2492232
Make explicit function arguments from modality_dispatch().
mzweilin Jun 21, 2023
430ed3f
Fix test.
mzweilin Jun 21, 2023
7bfb9cb
Simplify type annotation.
mzweilin Jun 21, 2023
3607491
Revert changes in Composer and make a new Modality(Composer).
mzweilin Jun 23, 2023
8c2a676
Merge branch 'main' into modality_dispatch
mzweilin Jun 23, 2023
086bda8
Add Modality(GradientModifier) and change the usage of GradientModifi…
mzweilin Jun 23, 2023
c57465f
Fix test on gradient modifier.
mzweilin Jun 23, 2023
bff59bc
Cleanup.
mzweilin Jun 23, 2023
ec648be
Merge branch 'main' into modality_dispatch
mzweilin Jun 23, 2023
ef3a055
Merge branch 'main' into modality_dispatch
mzweilin Sep 6, 2023
a06db23
Keep modality-wise params for weights for later iterations.
mzweilin Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mart/attack/adversary.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ def configure_gradient_clipping(

if self.gradient_modifier:
for group in optimizer.param_groups:
self.gradient_modifier(group["params"])
self.gradient_modifier(group)

@silent()
def fit(self, input, target, *, model: Callable):
Expand Down
33 changes: 33 additions & 0 deletions mart/attack/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import torch

from ..utils.modality_dispatch import DEFAULT_MODALITY, modality_dispatch


class Composer(abc.ABC):
def __call__(
Expand Down Expand Up @@ -78,3 +80,34 @@ def compose(self, perturbation, *, input, target):
masked_perturbation = perturbation * mask

return input + masked_perturbation


class Modality(Composer):
def __init__(self, **modality_method):
self.modality_method = modality_method

def __call__(
self,
perturbation: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]],
*,
input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]],
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
**kwargs,
) -> torch.Tensor | Iterable[torch.Tensor]:
return modality_dispatch(
input,
data=perturbation,
target=target,
modality_func=self.compose,
modality=DEFAULT_MODALITY,
)

def compose(
self,
perturbation: torch.Tensor,
*,
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
modality: str,
) -> torch.Tensor:
return self.modality_method[modality](perturbation, input=input, target=target)
43 changes: 25 additions & 18 deletions mart/attack/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import torch

from ..utils.modality_dispatch import modality_dispatch

__all__ = ["Enforcer"]


Expand Down Expand Up @@ -96,36 +98,41 @@ def verify(self, input_adv, *, input, target):


class Enforcer:
def __init__(self, constraints: dict[str, Constraint]) -> None:
self.constraints = list(constraints.values()) # intentionally ignore keys
def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None:
self.modality_constraints = {}

for modality, constraints in modality_constraints.items():
# Intentionally ignore keys after modality.
# The keys are there for combining constraints easily in Hydra.
self.modality_constraints[modality] = constraints.values()

@torch.no_grad()
def __call__(
self,
input_adv: torch.Tensor | Iterable[torch.Tensor],
input_adv: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]],
*,
input: torch.Tensor | Iterable[torch.Tensor],
input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]],
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
**kwargs,
):
if isinstance(input_adv, torch.Tensor) and isinstance(input, torch.Tensor):
self.enforce(input_adv, input=input, target=target)

elif (
isinstance(input_adv, Iterable)
and isinstance(input, Iterable) # noqa: W503
and isinstance(target, Iterable) # noqa: W503
):
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
self.enforce(input_adv_i, input=input_i, target=target_i)
# The default modality is set to "constraints", so that it is backward compatible with existing configs.
modality_dispatch(
input,
data=input_adv,
target=target,
modality_func=self.enforce,
modality="constraints",
)

@torch.no_grad()
def enforce(
self,
input_adv: torch.Tensor,
input_adv: torch.Tensor | Iterable[torch.Tensor],
*,
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
modality: str,
):
for constraint in self.constraints:

for constraint in self.modality_constraints[modality]:
constraint(input_adv, input=input, target=target)
19 changes: 17 additions & 2 deletions mart/attack/gradient_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,20 @@

from __future__ import annotations

from typing import Iterable
from typing import Any

import torch

from ..utils.modality_dispatch import DEFAULT_MODALITY

__all__ = ["GradientModifier"]


class GradientModifier:
"""Gradient modifier base class."""

def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
def __call__(self, param_group: dict[str, Any]) -> None:
parameters = param_group["params"]
if isinstance(parameters, torch.Tensor):
parameters = [parameters]

Expand All @@ -43,3 +46,15 @@ def __init__(self, p: int | float):
def modify_(self, parameter: torch.Tensor) -> None:
p_norm = torch.norm(parameter.grad.detach(), p=self.p)
parameter.grad.detach().div_(p_norm)


class Modality(GradientModifier):
def __init__(self, **modality_method):
if len(modality_method) == 0:
modality_method = {DEFAULT_MODALITY: self.modify_}

self.modality_method_ = modality_method

def __call__(self, param_group: dict[str, Any]) -> None:
modality = param_group["modality"] if "modality" in param_group else DEFAULT_MODALITY
self.modality_method_[modality](param_group)
16 changes: 10 additions & 6 deletions mart/attack/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

from typing import Iterable
from typing import Any

import torch
import torchvision
Expand All @@ -21,11 +21,15 @@ class Initializer:
"""Initializer base class."""

@torch.no_grad()
def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]

[self.initialize_(parameter) for parameter in parameters]
def __call__(
self,
parameter: torch.Tensor,
*,
input: torch.Tensor | None = None,
target: torch.Tensor | dict[str, Any] | None = None,
) -> None:
# Accept input and target from modality_dispatch().
self.initialize_(parameter)

@torch.no_grad()
def initialize_(self, parameter: torch.Tensor) -> None:
Expand Down
42 changes: 36 additions & 6 deletions mart/attack/perturber.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,11 @@
import torch
from lightning.pytorch.utilities.exceptions import MisconfigurationException

from ..utils.modality_dispatch import (
DEFAULT_MODALITY,
ModalityParameterDict,
modality_dispatch,
)
from .projector import Projector

if TYPE_CHECKING:
Expand All @@ -23,8 +28,8 @@ class Perturber(torch.nn.Module):
def __init__(
self,
*,
initializer: Initializer,
projector: Projector | None = None,
initializer: Initializer | dict[str, Initializer],
projector: Projector | dict[str, Projector] | None = None,
):
"""_summary_

Expand All @@ -34,8 +39,17 @@ def __init__(
"""
super().__init__()

projector = projector or Projector()

# Modality-specific objects.
# Backward compatibility, in case modality is unknown, and not given in input.
if not isinstance(initializer, dict):
initializer = {DEFAULT_MODALITY: initializer}
if not isinstance(projector, dict):
projector = {DEFAULT_MODALITY: projector}

self.initializer_ = initializer
self.projector_ = projector or Projector()
self.projector_ = projector

self.perturbation = None

Expand Down Expand Up @@ -65,6 +79,10 @@ def create_from_tensor(tensor):
return torch.nn.Parameter(
torch.empty_like(tensor, dtype=torch.float, requires_grad=True)
)
elif isinstance(tensor, dict):
return ModalityParameterDict(
{modality: create_from_tensor(t) for modality, t in tensor.items()}
)
elif isinstance(tensor, Iterable):
return torch.nn.ParameterList([create_from_tensor(t) for t in tensor])
else:
Expand All @@ -76,7 +94,13 @@ def create_from_tensor(tensor):
self.perturbation = create_from_tensor(input)

# Always (re)initialize perturbation.
self.initializer_(self.perturbation)
modality_dispatch(
input,
data=self.perturbation,
target=None,
modality_func=self.initializer_,
modality=DEFAULT_MODALITY,
)

def named_parameters(self, *args, **kwargs):
if self.perturbation is None:
Expand All @@ -90,12 +114,18 @@ def parameters(self, *args, **kwargs):

return super().parameters(*args, **kwargs)

def forward(self, **batch):
def forward(self, *, input, target, **batch):
if self.perturbation is None:
raise MisconfigurationException(
"You need to call the configure_perturbation before forward."
)

self.projector_(self.perturbation, **batch)
modality_dispatch(
input,
data=self.perturbation,
target=target,
modality_func=self.projector_,
modality=DEFAULT_MODALITY,
)

return self.perturbation
34 changes: 11 additions & 23 deletions mart/attack/projector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from __future__ import annotations

from typing import Any, Iterable
from typing import Any

import torch

Expand All @@ -17,33 +17,21 @@ class Projector:
@torch.no_grad()
def __call__(
self,
perturbation: torch.Tensor | Iterable[torch.Tensor],
perturbation: torch.Tensor,
*,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
**kwargs,
) -> None:
if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor):
self.project_(perturbation, input=input, target=target)

elif (
isinstance(perturbation, Iterable)
and isinstance(input, Iterable) # noqa: W503
and isinstance(target, Iterable) # noqa: W503
):
for perturbation_i, input_i, target_i in zip(perturbation, input, target):
self.project_(perturbation_i, input=input_i, target=target_i)

else:
raise NotImplementedError
self.project_(perturbation, input=input, target=target)

@torch.no_grad()
def project_(
self,
perturbation: torch.Tensor | Iterable[torch.Tensor],
perturbation: torch.Tensor,
*,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
) -> None:
pass

Expand All @@ -57,10 +45,10 @@ def __init__(self, projectors: list[Projector]):
@torch.no_grad()
def __call__(
self,
perturbation: torch.Tensor | Iterable[torch.Tensor],
perturbation: torch.Tensor,
*,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
input: torch.Tensor,
target: torch.Tensor | dict[str, Any],
**kwargs,
) -> None:
for projector in self.projectors:
Expand Down
1 change: 1 addition & 0 deletions mart/configs/attack/composer/modality.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_target_: mart.attack.composer.Modality
3 changes: 2 additions & 1 deletion mart/configs/attack/enforcer/default.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
_target_: mart.attack.Enforcer
constraints: ???
# FIXME: Hydra does not detect modality-aware constraints defined as sub-components.
# constraints: ???
1 change: 1 addition & 0 deletions mart/configs/attack/gradient_modifier/modality.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_target_: mart.attack.gradient_modifier.Modality
28 changes: 28 additions & 0 deletions mart/configs/attack/object_detection_rgb_mask_adversary.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
defaults:
- adversary
- perturber: default
- perturber/initializer@perturber.initializer.rgb: constant
- perturber/projector@perturber.projector.rgb: mask_range
- /optimizer@optimizer: sgd
- gain: rcnn_training_loss
- objective: zero_ap
- gradient_modifier: modality
- gradient_modifier@gradient_modifier.rgb: sign
- composer: modality
- composer@composer.rgb: overlay
- enforcer: default
- enforcer/constraints@enforcer.rgb: [mask, pixel_range]

# Make a 5-step attack for the demonstration purpose.
optimizer:
# Though we only use modality-aware hyper-params, the config requires a value for optimizer.lr.
lr: 0
rgb:
lr: 55

max_iters: 5

perturber:
initializer:
rgb:
constant: 127
Loading
Loading