diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index 47f42d57..19e0c488 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -381,7 +381,7 @@ def _cross_attention_block( # TODO: figure out caching for cross-attention use_cache=False, ) - assert torch.jit.isinstance( + assert isinstance( output, Tensor ), "cross-attention output must be Tensor." attention_output = self.cross_attention_dropout(output)