From ea8af7aa28446f49cbcb2de5d89cd99f3832624f Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 13 Jan 2024 09:44:37 -0500 Subject: [PATCH] [CQ] --- README.md | 2 +- example.py | 2 +- mamba_transformer/model.py | 55 +++++++++++++++++++++++++++----------- 3 files changed, 41 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index a42bfde..ab51374 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ I added in many normalizations as I believe by default training stability would ### Usage ```python import torch -from mamba_transformer.model import MambaTransformer +from mamba_transformer import MambaTransformer # Generate a random tensor of shape (1, 10) with values between 0 and 99 x = torch.randint(0, 100, (1, 10)) diff --git a/example.py b/example.py index 2a19ff2..135c460 100644 --- a/example.py +++ b/example.py @@ -1,5 +1,5 @@ import torch -from mamba_transformer.model import MambaTransformer +from mamba_transformer import MambaTransformer # Generate a random tensor of shape (1, 10) with values between 0 and 99 x = torch.randint(0, 100, (1, 10)) diff --git a/mamba_transformer/model.py b/mamba_transformer/model.py index 714ac9c..78e59fe 100644 --- a/mamba_transformer/model.py +++ b/mamba_transformer/model.py @@ -196,6 +196,38 @@ def forward(self, x: Tensor) -> Tensor: class MambaTransformer(nn.Module): + """ + MambaTransformer is a PyTorch module that implements the Mamba Transformer model. + + Args: + num_tokens (int): The number of tokens in the input vocabulary. + dim (int): The dimensionality of the token embeddings and model hidden states. + heads (int): The number of attention heads. + depth (int): The number of transformer blocks. + dim_head (int): The dimensionality of each attention head. + dropout (float, optional): The dropout rate. Defaults to 0.1. + ff_mult (int, optional): The multiplier for the feed-forward network dimension. Defaults to 4. + d_state (int, optional): The dimensionality of the state embeddings. Defaults to None. + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Examples: + >>> import torch + >>> from mt import MambaTransformer + >>> x = torch.randint(0, 100, (1, 10)) + >>> model = MambaTransformer( + ... num_tokens=100, + ... dim=512, + ... heads=8, + ... depth=4, + ... dim_head=64, + ... d_state=512, + ... dropout=0.1, + ... ff_mult=4 + ... ) + >>> print(model(x).shape) + torch.Size([1, 10, 100]) + """ def __init__( self, num_tokens: int, @@ -206,24 +238,10 @@ def __init__( dropout: float = 0.1, ff_mult: int = 4, d_state: int = None, + return_embeddings: bool = False, *args, **kwargs, ): - """ - MambaTransformer is a PyTorch module that implements the Mamba Transformer model. - - Args: - num_tokens (int): The number of tokens in the input vocabulary. - dim (int): The dimensionality of the token embeddings and model hidden states. - heads (int): The number of attention heads. - depth (int): The number of transformer blocks. - dim_head (int): The dimensionality of each attention head. - dropout (float, optional): The dropout rate. Defaults to 0.1. - ff_mult (int, optional): The multiplier for the feed-forward network dimension. Defaults to 4. - d_state (int, optional): The dimensionality of the state embeddings. Defaults to None. - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - """ super().__init__() self.dim = dim self.depth = depth @@ -232,6 +250,7 @@ def __init__( self.dropout = dropout self.ff_mult = ff_mult self.d_state = d_state + self.return_embeddings = return_embeddings self.emb = nn.Embedding(num_tokens, dim) self.mt_block = MambaTransformerblock( @@ -261,4 +280,8 @@ def forward(self, x: Tensor) -> Tensor: """ x = self.emb(x) x = self.mt_block(x) - return self.to_logits(x) + + if self.return_embeddings: + return x + else: + return self.to_logits(x)