-
-
Notifications
You must be signed in to change notification settings - Fork 98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Mamba (minimal) #918
Closed
Closed
Add Mamba (minimal) #918
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
- Makes the safetensors module private. - Doesn't get exported on the preamble, avoiding a naming clash with the safetensors external crate. - Change how and when the period is inserted. - This should make it closer to how the fields are accessed in the code.
- Add the try_normalize_rms related functions. - Add the `LayerRMSNorm1D` module.
- Add `TrySplitShapeAlong` and `TrySplitTensorAlong`. - Minor linting and docs fix. TODO - Check if the tape should be returned. If not, it can be removed from the interface. - Add cuda kernel. - Consider a different interface, where it could get split in more than two tensors - possibly stated on a vec. In this way it could get closer to the pytorch interface (chunks).
- Also added `from_fn` for Arrays. Note: the interface currently requires two passes for construction, one for creating a list of tensors with NoneTape and another for putting tapes into those tensors.
Remove ftz
swfsql
force-pushed
the
mamba-minimal
branch
2 times, most recently
from
February 7, 2024 22:02
cadf65c
to
9a2cf25
Compare
This alternative method: - Requires load/read to decide whether it should skip missing tensors; - Requires load/read/save/write to decide how should keys be mapped.
swfsql
force-pushed
the
mamba-minimal
branch
2 times, most recently
from
February 9, 2024 17:27
1207867
to
ce6d624
Compare
swfsql
force-pushed
the
mamba-minimal
branch
2 times, most recently
from
February 9, 2024 17:55
3f392a6
to
165abc9
Compare
- Add stateless forward impl. - Efficient for training (but training is not yet implemented). - Input requires the entire sequence, and requires no state cache. - Generates one output for each input sequence. - Add stateful forward impl. - Efficient for inference. - Input requires the last single sequence point, and requires the last state cache. - Generates a single output referring to the last input.
swfsql
force-pushed
the
mamba-minimal
branch
from
February 20, 2024 02:52
165abc9
to
bff1b65
Compare
I'll prioritize moving this experiment to a separate crate, but feel free to ping in case anyone have some question or suggestion. |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Ports a minimal (non-optimized) implementation of Mamba (submitted on 2023-12), highly related to S4 (submitted on 2021-10).
In short and simple terms, Mamba is an alternative, with trade-offs, to the attention mechanism. Mamba can be used in RNNs that steps over a single sequence point at a time (instead of requiring to observe multiple sequence points at the same time, but it needs to carry the previous state over), and so it's memory and time requirements are fixed for each sequence point.
Implementation references:
This pr requires others (some of which are drafts or are useful for an app using this Module):
_MM_SET_FLUSH_ZERO_MODE
#912core::ops::Sub
for Dim types #914TryUnstack
for tensors. #919The commits specific to this Mamba pr are:
Tasks
forward_mut
and backpropagation). Note: This stateless interface is appropriate for training only, not for inference.forward_mut
for training. Note: This stateful interface is appropriate for inference only, not for training.Vec
conversion near the end ofselective_scan
for the stateless version.Youtube Videos
S4
Mamba