Skip to content

Commit

Permalink
[README]
Browse files Browse the repository at this point in the history
  • Loading branch information
Kye committed Jan 13, 2024
1 parent 903d61d commit f4099cd
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,22 +19,22 @@ I added in many normalizations as I believe by default training stability would

### Usage
```python
import torch
from mt import MambaTransformer
import torch
from mamba_transformer.model import MambaTransformer

# Generate a random tensor of shape (1, 10) with values between 0 and 99
x = torch.randint(0, 100, (1, 10))

# Create an instance of the MambaTransformer model
model = MambaTransformer(
num_tokens=100, # Number of tokens in the input sequence
dim=512, # Dimension of the model
heads=8, # Number of attention heads
depth=4, # Number of transformer layers
dim_head=64, # Dimension of each attention head
d_state=512, # Dimension of the state
dropout=0.1, # Dropout rate
ff_mult=4 # Multiplier for the feed-forward layer dimension
num_tokens=100, # Number of tokens in the input sequence
dim=512, # Dimension of the model
heads=8, # Number of attention heads
depth=4, # Number of transformer layers
dim_head=64, # Dimension of each attention head
d_state=512, # Dimension of the state
dropout=0.1, # Dropout rate
ff_mult=4, # Multiplier for the feed-forward layer dimension
)

# Pass the input tensor through the model and print the output shape
Expand Down

0 comments on commit f4099cd

Please sign in to comment.