Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modality dispatch #136

Open
wants to merge 41 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
56014e6
Replace tuple with Iterable[torch.Tensor]
dxoigmn Apr 14, 2023
1c47cc0
Fix tests
dxoigmn Apr 14, 2023
70cc36a
Cleanup
dxoigmn Apr 14, 2023
53ee7f4
Make GradientModifier accept Iterable[torch.Tensor]
dxoigmn Apr 14, 2023
3f399fa
Pull the modality_dispatch code from PR 115.
mzweilin Apr 18, 2023
fe87864
Add a constant DEFAULT_MODALITY.
mzweilin Apr 18, 2023
ad1f372
Add modality aware enforcer.
mzweilin Apr 18, 2023
5acc632
Type annotation.
mzweilin Apr 18, 2023
8861436
Type annotation.
mzweilin Apr 18, 2023
9c03718
Merge branch 'main' into modality_dispatch
mzweilin Jun 12, 2023
e46151d
Make a single-level if-else in modality_dispatch().
mzweilin Jun 12, 2023
7bb3321
Remove unused keys early.
mzweilin Jun 12, 2023
ca92767
Merge branch 'main' into modality_dispatch
mzweilin Jun 12, 2023
20ffada
Make it fancy with singledispatch.
mzweilin Jun 20, 2023
e77236f
Rename back to Enforcer.enforce().
mzweilin Jun 20, 2023
3737c98
Comment.
mzweilin Jun 20, 2023
3833c34
Comment.
mzweilin Jun 20, 2023
f08510b
Loosen data type requirement.
mzweilin Jun 21, 2023
4038610
Modality-aware adversary.
mzweilin Jun 21, 2023
7f47ab6
Backward compatible with exisiting non-modality configs of adversary.
mzweilin Jun 21, 2023
a121c80
Fix test.
mzweilin Jun 21, 2023
c508266
Type annotation for modality-aware componenets.
mzweilin Jun 21, 2023
2676954
Make a new name ModalityParameterDict for modality-aware parameters.
mzweilin Jun 21, 2023
e386276
Fix function arguments and type annotations.
mzweilin Jun 21, 2023
7f100d9
Make modality an optional keyword argument.
mzweilin Jun 21, 2023
c523907
Fix type annotation.
mzweilin Jun 21, 2023
8866c53
Fix type annotation.
mzweilin Jun 21, 2023
5611b3a
Simplify composer, initializerr and projector with modality_dispatch.
mzweilin Jun 21, 2023
3fbd048
Simplify type annotaiton with modality_dispatch().
mzweilin Jun 21, 2023
b675d3f
Update type annotation.
mzweilin Jun 21, 2023
2492232
Make explicit function arguments from modality_dispatch().
mzweilin Jun 21, 2023
430ed3f
Fix test.
mzweilin Jun 21, 2023
7bfb9cb
Simplify type annotation.
mzweilin Jun 21, 2023
3607491
Revert changes in Composer and make a new Modality(Composer).
mzweilin Jun 23, 2023
8c2a676
Merge branch 'main' into modality_dispatch
mzweilin Jun 23, 2023
086bda8
Add Modality(GradientModifier) and change the usage of GradientModifi…
mzweilin Jun 23, 2023
c57465f
Fix test on gradient modifier.
mzweilin Jun 23, 2023
bff59bc
Cleanup.
mzweilin Jun 23, 2023
ec648be
Merge branch 'main' into modality_dispatch
mzweilin Jun 23, 2023
ef3a055
Merge branch 'main' into modality_dispatch
mzweilin Sep 6, 2023
a06db23
Keep modality-wise params for weights for later iterations.
mzweilin Sep 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions mart/attack/adversary_in_art.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
# SPDX-License-Identifier: BSD-3-Clause
#

from typing import Any, List, Optional
from typing import Any, Iterable, List, Optional

import hydra
import numpy
Expand Down Expand Up @@ -82,17 +82,18 @@ def convert_input_art_to_mart(self, x: numpy.ndarray):
x (np.ndarray): NHWC, [0, 1]

Returns:
tuple: a tuple of tensors in CHW, [0, 255].
Iterable[torch.Tensor]: an Iterable of tensors in CHW, [0, 255].
"""
input = torch.tensor(x).permute((0, 3, 1, 2)).to(self._device) * 255
# FIXME: replace tuple with whatever input's type is
input = tuple(inp_ for inp_ in input)
return input

def convert_input_mart_to_art(self, input: tuple):
def convert_input_mart_to_art(self, input: Iterable[torch.Tensor]):
"""Convert MART input to the ART's format.

Args:
input (tuple): a tuple of tensors in CHW, [0, 255].
input (Iterable[torch.Tensor]): an Iterable of tensors in CHW, [0, 255].

Returns:
np.ndarray: NHWC, [0, 1]
Expand All @@ -112,7 +113,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List):
y_patch_metadata (_type_): _description_

Returns:
tuple: a tuple of target dictionaies.
Iterable[dict[str, Any]]: an Iterable of target dictionaies.
"""
# Copy y to target, and convert ndarray to pytorch tensors accordingly.
target = []
Expand All @@ -132,6 +133,7 @@ def convert_target_art_to_mart(self, y: numpy.ndarray, y_patch_metadata: List):
target_i["file_name"] = f"{yi['image_id'][0]}.jpg"
target.append(target_i)

# FIXME: replace tuple with input type?
target = tuple(target)

return target
11 changes: 7 additions & 4 deletions mart/attack/adversary_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,13 @@

from __future__ import annotations

from typing import Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Iterable

import torch

if TYPE_CHECKING:
from .enforcer import Enforcer

__all__ = ["NormalizedAdversaryAdapter"]


Expand All @@ -22,7 +25,7 @@ class NormalizedAdversaryAdapter(torch.nn.Module):
def __init__(
self,
adversary: Callable[[Callable], Callable],
enforcer: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], None],
enforcer: Enforcer,
):
"""

Expand All @@ -37,8 +40,8 @@ def __init__(

def forward(
self,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
model: torch.nn.Module | None = None,
**kwargs,
):
Expand Down
26 changes: 13 additions & 13 deletions mart/attack/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from __future__ import annotations

import abc
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Iterable

import torch

Expand All @@ -24,8 +24,8 @@ def on_run_start(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -35,8 +35,8 @@ def on_examine_start(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -46,8 +46,8 @@ def on_examine_end(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -57,8 +57,8 @@ def on_advance_start(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -68,8 +68,8 @@ def on_advance_end(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
model: torch.nn.Module,
**kwargs,
):
Expand All @@ -79,8 +79,8 @@ def on_run_end(
self,
*,
adversary: Adversary,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
model: torch.nn.Module,
**kwargs,
):
Expand Down
27 changes: 17 additions & 10 deletions mart/attack/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,29 +7,36 @@
from __future__ import annotations

import abc
from typing import Any
from typing import Any, Iterable

import torch


class Composer(abc.ABC):
def __call__(
self,
perturbation: torch.Tensor | tuple,
perturbation: torch.Tensor | Iterable[torch.Tensor],
*,
input: torch.Tensor | tuple,
target: torch.Tensor | dict[str, Any] | tuple,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]],
**kwargs,
) -> torch.Tensor | tuple:
if isinstance(perturbation, tuple):
input_adv = tuple(
) -> torch.Tensor | Iterable[torch.Tensor]:
if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor):
return self.compose(perturbation, input=input, target=target)

elif (
isinstance(perturbation, Iterable)
and isinstance(input, Iterable) # noqa: W503
and isinstance(target, Iterable) # noqa: W503
):
# FIXME: replace tuple with whatever input's type is
return tuple(
self.compose(perturbation_i, input=input_i, target=target_i)
for perturbation_i, input_i, target_i in zip(perturbation, input, target)
)
else:
input_adv = self.compose(perturbation, input=input, target=target)

return input_adv
else:
raise NotImplementedError

@abc.abstractmethod
def compose(
Expand Down
47 changes: 18 additions & 29 deletions mart/attack/enforcer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,12 @@
from __future__ import annotations

import abc
from typing import Any
from typing import Any, Iterable

import torch

from ..utils.modality_dispatch import modality_dispatch

__all__ = ["Enforcer"]


Expand Down Expand Up @@ -98,6 +100,20 @@ class Enforcer:
def __init__(self, **modality_constraints: dict[str, dict[str, Constraint]]) -> None:
self.modality_constraints = modality_constraints

@torch.no_grad()
def __call__(
self,
input_adv: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]],
*,
input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]],
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
**kwargs,
):
# The default modality is set to "constraints", so that it is backward compatible with existing configs.
modality_dispatch(
self._enforce, input_adv, input=input, target=target, modality="constraints"
)

@torch.no_grad()
def _enforce(
self,
Expand All @@ -107,33 +123,6 @@ def _enforce(
target: torch.Tensor | dict[str, Any],
modality: str,
):
# intentionally ignore keys after modality.
for constraint in self.modality_constraints[modality].values():
constraint(input_adv, input=input, target=target)

def __call__(
self,
input_adv: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor],
*,
input: torch.Tensor | tuple | list[torch.Tensor] | dict[str, torch.Tensor],
target: torch.Tensor | dict[str, Any],
modality: str = "constraints",
**kwargs,
):
assert type(input_adv) == type(input)

if isinstance(input_adv, torch.Tensor):
# Finally we can verify constraints on tensor, per its modality.
# Set modality="constraints" by default, so that it is backward compatible with existing configs without modalities.
self._enforce(input_adv, input=input, target=target, modality=modality)
elif isinstance(input_adv, dict):
# The dict input has modalities specified in keys, passing them recursively.
for modality in input_adv:
self(input_adv[modality], input=input[modality], target=target, modality=modality)
elif isinstance(input_adv, (list, tuple)):
# We assume a modality-dictionary only contains tensors, but not list/tuple.
assert modality == "constraints"
# The list or tuple input is a collection of sub-input and sub-target.
for input_adv_i, input_i, target_i in zip(input_adv, input, target):
self(input_adv_i, input=input_i, target=target_i, modality=modality)
else:
raise ValueError(f"Unsupported data type of input_adv: {type(input_adv)}.")
36 changes: 16 additions & 20 deletions mart/attack/gradient_modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,44 +6,40 @@

from __future__ import annotations

import abc
from typing import Iterable

import torch

__all__ = ["GradientModifier"]


class GradientModifier(abc.ABC):
class GradientModifier:
"""Gradient modifier base class."""

def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
pass


class Sign(GradientModifier):
def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]

parameters = [p for p in parameters if p.grad is not None]
[self.modify_(parameter) for parameter in parameters]

@torch.no_grad()
def modify_(self, parameter: torch.Tensor) -> None:
pass


for p in parameters:
p.grad.detach().sign_()
class Sign(GradientModifier):
@torch.no_grad()
def modify_(self, parameter: torch.Tensor) -> None:
parameter.grad.sign_()


class LpNormalizer(GradientModifier):
"""Scale gradients by a certain L-p norm."""

def __init__(self, p: int | float):
self.p = p

def __call__(self, parameters: torch.Tensor | Iterable[torch.Tensor]) -> None:
if isinstance(parameters, torch.Tensor):
parameters = [parameters]

parameters = [p for p in parameters if p.grad is not None]
self.p = float(p)

for p in parameters:
p_norm = torch.norm(p.grad.detach(), p=self.p)
p.grad.detach().div_(p_norm)
@torch.no_grad()
def modify_(self, parameter: torch.Tensor) -> None:
p_norm = torch.norm(parameter.grad.detach(), p=self.p)
parameter.grad.detach().div_(p_norm)
Loading