From 854b370b3f455a469699eec50e7188bb925d929d Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 13 Jan 2024 00:37:52 -0500 Subject: [PATCH] [FEAT][MambaTransformer] [README] --- README.md | 113 ++++---------- example.py | 22 +++ mamba_transformer/__init__.py | 7 + mamba_transformer/model.py | 267 ++++++++++++++++++++++++++++++++++ pyproject.toml | 22 +-- 5 files changed, 339 insertions(+), 92 deletions(-) create mode 100644 mamba_transformer/__init__.py create mode 100644 mamba_transformer/model.py diff --git a/README.md b/README.md index 98e66c5..67af5db 100644 --- a/README.md +++ b/README.md @@ -1,104 +1,53 @@ [![Multi-Modality](agorabanner.png)](https://discord.gg/qUtxnK2NMf) -# Python Package Template -A easy, reliable, fluid template for python packages complete with docs, testing suites, readme's, github workflows, linting and much much more +# Mamba Transformer +Integrating Mamba/SSMs with Transformer for Enhanced Long Context and High-Quality Sequence Modeling. +This is 100% novel architecture that I have designed to combine the strengths and weaknesses out of SSMs and Attention for an all-new advanced architecture with the purpose of surpassing our old limits. Faster processing speed, longer context lengths, lower perplexity over long sequences, enhanced and superior reasoning while remaining small and compact. -## Installation +The architecture is essentially: `x -> norm -> mamba -> norm -> transformer -> norm -> ffn -> norm -> out`. -You can install the package using pip +I added in many normalizations as I believe by default training stability would be severly degraded due to 2 foreign architecture's integrating with one another. -```bash -pip install -e . -``` -## Structure -``` -├── LICENSE -├── Makefile -├── README.md -├── agorabanner.png -├── example.py -├── package -│ ├── __init__.py -│ ├── main.py -│ └── subfolder -│ ├── __init__.py -│ └── main.py -├── pyproject.toml -└── requirements.txt - -2 directories, 11 files -``` -# Usage -# Documentation +## Install +`pip3 install mambatransformer` -### Code Quality 🧹 +### Usage +```python +import torch +from mt import MambaTransformer -We provide two handy commands inside the `Makefile`, namely: +# Generate a random tensor of shape (1, 10) with values between 0 and 99 +x = torch.randint(0, 100, (1, 10)) -- `make style` to format the code -- `make check_code_quality` to check code quality (PEP8 basically) +# Create an instance of the MambaTransformer model +model = MambaTransformer( + num_tokens=100, # Number of tokens in the input sequence + dim=512, # Dimension of the model + heads=8, # Number of attention heads + depth=4, # Number of transformer layers + dim_head=64, # Dimension of each attention head + d_state=512, # Dimension of the state + dropout=0.1, # Dropout rate + ff_mult=4 # Multiplier for the feed-forward layer dimension +) -So far, **there is no types checking with mypy**. See [issue](https://github.com/roboflow-ai/template-python/issues/4). +# Pass the input tensor through the model and print the output shape +print(model(x).shape) -### Tests 🧪 -[`pytests`](https://docs.pytest.org/en/7.1.x/) is used to run our tests. +# to train +model.eval() -### Publish on PyPi 🚀 +# Would you like to train this model? Zeta Corporation offers unmatchable GPU clusters at unbeatable prices, let's partner! -**Important**: Before publishing, edit `__version__` in [src/__init__](/src/__init__.py) to match the wanted new version. +# Tokenizer +model.generate(text) -We use [`twine`](https://twine.readthedocs.io/en/stable/) to make our life easier. You can publish by using ``` -export PYPI_USERNAME="you_username" -export PYPI_PASSWORD="your_password" -export PYPI_TEST_PASSWORD="your_password_for_test_pypi" -make publish -e PYPI_USERNAME=$PYPI_USERNAME -e PYPI_PASSWORD=$PYPI_PASSWORD -e PYPI_TEST_PASSWORD=$PYPI_TEST_PASSWORD -``` - -You can also use token for auth, see [pypi doc](https://pypi.org/help/#apitoken). In that case, - -``` -export PYPI_USERNAME="__token__" -export PYPI_PASSWORD="your_token" -export PYPI_TEST_PASSWORD="your_token_for_test_pypi" -make publish -e PYPI_USERNAME=$PYPI_USERNAME -e PYPI_PASSWORD=$PYPI_PASSWORD -e PYPI_TEST_PASSWORD=$PYPI_TEST_PASSWORD -``` - -**Note**: We will try to push to [test pypi](https://test.pypi.org/) before pushing to pypi, to assert everything will work - -### CI/CD 🤖 - -We use [GitHub actions](https://github.com/features/actions) to automatically run tests and check code quality when a new PR is done on `main`. - -On any pull request, we will check the code quality and tests. - -When a new release is created, we will try to push the new code to PyPi. We use [`twine`](https://twine.readthedocs.io/en/stable/) to make our life easier. - -The **correct steps** to create a new realease are the following: -- edit `__version__` in [src/__init__](/src/__init__.py) to match the wanted new version. -- create a new [`tag`](https://git-scm.com/docs/git-tag) with the release name, e.g. `git tag v0.0.1 && git push origin v0.0.1` or from the GitHub UI. -- create a new release from GitHub UI - -The CI will run when you create the new release. - -# Docs -We use MK docs. This repo comes with the zeta docs. All the docs configurations are already here along with the readthedocs configs - -# Q&A - -## Why no cookiecutter? -This is a template repo, it's meant to be used inside GitHub upon repo creation. - -## Why reinvent the wheel? - -There are several very good templates on GitHub, I prefer to use code we wrote instead of blinding taking the most starred template and having features we don't need. From experience, it's better to keep it simple and general enough for our specific use cases. - -# Architecture # License MIT diff --git a/example.py b/example.py index e69de29..fc51873 100644 --- a/example.py +++ b/example.py @@ -0,0 +1,22 @@ +import torch +from mamba_transformer.model import MambaTransformer + +# Generate a random tensor of shape (1, 10) with values between 0 and 99 +x = torch.randint(0, 100, (1, 10)) + +# Create an instance of the MambaTransformer model +model = MambaTransformer( + num_tokens=100, # Number of tokens in the input sequence + dim=512, # Dimension of the model + heads=8, # Number of attention heads + depth=4, # Number of transformer layers + dim_head=64, # Dimension of each attention head + d_state=512, # Dimension of the state + dropout=0.1, # Dropout rate + ff_mult=4 # Multiplier for the feed-forward layer dimension +) + +# Pass the input tensor through the model and print the output shape +print(model(x).shape) + + diff --git a/mamba_transformer/__init__.py b/mamba_transformer/__init__.py new file mode 100644 index 0000000..f244c0a --- /dev/null +++ b/mamba_transformer/__init__.py @@ -0,0 +1,7 @@ +from mamba_transformer.model import RMSNorm, MambaTransformerblock, MambaTransformer + +__all__ = [ + "RMSNorm", + "MambaTransformerblock", + "MambaTransformer" +] \ No newline at end of file diff --git a/mamba_transformer/model.py b/mamba_transformer/model.py new file mode 100644 index 0000000..247ac01 --- /dev/null +++ b/mamba_transformer/model.py @@ -0,0 +1,267 @@ +import torch +from torch import nn, Tensor +from zeta.nn import MambaBlock, FeedForward, MultiQueryAttention +import torch.nn.functional as F + + +class RMSNorm(nn.Module): + def __init__(self, dim: int): + super().__init__() + self.scale = dim ** -0.5 + self.g = nn.Parameter(torch.ones(dim)) + + def forward(self, x: Tensor): + return F.normalize(x, dim = - 1) * self.scale * self.g + + +class MultiQueryTransformerBlock(nn.Module): + """ + MultiQueryTransformerBlock is a module that represents a single block of the Multi-Query Transformer. + It consists of a multi-query attention layer, a feed-forward network, and layer normalization. + + Args: + dim (int): The input and output dimension of the block. + heads (int): The number of attention heads. + dim_head (int): The dimension of each attention head. + dropout (float, optional): The dropout probability. Defaults to 0.1. + ff_mult (int, optional): The multiplier for the feed-forward network dimension. Defaults to 4. + + Attributes: + dim (int): The input and output dimension of the block. + heads (int): The number of attention heads. + dim_head (int): The dimension of each attention head. + dropout (float): The dropout probability. + ff_mult (int): The multiplier for the feed-forward network dimension. + attn (MultiQueryAttention): The multi-query attention layer. + ffn (FeedForward): The feed-forward network. + norm (nn.LayerNorm): The layer normalization. + + Methods: + forward(x: Tensor) -> Tensor: + Performs a forward pass of the MultiQueryTransformerBlock. + + """ + + def __init__( + self, + dim: int, + heads: int, + dim_head: int, + dropout: float = 0.1, + ff_mult: int = 4, + *args, + **kwargs, + ): + super().__init__() + self.dim = dim + self.heads = heads + self.dim_head = dim_head + self.dropout = dropout + self.ff_mult = ff_mult + + self.attn = MultiQueryAttention(dim, heads, *args, **kwargs) + + self.ffn = FeedForward(dim, dim, ff_mult, *args, **kwargs) + + # Normalization + self.norm = nn.LayerNorm(dim) + + def forward(self, x: Tensor) -> Tensor: + """ + Performs a forward pass of the MultiQueryTransformerBlock. + + Args: + x (Tensor): The input tensor. + + Returns: + Tensor: The output tensor. + + """ + x, _, _ = self.attn(x) + x = self.norm(x) + x = self.ffn(x) + return x + + +class MambaTransformerblock(nn.Module): + """ + MambaTransformerblock is a module that represents a block in the Mamba Transformer model. + + Args: + dim (int): The input dimension of the block. + heads (int): The number of attention heads in the block. + depth (int): The number of layers in the block. + dim_head (int): The dimension of each attention head. + dropout (float, optional): The dropout rate. Defaults to 0.1. + ff_mult (int, optional): The multiplier for the feed-forward network dimension. Defaults to 4. + d_state (int, optional): The dimension of the state. Defaults to None. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Attributes: + dim (int): The input dimension of the block. + depth (int): The number of layers in the block. + dim_head (int): The dimension of each attention head. + d_state (int): The dimension of the state. + dropout (float): The dropout rate. + ff_mult (int): The multiplier for the feed-forward network dimension. + mamba_blocks (nn.ModuleList): List of MambaBlock instances. + transformer_blocks (nn.ModuleList): List of MultiQueryTransformerBlock instances. + ffn_blocks (nn.ModuleList): List of FeedForward instances. + norm (nn.LayerNorm): Layer normalization module. + + Examples: + import torch + from mt import MambaTransformerblock + + x = torch.randn(1, 10, 512) + model = MambaTransformerblock( + dim=512, + heads=8, + depth=4, + dim_head=64, + d_state=512, + dropout=0.1, + ff_mult=4 + ) + print(model(x).shape) + + + """ + + def __init__( + self, + dim: int, + heads: int, + depth: int, + dim_head: int, + dropout: float = 0.1, + ff_mult: int = 4, + d_state: int = None, + *args, + **kwargs, + ): + super().__init__() + self.dim = dim + self.depth = depth + self.dim_head = dim_head + self.d_state = d_state + self.dropout = dropout + self.ff_mult = ff_mult + self.d_state = d_state + + self.mamba_blocks = nn.ModuleList([]) + self.transformer_blocks = nn.ModuleList([]) + self.ffn_blocks = nn.ModuleList([]) + + self.mamba_blocks.append( + MambaBlock(dim, depth, d_state, *args, **kwargs) + ) + + # Transformer and ffn blocks + for _ in range(depth): + self.transformer_blocks.append( + MultiQueryTransformerBlock( + dim, + heads, + dim_head, + dropout, + ff_mult, + *args, + **kwargs, + ) + ) + + self.ffn_blocks.append( + FeedForward(dim, dim, ff_mult, *args, **kwargs) + ) + + # Layernorm + self.norm = nn.LayerNorm(dim) + + def forward(self, x: Tensor) -> Tensor: + for mamba, attn, ffn in zip( + self.mamba_blocks, + self.transformer_blocks, + self.ffn_blocks, + ): + x = self.norm(x) + x = mamba(x) + x + x = self.norm(x) + x = attn(x) + x + x = self.norm(x) + x = ffn(x) + x + + return x + + +class MambaTransformer(nn.Module): + def __init__( + self, + num_tokens: int, + dim: int, + heads: int, + depth: int, + dim_head: int, + dropout: float = 0.1, + ff_mult: int = 4, + d_state: int = None, + *args, + **kwargs, + ): + """ + MambaTransformer is a PyTorch module that implements the Mamba Transformer model. + + Args: + num_tokens (int): The number of tokens in the input vocabulary. + dim (int): The dimensionality of the token embeddings and model hidden states. + heads (int): The number of attention heads. + depth (int): The number of transformer blocks. + dim_head (int): The dimensionality of each attention head. + dropout (float, optional): The dropout rate. Defaults to 0.1. + ff_mult (int, optional): The multiplier for the feed-forward network dimension. Defaults to 4. + d_state (int, optional): The dimensionality of the state embeddings. Defaults to None. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + """ + super().__init__() + self.dim = dim + self.depth = depth + self.dim_head = dim_head + self.d_state = d_state + self.dropout = dropout + self.ff_mult = ff_mult + self.d_state = d_state + + self.emb = nn.Embedding(num_tokens, dim) + self.mt_block = MambaTransformerblock( + dim, + heads, + depth, + dim_head, + dropout, + ff_mult, + d_state, + *args, + **kwargs, + ) + self.to_logits = nn.Sequential( + RMSNorm(dim), + nn.Linear(dim, num_tokens) + ) + + def forward(self, x: Tensor) -> Tensor: + """ + Forward pass of the MambaTransformer model. + + Args: + x (Tensor): Input tensor of shape (batch_size, sequence_length). + + Returns: + Tensor: Output tensor of shape (batch_size, sequence_length, num_tokens). + """ + x = self.emb(x) + x = self.mt_block(x) + return self.to_logits(x) + + diff --git a/pyproject.toml b/pyproject.toml index 5d4ac8e..640e26e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,15 +3,15 @@ requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" [tool.poetry] -name = "paper" +name = "mambatransformer" version = "0.0.1" -description = "Paper - Pytorch" +description = "MambaTransformer - Pytorch" license = "MIT" authors = ["Kye Gomez "] -homepage = "https://github.com/kyegomez/paper" -documentation = "https://github.com/kyegomez/paper" # Add this if you have documentation. +homepage = "https://github.com/kyegomez/MambaTransformer" +documentation = "https://github.com/kyegomez/MambaTransformer" readme = "README.md" # Assuming you have a README.md -repository = "https://github.com/kyegomez/paper" +repository = "https://github.com/kyegomez/MambaTransformer" keywords = ["artificial intelligence", "deep learning", "optimizers", "Prompt Engineering"] classifiers = [ "Development Status :: 4 - Beta", @@ -20,15 +20,17 @@ classifiers = [ "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3.9" ] +packages = [ + { include = "mamba_transformer" }, + { include = "mamba_transformer/**/*.py" }, +] + [tool.poetry.dependencies] python = "^3.6" -swarms = "*" zetascale = "*" - -[tool.poetry.dev-dependencies] -# Add development dependencies here - +torch = "*" +einops = "*" [tool.poetry.group.lint.dependencies] ruff = "^0.1.6"