From ec9d7280cbeb31f1699fe1b114fced0f7fb521ba Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 27 Aug 2023 15:59:59 +0000 Subject: [PATCH 1/7] add parallel_attention.py --- .../modules/layers/parallel_attention.py | 269 ++++++++++++++++++ 1 file changed, 269 insertions(+) create mode 100644 torchmultimodal/modules/layers/parallel_attention.py diff --git a/torchmultimodal/modules/layers/parallel_attention.py b/torchmultimodal/modules/layers/parallel_attention.py new file mode 100644 index 00000000..c3d5e27a --- /dev/null +++ b/torchmultimodal/modules/layers/parallel_attention.py @@ -0,0 +1,269 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from normalizations import RMSNorm +from packaging import version + +# from position_embedding import RotaryEmbedding + + +def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor: + """expands kv_heads to match q num_heads + via + torch.repeat_interleave(x, dim=2, repeats=n_rep)""" + + bs, slen, n_kv_heads, head_dim = x.shape + + if n_rep == 1: + return x + return ( + x[:, :, :, None, :] + .expand(bs, slen, n_kv_heads, n_rep, head_dim) + .reshape(bs, slen, n_kv_heads * n_rep, head_dim) + ) + + +class ParallelAttentionBlock(nn.Module): + """ + Transformer layer multi-head attention and MLP, in a parallelized fashion rather than sequential, + with optional attention masking. + Inspired by PaLM: https://arxiv.org/abs/2204.02311 + + * We use SwiGLU for the activation function + * SwiGLU will approximate same total num params as traditional MLP with GELU + * Cross Attention is not enabled here + + * MQA and GQA are enabled - modify heads via 'num_heads_group_query_attn' + * MQA is num_heads_group_query_attn = 1 + * GQA is num_heads_group_query_attn < num_heads, and must be evenly divisible into num_heads + + * Bias is enabled by default. Experiment with removing via use...bias = False, for your application. + + * Parallel blocks have automated weight initialization via _init_weights. + * Please pass in num_layers of your model in num_layers for the weight initialization. + """ + + def __init__( + self, + emb_dimension, + num_heads, + head_dimension=None, + mlp_expansion_ratio: float = 2.6875, # 8/3 is param matching + qk_normalization: bool = True, + projection_dropout: float = 0.0, + attention_dropout: float = 0.0, + use_group_query_attention: bool = True, + num_heads_group_query_attention: int = 1, + use_in_projection_bias: bool = True, + use_out_projection_bias: bool = True, + use_weight_init: bool = True, + num_layers: int = 1, + use_rms_norm: bool = True, + use_rotary_embeddings: bool = False, + max_expected_seq_len: int = 2048, # needed only if using rotary + ): + super().__init__() + + version_check = not version.parse(torch.__version__) < version.parse("2.0.0") + assert ( + version_check + ), f"Parallel Attention Blocks requires PT 2.0+, you are running {torch.__version__}.\nPlease upgrade your PyTorch version." + + self.num_heads = num_heads + self.emb_dim = emb_dimension + self.head_dim = head_dimension if head_dimension else emb_dimension // num_heads + assert ( + self.emb_dim % self.num_heads == 0 + ), f"dimensions {self.emb_dim.shape} must be evenly divisible by num_heads {num_heads=}" + + # group query attn + if use_group_query_attention: + assert ( + self.num_heads % num_heads_group_query_attention == 0 + ), f"{self.num_heads=} not evenly divisible by {num_heads_group_query_attention=}" + + self.use_variable_kv = use_group_query_attention + self.group_num_kv = num_heads_group_query_attention + self.num_kv = self.group_num_kv if self.use_variable_kv else self.num_heads + self.kv_head_dims = self.head_dim * self.num_kv + self.kv_expansion_factor = int(self.num_heads / self.group_num_kv) + assert ( + self.kv_expansion_factor > 0 + ), f"kv expansion factor must be positive integer, got {self.kv_expansion_factor=}" + + self.mlp_hidden_dim = int(mlp_expansion_ratio * self.emb_dim) + + self.qk_norm: bool = qk_normalization + + self.attention_dropout = nn.Dropout(attention_dropout) + self.mlp_dropout = nn.Dropout(projection_dropout) + + # weight init + self.num_layers = num_layers + self.use_weight_init = use_weight_init + if self.use_weight_init: + assert ( + self.num_layers > 1 + ), f"Need to pass in global num layers for weight init, {self.num_layers=}" + + self.weight_init_standard_dev = 0.02 / math.sqrt(2 * self.num_layers) + + self.use_in_projection_bias = use_in_projection_bias + self.use_out_projection_bias = use_out_projection_bias + + # previous init params, moved to internal defaults for streamlining + normalization_layer = RMSNorm if use_rms_norm else nn.LayerNorm + self.mlp_activation = nn.SiLU() + + self.num_q = 1 + + self.in_proj_dims = [ + self.head_dim * num_heads * self.num_q, + self.kv_head_dims, + self.kv_head_dims, + self.mlp_hidden_dim, + self.mlp_hidden_dim, + ] # q, k, v, mlp, gate + + # layer objects + self.in_norm = normalization_layer(emb_dimension) + self.in_proj = nn.Linear( + emb_dimension, sum(self.in_proj_dims), bias=use_in_projection_bias + ) + + self.q_norm = normalization_layer(self.head_dim) + self.k_norm = normalization_layer(self.head_dim) + + # fused out projection + fused_out_input_dim = emb_dimension + self.mlp_hidden_dim + self.out_fused_proj = nn.Linear( + fused_out_input_dim, emb_dimension, bias=use_out_projection_bias + ) + + # rotary embeddings + if use_rotary_embeddings: + raise AssertionError("RotaryEmbeddings has not been merged yet...") + # self.rotary_emb = RotaryEmbedding(emb_dimension, max_expected_seq_len) + else: + self.rotary_emb = None + + # init weights + if use_weight_init: + self.apply(self._init_weights) + + def _init_weights(self, module: nn.Module): + """init weights using trunc + llama style depth scaling""" + if isinstance(module, nn.Linear): + torch.nn.init.trunc_normal_( + module.weight, + mean=0.0, + # std_dev = 0.02 / math.sqrt(2 * self.num_layers) + std=self.weight_init_standard_dev, + ) + + if module.bias is not None: + nn.init.zeros_(module.bias) + + def forward( + self, + x: torch.Tensor, + cross_x: Optional[torch.Tensor] = None, + mask: Optional[torch.Tensor] = None, + cross_mask: Optional[torch.Tensor] = None, + rel_pos_bias: Optional[torch.Tensor] = None, + has_causal_mask: bool = False, + ): + """TODO: No KV cache support yet""" + + assert not ( + rel_pos_bias is not None and self.rotary_emb is not None + ), "Rotary and additive biases are exclusive" + assert not ( + (rel_pos_bias is not None or attn_mask is not None) and has_causal_mask + ), "Causal mask optimization only valid without attn_mask or rel_pos_bias" + + batch_size, seq_len, channels = x.shape + + y = self.in_norm(x) + y = self.in_proj(y) + + q, k, v, inner_mlp, gate = torch.split(y, self.in_proj_dims, dim=-1) + + # b n nq h d + q = q.view(batch_size, seq_len, self.num_q, self.num_heads, self.head_dim) + + q = q[:, :, 0].transpose(2, 1) + + if self.rotary_emb: + start_pos = 0 # TODO: No kv-cache yet, when that happens this is seqlen saved in kv-cache + q, k = self.rotary_emb(q, k, start_pos) + + # group query expansion + def kv_expansion(head): + head = head.view( + batch_size, seq_len, self.num_kv, self.head_dim + ) # b n hnum dimh + # bs, slen, n_kv_heads, head_dim = x.shape + if self.use_variable_kv and self.num_kv > 1: + head = repeat_kv(head, n_rep=self.kv_expansion_factor) + return head.transpose(2, 1) # b hnum n dimh + + k = kv_expansion(k) + v = kv_expansion(v) + + if self.qk_norm: + q = self.q_norm(q) + k = self.k_norm(k) + + # Merge rel pos bias and mask into single float mask + if rel_pos_bias is None: + # Given SDPA API, we expect users to either provide a boolean mask if + # they expect masked_fill to be done inside SDPA, or provide the float + # mask already with the correct -inf + attn_mask = mask # b? ...? nq nk + else: + attn_mask = rel_pos_bias # b? ...? nq nk + + # We expect the shapes of mask and rel_pos_bias to be at least broadcastable + if mask is not None: + # Can't do in-place op in case broadcast makes attn_mask bigger + attn_mask = attn_mask.masked_fill(mask == 0, -float("inf")) + + final_attn = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask=attn_mask, + dropout_p=self.attention_dropout.p, + is_causal=has_causal_mask, + ) + + final_attn = ( + final_attn.transpose(2, 1) + .contiguous() + .view(batch_size, seq_len, self.head_dim * self.num_heads) + ) + + # swiglu + activated_mlp = self.mlp_activation(inner_mlp) * gate + + if self.mlp_dropout.p: + activated_mlp = self.mlp_dropout(activated_mlp) + + y = torch.cat((final_attn, activated_mlp), dim=2) + + y = self.out_fused_proj(y) + + # Add residual + x = x + y + return x From 615225d821ae3ad74c6dea40f7e1b1d95cc20aff Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 27 Aug 2023 17:04:34 +0000 Subject: [PATCH 2/7] add first unit tests for parallel_attention_blocks --- .../modules/layers/test_parallel_attention.py | 100 ++++++++++++++++++ .../modules/layers/parallel_attention.py | 10 +- 2 files changed, 105 insertions(+), 5 deletions(-) create mode 100644 tests/modules/layers/test_parallel_attention.py diff --git a/tests/modules/layers/test_parallel_attention.py b/tests/modules/layers/test_parallel_attention.py new file mode 100644 index 00000000..4642dbd2 --- /dev/null +++ b/tests/modules/layers/test_parallel_attention.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import pytest + +# import torch +from tests.test_utils import assert_expected # , init_weights_with_constant + +# from torch import nn +from torchmultimodal.modules.layers.parallel_attention import ParallelAttentionBlock + + +class TestParallelAttentionBlocks: + @pytest.fixture + def embedding_dim(self): + return 64 + + @pytest.fixture + def total_layers(self): + return 1 + + @pytest.fixture + def mqa_num_heads(self): + return 1 + + @pytest.fixture + def gqa_num_heads(self): + return 2 + + @pytest.fixture + def num_heads(self): + return 16 + + @pytest.fixture + def mha_parallel_attention(self, embedding_dim, num_heads, total_layers): + print(f"{embedding_dim=}, {num_heads=}, {total_layers=}") + pab_mha = ParallelAttentionBlock( + emb_dimension=embedding_dim, + num_heads=num_heads, + use_group_query_attention=False, + num_layers=total_layers, + use_weight_init=True, + ) + pab_mha.eval() + return pab_mha + + @pytest.fixture + def gqa_parallel_attention( + self, embedding_dim, num_heads, total_layers, gqa_num_heads + ): + print(f"{embedding_dim=}, {num_heads=}, {total_layers=}") + pab_gqa = ParallelAttentionBlock( + emb_dimension=embedding_dim, + num_heads=num_heads, + use_group_query_attention=True, + num_heads_group_query_attention=gqa_num_heads, + num_layers=total_layers, + use_weight_init=True, + ) + pab_gqa.eval() + return pab_gqa + + @pytest.fixture + def mqa_parallel_attention( + self, embedding_dim, num_heads, total_layers, mqa_num_heads + ): + print(f"{embedding_dim=}, {num_heads=}, {total_layers=}") + pab_mqa = ParallelAttentionBlock( + emb_dimension=embedding_dim, + num_heads=num_heads, + use_group_query_attention=True, + num_heads_group_query_attention=mqa_num_heads, + num_layers=total_layers, + use_weight_init=True, + ) + pab_mqa.eval() + return pab_mqa + + def test_mha_parallel_attention(self, mha_parallel_attention, num_heads): + # confirm all K and V keys match Q (i.e. MHA) + assert_expected(num_heads, mha_parallel_attention.num_kv) + + def test_mqa_parallel_attention( + self, mqa_parallel_attention, num_heads, mqa_num_heads + ): + print("in test") + + # confirm all K and V keys match MQA num heads (i.e. MQA == 1) + assert_expected(mqa_num_heads, mqa_parallel_attention.num_kv) + + def test_gqa_parallel_attention( + self, gqa_parallel_attention, num_heads, gqa_num_heads + ): + print("in test") + + # confirm all K and V keys match GQA num heads (i.e. GQA >= 2) + assert_expected(gqa_num_heads, gqa_parallel_attention.num_kv) diff --git a/torchmultimodal/modules/layers/parallel_attention.py b/torchmultimodal/modules/layers/parallel_attention.py index c3d5e27a..633169bd 100644 --- a/torchmultimodal/modules/layers/parallel_attention.py +++ b/torchmultimodal/modules/layers/parallel_attention.py @@ -10,10 +10,10 @@ import torch import torch.nn as nn import torch.nn.functional as F - -from normalizations import RMSNorm from packaging import version +from torchmultimodal.modules.layers.normalizations import RMSNorm + # from position_embedding import RotaryEmbedding @@ -67,7 +67,7 @@ def __init__( use_in_projection_bias: bool = True, use_out_projection_bias: bool = True, use_weight_init: bool = True, - num_layers: int = 1, + num_layers: int = 0, use_rms_norm: bool = True, use_rotary_embeddings: bool = False, max_expected_seq_len: int = 2048, # needed only if using rotary @@ -113,7 +113,7 @@ def __init__( self.use_weight_init = use_weight_init if self.use_weight_init: assert ( - self.num_layers > 1 + self.num_layers > 0 ), f"Need to pass in global num layers for weight init, {self.num_layers=}" self.weight_init_standard_dev = 0.02 / math.sqrt(2 * self.num_layers) @@ -189,7 +189,7 @@ def forward( rel_pos_bias is not None and self.rotary_emb is not None ), "Rotary and additive biases are exclusive" assert not ( - (rel_pos_bias is not None or attn_mask is not None) and has_causal_mask + (rel_pos_bias is not None or mask is not None) and has_causal_mask ), "Causal mask optimization only valid without attn_mask or rel_pos_bias" batch_size, seq_len, channels = x.shape From 290ce85a502957e7027dec07546d7a4fb91a8a30 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 27 Aug 2023 17:19:54 +0000 Subject: [PATCH 3/7] confirm num Q heads matches num_heads --- tests/modules/layers/test_parallel_attention.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/tests/modules/layers/test_parallel_attention.py b/tests/modules/layers/test_parallel_attention.py index 4642dbd2..eca257ab 100644 --- a/tests/modules/layers/test_parallel_attention.py +++ b/tests/modules/layers/test_parallel_attention.py @@ -36,7 +36,6 @@ def num_heads(self): @pytest.fixture def mha_parallel_attention(self, embedding_dim, num_heads, total_layers): - print(f"{embedding_dim=}, {num_heads=}, {total_layers=}") pab_mha = ParallelAttentionBlock( emb_dimension=embedding_dim, num_heads=num_heads, @@ -51,7 +50,6 @@ def mha_parallel_attention(self, embedding_dim, num_heads, total_layers): def gqa_parallel_attention( self, embedding_dim, num_heads, total_layers, gqa_num_heads ): - print(f"{embedding_dim=}, {num_heads=}, {total_layers=}") pab_gqa = ParallelAttentionBlock( emb_dimension=embedding_dim, num_heads=num_heads, @@ -67,7 +65,6 @@ def gqa_parallel_attention( def mqa_parallel_attention( self, embedding_dim, num_heads, total_layers, mqa_num_heads ): - print(f"{embedding_dim=}, {num_heads=}, {total_layers=}") pab_mqa = ParallelAttentionBlock( emb_dimension=embedding_dim, num_heads=num_heads, @@ -82,6 +79,8 @@ def mqa_parallel_attention( def test_mha_parallel_attention(self, mha_parallel_attention, num_heads): # confirm all K and V keys match Q (i.e. MHA) assert_expected(num_heads, mha_parallel_attention.num_kv) + # confirm num Q matches num_heads + assert_expected(num_heads, mha_parallel_attention.num_heads) def test_mqa_parallel_attention( self, mqa_parallel_attention, num_heads, mqa_num_heads @@ -90,6 +89,8 @@ def test_mqa_parallel_attention( # confirm all K and V keys match MQA num heads (i.e. MQA == 1) assert_expected(mqa_num_heads, mqa_parallel_attention.num_kv) + # confirm num Q matches num_heads + assert_expected(num_heads, mqa_parallel_attention.num_heads) def test_gqa_parallel_attention( self, gqa_parallel_attention, num_heads, gqa_num_heads @@ -98,3 +99,5 @@ def test_gqa_parallel_attention( # confirm all K and V keys match GQA num heads (i.e. GQA >= 2) assert_expected(gqa_num_heads, gqa_parallel_attention.num_kv) + # confirm num Q matches num_heads + assert_expected(num_heads, gqa_parallel_attention.num_heads) From 7972d4dbf57629b4ffe0eb4f2b25ceb8b253acb8 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Sun, 27 Aug 2023 23:23:02 +0000 Subject: [PATCH 4/7] extend unit tests to check output numerics and attn_output shape --- .../modules/layers/test_parallel_attention.py | 129 ++++++++++++++++-- 1 file changed, 118 insertions(+), 11 deletions(-) diff --git a/tests/modules/layers/test_parallel_attention.py b/tests/modules/layers/test_parallel_attention.py index eca257ab..1ee1e4d1 100644 --- a/tests/modules/layers/test_parallel_attention.py +++ b/tests/modules/layers/test_parallel_attention.py @@ -5,18 +5,21 @@ # LICENSE file in the root directory of this source tree. import pytest +import torch -# import torch from tests.test_utils import assert_expected # , init_weights_with_constant - -# from torch import nn from torchmultimodal.modules.layers.parallel_attention import ParallelAttentionBlock +@pytest.fixture(autouse=True) +def random(): + torch.manual_seed(2023) + + class TestParallelAttentionBlocks: @pytest.fixture def embedding_dim(self): - return 64 + return 16 @pytest.fixture def total_layers(self): @@ -34,6 +37,10 @@ def gqa_num_heads(self): def num_heads(self): return 16 + @pytest.fixture + def max_seq_len(self): + return 32 + @pytest.fixture def mha_parallel_attention(self, embedding_dim, num_heads, total_layers): pab_mha = ParallelAttentionBlock( @@ -76,28 +83,128 @@ def mqa_parallel_attention( pab_mqa.eval() return pab_mqa - def test_mha_parallel_attention(self, mha_parallel_attention, num_heads): + def test_mha_parallel_attention( + self, + mha_parallel_attention, + num_heads, + embedding_dim, + max_seq_len, + ): # confirm all K and V keys match Q (i.e. MHA) assert_expected(num_heads, mha_parallel_attention.num_kv) # confirm num Q matches num_heads assert_expected(num_heads, mha_parallel_attention.num_heads) + # input_ones = torch.ones(dims, dtype=torch.float) + + x = torch.randint(0, 256, (1, max_seq_len, embedding_dim)) # bs =1, + attn_output = mha_parallel_attention(x) + + fixed_result_firstrow = torch.tensor( + [ + 15.9989, + 119.0005, + 32.0014, + 119.9999, + 113.9993, + 8.9996, + 141.0015, + 200.0015, + 136.9985, + 238.9991, + 236.0013, + 144.9993, + 224.9991, + 165.9994, + 193.9994, + 93.0001, + ], + dtype=torch.float32, + ) + fixed_output_shape = torch.Size([1, max_seq_len, embedding_dim]) + + assert_expected(fixed_result_firstrow, attn_output[0][0], rtol=0, atol=1e-4) + assert_expected(fixed_output_shape, attn_output.shape) + def test_mqa_parallel_attention( - self, mqa_parallel_attention, num_heads, mqa_num_heads + self, + mqa_parallel_attention, + num_heads, + mqa_num_heads, + max_seq_len, + embedding_dim, ): - print("in test") - # confirm all K and V keys match MQA num heads (i.e. MQA == 1) assert_expected(mqa_num_heads, mqa_parallel_attention.num_kv) # confirm num Q matches num_heads assert_expected(num_heads, mqa_parallel_attention.num_heads) + x = torch.randint(0, 256, (1, max_seq_len, embedding_dim)) + attn_output = mqa_parallel_attention(x) + + fixed_result_firstrow = torch.tensor( + [ + 91.9992, + 24.0038, + 237.9937, + 74.0036, + 186.0031, + 53.0041, + 106.0050, + 179.9931, + 190.9989, + 178.9995, + 82.0005, + 190.9972, + 213.0028, + 213.9989, + 12.0008, + 190.9990, + ], + dtype=torch.float32, + ) + fixed_output_shape = torch.Size([1, max_seq_len, embedding_dim]) + assert_expected(fixed_output_shape, attn_output.shape) + # print(f"{attn_output[0][0]}") + assert_expected(fixed_result_firstrow, attn_output[0][0], rtol=0, atol=1e-4) + def test_gqa_parallel_attention( - self, gqa_parallel_attention, num_heads, gqa_num_heads + self, + gqa_parallel_attention, + num_heads, + gqa_num_heads, + max_seq_len, + embedding_dim, ): - print("in test") - # confirm all K and V keys match GQA num heads (i.e. GQA >= 2) assert_expected(gqa_num_heads, gqa_parallel_attention.num_kv) # confirm num Q matches num_heads assert_expected(num_heads, gqa_parallel_attention.num_heads) + + x = torch.randint(0, 256, (1, max_seq_len, embedding_dim)) + attn_output = gqa_parallel_attention(x) + + fixed_result_firstrow = torch.tensor( + [ + 201.0000, + 138.0011, + 196.9992, + 82.9997, + 4.9996, + 211.9985, + 103.9994, + 15.9996, + 177.0008, + 140.9993, + 213.9985, + 199.0000, + 146.9993, + 207.0003, + 109.0001, + 3.0005, + ], + dtype=torch.float32, + ) + fixed_output_shape = torch.Size([1, max_seq_len, embedding_dim]) + assert_expected(fixed_output_shape, attn_output.shape) + assert_expected(fixed_result_firstrow, attn_output[0][0], rtol=0, atol=1e-4) From 5d83d48b06753e4c4131d99faf34fb238f2bc8b6 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Mon, 28 Aug 2023 23:38:38 +0000 Subject: [PATCH 5/7] fix typing for missing function returns --- torchmultimodal/modules/layers/parallel_attention.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchmultimodal/modules/layers/parallel_attention.py b/torchmultimodal/modules/layers/parallel_attention.py index 633169bd..b7b2d47a 100644 --- a/torchmultimodal/modules/layers/parallel_attention.py +++ b/torchmultimodal/modules/layers/parallel_attention.py @@ -71,7 +71,7 @@ def __init__( use_rms_norm: bool = True, use_rotary_embeddings: bool = False, max_expected_seq_len: int = 2048, # needed only if using rotary - ): + ) -> None: super().__init__() version_check = not version.parse(torch.__version__) < version.parse("2.0.0") @@ -161,7 +161,7 @@ def __init__( if use_weight_init: self.apply(self._init_weights) - def _init_weights(self, module: nn.Module): + def _init_weights(self, module: nn.Module) -> None: """init weights using trunc + llama style depth scaling""" if isinstance(module, nn.Linear): torch.nn.init.trunc_normal_( @@ -182,7 +182,7 @@ def forward( cross_mask: Optional[torch.Tensor] = None, rel_pos_bias: Optional[torch.Tensor] = None, has_causal_mask: bool = False, - ): + ) -> torch.Tensor: """TODO: No KV cache support yet""" assert not ( From 2f8e6ffb573ad966e6bbe0b0a59d0928866848e9 Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Mon, 28 Aug 2023 23:59:42 +0000 Subject: [PATCH 6/7] add typing for inline function kv_expansion params --- torchmultimodal/modules/layers/parallel_attention.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchmultimodal/modules/layers/parallel_attention.py b/torchmultimodal/modules/layers/parallel_attention.py index b7b2d47a..b801b8e1 100644 --- a/torchmultimodal/modules/layers/parallel_attention.py +++ b/torchmultimodal/modules/layers/parallel_attention.py @@ -55,9 +55,9 @@ class ParallelAttentionBlock(nn.Module): def __init__( self, - emb_dimension, - num_heads, - head_dimension=None, + emb_dimension: int, + num_heads: int, + head_dimension: int = None, mlp_expansion_ratio: float = 2.6875, # 8/3 is param matching qk_normalization: bool = True, projection_dropout: float = 0.0, @@ -209,7 +209,7 @@ def forward( q, k = self.rotary_emb(q, k, start_pos) # group query expansion - def kv_expansion(head): + def kv_expansion(head: torch.Tensor) -> torch.Tensor: head = head.view( batch_size, seq_len, self.num_kv, self.head_dim ) # b n hnum dimh From 8b0608b1aa8eb77810ac06cb31d149bb0875267e Mon Sep 17 00:00:00 2001 From: lessw2020 Date: Tue, 29 Aug 2023 00:28:10 +0000 Subject: [PATCH 7/7] remove .shape from item assert attribute --- torchmultimodal/modules/layers/parallel_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchmultimodal/modules/layers/parallel_attention.py b/torchmultimodal/modules/layers/parallel_attention.py index b801b8e1..ac099d24 100644 --- a/torchmultimodal/modules/layers/parallel_attention.py +++ b/torchmultimodal/modules/layers/parallel_attention.py @@ -84,7 +84,7 @@ def __init__( self.head_dim = head_dimension if head_dimension else emb_dimension // num_heads assert ( self.emb_dim % self.num_heads == 0 - ), f"dimensions {self.emb_dim.shape} must be evenly divisible by num_heads {num_heads=}" + ), f"dimensions {self.emb_dim} must be evenly divisible by num_heads {num_heads=}" # group query attn if use_group_query_attention: