diff --git a/mart/attack/enforcer.py b/mart/attack/enforcer.py index c0f57ca6..250537a7 100644 --- a/mart/attack/enforcer.py +++ b/mart/attack/enforcer.py @@ -83,10 +83,14 @@ def verify(self, input_adv, *, input, target): class Mask(Constraint): + def __init__(self, key="perturbable_mask"): + super().__init__() + self.key = key + def verify(self, input_adv, *, input, target): # True/1 is mutable, False/0 is immutable. # mask.shape=(H, W) - mask = target["perturbable_mask"] + mask = target[self.key] # Immutable boolean mask, True is immutable. imt_mask = (1 - mask).bool() diff --git a/mart/attack/projector.py b/mart/attack/projector.py index f9887354..8f88c0dd 100644 --- a/mart/attack/projector.py +++ b/mart/attack/projector.py @@ -154,9 +154,13 @@ def project_(self, perturbation, *, input, target): class Mask(Projector): + def __init__(self, key="perturbable_mask"): + super().__init__() + self.key = key + @torch.no_grad() def project_(self, perturbation, *, input, target): - perturbation.mul_(target["perturbable_mask"]) + perturbation.mul_(target[self.key]) def __repr__(self): return f"{self.__class__.__name__}()"