Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Flax NNX GSPMD guide #4220

Merged
merged 1 commit into from
Sep 24, 2024
Merged

Flax NNX GSPMD guide #4220

merged 1 commit into from
Sep 24, 2024

Conversation

IvyZX
Copy link
Collaborator

@IvyZX IvyZX commented Sep 23, 2024

Add a guide to do GSPMD-style sharding annotation on NNX models.

Covered everything in the Linen pjit guide, but better explanations and demonstrations, and more concise code!

Also added a small example for loading sharded model from checkpoint.

Preview

@IvyZX IvyZX requested a review from cgarciae September 23, 2024 23:13
Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB


+++

## Flax and `jax.jit` scaled up
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should change the writing here to talk about nnx.jit?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I kinda want to convey the idea that essentially we are using JAX's compilation machinery for the scaling up work. I renamed the title and added another paragraph explaining this (and mentioning nnx.jit there).

Comment on lines 94 to 100
self.w2 = nnx.Param(
nnx.with_partitioning(init_fn, ('model', None))(
rngs.params(), (depth, depth)) # RNG key and shape for W2 creation
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good opportunity to show how to manually add the sharding metadata:

Suggested change
self.w2 = nnx.Param(
nnx.with_partitioning(init_fn, ('model', None))(
rngs.params(), (depth, depth)) # RNG key and shape for W2 creation
)
self.w2 = nnx.Param(
init_fn(rngs.params(), (depth, depth)) # RNG key and shape for W2 creation
sharding=('model', None),
)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good idea!

# In data parallelism, input / intermediate value's first dimension (batch)
# will be sharded on `data` axis
y = jax.lax.with_sharding_constraint(y, PartitionSpec('data', 'model'))
z = jnp.dot(y, self.w2.value)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variables can be used as JAX arrays thanks to the __jax_aray__ protocol.

Suggested change
z = jnp.dot(y, self.w2.value)
z = jnp.dot(y, self.w2)

Copy link
Collaborator Author

@IvyZX IvyZX Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For some reason this will fail later when I do:

with mesh:
  output = sharded_model(input)

With error: AttributeError: 'tuple' object has no attribute '_device_assignment'.

I'll keep this as-is for now.

print(unsharded_model.w2.value.sharding) # SingleDeviceSharding
```

We should leverage JAX's compilation mechanism, aka. `jax.jit`, to create the sharded model. The key is to intialize a model and assign shardings upon the model state within a jitted function:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
We should leverage JAX's compilation mechanism, aka. `jax.jit`, to create the sharded model. The key is to intialize a model and assign shardings upon the model state within a jitted function:
We should leverage JAX's compilation mechanism, via `nnx.jit`, to create the sharded model. The key is to intialize a model and assign shardings upon the model state within a jitted function:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

Comment on lines 142 to 151
1. Call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) to bind the model state with the sharding annotations. This API tells the top-level `jax.jit` how to shard a variable!

1. Throw away the unsharded state and return the model based upon the sharded state.

1. Compile the whole function with `nnx.jit` instead of `jax.jit` because it allows the output to be a stateful NNX module.

1. Run it under a device mesh context so that JAX knows which devices to shard it to.
Copy link
Collaborator

@cgarciae cgarciae Sep 24, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggestion: replaced jax.jit with nnx.jit in the other points and remove the point where you suggest using nnx.jit instead of jax.jit.

Suggested change
1. Call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) to bind the model state with the sharding annotations. This API tells the top-level `jax.jit` how to shard a variable!
1. Throw away the unsharded state and return the model based upon the sharded state.
1. Compile the whole function with `nnx.jit` instead of `jax.jit` because it allows the output to be a stateful NNX module.
1. Run it under a device mesh context so that JAX knows which devices to shard it to.
1. Call [`jax.lax.with_sharding_constraint`](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.with_sharding_constraint.html) to bind the model state with the sharding annotations. This API tells the top-level `nnx.jit` how to shard a variable!
1. Throw away the unsharded state and return the model based upon the sharded state.
1. Run it under a device mesh context so that JAX knows which devices to shard it to.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm... I think we should still briefly explain why using nnx.jit is a better pattern. Especially since we are making transforms closer to JAX style now we should assume some users have experience with jax.jit. I can remove the mentions of jax.jit here and direct users more explicitly to nnx.jit.


Now, from initialization or from checkpoint, we have a sharded model. To carry out the compiled, scaled up training, we need to shard the inputs as well. In this data parallelism example, the training data has its batch dimension sharded across `data` device axis, so you should put your data in sharding `('data', None)`. You can use `jax.device_put` for this.

Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without `jax.jit`. See the example below - even without `jax.lax.with_sharding_constraint` on the output `y`, it was still sharded as `('data', None)`.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without `jax.jit`. See the example below - even without `jax.lax.with_sharding_constraint` on the output `y`, it was still sharded as `('data', None)`.
Note that with the correct sharding for all inputs, the output will be sharded in the most natural way even without `nnx.jit`. See the example below - even without `jax.lax.with_sharding_constraint` on the output `y`, it was still sharded as `('data', None)`.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done.

new_state = block_all(train_step(sharded_model, optimizer, input, label))
```

## Logical axis annotation
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! I didn't know about sharding_rules. In nnx_lm1b with have this other pattern which maps the mesh axes in the constructor:

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I added them recently to align with Linen's LogicallyPartitioned. It's just annotations so there's actually a ton of ways to make them work, and I like how you made it in nnx_lm1b!

@copybara-service copybara-service bot merged commit 9c86396 into google:main Sep 24, 2024
17 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants