Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[nnx] add transforms guide #4197

Merged
merged 1 commit into from
Sep 23, 2024
Merged

Conversation

cgarciae
Copy link
Collaborator

@cgarciae cgarciae commented Sep 16, 2024

What does this PR do?

Adds the Transforms guide.

Preview

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

Base automatically changed from nnx-transforms-guide to main September 16, 2024 20:32
@cgarciae cgarciae force-pushed the nnx-real-transforms-guide branch 7 times, most recently from f0d5f35 to f46c9ea Compare September 19, 2024 08:55
Copy link
Collaborator

@IvyZX IvyZX left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for making this guide! Super helpful and cool. Just a few nits on wordings.

+++

### Graph updates propagation
JAX models inputs to transformations as trees, Flax NNX models inputs as graphs to allow for sharing references. However, to express most of Python's object model Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local (updates to globals inside transforms are not supported). This means that you can modify graph structure as needed, including updating existing attributes, adding/deleting attributes, swapping attributes, sharing (new) references between objects, sharing Variables between objects, etc. The sky is the limit!
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JAX models inputs to transformations as trees, Flax NNX models inputs as graphs to allow for sharing references.

A bit hard to read - maybe:
JAX transformations see inputs as trees of arrays, and Flax NNX see inputs as graphs of Python references.

However, to express most of Python's object model Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local (updates to globals inside transforms are not supported).

This line also a bit verbose? Maybe just:
Flax NNX's state propagation machinery can track arbitrary updates to the objects as long as they're local to the input graph (updates to globals inside transforms are not supported).

)
x = jax.random.normal(random.key(1), (10, 2))

def crazy_vector_dot(weights: Weights, x: jax.Array):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not here, but I was hoping to see an example of transforming and using an nnx.Module method to showcase that it works and can be a natural pattern for users to take, since most transforms happen not at top level but in-between two layer definitions.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Its a good point. I'll add a variation of the first example using vmap over __call__ so users know that its possible.

> With great power comes great responsibility.
> <br> \- Uncle Ben

While this feature is very powerful, it must be used with care as it can clash with JAX's underlying assumptions for certain transformations. For example, `jit` expects the structure of the inputs to be stable in order to cache the compiled function, changing the graph structure inside a `nnx.jit`-ed function cause continuous recompilations and performance degradation, `scan` on the other hand only allows a fixed `carry` structure, so adding/removing substates declared as carry will cause an error.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit:

For example, jit expects the structure of the inputs to be stable in order to cache the compiled function, changing the graph structure inside a nnx.jit-ed function cause continuous recompilations and performance degradation, scan on the other hand only allows a fixed carry structure, so adding/removing substates declared as carry will cause an error.

For example, jit expects the structure of the inputs to be stable in order to cache the compiled function, so changing the graph structure inside a nnx.jit-ed function cause continuous recompilations and performance degradation. On the other hand, scan only allows a fixed carry structure, so adding/removing substates declared as carry will cause an error.

)
x = jax.random.normal(random.key(1), (10, 2))


Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's probably better to only call vmap once when only one call is needed, to avoid confusion. Same for the example below.

state_axes = nnx.StateAxes({nnx.Param: 0, Count: None}) # broadcast Count
@nnx.vmap(in_axes=(state_axes, 0), out_axes=1)
def stateful_vector_dot(weights: Weights, x: jax.Array):
  assert weights.kernel.ndim == 2, 'Batch dimensions not allowed'
  assert x.ndim == 1, 'Batch dimensions not allowed'
  weights.count += 1
  return x @ weights.kernel + weights.bias

y = stateful_vector_dot(weights, x)
y = stateful_vector_dot(weights, x)

+++

### Random State
In Flax NNX random state is just regular state. This means that its stored inside Modules that need it and its treated as any other type of state. This is a simplification over Flax Linen where random state was handled by a separate mechanism. In practice Modules usually keep that need random state simply need a references to a `Rngs` object that is passed to them during initialization, and
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In practice Modules usually keep that need random state simply need a references to a Rngs object that is passed to them during initialization, and use it to generate a unique key for each random operation.

What about:
In practice Modules simply need to keep a reference to a Rngs object that is passed to them during initialization, and use it to generate a unique key for each random operation.

@cgarciae
Copy link
Collaborator Author

Thanks @IvyZX for the detailed feedback. I've integrated all the suggestions.

@copybara-service copybara-service bot merged commit b2277ab into main Sep 23, 2024
19 checks passed
@copybara-service copybara-service bot deleted the nnx-real-transforms-guide branch September 23, 2024 16:35
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants