Skip to content

Commit

Permalink
[CQ]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 13, 2024
1 parent 42e5b82 commit ea8af7a
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 18 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ I added in many normalizations as I believe by default training stability would
### Usage
```python
import torch
from mamba_transformer.model import MambaTransformer
from mamba_transformer import MambaTransformer

# Generate a random tensor of shape (1, 10) with values between 0 and 99
x = torch.randint(0, 100, (1, 10))
Expand Down
2 changes: 1 addition & 1 deletion example.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from mamba_transformer.model import MambaTransformer
from mamba_transformer import MambaTransformer

# Generate a random tensor of shape (1, 10) with values between 0 and 99
x = torch.randint(0, 100, (1, 10))
Expand Down
55 changes: 39 additions & 16 deletions mamba_transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,38 @@ def forward(self, x: Tensor) -> Tensor:


class MambaTransformer(nn.Module):
"""
MambaTransformer is a PyTorch module that implements the Mamba Transformer model.
Args:
num_tokens (int): The number of tokens in the input vocabulary.
dim (int): The dimensionality of the token embeddings and model hidden states.
heads (int): The number of attention heads.
depth (int): The number of transformer blocks.
dim_head (int): The dimensionality of each attention head.
dropout (float, optional): The dropout rate. Defaults to 0.1.
ff_mult (int, optional): The multiplier for the feed-forward network dimension. Defaults to 4.
d_state (int, optional): The dimensionality of the state embeddings. Defaults to None.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
Examples:
>>> import torch
>>> from mt import MambaTransformer
>>> x = torch.randint(0, 100, (1, 10))
>>> model = MambaTransformer(
... num_tokens=100,
... dim=512,
... heads=8,
... depth=4,
... dim_head=64,
... d_state=512,
... dropout=0.1,
... ff_mult=4
... )
>>> print(model(x).shape)
torch.Size([1, 10, 100])
"""
def __init__(
self,
num_tokens: int,
Expand All @@ -206,24 +238,10 @@ def __init__(
dropout: float = 0.1,
ff_mult: int = 4,
d_state: int = None,
return_embeddings: bool = False,
*args,
**kwargs,
):
"""
MambaTransformer is a PyTorch module that implements the Mamba Transformer model.
Args:
num_tokens (int): The number of tokens in the input vocabulary.
dim (int): The dimensionality of the token embeddings and model hidden states.
heads (int): The number of attention heads.
depth (int): The number of transformer blocks.
dim_head (int): The dimensionality of each attention head.
dropout (float, optional): The dropout rate. Defaults to 0.1.
ff_mult (int, optional): The multiplier for the feed-forward network dimension. Defaults to 4.
d_state (int, optional): The dimensionality of the state embeddings. Defaults to None.
*args: Variable length argument list.
**kwargs: Arbitrary keyword arguments.
"""
super().__init__()
self.dim = dim
self.depth = depth
Expand All @@ -232,6 +250,7 @@ def __init__(
self.dropout = dropout
self.ff_mult = ff_mult
self.d_state = d_state
self.return_embeddings = return_embeddings

self.emb = nn.Embedding(num_tokens, dim)
self.mt_block = MambaTransformerblock(
Expand Down Expand Up @@ -261,4 +280,8 @@ def forward(self, x: Tensor) -> Tensor:
"""
x = self.emb(x)
x = self.mt_block(x)
return self.to_logits(x)

if self.return_embeddings:
return x
else:
return self.to_logits(x)

0 comments on commit ea8af7a

Please sign in to comment.