Skip to content

Commit

Permalink
Update NNX State.split method docs in statelib.py
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Dec 17, 2024
1 parent fc38f21 commit 3fe09b9
Showing 1 changed file with 9 additions and 8 deletions.
17 changes: 9 additions & 8 deletions flax/nnx/statelib.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,10 +264,10 @@ def split(
def split( # type: ignore[misc]
self, first: filterlib.Filter, /, *filters: filterlib.Filter
) -> tp.Union[State[K, V], tuple[State[K, V], ...]]:
"""Split a ``State`` into one or more ``State``'s. The
user must pass at least one ``Filter`` (i.e. :class:`Variable`),
and the filters must be exhaustive (i.e. they must cover all
:class:`Variable` types in the ``State``).
"""Splits a :class:`flax.nnx.State` into one or more ``nnx.State``'s.
You must pass at least one NNX ``Filter`` (``flax.nnx.filterlib``)
(i.e. :class:`flax.nnx.Variable`), and the ``Filter``'s must be exhaustive
(i.e. they must cover all ``nnx.Variable`` types in the ``nnx.State``).
Example usage::
Expand All @@ -285,10 +285,11 @@ def split( # type: ignore[misc]
>>> param, batch_stats = state.split(nnx.Param, nnx.BatchStat)
Arguments:
first: The first filter
*filters: The optional, additional filters to group the state into mutually exclusive substates.
first: The first NNX ``Filter``.
*filters: The optional, additional NNX ``Filter``'s to group the
:class:`flax.nnx.State` into mutually exclusive substates.
Returns:
One or more ``States`` equal to the number of filters passed.
One or more ``nnx.State``'s equal to the number of NNX ``Filter``'s passed.
"""
filters = (first, *filters)
*states_, rest = _split_state(self.flat_state(), *filters)
Expand Down Expand Up @@ -492,4 +493,4 @@ def create_path_filters(state: State):
if isinstance(value, (variablelib.Variable, variablelib.VariableState)):
value = value.value
value_paths.setdefault(value, set()).add(path)
return {filterlib.PathIn(*value_paths[value]): value for value in value_paths}
return {filterlib.PathIn(*value_paths[value]): value for value in value_paths}

0 comments on commit 3fe09b9

Please sign in to comment.