Skip to content

Commit

Permalink
[CODE QUALITY]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 13, 2024
1 parent 854b370 commit aa5d384
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 29 deletions.
20 changes: 9 additions & 11 deletions example.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,20 @@
import torch
import torch
from mamba_transformer.model import MambaTransformer

# Generate a random tensor of shape (1, 10) with values between 0 and 99
x = torch.randint(0, 100, (1, 10))

# Create an instance of the MambaTransformer model
model = MambaTransformer(
num_tokens=100, # Number of tokens in the input sequence
dim=512, # Dimension of the model
heads=8, # Number of attention heads
depth=4, # Number of transformer layers
dim_head=64, # Dimension of each attention head
d_state=512, # Dimension of the state
dropout=0.1, # Dropout rate
ff_mult=4 # Multiplier for the feed-forward layer dimension
num_tokens=100, # Number of tokens in the input sequence
dim=512, # Dimension of the model
heads=8, # Number of attention heads
depth=4, # Number of transformer layers
dim_head=64, # Dimension of each attention head
d_state=512, # Dimension of the state
dropout=0.1, # Dropout rate
ff_mult=4, # Multiplier for the feed-forward layer dimension
)

# Pass the input tensor through the model and print the output shape
print(model(x).shape)


12 changes: 6 additions & 6 deletions mamba_transformer/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from mamba_transformer.model import RMSNorm, MambaTransformerblock, MambaTransformer
from mamba_transformer.model import (
RMSNorm,
MambaTransformerblock,
MambaTransformer,
)

__all__ = [
"RMSNorm",
"MambaTransformerblock",
"MambaTransformer"
]
__all__ = ["RMSNorm", "MambaTransformerblock", "MambaTransformer"]
21 changes: 9 additions & 12 deletions mamba_transformer/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
class RMSNorm(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = dim ** -0.5
self.scale = dim**-0.5
self.g = nn.Parameter(torch.ones(dim))

def forward(self, x: Tensor):
return F.normalize(x, dim = - 1) * self.scale * self.g
return F.normalize(x, dim=-1) * self.scale * self.g


class MultiQueryTransformerBlock(nn.Module):
Expand Down Expand Up @@ -109,11 +109,11 @@ class MambaTransformerblock(nn.Module):
transformer_blocks (nn.ModuleList): List of MultiQueryTransformerBlock instances.
ffn_blocks (nn.ModuleList): List of FeedForward instances.
norm (nn.LayerNorm): Layer normalization module.
Examples:
import torch
import torch
from mt import MambaTransformerblock
x = torch.randn(1, 10, 512)
model = MambaTransformerblock(
dim=512,
Expand Down Expand Up @@ -232,7 +232,7 @@ def __init__(
self.dropout = dropout
self.ff_mult = ff_mult
self.d_state = d_state

self.emb = nn.Embedding(num_tokens, dim)
self.mt_block = MambaTransformerblock(
dim,
Expand All @@ -246,10 +246,9 @@ def __init__(
**kwargs,
)
self.to_logits = nn.Sequential(
RMSNorm(dim),
nn.Linear(dim, num_tokens)
RMSNorm(dim), nn.Linear(dim, num_tokens)
)

def forward(self, x: Tensor) -> Tensor:
"""
Forward pass of the MambaTransformer model.
Expand All @@ -263,5 +262,3 @@ def forward(self, x: Tensor) -> Tensor:
x = self.emb(x)
x = self.mt_block(x)
return self.to_logits(x)


0 comments on commit aa5d384

Please sign in to comment.