From f4099cdfc1486c64cb78c7a4f1e6e15ff23eb013 Mon Sep 17 00:00:00 2001 From: Kye Date: Sat, 13 Jan 2024 00:56:45 -0500 Subject: [PATCH] [README] --- README.md | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 60e9959..a42bfde 100644 --- a/README.md +++ b/README.md @@ -19,22 +19,22 @@ I added in many normalizations as I believe by default training stability would ### Usage ```python -import torch -from mt import MambaTransformer +import torch +from mamba_transformer.model import MambaTransformer # Generate a random tensor of shape (1, 10) with values between 0 and 99 x = torch.randint(0, 100, (1, 10)) # Create an instance of the MambaTransformer model model = MambaTransformer( - num_tokens=100, # Number of tokens in the input sequence - dim=512, # Dimension of the model - heads=8, # Number of attention heads - depth=4, # Number of transformer layers - dim_head=64, # Dimension of each attention head - d_state=512, # Dimension of the state - dropout=0.1, # Dropout rate - ff_mult=4 # Multiplier for the feed-forward layer dimension + num_tokens=100, # Number of tokens in the input sequence + dim=512, # Dimension of the model + heads=8, # Number of attention heads + depth=4, # Number of transformer layers + dim_head=64, # Dimension of each attention head + d_state=512, # Dimension of the state + dropout=0.1, # Dropout rate + ff_mult=4, # Multiplier for the feed-forward layer dimension ) # Pass the input tensor through the model and print the output shape