Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add RMS normalization #916

Draft
wants to merge 9 commits into
base: main
Choose a base branch
from
Draft

Add RMS normalization #916

wants to merge 9 commits into from

Conversation

swfsql
Copy link
Contributor

@swfsql swfsql commented Feb 1, 2024

  • Add the try_normalize_rms related functions.
  • Add the LayerRMSNorm1D module.

Implements RMS layer normalization as described in Root Mean Square Layer Normalization.
The layer normalizes a tensor axis to have stddev of 1.0, but differently from the other normal layer normalization, the mean is not forced to zero.
Computes tensor / (tensor.square().mean() + epsilon).sqrt().

  • Not sure if the bias (delta) should be removed from this layer. It may not be needed.

Note: I haven't made an actual pytorch test to compare, not even locally, but have made a bigger test that depend on this functionality and it appeared to work ok. So this PR should be considered a draft.

@swfsql swfsql changed the title add RMS normalization Add RMS normalization Feb 1, 2024
@swfsql swfsql mentioned this pull request Feb 2, 2024
13 tasks
@swfsql swfsql marked this pull request as draft March 1, 2024 14:54
- Add the try_normalize_rms related functions.
- Add the `LayerRMSNorm1D` module.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants