Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add Randomness guide #4216

Merged
merged 1 commit into from
Sep 26, 2024
Merged

[nnx] add Randomness guide #4216

merged 1 commit into from
Sep 26, 2024

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Sep 22, 2024

What does this PR do?

Adds Randoness guide.

Preview

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@cgarciae cgarciae force-pushed the nnx-randomness-guide branch 5 times, most recently from de289a6 to d9bd85c Compare September 24, 2024 10:49
Copy link
Collaborator

@IvyZX IvyZX left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great work! Very informative!

```

## Rngs, RngStream, and RngState
Flax NNX provides the `nnx.Rngs` type a the main convenice API for managing random state. Following Flax Linen's footsteps, `Rngs` has the ability to create multiple named RNG streams, each with its own state, for the purpose of allowing for tight control over randomness in the context of JAX transforms. Here's a breakdown of the main RNG-related types in Flax NNX:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:
convenice -> convenience
for the purpose of allowing for tight control -> for the purpose of tight control

## Rngs, RngStream, and RngState
Flax NNX provides the `nnx.Rngs` type a the main convenice API for managing random state. Following Flax Linen's footsteps, `Rngs` has the ability to create multiple named RNG streams, each with its own state, for the purpose of allowing for tight control over randomness in the context of JAX transforms. Here's a breakdown of the main RNG-related types in Flax NNX:

* **Rngs**: The main user interface. It defines a set of named `RngStream` objects.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add nnx. to all these titles? like nnx.Rngs, nnx.RngStream, nnx.RngState, nnx.RngKey, nnx.RngCount.

Note that the `key` attribute does not change when a new keys are generated.

### Standard stream names
There are only two standard stream names used by Flax NNX, shown in the table below:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

used by Flax NNX -> used by Flax NNX built-in layers

Just to avoid user confusion that other names are not allowed.

| `params` | Used for parameter initialization |
| `dropout`| Used by `Dropout` to create dropout masks |

`params` is used my most of the standard layers (`Linear`, `Conv`, `MultiHeadAttention`, etc.) during construction to initialize their parameters. `dropout` is used by the `Dropout` and `MultiHeadAttention` to generate dropout masks. Here's a simple example of a model using `params` and `dropout` streams:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: etc. -> etc

print(e)
```

The other option is to use the `nnx.split_rngs` decorator which will automatically split the random state of any `RngStream`s found in the inputs of the function, and will automatically "lower" them once the function call ends so the `Rngs` can be used outside again. `split_rngs` allows passing Filters to the `only` keyword argument to select the `RngStream`s that should be split. Using `split_rngs` is useful in combination with a transform but here we will show a simple example without any transforms to illustrate the concept, we'll use `split_rngs` with a transform on the next section.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have a feeling that we should introduce and suggest nnx.split_rngs first, since it works better with transform cases and doesn't have the drawback that splitting in stream has (aka. can't work outside transform).

Actually, if there isn't any case in which splitting in stream is better, maybe it's fine to simply avoid introducing that... or introducing that as a lower-level view of what nnx.split_rngs actually does.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I'll think about it. We might as well just delete the other example as stick with split_rngs.
The main motivation of showing how to manually lift was to make it more familiar as the operation is explicit but it might be less used.

```

## Transforms
As stated before, in Flax NNX random state is just another type of state, this means that there is nothing special about it regarding transforms. This means that you should be able to use the state handling APIs of each transform to get the results you want. In this section we will two examples of using random state in transforms, one with `pmap` and another one with `scan`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one with pmap and another one with scan.

What about this to highlight the intention of both examples:
one with pmap to split multiple RNG keys and one with scan to broadcast a single RNG key.

@copybara-service copybara-service bot merged commit 8b37d1a into main Sep 26, 2024
19 checks passed
@copybara-service copybara-service bot deleted the nnx-randomness-guide branch September 26, 2024 14:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants