diff --git a/README.md b/README.md index 924111c..fca4cd6 100644 --- a/README.md +++ b/README.md @@ -60,3 +60,14 @@ answer = model.generate(prompt, max_length = 64) # (2, 64) url = {https://api.semanticscholar.org/CorpusID:236171087} } ``` + +```bibtex +@article{Zhu2024HyperConnections, + title = {Hyper-Connections}, + author = {Defa Zhu and Hongzhi Huang and Zihao Huang and Yutao Zeng and Yunyao Mao and Banggu Wu and Qiyang Min and Xun Zhou}, + journal = {ArXiv}, + year = {2024}, + volume = {abs/2409.19606}, + url = {https://api.semanticscholar.org/CorpusID:272987528} +} +``` diff --git a/coconut_pytorch/coconut.py b/coconut_pytorch/coconut.py index 7d00f1c..584e4b9 100644 --- a/coconut_pytorch/coconut.py +++ b/coconut_pytorch/coconut.py @@ -14,6 +14,8 @@ from x_transformers.attend import Attend +from hyper_connections import get_init_and_expand_reduce_stream_functions + # helper functions def exists(v): @@ -138,7 +140,8 @@ def __init__( dim_head = 64, heads = 8, ff_mult = 4, - attend_kwargs: dict = dict() + attend_kwargs: dict = dict(), + num_residual_streams = 4 ): super().__init__() self.dim = dim @@ -146,12 +149,14 @@ def __init__( self.token_emb = nn.Embedding(num_tokens, dim) self.rotary_emb = RotaryEmbedding(dim_head) + init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) + layers = ModuleList([]) for _ in range(depth): layers.append(ModuleList([ - Attention(dim = dim, dim_head = dim_head, heads = heads, attend_kwargs = attend_kwargs), - FeedForward(dim = dim, mult = ff_mult) + init_hyper_conn(dim = dim, branch = Attention(dim = dim, dim_head = dim_head, heads = heads, attend_kwargs = attend_kwargs)), + init_hyper_conn(dim = dim, branch = FeedForward(dim = dim, mult = ff_mult)) ])) self.layers = layers @@ -201,20 +206,22 @@ def forward( next_keys_values = [] + x = self.expand_streams(x) + for attn, ff in self.layers: - attn_out, key_values = attn( + x, key_values = attn( x, rotary_pos_emb = rotary_pos_emb, cached_kv = next(cached_kv_iter, None), return_cached_kv = True ) - x = attn_out + x - next_keys_values.append(key_values) - x = ff(x) + x + x = ff(x) + + x = self.reduce_streams(x) embeds = self.norm(x) diff --git a/coconut_pytorch/multi_stream_coconut.py b/coconut_pytorch/multi_stream_coconut.py index b5dfb28..8d11b18 100644 --- a/coconut_pytorch/multi_stream_coconut.py +++ b/coconut_pytorch/multi_stream_coconut.py @@ -15,6 +15,8 @@ from x_transformers.attend import Attend +from hyper_connections import get_init_and_expand_reduce_stream_functions + # helper functions def exists(v): @@ -144,7 +146,8 @@ def __init__( dim_head = 64, heads = 8, ff_mult = 4, - attend_kwargs: dict = dict() + attend_kwargs: dict = dict(), + num_residual_streams = 4 ): super().__init__() self.dim = dim @@ -152,12 +155,14 @@ def __init__( self.token_emb = nn.Embedding(num_tokens, dim) self.rotary_emb = RotaryEmbedding(dim_head) + init_hyper_conn, self.expand_streams, self.reduce_streams = get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1) + layers = ModuleList([]) for _ in range(depth): layers.append(ModuleList([ - Attention(dim = dim, dim_head = dim_head, heads = heads, attend_kwargs = attend_kwargs), - FeedForward(dim = dim, mult = ff_mult) + init_hyper_conn(dim = dim, branch = Attention(dim = dim, dim_head = dim_head, heads = heads, attend_kwargs = attend_kwargs)), + init_hyper_conn(dim = dim, branch = FeedForward(dim = dim, mult = ff_mult)) ])) self.layers = layers @@ -201,19 +206,21 @@ def forward( next_keys_values = [] + x = self.expand_streams(x) + for attn, ff in self.layers: - attn_out, key_values = attn( + x, key_values = attn( x, cached_kv = next(cached_kv_iter, None), return_cached_kv = True ) - x = attn_out + x - next_keys_values.append(key_values) - x = ff(x) + x + x = ff(x) + + x = self.reduce_streams(x) embeds = self.norm(x) diff --git a/pyproject.toml b/pyproject.toml index bf7ba18..6d7d88b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "coconut-pytorch" -version = "0.0.28" +version = "0.0.29" description = "Coconut in Pytorch" authors = [ { name = "Phil Wang", email = "lucidrains@gmail.com" } @@ -26,6 +26,7 @@ classifiers=[ dependencies = [ 'einops>=0.8.0', + 'hyper-connections>=0.1.0', 'rotary-embedding-torch>=0.5.3', 'x-transformers>=1.42.26', 'torch>=2.4'