Skip to content

Commit

Permalink
[nnx] fix ToLinen kwargs
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 695920522
  • Loading branch information
Cristian Garcia authored and Flax Authors committed Nov 13, 2024
1 parent d31f290 commit 86ff7af
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions flax/nnx/bridge/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
import typing as tp
from typing import Any

from flax import nnx
from flax import linen
from flax import nnx
from flax.core import FrozenDict
from flax.core import meta
from flax.nnx import graph
from flax.nnx.bridge import variables as bv
from flax.nnx.module import GraphDef, Module
from flax.nnx.object import Object
from flax.nnx.rnglib import Rngs
from flax.nnx.statelib import State
from flax.nnx.object import Object
import jax
from jax import tree_util as jtu

Expand Down Expand Up @@ -220,7 +221,7 @@ class ToLinen(linen.Module):
"""
nnx_class: tp.Callable[..., Module]
args: tp.Sequence = ()
kwargs: tp.Mapping = dataclasses.field(default_factory=dict)
kwargs: tp.Mapping[str, tp.Any] = FrozenDict({})
skip_rng: bool = False
metadata_type: tp.Type = bv.NNXMeta

Expand Down Expand Up @@ -277,4 +278,4 @@ def _update_variables(self, module):
def to_linen(nnx_class: tp.Callable[..., Module], *args,
name: str | None = None, **kwargs):
"""Shortcut of `nnx.bridge.ToLinen` if user is not changing any of its default fields."""
return ToLinen(nnx_class, args=args, kwargs=kwargs, name=name)
return ToLinen(nnx_class, args=args, kwargs=FrozenDict(kwargs), name=name)

0 comments on commit 86ff7af

Please sign in to comment.