Skip to content

Commit

Permalink
Move batch_c15n out of Adversary (#220)
Browse files Browse the repository at this point in the history
* Move batch_c15n to AdversaryConnector.

* Move configs from /attack/batch_c15n to /batch_c15n.

* Move mart/attack/batch_c15n.py to mart/transfroms.

* Fix test.

* Use non-keyword arguments in Adversary.forward() for simplicity.

* Comment.

* Use non-keyword arguments for input and target in Adversary.fit() for simplicity.

* Fix a bug.

* Add test.

* Simplify if-else for LightningModule.

* Update test.

* Set a tuple batch_c15n by default in the adversary connector callback, since datamodules in MART often give a batch=(input, target).
  • Loading branch information
mzweilin authored Sep 1, 2023
1 parent 15e9090 commit 4db45e8
Show file tree
Hide file tree
Showing 18 changed files with 112 additions and 87 deletions.
1 change: 0 additions & 1 deletion mart/attack/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .adversary import *
from .adversary_wrapper import *
from .batch_c15n import *
from .composer import *
from .enforcer import *
from .gain import *
Expand Down
35 changes: 8 additions & 27 deletions mart/attack/adversary.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def __init__(
objective: Objective | None = None,
enforcer: Enforcer | None = None,
attacker: pl.Trainer | None = None,
batch_c15n: Callable,
**kwargs,
):
"""_summary_
Expand All @@ -57,7 +56,6 @@ def __init__(
objective (Objective): A function for computing adversarial objective, which returns True or False. Optional.
enforcer (Enforcer): A Callable that enforce constraints on the adversarial input.
attacker (Trainer): A PyTorch-Lightning Trainer object used to fit the perturbation.
batch_c15n (Callable): Canonicalize batch into convenient format and revert to the original format.
"""
super().__init__()

Expand Down Expand Up @@ -104,8 +102,6 @@ def __init__(
assert self._attacker.max_epochs == 0
assert self._attacker.limit_train_batches > 0

self.batch_c15n = batch_c15n

@property
def perturber(self) -> Perturber:
# Hide the perturber module in a list, so that perturbation is not exported as a parameter in the model checkpoint,
Expand All @@ -116,20 +112,13 @@ def configure_optimizers(self):
return self.optimizer(self.perturber)

def training_step(self, batch_and_model, batch_idx):
batch, model = batch_and_model
input, target, model = batch_and_model

# Compose adversarial examples.
batch_adv = self.forward(batch=batch)
input_adv, target_adv = self.forward(input, target)

# A model that returns output dictionary.
if hasattr(model, "attack_step"):
outputs = model.attack_step(batch_adv, batch_idx)
elif hasattr(model, "training_step"):
# Disable logging if we have to reuse training_step() of the target model.
with MonkeyPatch(model, "log", lambda *args, **kwargs: None):
outputs = model.training_step(batch_adv, batch_idx)
else:
outputs = model(batch_adv)
outputs = model(input_adv, target_adv)

# FIXME: This should really be just `return outputs`. But this might require a new sequence?
# FIXME: Everything below here should live in the model as modules.
Expand Down Expand Up @@ -163,12 +152,9 @@ def configure_gradient_clipping(
self.gradient_modifier(group["params"])

@silent()
def fit(self, *, batch: torch.Tensor | list | dict, model: Callable):
# Extract and canonicalize input for initializing perturbation.
# TODO: Get rid of batch_c15n() here by converting perturbation to UninitializedParameter.
input, _target = self.batch_c15n(batch)
def fit(self, input, target, *, model: Callable):
# The attack also needs access to the model at every iteration.
batch_and_model = (batch, model)
batch_and_model = (input, target, model)

# Configure and reset perturbation for current inputs
self.perturber.configure_perturbation(input)
Expand All @@ -178,20 +164,15 @@ def fit(self, *, batch: torch.Tensor | list | dict, model: Callable):
self.attacker.fit_loop.max_epochs += 1
self.attacker.fit(self, train_dataloaders=cycle([batch_and_model]))

def forward(self, *, batch):
"""Compose adversarial examples and revert to the original input format."""
input, target = self.batch_c15n(batch)

# Get the canonicalized input_adv for enforcer checking.
def forward(self, input, target):
"""Compose adversarial examples and enforce the threat model."""
perturbation = self.perturber(input=input, target=target)
input_adv = self.composer(perturbation, input=input, target=target)

if self.enforcer is not None:
self.enforcer(input_adv, input=input, target=target)

# Target model expects input in the original format.
batch_adv = self.batch_c15n.revert(input_adv, target)
return batch_adv
return input_adv, target

@property
def attacker(self):
Expand Down
29 changes: 26 additions & 3 deletions mart/callbacks/adversary_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

from lightning.pytorch.callbacks import Callback

from ..utils import MonkeyPatch

__all__ = ["AdversaryConnector"]


Expand All @@ -23,6 +25,7 @@ def __init__(
train_adversary: Callable = None,
val_adversary: Callable = None,
test_adversary: Callable = None,
batch_c15n: Callable = None,
):
"""A pl.Trainer callback which perturbs input to be adversarial in training/validation/test
phase.
Expand All @@ -32,10 +35,12 @@ def __init__(
train_adversary (Callable, optional): Adversary in the training phase. Defaults to None.
val_adversary (Callable, optional): Adversary in the validation phase. Defaults to None.
test_adversary (Callable, optional): Adversary in the test phase. Defaults to None.
batch_c15n (Callable): Canonicalize batch into convenient format and revert to the original format.
"""
self.train_adversary = train_adversary or adversary
self.val_adversary = val_adversary or adversary
self.test_adversary = test_adversary or adversary
self.batch_c15n = batch_c15n

def setup(self, trainer, pl_module, stage=None):
self._on_after_batch_transfer = pl_module.on_after_batch_transfer
Expand Down Expand Up @@ -66,8 +71,26 @@ def on_after_batch_transfer(self, pl_module, batch, dataloader_idx):
# Move adversary to same device as pl_module and run attack
adversary.to(pl_module.device)

# Directly pass batch instead of assuming it has a structure.
adversary.fit(batch=batch, model=pl_module)
batch_adv = adversary(batch=batch)
# Make a simple model interface that outputs=model(input, target)
def model(input, target):
batch = self.batch_c15n.revert(input, target)

if hasattr(pl_module, "attack_step"):
outputs = pl_module.attack_step(batch, dataloader_idx)
else:
# LightningModule must have "training_step".
# Disable logging if we have to reuse training_step() of the target model.
with MonkeyPatch(pl_module, "log", lambda *args, **kwargs: None):
outputs = pl_module.training_step(batch, dataloader_idx)
return outputs

# Canonicalize the batch to work with Adversary.
input, target = self.batch_c15n(batch)

adversary.fit(input, target, model=model)
input_adv, target_adv = adversary(input, target)

# Revert to the original batch format.
batch_adv = self.batch_c15n.revert(input_adv, target_adv)

return batch_adv
1 change: 0 additions & 1 deletion mart/configs/attack/adversary.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,3 @@ gradient_modifier: null
objective: null
enforcer: ???
attacker: null
batch_c15n: ???
1 change: 0 additions & 1 deletion mart/configs/attack/batch_c15n/input_only.yaml

This file was deleted.

1 change: 0 additions & 1 deletion mart/configs/attack/classification_fgsm_linf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ defaults:
- gradient_modifier: sign
- gain: cross_entropy
- objective: misclassification
- batch_c15n: list

eps: ???
max_iters: 1
1 change: 0 additions & 1 deletion mart/configs/attack/classification_pgd_linf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ defaults:
- gradient_modifier: sign
- gain: cross_entropy
- objective: misclassification
- batch_c15n: list

eps: ???
lr: ???
Expand Down
1 change: 0 additions & 1 deletion mart/configs/attack/object_detection_mask_adversary.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ defaults:
- gradient_modifier: sign
- gain: rcnn_training_loss
- objective: zero_ap
- batch_c15n: tuple

max_iters: ???
lr: ???
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ defaults:
- gradient_modifier: sign
- gain: rcnn_class_background
- objective: object_detection_missed
- batch_c15n: tuple

max_iters: ???
lr: ???
Expand Down
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# We expect the original batch looks like `{"input": tensor, ...}` with the default parameters.
_target_: mart.attack.batch_c15n.DictBatchC15n
_target_: mart.transforms.DictBatchC15n
input_key: input
1 change: 1 addition & 0 deletions mart/configs/batch_c15n/input_only.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
_target_: mart.transforms.InputOnlyBatchC15n
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# We expect the original batch looks like `[input, target]` with the default parameters.
_target_: mart.attack.batch_c15n.ListBatchC15n
_target_: mart.transforms.ListBatchC15n
input_key: 0
target_size: 1
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# We expect the original batch looks like `(input, target)` with the default parameters.
_target_: mart.attack.batch_c15n.TupleBatchC15n
_target_: mart.transforms.TupleBatchC15n
input_key: 0
target_size: 1
3 changes: 3 additions & 0 deletions mart/configs/callbacks/adversary_connector.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
defaults:
- /batch_c15n@adversary_connector.batch_c15n: tuple

adversary_connector:
_target_: mart.callbacks.AdversaryConnector
adversary: null
Expand Down
1 change: 1 addition & 0 deletions mart/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,6 @@
# SPDX-License-Identifier: BSD-3-Clause
#

from .batch_c15n import * # noqa: F403
from .extended import * # noqa: F403
from .transforms import * # noqa: F403
File renamed without changes.
Loading

0 comments on commit 4db45e8

Please sign in to comment.