Skip to content

Commit

Permalink
Fix adversarial image visualizer with canonical batches (#227)
Browse files Browse the repository at this point in the history
* Fetch canonical input and target.

* Invoke Adversary.forward().

* Fix test.
  • Loading branch information
mzweilin authored Sep 28, 2023
1 parent b89b67b commit 8b70e88
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 12 deletions.
9 changes: 4 additions & 5 deletions mart/callbacks/visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,14 @@ def __init__(self, folder):
os.makedirs(self.folder)

def on_train_batch_end(self, trainer, model, outputs, batch, batch_idx):
# Save input and target for on_train_end
self.input = batch["input"]
self.target = batch["target"]
# Save canonical input and target for on_train_end
self.input = batch[0]
self.target = batch[1]

def on_train_end(self, trainer, model):
# FIXME: We should really just save this to outputs instead of recomputing adv_input
with torch.no_grad():
perturbation = model.perturber(input=self.input, target=self.target)
adv_input = model.composer(perturbation, input=self.input, target=self.target)
adv_input, _target = model(self.input, self.target)

for img, tgt in zip(adv_input, self.target):
fname = tgt["file_name"]
Expand Down
16 changes: 9 additions & 7 deletions tests/test_visualizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,21 @@ def test_visualizer_run_end(input_data, target_data, perturbation, tmp_path):
target_list = [target_data]

# simulate an addition perturbation
def perturb(input):
def perturb(input, target):
result = [sample + perturbation for sample in input]
return result
return result, target

adversary = Mock(spec=Adversary, side_effect=perturb)
trainer = Mock()
model = Mock(composer=Mock(return_value=perturb(input_list)))
outputs = Mock()
batch = {"input": input_list, "target": target_list}
adversary = Mock(spec=Adversary, side_effect=perturb)
target_model = Mock()

# Canonical batch in Adversary.
batch = (input_list, target_list, target_model)

visualizer = PerturbedImageVisualizer(folder)
visualizer.on_train_batch_end(trainer, model, outputs, batch, 0)
visualizer.on_train_end(trainer, model)
visualizer.on_train_batch_end(trainer, adversary, outputs, batch, 0)
visualizer.on_train_end(trainer, adversary)

# verify that the visualizer created the JPG file
expected_output_path = folder / target_data["file_name"]
Expand Down

0 comments on commit 8b70e88

Please sign in to comment.