-
Notifications
You must be signed in to change notification settings - Fork 645
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[nnx] support pure dicts #4352
[nnx] support pure dicts #4352
Conversation
3872314
to
627bcdc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Super cool!
@@ -86,6 +86,7 @@ def test_step(model: MLP, batch): | |||
total_steps = 10_000 | |||
for step, batch in enumerate(dataset(32)): | |||
train_step(model, optimizer, batch) | |||
print(nnx.graph.GRAPH_CONTEXT) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove print
flax/nnx/graph.py
Outdated
value = value.to_variable() | ||
children[key] = value | ||
elif noderef.index in index_ref: | ||
# if not is_state_leaf(value): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: remove comments
flax/nnx/graph.py
Outdated
f'Expected a Variable type for {key!r}, but got {type(value)}.' | ||
) | ||
assert isinstance(variabledef, VariableDef) | ||
# if not isinstance(value, VariableState): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: Same here
627bcdc
to
97c3e9c
Compare
What does this PR do?
nnx.merge
andnnx.update
can now use pure dictionaries returned fromState.to_pure_dict
.NodeDef.leaves
now contain a new type calledVariableDef
which contains the static information ofVariable
s, this makes it possible to reconstruct aVariables
from states that dont containVariableState
s as its the case with dictionaries returned byto_pure_dict
.Example: