Skip to content

Commit

Permalink
[CLEANUP]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye Gomez authored and Kye Gomez committed Jun 14, 2024
1 parent 41e1f0a commit c20c516
Show file tree
Hide file tree
Showing 12 changed files with 110 additions and 12 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "zetascale"
version = "2.5.6"
version = "2.5.8"
description = "Rapidly Build, Optimize, and Deploy SOTA AI Models"
authors = ["Zeta Team <kye@apac.ai>"]
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion zeta/nn/attention/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def forward(self, q, k, v, mask=None, attn_bias=None, prev_attn=None):
Intermediates: Intermediate values during attention computation.
"""

n, heads, kv_heads, device = (
_n, heads, kv_heads, device = (
q.shape[-2],
q.shape[1],
k.shape[1],
Expand Down
2 changes: 1 addition & 1 deletion zeta/nn/attention/cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def forward(self, x, context, mask=None):
Returns:
torch.Tensor: The output tensor of shape (batch_size, sequence_length, dim).
"""
b, n, device = *x.shape[:2], x.device
b, _n, _device = *x.shape[:2], x.device

x = self.norm(x)
context = self.norm_context(context)
Expand Down
4 changes: 2 additions & 2 deletions zeta/nn/attention/local_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def forward(
), "cannot perform window size extrapolation if xpos is not turned on"

(
shape,
_shape,
autopad,
pad_value,
window_size,
Expand Down Expand Up @@ -176,7 +176,7 @@ def forward(
(q, k, v),
)

b, n, dim_head, device, dtype = *q.shape, q.device, q.dtype
b, n, dim_head, device, _dtype = *q.shape, q.device, q.dtype

scale = default(self.scale, dim_head**-0.5)

Expand Down
2 changes: 1 addition & 1 deletion zeta/nn/attention/multi_modal_causal_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def __init__(
self.to_out = nn.Sequential(nn.Linear(dim, dim), nn.Dropout(dropout))

def forward(self, visual_features, textual_features, mask=None):
b, n, _, h = *visual_features.shape, self.heads
_b, _n, _, h = *visual_features.shape, self.heads

qkv_visual = self.to_qkv(visual_features).chunk(3, dim=-1)
qkv_textual = self.to_qkv(textual_features).chunk(3, dim=-1)
Expand Down
3 changes: 3 additions & 0 deletions zeta/nn/modules/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@
from zeta.nn.modules.simple_lstm import SimpleLSTM
from zeta.nn.modules.simple_rnn import SimpleRNN
from zeta.nn.modules.cope import CoPE
from zeta.nn.modules.multi_layer_key_cache import MultiLayerKeyValueAttention


# from zeta.nn.modules.img_reshape import image_reshape
# from zeta.nn.modules.flatten_features import flatten_features
Expand Down Expand Up @@ -442,4 +444,5 @@
"SimpleLSTM",
"SimpleRNN",
"CoPE",
"MultiLayerKeyValueAttention",
]
95 changes: 95 additions & 0 deletions zeta/nn/modules/multi_layer_key_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import torch
import torch.nn as nn


class MultiLayerKeyValueAttention(nn.Module):
def __init__(self, embed_size, num_heads, num_layers, kv_layers):
super(MultiLayerKeyValueAttention, self).__init__()
self.num_heads = num_heads
self.num_layers = num_layers
self.kv_layers = kv_layers # m in the description
self.embed_size = embed_size
self.head_dim = embed_size // num_heads

assert (
self.head_dim * num_heads == embed_size
), "Embedding size needs to be divisible by num_heads"

# Define the key and value projections for each layer
self.values = nn.ModuleList(
[
nn.Linear(embed_size, embed_size, bias=False)
for _ in range(kv_layers)
]
)
self.keys = nn.ModuleList(
[
nn.Linear(embed_size, embed_size, bias=False)
for _ in range(kv_layers)
]
)

# Define the query projections for each layer
self.queries = nn.ModuleList(
[
nn.Linear(embed_size, embed_size, bias=False)
for _ in range(num_layers)
]
)

self.fc_out = nn.Linear(embed_size, embed_size)

def forward(self, values, keys, queries):
N = queries.shape[0]
value_len, key_len, query_len = (
values.shape[1],
keys.shape[1],
queries.shape[1],
)

out = torch.zeros(N, query_len, self.embed_size).to(values.device)

for layer in range(self.num_layers):
kv_index = layer % self.kv_layers

values_layer = self.values[kv_index](values).view(
N, value_len, self.num_heads, self.head_dim
)
keys_layer = self.keys[kv_index](keys).view(
N, key_len, self.num_heads, self.head_dim
)
queries_layer = self.queries[layer](queries).view(
N, query_len, self.num_heads, self.head_dim
)

energy = torch.einsum(
"nqhd,nkhd->nhqk", [queries_layer, keys_layer]
)
attention = torch.softmax(
energy / (self.embed_size ** (1 / 2)), dim=3
)
out_layer = torch.einsum(
"nhql,nlhd->nqhd", [attention, values_layer]
).reshape(N, query_len, self.embed_size)

out += out_layer

out = self.fc_out(out)
return out


# Example usage
embed_size = 256
num_heads = 8
num_layers = 4
kv_layers = 2 # Number of layers with their own KV heads

mlkv_attention = MultiLayerKeyValueAttention(
embed_size, num_heads, num_layers, kv_layers
)
values = torch.rand(32, 10, embed_size) # batch size 32, sequence length 10
keys = torch.rand(32, 10, embed_size)
queries = torch.rand(32, 10, embed_size)

output = mlkv_attention(values, keys, queries)
print(output.shape)
2 changes: 1 addition & 1 deletion zeta/nn/modules/perceiver_resampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def forward(self, x, latents):
x = self.norm_media(x)
latents = self.norm_latents(latents)

b, m, h = *x.shape[:2], self.heads
_b, _m, h = *x.shape[:2], self.heads
q = self.to_q(latents)

kv_input = torch.cat((x, latents), dim=-2)
Expand Down
2 changes: 1 addition & 1 deletion zeta/nn/modules/return_loss_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def return_loss_text(
Returns:
Tensor: The computed cross-entropy loss.
"""
seq, labels = x[:, :-1], x[:, 1:]
_seq, labels = x[:, :-1], x[:, 1:]

labels = labels.masked_fill(~mask[:, 1:], ignore_index)

Expand Down
4 changes: 2 additions & 2 deletions zeta/nn/modules/sparse_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def __init__(
self.loss_coef = loss_coef

def forward(self, inputs, **kwargs):
b, n, d, e = *inputs.shape, self.num_experts
_b, _n, d, e = *inputs.shape, self.num_experts
dispatch_tensor, combine_tensor, loss = self.gate(inputs)
expert_inputs = torch.einsum("bnd,bnec->ebcd", inputs, dispatch_tensor)

Expand Down Expand Up @@ -373,7 +373,7 @@ def __init__(
self.loss_coef = loss_coef

def forward(self, inputs, **kwargs):
b, n, d, eo, ei = (
_b, _n, d, eo, ei = (
*inputs.shape,
self.num_experts_outer,
self.num_experts_inner,
Expand Down
2 changes: 1 addition & 1 deletion zeta/nn/modules/top_n_gating.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def forward(self, x, noise_gates=False, noise_mult=1.0):
k - top-n experts
"""

*_, b, group_size, dim, dtype, top_n, num_gates, eps = (
*_, _b, group_size, _dim, dtype, top_n, num_gates, eps = (
*x.shape,
x.dtype,
self.top_n,
Expand Down
2 changes: 1 addition & 1 deletion zeta/structs/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def forward(self, q, k, v, mask=None, attn_bias=None, prev_attn=None):
Intermediates: Intermediate values during attention computation.
"""

n, heads, kv_heads, device = (
_n, heads, kv_heads, device = (
q.shape[-2],
q.shape[1],
k.shape[1],
Expand Down

0 comments on commit c20c516

Please sign in to comment.