diff --git a/pyproject.toml b/pyproject.toml index 24fcad2d..4e19127d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "zetascale" -version = "2.5.6" +version = "2.5.8" description = "Rapidly Build, Optimize, and Deploy SOTA AI Models" authors = ["Zeta Team "] license = "MIT" diff --git a/zeta/nn/attention/attend.py b/zeta/nn/attention/attend.py index 54915248..b57050e0 100644 --- a/zeta/nn/attention/attend.py +++ b/zeta/nn/attention/attend.py @@ -305,7 +305,7 @@ def forward(self, q, k, v, mask=None, attn_bias=None, prev_attn=None): Intermediates: Intermediate values during attention computation. """ - n, heads, kv_heads, device = ( + _n, heads, kv_heads, device = ( q.shape[-2], q.shape[1], k.shape[1], diff --git a/zeta/nn/attention/cross_attention.py b/zeta/nn/attention/cross_attention.py index 6d557cfa..62992128 100644 --- a/zeta/nn/attention/cross_attention.py +++ b/zeta/nn/attention/cross_attention.py @@ -69,7 +69,7 @@ def forward(self, x, context, mask=None): Returns: torch.Tensor: The output tensor of shape (batch_size, sequence_length, dim). """ - b, n, device = *x.shape[:2], x.device + b, _n, _device = *x.shape[:2], x.device x = self.norm(x) context = self.norm_context(context) diff --git a/zeta/nn/attention/local_attention.py b/zeta/nn/attention/local_attention.py index 323e36db..d3da6bcf 100644 --- a/zeta/nn/attention/local_attention.py +++ b/zeta/nn/attention/local_attention.py @@ -143,7 +143,7 @@ def forward( ), "cannot perform window size extrapolation if xpos is not turned on" ( - shape, + _shape, autopad, pad_value, window_size, @@ -176,7 +176,7 @@ def forward( (q, k, v), ) - b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype + b, n, dim_head, device, _dtype = *q.shape, q.device, q.dtype scale = default(self.scale, dim_head**-0.5) diff --git a/zeta/nn/attention/multi_modal_causal_attention.py b/zeta/nn/attention/multi_modal_causal_attention.py index 1524133a..8a1061e8 100644 --- a/zeta/nn/attention/multi_modal_causal_attention.py +++ b/zeta/nn/attention/multi_modal_causal_attention.py @@ -20,7 +20,7 @@ def __init__( self.to_out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout)) def forward(self, visual_features, textual_features, mask=None): - b, n, _, h = *visual_features.shape, self.heads + _b, _n, _, h = *visual_features.shape, self.heads qkv_visual = self.to_qkv(visual_features).chunk(3, dim=-1) qkv_textual = self.to_qkv(textual_features).chunk(3, dim=-1) diff --git a/zeta/nn/modules/__init__.py b/zeta/nn/modules/__init__.py index 01d9a867..a5cd6e0c 100644 --- a/zeta/nn/modules/__init__.py +++ b/zeta/nn/modules/__init__.py @@ -220,6 +220,8 @@ from zeta.nn.modules.simple_lstm import SimpleLSTM from zeta.nn.modules.simple_rnn import SimpleRNN from zeta.nn.modules.cope import CoPE +from zeta.nn.modules.multi_layer_key_cache import MultiLayerKeyValueAttention + # from zeta.nn.modules.img_reshape import image_reshape # from zeta.nn.modules.flatten_features import flatten_features @@ -442,4 +444,5 @@ "SimpleLSTM", "SimpleRNN", "CoPE", + "MultiLayerKeyValueAttention", ] diff --git a/zeta/nn/modules/multi_layer_key_cache.py b/zeta/nn/modules/multi_layer_key_cache.py new file mode 100644 index 00000000..08f9e1ea --- /dev/null +++ b/zeta/nn/modules/multi_layer_key_cache.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn + + +class MultiLayerKeyValueAttention(nn.Module): + def __init__(self, embed_size, num_heads, num_layers, kv_layers): + super(MultiLayerKeyValueAttention, self).__init__() + self.num_heads = num_heads + self.num_layers = num_layers + self.kv_layers = kv_layers # m in the description + self.embed_size = embed_size + self.head_dim = embed_size // num_heads + + assert ( + self.head_dim * num_heads == embed_size + ), "Embedding size needs to be divisible by num_heads" + + # Define the key and value projections for each layer + self.values = nn.ModuleList( + [ + nn.Linear(embed_size, embed_size, bias=False) + for _ in range(kv_layers) + ] + ) + self.keys = nn.ModuleList( + [ + nn.Linear(embed_size, embed_size, bias=False) + for _ in range(kv_layers) + ] + ) + + # Define the query projections for each layer + self.queries = nn.ModuleList( + [ + nn.Linear(embed_size, embed_size, bias=False) + for _ in range(num_layers) + ] + ) + + self.fc_out = nn.Linear(embed_size, embed_size) + + def forward(self, values, keys, queries): + N = queries.shape[0] + value_len, key_len, query_len = ( + values.shape[1], + keys.shape[1], + queries.shape[1], + ) + + out = torch.zeros(N, query_len, self.embed_size).to(values.device) + + for layer in range(self.num_layers): + kv_index = layer % self.kv_layers + + values_layer = self.values[kv_index](values).view( + N, value_len, self.num_heads, self.head_dim + ) + keys_layer = self.keys[kv_index](keys).view( + N, key_len, self.num_heads, self.head_dim + ) + queries_layer = self.queries[layer](queries).view( + N, query_len, self.num_heads, self.head_dim + ) + + energy = torch.einsum( + "nqhd,nkhd->nhqk", [queries_layer, keys_layer] + ) + attention = torch.softmax( + energy / (self.embed_size ** (1 / 2)), dim=3 + ) + out_layer = torch.einsum( + "nhql,nlhd->nqhd", [attention, values_layer] + ).reshape(N, query_len, self.embed_size) + + out += out_layer + + out = self.fc_out(out) + return out + + +# Example usage +embed_size = 256 +num_heads = 8 +num_layers = 4 +kv_layers = 2 # Number of layers with their own KV heads + +mlkv_attention = MultiLayerKeyValueAttention( + embed_size, num_heads, num_layers, kv_layers +) +values = torch.rand(32, 10, embed_size) # batch size 32, sequence length 10 +keys = torch.rand(32, 10, embed_size) +queries = torch.rand(32, 10, embed_size) + +output = mlkv_attention(values, keys, queries) +print(output.shape) diff --git a/zeta/nn/modules/perceiver_resampler.py b/zeta/nn/modules/perceiver_resampler.py index a56a207b..f8f55f22 100644 --- a/zeta/nn/modules/perceiver_resampler.py +++ b/zeta/nn/modules/perceiver_resampler.py @@ -51,7 +51,7 @@ def forward(self, x, latents): x = self.norm_media(x) latents = self.norm_latents(latents) - b, m, h = *x.shape[:2], self.heads + _b, _m, h = *x.shape[:2], self.heads q = self.to_q(latents) kv_input = torch.cat((x, latents), dim=-2) diff --git a/zeta/nn/modules/return_loss_text.py b/zeta/nn/modules/return_loss_text.py index 7a8dd132..29018c87 100644 --- a/zeta/nn/modules/return_loss_text.py +++ b/zeta/nn/modules/return_loss_text.py @@ -26,7 +26,7 @@ def return_loss_text( Returns: Tensor: The computed cross-entropy loss. """ - seq, labels = x[:, :-1], x[:, 1:] + _seq, labels = x[:, :-1], x[:, 1:] labels = labels.masked_fill(~mask[:, 1:], ignore_index) diff --git a/zeta/nn/modules/sparse_moe.py b/zeta/nn/modules/sparse_moe.py index b88c98a2..e0652244 100644 --- a/zeta/nn/modules/sparse_moe.py +++ b/zeta/nn/modules/sparse_moe.py @@ -300,7 +300,7 @@ def __init__( self.loss_coef = loss_coef def forward(self, inputs, **kwargs): - b, n, d, e = *inputs.shape, self.num_experts + _b, _n, d, e = *inputs.shape, self.num_experts dispatch_tensor, combine_tensor, loss = self.gate(inputs) expert_inputs = torch.einsum("bnd,bnec->ebcd", inputs, dispatch_tensor) @@ -373,7 +373,7 @@ def __init__( self.loss_coef = loss_coef def forward(self, inputs, **kwargs): - b, n, d, eo, ei = ( + _b, _n, d, eo, ei = ( *inputs.shape, self.num_experts_outer, self.num_experts_inner, diff --git a/zeta/nn/modules/top_n_gating.py b/zeta/nn/modules/top_n_gating.py index 34f565da..acddb659 100644 --- a/zeta/nn/modules/top_n_gating.py +++ b/zeta/nn/modules/top_n_gating.py @@ -124,7 +124,7 @@ def forward(self, x, noise_gates=False, noise_mult=1.0): k - top-n experts """ - *_, b, group_size, dim, dtype, top_n, num_gates, eps = ( + *_, _b, group_size, _dim, dtype, top_n, num_gates, eps = ( *x.shape, x.dtype, self.top_n, diff --git a/zeta/structs/transformer.py b/zeta/structs/transformer.py index ac6d24a1..acf032db 100644 --- a/zeta/structs/transformer.py +++ b/zeta/structs/transformer.py @@ -317,7 +317,7 @@ def forward(self, q, k, v, mask=None, attn_bias=None, prev_attn=None): Intermediates: Intermediate values during attention computation. """ - n, heads, kv_heads, device = ( + _n, heads, kv_heads, device = ( q.shape[-2], q.shape[1], k.shape[1],