Skip to content

Commit

Permalink
Copybara import of the project:
Browse files Browse the repository at this point in the history
--
1f27faa by Cristian Garcia <cgarcia.e88@gmail.com>:

[nnx] improve nnx basics

PiperOrigin-RevId: 625789767
  • Loading branch information
Cristian Garcia authored and Flax Authors committed Apr 17, 2024
1 parent 9973693 commit 1fbe17d
Show file tree
Hide file tree
Showing 11 changed files with 398 additions and 614 deletions.
7 changes: 0 additions & 7 deletions flax/experimental/nnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,15 +33,8 @@
from .nnx.module import M as M
from .nnx.module import Module as Module
from .nnx.graph_utils import merge as merge
from .nnx.graph_utils import full_merge as full_merge
from .nnx.graph_utils import split as split
from .nnx.graph_utils import full_split as full_split
from .nnx.graph_utils import update as update
from .nnx.graph_utils import full_update as full_update
from .nnx.graph_utils import clone as clone
from .nnx.graph_utils import pop as pop
from .nnx.rnglib import init as init
from .nnx.rnglib import empty as empty
from .nnx.nn import initializers as initializers
from .nnx.nn.activations import celu as celu
from .nnx.nn.activations import elu as elu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,5 @@ def __call__(self, x):
# split the parameters into trainable and non-trainable parameters
trainable_params, non_trainable, static = model.split(is_trainable, ...)

print(
'trainable_params =',
jax.tree_util.tree_map(jax.numpy.shape, trainable_params),
)
print(
'non_trainable = ', jax.tree_util.tree_map(jax.numpy.shape, non_trainable)
)
print('trainable_params =', jax.tree_util.tree_map(jax.numpy.shape, trainable_params))
print('non_trainable = ', jax.tree_util.tree_map(jax.numpy.shape, non_trainable))
Loading

0 comments on commit 1fbe17d

Please sign in to comment.