Skip to content

Commit

Permalink
hyper connected coconut
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 31, 2024
1 parent 9a26cfd commit a764be1
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 15 deletions.
11 changes: 11 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,14 @@ answer = model.generate(prompt, max_length = 64) # (2, 64)
url = {https://api.semanticscholar.org/CorpusID:236171087}
}
```

```bibtex
@article{Zhu2024HyperConnections,
title = {Hyper-Connections},
author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou},
journal = {ArXiv},
year = {2024},
volume = {abs/2409.19606},
url = {https://api.semanticscholar.org/CorpusID:272987528}
}
```
21 changes: 14 additions & 7 deletions coconut_pytorch/coconut.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

from x_transformers.attend import Attend

from hyper_connections import get_init_and_expand_reduce_stream_functions

# helper functions

def exists(v):
Expand Down Expand Up @@ -138,20 +140,23 @@ def __init__(
dim_head = 64,
heads = 8,
ff_mult = 4,
attend_kwargs: dict = dict()
attend_kwargs: dict = dict(),
num_residual_streams = 4
):
super().__init__()
self.dim = dim

self.token_emb = nn.Embedding(num_tokens, dim)
self.rotary_emb = RotaryEmbedding(dim_head)

init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)

layers = ModuleList([])

for _ in range(depth):
layers.append(ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, attend_kwargs = attend_kwargs),
FeedForward(dim = dim, mult = ff_mult)
init_hyper_conn(dim = dim, branch = Attention(dim = dim, dim_head = dim_head, heads = heads, attend_kwargs = attend_kwargs)),
init_hyper_conn(dim = dim, branch = FeedForward(dim = dim, mult = ff_mult))
]))

self.layers = layers
Expand Down Expand Up @@ -201,20 +206,22 @@ def forward(

next_keys_values = []

x = self.expand_streams(x)

for attn, ff in self.layers:

attn_out, key_values = attn(
x, key_values = attn(
x,
rotary_pos_emb = rotary_pos_emb,
cached_kv = next(cached_kv_iter, None),
return_cached_kv = True
)

x = attn_out + x

next_keys_values.append(key_values)

x = ff(x) + x
x = ff(x)

x = self.reduce_streams(x)

embeds = self.norm(x)

Expand Down
21 changes: 14 additions & 7 deletions coconut_pytorch/multi_stream_coconut.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@

from x_transformers.attend import Attend

from hyper_connections import get_init_and_expand_reduce_stream_functions

# helper functions

def exists(v):
Expand Down Expand Up @@ -144,20 +146,23 @@ def __init__(
dim_head = 64,
heads = 8,
ff_mult = 4,
attend_kwargs: dict = dict()
attend_kwargs: dict = dict(),
num_residual_streams = 4
):
super().__init__()
self.dim = dim

self.token_emb = nn.Embedding(num_tokens, dim)
self.rotary_emb = RotaryEmbedding(dim_head)

init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)

layers = ModuleList([])

for _ in range(depth):
layers.append(ModuleList([
Attention(dim = dim, dim_head = dim_head, heads = heads, attend_kwargs = attend_kwargs),
FeedForward(dim = dim, mult = ff_mult)
init_hyper_conn(dim = dim, branch = Attention(dim = dim, dim_head = dim_head, heads = heads, attend_kwargs = attend_kwargs)),
init_hyper_conn(dim = dim, branch = FeedForward(dim = dim, mult = ff_mult))
]))

self.layers = layers
Expand Down Expand Up @@ -201,19 +206,21 @@ def forward(

next_keys_values = []

x = self.expand_streams(x)

for attn, ff in self.layers:

attn_out, key_values = attn(
x, key_values = attn(
x,
cached_kv = next(cached_kv_iter, None),
return_cached_kv = True
)

x = attn_out + x

next_keys_values.append(key_values)

x = ff(x) + x
x = ff(x)

x = self.reduce_streams(x)

embeds = self.norm(x)

Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "coconut-pytorch"
version = "0.0.28"
version = "0.0.29"
description = "Coconut in Pytorch"
authors = [
{ name = "Phil Wang", email = "lucidrains@gmail.com" }
Expand All @@ -26,6 +26,7 @@ classifiers=[

dependencies = [
'einops>=0.8.0',
'hyper-connections>=0.1.0',
'rotary-embedding-torch>=0.5.3',
'x-transformers>=1.42.26',
'torch>=2.4'
Expand Down

0 comments on commit a764be1

Please sign in to comment.