diff --git a/mart/attack/composer.py b/mart/attack/composer.py index e1df942e..7747f958 100644 --- a/mart/attack/composer.py +++ b/mart/attack/composer.py @@ -110,17 +110,6 @@ def forward(self, perturbation, input, target): return perturbation, input, target -class PerturbationMask(Function): - def __init__(self, *args, key="perturbable_mask", **kwargs): - super().__init__(*args, **kwargs) - self.key = key - - def forward(self, perturbation, input, target): - mask = target[self.key] - perturbation = perturbation * mask - return perturbation, input, target - - # TODO: We may decompose Overlay into: perturbation-mask, input-re-mask, additive. class Overlay(Function): """We assume an adversary overlays a patch to the input.""" @@ -163,6 +152,17 @@ def forward(self, perturbation, input, target): return perturbation, input, target +class PerturbationMask(Function): + def __init__(self, *args, key="perturbable_mask", **kwargs): + super().__init__(*args, **kwargs) + self.key = key + + def forward(self, perturbation, input, target): + mask = target[self.key] + perturbation = perturbation * mask + return perturbation, input, target + + class PerturbationRectangleCrop(Function): def __init__(self, *args, coords_key="patch_coords", **kwargs): super().__init__(*args, **kwargs)