diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index 3f2a3389..3e45545d 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -7,7 +7,7 @@ # Code for some of the transformers components in this file are initialized # from their counterparts in Hugging Face Transformers library. -from typing import List, NamedTuple, Optional +from typing import List, NamedTuple, Optional, Tuple from torch import Tensor @@ -18,3 +18,4 @@ class TransformerOutput(NamedTuple): hidden_states: Optional[List[Tensor]] = None attentions: Optional[List[Tensor]] = None image_labels: Optional[Tensor] = None + current_key_values: Optional[List[Tuple[Tensor, Tensor]]] = None