diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index 47f42d57..e2eea49e 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -381,9 +381,8 @@ def _cross_attention_block( # TODO: figure out caching for cross-attention use_cache=False, ) - assert torch.jit.isinstance( - output, Tensor - ), "cross-attention output must be Tensor." + assert isinstance(output, Tensor), "cross-attention output must be Tensor." + attention_output = self.cross_attention_dropout(output) return attention_output