Skip to content

Commit

Permalink
Merge branch 'main' into patch_composer
Browse files Browse the repository at this point in the history
  • Loading branch information
mzweilin committed Sep 28, 2023
2 parents 0975372 + 6fb35a5 commit d3f7869
Show file tree
Hide file tree
Showing 21 changed files with 101 additions and 102 deletions.
22 changes: 8 additions & 14 deletions mart/attack/adversary.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@
from mart.utils import silent

from ..optim import OptimizerFactory
from ..utils import MonkeyPatch

if TYPE_CHECKING:
from .composer import Composer
from .enforcer import Enforcer
from .gain import Gain
from .gradient_modifier import GradientModifier
from .objective import Objective
from .perturber import Perturber

__all__ = ["Adversary"]

Expand All @@ -35,7 +33,6 @@ class Adversary(pl.LightningModule):
def __init__(
self,
*,
perturber: Perturber,
composer: Composer,
optimizer: OptimizerFactory | Callable[[Any], torch.optim.Optimizer],
gain: Gain,
Expand All @@ -48,7 +45,6 @@ def __init__(
"""_summary_
Args:
perturber (Perturber): A MART Perturber.
composer (Composer): A MART Composer.
optimizer (OptimizerFactory | Callable[[Any], torch.optim.Optimizer]): A MART OptimizerFactory or partial that returns an Optimizer when given params.
gain (Gain): An adversarial gain function, which is a differentiable estimate of adversarial objective.
Expand All @@ -65,10 +61,9 @@ def __init__(
lambda state_dict, *args, **kwargs: state_dict.clear()
)

# Hide the perturber module in a list, so that perturbation is not exported as a parameter in the model checkpoint.
# Hide the composer module in a list, so that perturbation is not exported as a parameter in the model checkpoint.
# and DDP won't try to get the uninitialized parameters of perturbation.
self._perturber = [perturber]
self.composer = composer
self._composer = [composer]
self.optimizer = optimizer
if not isinstance(self.optimizer, OptimizerFactory):
self.optimizer = OptimizerFactory(self.optimizer)
Expand Down Expand Up @@ -103,13 +98,13 @@ def __init__(
assert self._attacker.limit_train_batches > 0

@property
def perturber(self) -> Perturber:
# Hide the perturber module in a list, so that perturbation is not exported as a parameter in the model checkpoint,
def composer(self) -> Composer:
# Hide the composer module in a list, so that perturbation is not exported as a parameter in the model checkpoint,
# and DDP won't try to get the uninitialized parameters of perturbation.
return self._perturber[0]
return self._composer[0]

def configure_optimizers(self):
return self.optimizer(self.perturber)
return self.optimizer(self.composer)

def training_step(self, batch_and_model, batch_idx):
input, target, model = batch_and_model
Expand Down Expand Up @@ -157,7 +152,7 @@ def fit(self, input, target, *, model: Callable):
batch_and_model = (input, target, model)

# Configure and reset perturbation for current inputs
self.perturber.configure_perturbation(input)
self.composer.configure_perturbation(input)

# Attack, aka fit a perturbation, for one epoch by cycling over the same input batch.
# We use Trainer.limit_train_batches to control the number of attack iterations.
Expand All @@ -166,8 +161,7 @@ def fit(self, input, target, *, model: Callable):

def forward(self, input, target):
"""Compose adversarial examples and enforce the threat model."""
perturbation = self.perturber(input=input, target=target)
input_adv = self.composer(perturbation, input=input, target=target)
input_adv = self.composer(input=input, target=target)

if self.enforcer is not None:
self.enforcer(input_adv, input=input, target=target)
Expand Down
27 changes: 22 additions & 5 deletions mart/attack/composer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import abc
from collections import OrderedDict
from typing import Any, Iterable
from typing import TYPE_CHECKING, Any, Iterable

import torch
import torchvision
Expand All @@ -18,6 +18,9 @@

logger = pylogger.get_pylogger(__name__)

if TYPE_CHECKING:
from .perturber import Perturber


class Function(torch.nn.Module):
def __init__(self, *args, order=0, **kwargs) -> None:
Expand All @@ -38,22 +41,36 @@ def forward(
pass


class Composer:
def __init__(self, functions: dict[str, Function]) -> None:
class Composer(torch.nn.Module):
def __init__(self, perturber: Perturber, functions: dict[str, Function]) -> None:
"""_summary_
Args:
perturber (Perturber): Manage perturbations.
functions (dict[str, Function]): A dictionary of functions for composing pertured input.
"""
super().__init__()

self.perturber = perturber

# Sort functions by function.order and the name.
self.functions_dict = OrderedDict(
sorted(functions.items(), key=lambda name_fn: (name_fn[1].order, name_fn[0]))
)
self.functions = list(self.functions_dict.values())

def __call__(
def configure_perturbation(self, input: torch.Tensor | Iterable[torch.Tensor]):
return self.perturber.configure_perturbation(input)

def forward(
self,
perturbation: torch.Tensor | Iterable[torch.Tensor],
*,
input: torch.Tensor | Iterable[torch.Tensor],
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]],
**kwargs,
) -> torch.Tensor | Iterable[torch.Tensor]:
perturbation = self.perturber(input=input, target=target)

if isinstance(perturbation, torch.Tensor) and isinstance(input, torch.Tensor):
return self._compose(perturbation, input=input, target=target)

Expand Down
1 change: 0 additions & 1 deletion mart/configs/attack/adversary.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ defaults:

_target_: mart.attack.Adversary
_convert_: all
perturber: ???
optimizer:
maximize: True
gain: ???
Expand Down
3 changes: 3 additions & 0 deletions mart/configs/attack/composer/default.yaml
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
defaults:
- perturber: default

_target_: mart.attack.Composer
functions: ???
File renamed without changes.
13 changes: 7 additions & 6 deletions mart/configs/attack/fgm.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defaults:
- perturber/initializer: constant
- composer/perturber/initializer: constant
- /optimizer@optimizer: sgd

max_iters: 1
Expand All @@ -8,11 +8,12 @@ eps: ???
optimizer:
lr: ${..eps}

perturber:
initializer:
constant: 0
projector:
eps: ${...eps}
composer:
perturber:
initializer:
constant: 0
projector:
eps: ${....eps}

# We can turn off progress bar for one-step attack.
callbacks:
Expand Down
3 changes: 1 addition & 2 deletions mart/configs/attack/linf.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
defaults:
- perturber: default
- perturber/projector: linf_additive_range
- composer/perturber/projector: linf_additive_range
- enforcer: default
- enforcer/constraints: lp

Expand Down
3 changes: 1 addition & 2 deletions mart/configs/attack/mask.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
defaults:
- perturber: default
- perturber/projector: mask_range
- composer/perturber/projector: mask_range
- enforcer: default
- enforcer/constraints: [mask, pixel_range]
9 changes: 5 additions & 4 deletions mart/configs/attack/object_detection_mask_adversary.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defaults:
- adversary
- gradient_ascent
- mask
- perturber/initializer: constant
- composer/perturber/initializer: constant
- composer/functions: overlay
- gradient_modifier: sign
- gain: rcnn_training_loss
Expand All @@ -12,6 +12,7 @@ max_iters: ???
lr: ???

# Start with grey perturbation in the overlay mode.
perturber:
initializer:
constant: 127
composer:
perturber:
initializer:
constant: 127
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ defaults:
- adversary
- gradient_ascent
- mask
- perturber/initializer: constant
- composer/perturber/initializer: constant
- composer/functions: overlay
- gradient_modifier: sign
- gain: rcnn_class_background
Expand All @@ -12,6 +12,7 @@ max_iters: ???
lr: ???

# Start with grey perturbation in the overlay mode.
perturber:
initializer:
constant: 127
composer:
perturber:
initializer:
constant: 127
15 changes: 8 additions & 7 deletions mart/configs/attack/pgd.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
defaults:
- perturber/initializer: uniform
- composer/perturber/initializer: uniform
- /optimizer@optimizer: sgd

max_iters: ???
Expand All @@ -9,9 +9,10 @@ lr: ???
optimizer:
lr: ${..lr}

perturber:
initializer:
min: ${negate:${...eps}}
max: ${...eps}
projector:
eps: ${...eps}
composer:
perturber:
initializer:
min: ${negate:${....eps}}
max: ${....eps}
projector:
eps: ${....eps}
Loading

0 comments on commit d3f7869

Please sign in to comment.