-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Upgrade Flax NNX Filters doc #4199
base: main
Are you sure you want to change the base?
Conversation
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
cea5fb0
to
2d64500
Compare
57a361e
to
239c9de
Compare
ef24a1e
to
7c5afdb
Compare
a73e710
to
e5586f4
Compare
e5586f4
to
a0a19ed
Compare
@@ -31,28 +37,29 @@ print(f'{params = }') | |||
print(f'{batch_stats = }') | |||
``` | |||
|
|||
Here `nnx.Param` and `nnx.BatchStat` are used as Filters to split the model into two groups: one with the parameters and the other with the batch statistics. However, this begs the following questions: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't agree removing this explanation. Can you explain why its removed?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was not removed - it's above the code block on line 22. LMKWYT.
In the following example [`nnx.Param`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.Param)
and [`nnx.BatchStat`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/variables.html#flax.nnx.BatchStat)
are used as `Filter`s to split the model into two groups: one
with the parameters and the other with the batch statistics:
```{code-cell} ipython3
from flax import nnx
class Foo(nnx.Module):
def __init__(self):
self.a = nnx.Param(0)
self.b = nnx.BatchStat(True)
foo = Foo()
graphdef, params, batch_stats = nnx.split(foo, nnx.Param, nnx.BatchStat)
print(f'{params = }')
print(f'{batch_stats = }')
@cgarciae
@@ -8,12 +8,18 @@ jupytext: | |||
jupytext_version: 1.13.8 | |||
--- | |||
|
|||
# Using Filters | |||
# Using `Filter`s |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can remove codefont Markdown formatting from section titles similar to Flax Basics @cgarciae
No description provided.