diff --git a/mamba_transformer/model.py b/mamba_transformer/model.py index 12e8100..355c298 100644 --- a/mamba_transformer/model.py +++ b/mamba_transformer/model.py @@ -12,10 +12,10 @@ class RMSNorm(nn.Module): def __init__(self, dim: int): super().__init__() - self.scale = dim**-0.5 + self.scale = dim ** (-0.5) self.g = nn.Parameter(torch.ones(dim)) - def forward(self, x: Tensor): + def forward(self, x: Tensor) -> Tensor: return F.normalize(x, dim=-1) * self.scale * self.g @@ -97,6 +97,7 @@ def forward(self, x: Tensor) -> Tensor: x, _, _ = self.attn(x) x = self.norm(x) x = self.ffn(x) + return x @@ -172,33 +173,28 @@ def __init__( self.transformer_depth = transformer_depth self.mamba_depth = mamba_depth - self.mamba_blocks = nn.ModuleList([]) - self.transformer_blocks = nn.ModuleList([]) - self.ffn_blocks = nn.ModuleList([]) - - self.mamba_blocks.append( + # Mamba, Transformer, and ffn blocks + self.mamba_blocks = nn.ModuleList([ MambaBlock(dim, mamba_depth, d_state, *args, **kwargs) - ) - - # Transformer and ffn blocks - for _ in range(depth): - self.ffn_blocks.append( - FeedForward(dim, dim, ff_mult, *args, **kwargs) - ) - - for _ in range(transformer_depth): - self.transformer_blocks.append( - TransformerBlock( - dim, - heads, - dim_head, - dropout, - ff_mult, - use_linear_attn, - *args, - **kwargs, - ) - ) + for _ in range(mamba_depth) + ]) + self.transformer_blocks = nn.ModuleList([ + TransformerBlock( + dim, + heads, + dim_head, + dropout, + ff_mult, + use_linear_attn, + *args, + **kwargs, + ) for _ in range(transformer_depth) + ]) + + self.ffn_blocks = nn.ModuleList([ + FeedForward(dim, dim, ff_mult, *args, **kwargs) + for _ in range(depth) + ]) # Layernorm self.norm = nn.LayerNorm(dim)