-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Modality dispatch #136
base: main
Are you sure you want to change the base?
Modality dispatch #136
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See comments.
mart/utils/modality_dispatch.py
Outdated
data: torch.Tensor | dict[str, torch.Tensor] | Iterable[Any], | ||
*, | ||
input: torch.Tensor | dict[str, torch.Tensor] | Iterable[Any], | ||
target: torch.Tensor | Iterable[Any] | None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These types should match whatever is in Enforcer
above.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Um, shall we make some changes since it is a recursive function?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Huh? Whatever enforcer takes should be the same things this function takes? If you believe that, then why are the type annotations different?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because modality_dispatch()
is a recursive function that would see sub-components of the original input at some depth.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But so does Enforcer? I can pass a tensor, a iterable[tensor] or iterable[dict[str, tensor]] to Enforcer. I don't get why the function being recursive matters. The type annotations tell you what the function accepts.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it the same as in Enforcer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Well, it is now different from Enforcer, because other modules such as Initializer also use modality_dispatch()
and they may have different types of arguments.
mart/attack/enforcer.py
Outdated
input_adv: torch.Tensor | Iterable[torch.Tensor], | ||
input_adv: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], | ||
*, | ||
input: torch.Tensor | Iterable[torch.Tensor], | ||
target: torch.Tensor | Iterable[torch.Tensor | dict[str, Any]], | ||
input: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, torch.Tensor]], | ||
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you be more specific than Any
with the types?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. I think it can be a tensor or a string (file name for example).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed it back to Any because it doesn't really matter. Other modules that use modality_dispatch() also annotate that as Any.
mart/utils/modality_dispatch.py
Outdated
if isinstance(input, torch.Tensor): | ||
if isinstance(modality_func, dict): | ||
# A dictionary of Callable indexed by modality. | ||
return modality_func[modality](data, input=input, target=target) | ||
else: | ||
# A Callable with modality=? as a keyword argument. | ||
return modality_func(data, input=input, target=target, modality=modality) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's clear to explicit check the type of each input instead of having nested if statements. else
should trigger an exception.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I made it a singledispatch
function.
mart/utils/modality_dispatch.py
Outdated
The function returns an object that is homomorphic to input and data. | ||
""" | ||
|
||
assert type(data) == type(input) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Note that this will break for universal perturbation idea, since I want to apply the same perturbation to every input. instead, you should dispatch on types below and raise an NotImplementedError if you don't support that combination of types.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The new singledispatch
implementation does not have the assertion.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did we ever discuss the merits of making Composer
and GradientModifier
modality aware instea of making Adversary
modality aware? What if I want to generate a multi-modal universal attack?
mart/attack/composer.py
Outdated
perturbation: torch.Tensor | Iterable[torch.Tensor], | ||
perturbation: torch.Tensor, | ||
*, | ||
input: torch.Tensor | Iterable[torch.Tensor], | ||
target: torch.Tensor | Iterable[torch.Tensor] | Iterable[dict[str, Any]], | ||
input: torch.Tensor, | ||
target: torch.Tensor | dict[str, Any], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
These changes make me think this change is a regression? Why do we not accept Iterable[torch.Tensor]
anymore?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK. I reverted changes to Composer
and created a new Modality(Composer)
.
Do you think we should wait to merge this until we change MART to bound inputs between |
…er to consume modality info.
What does this PR do?
modality_dispatch()
in thesingledispatch
manner.modality_dispatch()
inPerturber
forInitializer
andProjector
.Initializer
andProjector
.Enforcer
andComposer
modality-aware withmodality_dispatch()
GradientModifier
is different because of the way we handle trainable parameters.Optimizer
forGradientModifier
.Benign:
│ test_metrics/map_50 │ 0.6349384784698486
Adversarial:
CUDA_VISIBLE_DEVICES=0 \ python -m mart \ experiment=ArmoryCarlaOverObjDet_TorchvisionFasterRCNN \ trainer=gpu \ fit=false \ +trainer.limit_test_batches=1 \ +model.load_state_dict.losses_and_detections.model=/home/weilinxu/coder/GARD-with-MART/oscar/model_zoo/carla_rgb_weights_eval6.pt \ +attack@model.modules.input_adv_test=object_detection_rgb_mask_adversary \ +model.test_sequence.seq005=input_adv_test \ model.test_sequence.seq010.preprocessor=["input_adv_test"]
│ test_metrics/map_50 │ 0.4633878767490387 │
Type of change
Please check all relevant options.
Testing
Please describe the tests that you ran to verify your changes. Consider listing any relevant details of your test configuration.
pytest
CUDA_VISIBLE_DEVICES=0,1 python -m mart experiment=CIFAR10_CNN_Adv trainer=ddp trainer.precision=16 trainer.devices=2 model.optimizer.lr=0.2 trainer.max_steps=2925 datamodule.ims_per_batch=256 datamodule.world_size=2
reports 70% (14 sec/epoch).Before submitting
pre-commit run -a
command without errorsDid you have fun?
Make sure you had fun coding 🙃