Skip to content

Commit

Permalink
Enhance documentation for attack APIs and add pretrained weights desc…
Browse files Browse the repository at this point in the history
…riptions for BIA, CDA, GAMA, and LTP attacks
  • Loading branch information
spencerwooo committed Dec 24, 2024
1 parent b26f985 commit a256c45
Show file tree
Hide file tree
Showing 5 changed files with 58 additions and 13 deletions.
20 changes: 17 additions & 3 deletions docs/populate_attack_apis.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,22 @@

import torchattack as ta

for attack_name in ta.SUPPORTED_ATTACKS:
filename = os.path.join('attacks', f'{attack_name.lower()}.md')
for attack in ta.SUPPORTED_ATTACKS:
filename = os.path.join('attacks', f'{attack.lower()}.md')
with mkdocs_gen_files.open(filename, 'w') as f:
f.write(f'# {attack_name}\n\n::: torchattack.{attack_name}\n')
if attack in ta.GENERATIVE_ATTACKS:
# For generative attacks, we need to document
# both the attack and its weights enum as well
attack_mod = getattr(ta, attack.lower())
weights_enum = getattr(attack_mod, f'{attack}Weights')
weights_doc = [f'- `{w}`\n' for w in weights_enum.__members__]
f.write(
f'# {attack}\n\n'
f'::: torchattack.{attack}\n'
f'::: torchattack.{attack.lower()}.{attack}Weights\n\n'
f'Available weights:\n\n'
f'{"".join(weights_doc)}\n'
)
else:
f.write(f'# {attack}\n\n::: torchattack.{attack}\n')
mkdocs_gen_files.set_edit_path(filename, 'docs/populate_attack_apis.py')
7 changes: 7 additions & 0 deletions torchattack/bia.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@


class BIAWeights(GeneratorWeightsEnum):
"""
Pretrained weights for the BIA attack generator are sourced from [the original
implementation of the BIA
attack](https://github.com/Alibaba-AAIG/Beyond-ImageNet-Attack#pretrained-generators).
RN stands for Random Normalization, and DA stands for Domain-Agnostic.
"""

RESNET152 = GeneratorWeights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/bia_resnet152_0.pth',
)
Expand Down
6 changes: 6 additions & 0 deletions torchattack/cda.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@


class CDAWeights(GeneratorWeightsEnum):
"""
Pretrained weights for the CDA attack generator are sourced from [the original
implementation of the CDA
attack](https://github.com/Muzammal-Naseer/CDA#Pretrained-Generators).
"""

RESNET152_IMAGENET = GeneratorWeights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/cda_res152_imagenet_0_rl.pth',
)
Expand Down
7 changes: 7 additions & 0 deletions torchattack/gama.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,13 @@


class GAMAWeights(GeneratorWeightsEnum):
"""
We provide pretrained weights of the GAMA attack generator with training steps
identical to the described settings in the paper and appendix. Specifically, we use
ViT-B/16 as the backend of CLIP. Training epochs are set to 5 and 10 for the COCO
and VOC datasets, respectively.
"""

DENSENET169_COCO = GeneratorWeights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/gama_dense169_coco_w_vitb16_epoch4.pth'
)
Expand Down
31 changes: 21 additions & 10 deletions torchattack/ltp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@


class LTPWeights(GeneratorWeightsEnum):
"""
Pretrained weights for the LTP attack generator are sourced from [the original
implementation of the LTP
attack](https://github.com/krishnakanthnakka/Transferable_Perturbations#installation).
"""

DENSENET121_IMAGENET = GeneratorWeights(
url='https://github.com/spencerwooo/torchattack/releases/download/v1.0-weights/ltp_densenet121_1_net_g.pth'
)
Expand Down Expand Up @@ -48,6 +54,7 @@ def __init__(
eps: float = 10 / 255,
weights: LTPWeights | str | None = LTPWeights.DEFAULT,
checkpoint_path: str | None = None,
inception: bool | None = None,
clip_min: float = 0.0,
clip_max: float = 1.0,
) -> None:
Expand All @@ -59,21 +66,25 @@ def __init__(
self.clip_min = clip_min
self.clip_max = clip_max

# Initialize the generator and its weights
self.generator = ResNetGenerator()
self.weights = LTPWeights.verify(weights)

# Whether is inception or not (crop layer 3x300x300 to 3x299x299)
is_inception = (
inception
if inception is not None
else (self.weights.inception if self.weights is not None else False)
)

# Initialize the generator
self.generator = ResNetGenerator(inception=is_inception)

# Prioritize checkpoint path over provided weights enum
# Load the weights
if self.checkpoint_path is not None:
self.generator.load_state_dict(
torch.load(self.checkpoint_path, weights_only=True)
)
else:
# Verify and load weights from enum if checkpoint path is not provided
self.weights: LTPWeights = LTPWeights.verify(weights)
if self.weights is not None:
self.generator.load_state_dict(
self.weights.get_state_dict(check_hash=True)
)
elif self.weights is not None:
self.generator.load_state_dict(self.weights.get_state_dict(check_hash=True))

self.generator.eval().to(self.device)

Expand Down

0 comments on commit a256c45

Please sign in to comment.