From 0dc3c216632a4ba0e71ea399b300b2239df085d4 Mon Sep 17 00:00:00 2001 From: Evan Smothers Date: Mon, 14 Aug 2023 15:32:04 -0700 Subject: [PATCH] Transformer decoder (#445) Summary: Pull Request resolved: https://github.com/facebookresearch/multimodal/pull/445 Add class for stack of transformer decoder layers Reviewed By: rck-meta, ankitade Differential Revision: D47891813 fbshipit-source-id: d58ebd5673247315cdaf18a2583f87d6e78fa274 --- torchmultimodal/modules/layers/transformer.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchmultimodal/modules/layers/transformer.py b/torchmultimodal/modules/layers/transformer.py index 3f2a3389..3e45545d 100644 --- a/torchmultimodal/modules/layers/transformer.py +++ b/torchmultimodal/modules/layers/transformer.py @@ -7,7 +7,7 @@ # Code for some of the transformers components in this file are initialized # from their counterparts in Hugging Face Transformers library. -from typing import List, NamedTuple, Optional +from typing import List, NamedTuple, Optional, Tuple from torch import Tensor @@ -18,3 +18,4 @@ class TransformerOutput(NamedTuple): hidden_states: Optional[List[Tensor]] = None attentions: Optional[List[Tensor]] = None image_labels: Optional[Tensor] = None + current_key_values: Optional[List[Tuple[Tensor, Tensor]]] = None