From 4a04f754ca8a009e0660b359d40bfc4b695275a5 Mon Sep 17 00:00:00 2001 From: IvyZX Date: Mon, 16 Sep 2024 16:56:48 -0700 Subject: [PATCH] Align bridge variable tree structures --- docs_nnx/guides/bridge_guide.ipynb | 388 +++++++++++++---------------- docs_nnx/guides/bridge_guide.md | 264 +++++++------------- flax/core/meta.py | 4 +- flax/nnx/bridge/variables.py | 74 +++++- flax/nnx/bridge/wrappers.py | 39 ++- tests/nnx/bridge/wrappers_test.py | 143 +++++++++-- tests/nnx/graph_utils_test.py | 4 +- 7 files changed, 477 insertions(+), 439 deletions(-) diff --git a/docs_nnx/guides/bridge_guide.ipynb b/docs_nnx/guides/bridge_guide.ipynb index 5db5904aa1..25967dd8c5 100644 --- a/docs_nnx/guides/bridge_guide.ipynb +++ b/docs_nnx/guides/bridge_guide.ipynb @@ -4,9 +4,9 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "# Use Flax NNX along with Flax Linen\n", + "# Use Flax NNX and Linen together\n", "\n", - "This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API. \n", + "This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API.\n", "\n", "This will be helpful if you:\n", "\n", @@ -15,7 +15,7 @@ "\n", "We hope this allows you to move and try out NNX at your own pace, and leverage the best of both worlds. We will also talk about how to resolve the caveats of interoperating the two APIs, on a few aspects that they are fundamentally different.\n", "\n", - "**Note**: \n", + "**Note**:\n", "\n", "This guide is about glueing Linen and NNX modules. To migrate an existing Linen module to NNX, check out the [Migrate from Flax Linen to Flax NNX](https://flax.readthedocs.io/en/latest/nnx/haiku_linen_vs_nnx.html) guide. \n", "\n", @@ -90,34 +90,28 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "ToNNX(\n", - " module=LinenDot(\n", - " # attributes\n", - " out_dim = 64\n", - " w_init = init\n", - " ),\n", - " rngs=Rngs(\n", - " default=RngStream(\n", - " key=RngKey(\n", - " value=Array((), dtype=key) overlaying:\n", - " [0 0],\n", - " tag='default'\n", - " ),\n", - " count=RngCount(\n", - " value=Array(1, dtype=uint32),\n", - " tag='default'\n", - " )\n", - " )\n", - " ),\n", - " linen_collections=('params',),\n", - " params={'w': Param(\n", - " value=Array(shape=(32, 64), dtype=float32)\n", - " )}\n", - ")\n" - ] + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -131,14 +125,15 @@ " return x @ w\n", "\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", - "model = bridge.ToNNX(LinenDot(64), rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen\n", - "bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen\n", - "y = model(x) # => `y = model.apply(var, x)` in Linen\n", + "model = bridge.ToNNX(LinenDot(64), \n", + " rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen\n", + "bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen\n", + "y = model(x) # => `y = model.apply(var, x)` in Linen\n", "\n", "nnx.display(model)\n", "\n", "# In-place swap your weight array and the model still works!\n", - "model.params['w'].value = jax.random.normal(jax.random.key(1), (32, 64))\n", + "model.w.value = jax.random.normal(jax.random.key(1), (32, 64))\n", "assert not jnp.allclose(y, model(x))" ] }, @@ -155,39 +150,28 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "NNXOuter(\n", - " dot=ToNNX(\n", - " module=LinenDot(\n", - " # attributes\n", - " out_dim = 64\n", - " w_init = init\n", - " ),\n", - " rngs=Rngs(\n", - " default=RngStream(\n", - " key=RngKey(\n", - " value=Array((), dtype=key) overlaying:\n", - " [0 0],\n", - " tag='default'\n", - " ),\n", - " count=RngCount(\n", - " value=Array(1, dtype=uint32),\n", - " tag='default'\n", - " )\n", - " )\n", - " ),\n", - " linen_collections=('params',),\n", - " params={'w': Param(\n", - " value=Array(shape=(32, 64), dtype=float32)\n", - " )}\n", - " ),\n", - " b=Param(\n", - " value=Array(shape=(1, 64), dtype=float32)\n", - " )\n", - ")\n" - ] + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -200,7 +184,7 @@ " return self.dot(x) + self.b\n", "\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", - "model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit them into one line too\n", + "model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit into one line\n", "nnx.display(model)" ] }, @@ -219,8 +203,8 @@ "metadata": {}, "outputs": [], "source": [ - "assert isinstance(model.dot.params['w'], nnx.Param)\n", - "assert isinstance(model.dot.params['w'].value, jax.Array)" + "assert isinstance(model.dot.w, nnx.Param)\n", + "assert isinstance(model.dot.w.value, jax.Array)" ] }, { @@ -236,36 +220,28 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "NNXOuter(\n", - " dot=ToNNX(\n", - " module=LinenDot(\n", - " # attributes\n", - " out_dim = 64\n", - " w_init = init\n", - " ),\n", - " rngs=Rngs(\n", - " default=RngStream(\n", - " key=RngKey(\n", - " value=Array((), dtype=key) overlaying:\n", - " [0 0],\n", - " tag='default'\n", - " ),\n", - " count=RngCount(\n", - " value=Array(1, dtype=uint32),\n", - " tag='default'\n", - " )\n", - " )\n", - " ),\n", - " linen_collections=()\n", - " ),\n", - " b=Param(\n", - " value=Array(shape=(1, 64), dtype=float32)\n", - " )\n", - ")\n" - ] + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -279,39 +255,28 @@ "metadata": {}, "outputs": [ { - "name": "stdout", - "output_type": "stream", - "text": [ - "NNXOuter(\n", - " dot=ToNNX(\n", - " module=LinenDot(\n", - " # attributes\n", - " out_dim = 64\n", - " w_init = init\n", - " ),\n", - " rngs=Rngs(\n", - " default=RngStream(\n", - " key=RngKey(\n", - " value=Array((), dtype=key) overlaying:\n", - " [0 0],\n", - " tag='default'\n", - " ),\n", - " count=RngCount(\n", - " value=Array(1, dtype=uint32),\n", - " tag='default'\n", - " )\n", - " )\n", - " ),\n", - " linen_collections=('params',),\n", - " params={'w': Param(\n", - " value=Array(shape=(32, 64), dtype=float32)\n", - " )}\n", - " ),\n", - " b=Param(\n", - " value=Array(shape=(1, 64), dtype=float32)\n", - " )\n", - ")\n" - ] + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" } ], "source": [ @@ -348,18 +313,20 @@ "source": [ "class NNXDot(nnx.Module):\n", " def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs):\n", - " self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (in_dim, out_dim)))\n", + " self.w = nnx.Param(nnx.initializers.lecun_normal()(\n", + " rngs.params(), (in_dim, out_dim)))\n", " def __call__(self, x: jax.Array):\n", " return x @ self.w\n", "\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", - "model = bridge.to_linen(NNXDot, 32, out_dim=64) # <- Pass in the arguments, not an actual module\n", + "# Pass in the arguments, not an actual module\n", + "model = bridge.to_linen(NNXDot, 32, out_dim=64)\n", "variables = model.init(jax.random.key(0), x)\n", "y = model.apply(variables, x)\n", "\n", "print(list(variables.keys()))\n", - "print(variables['params']['w'].value.shape) # => (32, 64)\n", - "print(y.shape) # => (4, 64)\n" + "print(variables['params']['w'].shape) # => (32, 64)\n", + "print(y.shape) # => (4, 64)\n" ] }, { @@ -406,39 +373,12 @@ " def __call__(self, x):\n", " return x + self.constant\n", "\n", - "# You have to use `skip_rng=True` because your module `__init__` don't take `rng` as argument\n", + "# You have to use `skip_rng=True` because this module's `__init__` don't\n", + "# take `rng` as argument\n", "model = bridge.ToLinen(NNXAddConstant, skip_rng=True)\n", "y, var = model.init_with_output(jax.random.key(0), x)" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "You may notice that you need to an additional `.value` to access this Flax `w` param. This is because all NNX variables will be wrapped with an `nnx.Variable` class, which will allow it to be annotated with various information, such as its partitioning. This was translated into an equivalent `nnx.bridge.NNXMeta` wrapper. \n", - "\n", - "If you use [partition metadata in Linen](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html), you can learn more about how that works in NNX in [Partition Metadata Section](#partition-metadata) below." - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "\n", - "\n" - ] - } - ], - "source": [ - "print(type(variables['params']['w'])) # => nnx.bridge.NNXMeta\n", - "print(type(variables['params']['w'].value)) # => jax.Array" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -448,7 +388,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 10, "metadata": {}, "outputs": [ { @@ -471,7 +411,7 @@ "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "model = LinenOuter(out_dim=64)\n", "y, variables = model.init_with_output(jax.random.key(0), x)\n", - "w, b = variables['params']['ToLinen_0']['w'].value, variables['params']['b']\n", + "w, b = variables['params']['ToLinen_0']['w'], variables['params']['b']\n", "print(w.shape, b.shape, y.shape)" ] }, @@ -497,14 +437,15 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "model = bridge.ToNNX(nn.Dropout(rate=0.5, deterministic=False), rngs=nnx.Rngs(dropout=0))\n", - "bridge.lazy_init(model, x) # We don't really need this b/c no extra params were created here,\n", - " # but it's a good practice to always add this line.\n", + "# We don't really need to call lazy_init because no extra params were created here,\n", + "# but it's a good practice to always add this line.\n", + "bridge.lazy_init(model, x)\n", "y1, y2 = model(x), model(x)\n", "assert not jnp.allclose(y1, y2) # Two runs yield different outputs!\n", "\n", @@ -532,7 +473,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 12, "metadata": {}, "outputs": [ { @@ -597,7 +538,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": 13, "metadata": {}, "outputs": [ { @@ -640,13 +581,13 @@ "\n", "x = jax.random.normal(jax.random.key(42), (2, 4))\n", "model = bridge.lazy_init(bridge.ToNNX(LinenMultiCollections(3), rngs=nnx.Rngs(0)), x)\n", - "print(model.params['w']) # Of type `nnx.Param` - note this is still under attribute `params`\n", - "print(model.params['b']) # Of type `nnx.Param`\n", - "print(model.counter['count']) # Of type `counter` - an auto-created dummy type from the name \"counter\"\n", - "print(type(model.counter['count']))\n", + "print(model.w) # Of type `nnx.Param` - note this is still under attribute `params`\n", + "print(model.b) # Of type `nnx.Param`\n", + "print(model.count) # Of type `counter` - auto-created type from the collection name\n", + "print(type(model.count))\n", "\n", - "y = model(x, mutable=True) # Linen's `sow()` needs `mutable=True` to trigger\n", - "print(model.intermediates['dot_sum']) # Of type `nnx.Intermediates`" + "y = model(x, mutable=True) # Linen's `sow()` needs `mutable=True` to trigger\n", + "print(model.dot_sum) # Of type `nnx.Intermediates`" ] }, { @@ -660,7 +601,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": 14, "metadata": {}, "outputs": [ { @@ -669,19 +610,20 @@ "text": [ "All Params: ['b', 'w']\n", "All Counters: ['count']\n", - "All the rest (intermediates and RNG keys): ['intermediates', 'rngs']\n" + "All the rest (intermediates and RNG keys): ['dot_sum', 'rngs']\n" ] } ], "source": [ "# Separate variables of different types with nnx.split\n", - "CountType = type(model.counter['count'])\n", + "CountType = type(model.count)\n", "static, params, counter, the_rest = nnx.split(model, nnx.Param, CountType, ...)\n", - "print('All Params:', list(params['params'].keys()))\n", - "print('All Counters:', list(counter['counter'].keys()))\n", + "print('All Params:', list(params.keys()))\n", + "print('All Counters:', list(counter.keys()))\n", "print('All the rest (intermediates and RNG keys):', list(the_rest.keys()))\n", "\n", - "model = nnx.merge(static, params, counter, the_rest) # You can merge them back at any time" + "model = nnx.merge(static, params, counter, the_rest) # You can merge them back at any time\n", + "y = model(x, mutable=True) # still works!" ] }, { @@ -695,7 +637,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": 15, "metadata": {}, "outputs": [ { @@ -703,10 +645,10 @@ "output_type": "stream", "text": [ "All Linen collections: ['nnx', 'LoRAParam', 'params', 'counts']\n", - "{'w': NNXMeta(var_type=, value=Array([[ 0.2916921 , 0.22780475, 0.06553137],\n", + "{'w': Array([[ 0.2916921 , 0.22780475, 0.06553137],\n", " [ 0.17487915, -0.34043145, 0.24764155],\n", " [ 0.6420431 , 0.6220095 , -0.44769976],\n", - " [ 0.11161668, 0.83873135, -0.7446058 ]], dtype=float32), metadata={'get_value_hooks': (), 'set_value_hooks': (), 'create_value_hooks': (), 'add_axis_hooks': (), 'remove_axis_hooks': ()})}\n" + " [ 0.11161668, 0.83873135, -0.7446058 ]], dtype=float32)}\n" ] } ], @@ -738,7 +680,7 @@ "source": [ "## Partition metadata\n", "\n", - "Flax uses a metadata wrapper box over the raw JAX array to annotate how a variable should be sharded. \n", + "Flax uses a metadata wrapper box over the raw JAX array to annotate how a variable should be sharded.\n", "\n", "In Linen, this is an optional feature that triggered by using `nn.with_partitioning` on initializers (see more on [Linen partition metadata guide](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html)). In NNX, since all NNX variables are wrapped by `nnx.Variable` class anyway, that class will hold the sharding annotations too. \n", "\n", @@ -760,7 +702,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": 16, "metadata": {}, "outputs": [ { @@ -779,13 +721,15 @@ " out_dim: int\n", " @nn.compact\n", " def __call__(self, x):\n", - " w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')), \n", + " w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(), \n", + " ('in', 'out')), \n", " (x.shape[-1], self.out_dim))\n", " return x @ w\n", "\n", "@nnx.jit\n", "def create_sharded_nnx_module(x):\n", - " model = bridge.lazy_init(bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x)\n", + " model = bridge.lazy_init(\n", + " bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x)\n", " state = nnx.state(model)\n", " sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state))\n", " nnx.update(model, sharded_state)\n", @@ -793,14 +737,15 @@ "\n", "\n", "print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...')\n", - "mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)), axis_names=('in', 'out'))\n", + "mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)), \n", + " axis_names=('in', 'out'))\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "with mesh:\n", " model = create_sharded_nnx_module(x)\n", "\n", - "print(type(model.params['w'])) # `nnx.Param`\n", - "print(model.params['w'].sharding) # The partition annotation attached with the weight `w`\n", - "print(model.params['w'].value.sharding) # The underlying JAX array is sharded across the 2x4 mesh" + "print(type(model.w)) # `nnx.Param`\n", + "print(model.w.sharding) # The partition annotation attached with `w`\n", + "print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh" ] }, { @@ -809,26 +754,23 @@ "source": [ "### NNX to Linen\n", "\n", - "Since all NNX variables are wrapped with `nnx.Variable` box, the converted Linen module will have all variables boxed too. We have a default Linen partition metadata class called `bridge.NNXMeta` to store these converted NNX variables.\n", + "If you are not using any metadata feature of the `nnx.Variable` (i.e., no sharding annotation, no registered hooks), the converted Linen module will not add a metadata wrapper to your NNX variable, and you don't need to worry about it.\n", "\n", - "`nnx.with_partitioning` will automatically shard the array with the annotation if it is called within a `jax.sharding.Mesh` context, so you don't need to do `with_sharding_constraint` yourself.\n", + "But if you did add sharding annotations to your NNX variables, `ToLinen` will convert them to a default Linen partition metadata class called `bridge.NNXMeta`, retaining all the metadata you put into the NNX variable.\n", "\n", "Like with any Linen metadata wrappers, you can use `linen.unbox()` to get the raw JAX array tree." ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\n", - "('in', 'out')\n", - "GSPMDSharding({devices=[2,4]<=[8]})\n", - "\n" + "GSPMDSharding({devices=[2,4]<=[8]})\n" ] } ], @@ -841,17 +783,31 @@ " return x @ self.w\n", "\n", "x = jax.random.normal(jax.random.key(42), (4, 32))\n", - "model = bridge.to_linen(NNXDotWithParititioning, 32, 64)\n", "\n", - "with mesh:\n", - " variables = jax.jit(model.init)(jax.random.key(0), x)\n", + "@jax.jit\n", + "def create_sharded_variables(key, x):\n", + " model = bridge.to_linen(NNXDotWithParititioning, 32, 64)\n", + " variables = model.init(key, x)\n", + " # A `NNXMeta` wrapper of the underlying `nnx.Param`\n", + " assert type(variables['params']['w']) == bridge.NNXMeta\n", + " # The annotation coming from the `nnx.Param` => (in, out)\n", + " assert variables['params']['w'].metadata['sharding'] == ('in', 'out')\n", + " \n", + " unboxed_variables = nn.unbox(variables)\n", + " variable_pspecs = nn.get_partition_spec(variables)\n", + " assert isinstance(unboxed_variables['params']['w'], jax.Array)\n", + " assert variable_pspecs['params']['w'] == jax.sharding.PartitionSpec('in', 'out')\n", + " \n", + " sharded_vars = jax.tree.map(jax.lax.with_sharding_constraint, \n", + " nn.unbox(variables),\n", + " nn.get_partition_spec(variables))\n", + " return sharded_vars\n", "\n", - "print(type(variables['params']['w'])) # A `NNXMeta` wrapper of the underlying `nnx.Param`\n", - "print(variables['params']['w'].metadata['sharding']) # The annotation coming from the `nnx.Param`\n", - "print(variables['params']['w'].value.sharding) # The underlying JAX array is sharded across the 2x4 mesh\n", + "with mesh:\n", + " variables = create_sharded_variables(jax.random.key(0), x)\n", "\n", - "unboxed = nn.unbox(variables)\n", - "print(type(unboxed['params']['w'])) # The raw jax.Array" + "# The underlying JAX array is sharded across the 2x4 mesh\n", + "print(variables['params']['w'].sharding)" ] }, { @@ -876,7 +832,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": 18, "metadata": {}, "outputs": [ { @@ -906,7 +862,7 @@ "x = jax.random.normal(jax.random.key(0), (4, 32))\n", "model = bridge.lazy_init(NNXVmapped(64, 4, rngs=nnx.Rngs(0)), x)\n", "\n", - "print(model.linen_dot.params['kernel'].shape) # (4, 32, 64) - first axis with dim 4 got vmapped\n", + "print(model.linen_dot.kernel.shape) # (4, 32, 64) - first axis with dim 4 got vmapped\n", "y = model(x)\n", "print(y.shape)" ] @@ -924,7 +880,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": 19, "metadata": {}, "outputs": [ { @@ -948,7 +904,7 @@ "x = jax.random.normal(jax.random.key(42), (4, 32))\n", "model = LinenVmapped(64)\n", "var = model.init(jax.random.key(0), x)\n", - "print(var['params']['VmapToLinen_0']['kernel'].value.shape) # (4, 32, 64) - leading dim 4 got vmapped\n", + "print(var['params']['VmapToLinen_0']['kernel'].shape) # (4, 32, 64) - leading dim 4 got vmapped\n", "y = model.apply(var, x)\n", "print(y.shape)" ] diff --git a/docs_nnx/guides/bridge_guide.md b/docs_nnx/guides/bridge_guide.md index 8b808d7f8f..cfc15b17f4 100644 --- a/docs_nnx/guides/bridge_guide.md +++ b/docs_nnx/guides/bridge_guide.md @@ -1,4 +1,4 @@ -# Use Flax NNX along with Flax Linen +# Use Flax NNX and Linen together This guide is for existing Flax users who want to make their codebase a mixture of Flax Linen and Flax NNX `Module`s, which is made possible thanks to the `flax.nnx.bridge` API. @@ -71,41 +71,24 @@ class LinenDot(nn.Module): return x @ w x = jax.random.normal(jax.random.key(42), (4, 32)) -model = bridge.ToNNX(LinenDot(64), rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen -bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen -y = model(x) # => `y = model.apply(var, x)` in Linen +model = bridge.ToNNX(LinenDot(64), + rngs=nnx.Rngs(0)) # => `model = LinenDot(64)` in Linen +bridge.lazy_init(model, x) # => `var = model.init(key, x)` in Linen +y = model(x) # => `y = model.apply(var, x)` in Linen nnx.display(model) # In-place swap your weight array and the model still works! -model.params['w'].value = jax.random.normal(jax.random.key(1), (32, 64)) +model.w.value = jax.random.normal(jax.random.key(1), (32, 64)) assert not jnp.allclose(y, model(x)) ``` - ToNNX( - module=LinenDot( - # attributes - out_dim = 64 - w_init = init - ), - rngs=Rngs( - default=RngStream( - key=RngKey( - value=Array((), dtype=key) overlaying: - [0 0], - tag='default' - ), - count=RngCount( - value=Array(1, dtype=uint32), - tag='default' - ) - ) - ), - linen_collections=('params',), - params={'w': Param( - value=Array(shape=(32, 64), dtype=float32) - )} - ) + +
+ + + +
`nnx.bridge.lazy_init` also works even if the top-level module is a pure-NNX one, so you can do sub-moduling as you wish: @@ -121,39 +104,16 @@ class NNXOuter(nnx.Module): return self.dot(x) + self.b x = jax.random.normal(jax.random.key(42), (4, 32)) -model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit them into one line too +model = bridge.lazy_init(NNXOuter(64, rngs=nnx.Rngs(0)), x) # Can fit into one line nnx.display(model) ``` - NNXOuter( - dot=ToNNX( - module=LinenDot( - # attributes - out_dim = 64 - w_init = init - ), - rngs=Rngs( - default=RngStream( - key=RngKey( - value=Array((), dtype=key) overlaying: - [0 0], - tag='default' - ), - count=RngCount( - value=Array(1, dtype=uint32), - tag='default' - ) - ) - ), - linen_collections=('params',), - params={'w': Param( - value=Array(shape=(32, 64), dtype=float32) - )} - ), - b=Param( - value=Array(shape=(1, 64), dtype=float32) - ) - ) + +
+ + + +
The Linen weight is already converted to a typical NNX variable, which is a thin wrapper of the actual JAX array value within. Here, `w` is an `nnx.Param` because it belongs to the `params` collection of `LinenDot` module. @@ -162,8 +122,8 @@ We will talk more about different collections and types in the [NNX Variable <-> ```python -assert isinstance(model.dot.params['w'], nnx.Param) -assert isinstance(model.dot.params['w'].value, jax.Array) +assert isinstance(model.dot.w, nnx.Param) +assert isinstance(model.dot.w.value, jax.Array) ``` If you create this model witout using `nnx.bridge.lazy_init`, the NNX variables defined outside will be initialized as usual, but the Linen part (wrapped inside `ToNNX`) will not. @@ -174,32 +134,12 @@ partial_model = NNXOuter(64, rngs=nnx.Rngs(0)) nnx.display(partial_model) ``` - NNXOuter( - dot=ToNNX( - module=LinenDot( - # attributes - out_dim = 64 - w_init = init - ), - rngs=Rngs( - default=RngStream( - key=RngKey( - value=Array((), dtype=key) overlaying: - [0 0], - tag='default' - ), - count=RngCount( - value=Array(1, dtype=uint32), - tag='default' - ) - ) - ), - linen_collections=() - ), - b=Param( - value=Array(shape=(1, 64), dtype=float32) - ) - ) + +
+ + + +
@@ -208,35 +148,12 @@ full_model = bridge.lazy_init(partial_model, x) nnx.display(full_model) ``` - NNXOuter( - dot=ToNNX( - module=LinenDot( - # attributes - out_dim = 64 - w_init = init - ), - rngs=Rngs( - default=RngStream( - key=RngKey( - value=Array((), dtype=key) overlaying: - [0 0], - tag='default' - ), - count=RngCount( - value=Array(1, dtype=uint32), - tag='default' - ) - ) - ), - linen_collections=('params',), - params={'w': Param( - value=Array(shape=(32, 64), dtype=float32) - )} - ), - b=Param( - value=Array(shape=(1, 64), dtype=float32) - ) - ) + +
+ + + +
### NNX -> Linen @@ -249,18 +166,20 @@ This is because NNX module instance initializes all its variables eagerly when i ```python class NNXDot(nnx.Module): def __init__(self, in_dim: int, out_dim: int, rngs: nnx.Rngs): - self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (in_dim, out_dim))) + self.w = nnx.Param(nnx.initializers.lecun_normal()( + rngs.params(), (in_dim, out_dim))) def __call__(self, x: jax.Array): return x @ self.w x = jax.random.normal(jax.random.key(42), (4, 32)) -model = bridge.to_linen(NNXDot, 32, out_dim=64) # <- Pass in the arguments, not an actual module +# Pass in the arguments, not an actual module +model = bridge.to_linen(NNXDot, 32, out_dim=64) variables = model.init(jax.random.key(0), x) y = model.apply(variables, x) print(list(variables.keys())) -print(variables['params']['w'].value.shape) # => (32, 64) -print(y.shape) # => (4, 64) +print(variables['params']['w'].shape) # => (32, 64) +print(y.shape) # => (4, 64) ``` @@ -290,25 +209,12 @@ class NNXAddConstant(nnx.Module): def __call__(self, x): return x + self.constant -# You have to use `skip_rng=True` because your module `__init__` don't take `rng` as argument +# You have to use `skip_rng=True` because this module's `__init__` don't +# take `rng` as argument model = bridge.ToLinen(NNXAddConstant, skip_rng=True) y, var = model.init_with_output(jax.random.key(0), x) ``` -You may notice that you need to an additional `.value` to access this Flax `w` param. This is because all NNX variables will be wrapped with an `nnx.Variable` class, which will allow it to be annotated with various information, such as its partitioning. This was translated into an equivalent `nnx.bridge.NNXMeta` wrapper. - -If you use [partition metadata in Linen](https://flax.readthedocs.io/en/latest/guides/parallel_training/flax_on_pjit.html), you can learn more about how that works in NNX in [Partition Metadata Section](#partition-metadata) below. - - -```python -print(type(variables['params']['w'])) # => nnx.bridge.NNXMeta -print(type(variables['params']['w'].value)) # => jax.Array -``` - - - - - Similar to `ToNNX`, you can use `ToLinen` to create a submodule of another Linen module. @@ -324,7 +230,7 @@ class LinenOuter(nn.Module): x = jax.random.normal(jax.random.key(42), (4, 32)) model = LinenOuter(out_dim=64) y, variables = model.init_with_output(jax.random.key(0), x) -w, b = variables['params']['ToLinen_0']['w'].value, variables['params']['b'] +w, b = variables['params']['ToLinen_0']['w'], variables['params']['b'] print(w.shape, b.shape, y.shape) ``` @@ -345,8 +251,9 @@ If you convert a Linen module to NNX, you enjoy the stateful benefit and don't n ```python x = jax.random.normal(jax.random.key(42), (4, 32)) model = bridge.ToNNX(nn.Dropout(rate=0.5, deterministic=False), rngs=nnx.Rngs(dropout=0)) -bridge.lazy_init(model, x) # We don't really need this b/c no extra params were created here, - # but it's a good practice to always add this line. +# We don't really need to call lazy_init because no extra params were created here, +# but it's a good practice to always add this line. +bridge.lazy_init(model, x) y1, y2 = model(x), model(x) assert not jnp.allclose(y1, y2) # Two runs yield different outputs! @@ -431,13 +338,13 @@ class LinenMultiCollections(nn.Module): x = jax.random.normal(jax.random.key(42), (2, 4)) model = bridge.lazy_init(bridge.ToNNX(LinenMultiCollections(3), rngs=nnx.Rngs(0)), x) -print(model.params['w']) # Of type `nnx.Param` - note this is still under attribute `params` -print(model.params['b']) # Of type `nnx.Param` -print(model.counter['count']) # Of type `counter` - an auto-created dummy type from the name "counter" -print(type(model.counter['count'])) +print(model.w) # Of type `nnx.Param` - note this is still under attribute `params` +print(model.b) # Of type `nnx.Param` +print(model.count) # Of type `counter` - auto-created type from the collection name +print(type(model.count)) -y = model(x, mutable=True) # Linen's `sow()` needs `mutable=True` to trigger -print(model.intermediates['dot_sum']) # Of type `nnx.Intermediates` +y = model(x, mutable=True) # Linen's `sow()` needs `mutable=True` to trigger +print(model.dot_sum) # Of type `nnx.Intermediates` ``` Param( @@ -465,18 +372,19 @@ This can be handy when you only want to set some variables as trainable. ```python # Separate variables of different types with nnx.split -CountType = type(model.counter['count']) +CountType = type(model.count) static, params, counter, the_rest = nnx.split(model, nnx.Param, CountType, ...) -print('All Params:', list(params['params'].keys())) -print('All Counters:', list(counter['counter'].keys())) +print('All Params:', list(params.keys())) +print('All Counters:', list(counter.keys())) print('All the rest (intermediates and RNG keys):', list(the_rest.keys())) model = nnx.merge(static, params, counter, the_rest) # You can merge them back at any time +y = model(x, mutable=True) # still works! ``` All Params: ['b', 'w'] All Counters: ['count'] - All the rest (intermediates and RNG keys): ['intermediates', 'rngs'] + All the rest (intermediates and RNG keys): ['dot_sum', 'rngs'] ### NNX to Linen @@ -507,10 +415,10 @@ print(var['params']) ``` All Linen collections: ['nnx', 'LoRAParam', 'params', 'counts'] - {'w': NNXMeta(var_type=, value=Array([[ 0.2916921 , 0.22780475, 0.06553137], + {'w': Array([[ 0.2916921 , 0.22780475, 0.06553137], [ 0.17487915, -0.34043145, 0.24764155], [ 0.6420431 , 0.6220095 , -0.44769976], - [ 0.11161668, 0.83873135, -0.7446058 ]], dtype=float32), metadata={'get_value_hooks': (), 'set_value_hooks': (), 'create_value_hooks': (), 'add_axis_hooks': (), 'remove_axis_hooks': ()})} + [ 0.11161668, 0.83873135, -0.7446058 ]], dtype=float32)} ## Partition metadata @@ -535,13 +443,15 @@ class LinenDotWithPartitioning(nn.Module): out_dim: int @nn.compact def __call__(self, x): - w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(), ('in', 'out')), + w = self.param('w', nn.with_partitioning(nn.initializers.lecun_normal(), + ('in', 'out')), (x.shape[-1], self.out_dim)) return x @ w @nnx.jit def create_sharded_nnx_module(x): - model = bridge.lazy_init(bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x) + model = bridge.lazy_init( + bridge.ToNNX(LinenDotWithPartitioning(64), rngs=nnx.Rngs(0)), x) state = nnx.state(model) sharded_state = nnx.with_sharding_constraint(state, nnx.get_partition_spec(state)) nnx.update(model, sharded_state) @@ -549,14 +459,15 @@ def create_sharded_nnx_module(x): print(f'We have {len(jax.devices())} fake JAX devices now to partition this model...') -mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)), axis_names=('in', 'out')) +mesh = jax.sharding.Mesh(devices=mesh_utils.create_device_mesh((2, 4)), + axis_names=('in', 'out')) x = jax.random.normal(jax.random.key(42), (4, 32)) with mesh: model = create_sharded_nnx_module(x) -print(type(model.params['w'])) # `nnx.Param` -print(model.params['w'].sharding) # The partition annotation attached with the weight `w` -print(model.params['w'].value.sharding) # The underlying JAX array is sharded across the 2x4 mesh +print(type(model.w)) # `nnx.Param` +print(model.w.sharding) # The partition annotation attached with `w` +print(model.w.value.sharding) # The underlying JAX array is sharded across the 2x4 mesh ``` We have 8 fake JAX devices now to partition this model... @@ -567,9 +478,9 @@ print(model.params['w'].value.sharding) # The underlying JAX array is sharded ### NNX to Linen -Since all NNX variables are wrapped with `nnx.Variable` box, the converted Linen module will have all variables boxed too. We have a default Linen partition metadata class called `bridge.NNXMeta` to store these converted NNX variables. +If you are not using any metadata feature of the `nnx.Variable` (i.e., no sharding annotation, no registered hooks), the converted Linen module will not add a metadata wrapper to your NNX variable, and you don't need to worry about it. -`nnx.with_partitioning` will automatically shard the array with the annotation if it is called within a `jax.sharding.Mesh` context, so you don't need to do `with_sharding_constraint` yourself. +But if you did add sharding annotations to your NNX variables, `ToLinen` will convert them to a default Linen partition metadata class called `bridge.NNXMeta`, retaining all the metadata you put into the NNX variable. Like with any Linen metadata wrappers, you can use `linen.unbox()` to get the raw JAX array tree. @@ -583,23 +494,34 @@ class NNXDotWithParititioning(nnx.Module): return x @ self.w x = jax.random.normal(jax.random.key(42), (4, 32)) -model = bridge.to_linen(NNXDotWithParititioning, 32, 64) -with mesh: - variables = jax.jit(model.init)(jax.random.key(0), x) +@jax.jit +def create_sharded_variables(key, x): + model = bridge.to_linen(NNXDotWithParititioning, 32, 64) + variables = model.init(key, x) + # A `NNXMeta` wrapper of the underlying `nnx.Param` + assert type(variables['params']['w']) == bridge.NNXMeta + # The annotation coming from the `nnx.Param` => (in, out) + assert variables['params']['w'].metadata['sharding'] == ('in', 'out') + + unboxed_variables = nn.unbox(variables) + variable_pspecs = nn.get_partition_spec(variables) + assert isinstance(unboxed_variables['params']['w'], jax.Array) + assert variable_pspecs['params']['w'] == jax.sharding.PartitionSpec('in', 'out') + + sharded_vars = jax.tree.map(jax.lax.with_sharding_constraint, + nn.unbox(variables), + nn.get_partition_spec(variables)) + return sharded_vars -print(type(variables['params']['w'])) # A `NNXMeta` wrapper of the underlying `nnx.Param` -print(variables['params']['w'].metadata['sharding']) # The annotation coming from the `nnx.Param` -print(variables['params']['w'].value.sharding) # The underlying JAX array is sharded across the 2x4 mesh +with mesh: + variables = create_sharded_variables(jax.random.key(0), x) -unboxed = nn.unbox(variables) -print(type(unboxed['params']['w'])) # The raw jax.Array +# The underlying JAX array is sharded across the 2x4 mesh +print(variables['params']['w'].sharding) ``` - - ('in', 'out') GSPMDSharding({devices=[2,4]<=[8]}) - ## Lifted transforms @@ -631,7 +553,7 @@ class NNXVmapped(nnx.Module): x = jax.random.normal(jax.random.key(0), (4, 32)) model = bridge.lazy_init(NNXVmapped(64, 4, rngs=nnx.Rngs(0)), x) -print(model.linen_dot.params['kernel'].shape) # (4, 32, 64) - first axis with dim 4 got vmapped +print(model.linen_dot.kernel.shape) # (4, 32, 64) - first axis with dim 4 got vmapped y = model(x) print(y.shape) ``` @@ -659,7 +581,7 @@ class LinenVmapped(nn.Module): x = jax.random.normal(jax.random.key(42), (4, 32)) model = LinenVmapped(64) var = model.init(jax.random.key(0), x) -print(var['params']['VmapToLinen_0']['kernel'].value.shape) # (4, 32, 64) - leading dim 4 got vmapped +print(var['params']['VmapToLinen_0']['kernel'].shape) # (4, 32, 64) - leading dim 4 got vmapped y = model.apply(var, x) print(y.shape) ``` diff --git a/flax/core/meta.py b/flax/core/meta.py index 531b463c7d..eca56ffb7c 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -337,7 +337,7 @@ def get_partition_spec(tree: Any) -> Any: """Extracts a PartitionSpec tree from a PyTree containing ``Partitioned`` values.""" def f(x): - if isinstance(x, Partitioned): + if hasattr(x, 'get_partition_spec'): return x.get_partition_spec() # Unboxed arrays, which should be replicated across all devices elif hasattr(x, 'shape'): @@ -346,7 +346,7 @@ def f(x): return None return jax.tree_util.tree_map( - f, tree, is_leaf=lambda x: isinstance(x, Partitioned) + f, tree, is_leaf=lambda x: isinstance(x, AxisMetadata) ) diff --git a/flax/nnx/bridge/variables.py b/flax/nnx/bridge/variables.py index d73f645f3b..3e799bf4db 100644 --- a/flax/nnx/bridge/variables.py +++ b/flax/nnx/bridge/variables.py @@ -12,12 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +from collections import defaultdict from typing import Any, TypeVar import jax from flax import struct from flax.core import meta +from flax.nnx import spmd +from flax.nnx import traversals from flax.nnx import variables as variableslib +from flax.nnx.module import GraphDef import typing as tp @@ -105,6 +109,28 @@ def remove_axis(self, index: int, params: dict[Any, Any]) -> 'NNXMeta[A]': # TODO: implement this, supporting hooks return self + def get_partition_spec(self) -> jax.sharding.PartitionSpec: + """Returns the ``Partitionspec`` for this partitioned value.""" + nnx_var = self.to_nnx_variable().to_state() + return spmd.get_partition_spec(nnx_var).value + + def to_nnx_variable(self) -> variableslib.Variable: + return self.var_type(self.value, **self.metadata) + + +def is_vanilla_variable(vs: variableslib.VariableState) -> bool: + """A variables state is vanilla if its metadata is essentially blank. + + Returns False only if it has non-empty hooks or any non-built-in attribute. + """ + for key, value in vs.get_metadata().items(): + if key.endswith('_hooks'): + if value != (): + return False + else: + return False + return True + def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata: metadata = vs.get_metadata() @@ -113,6 +139,8 @@ def to_linen_var(vs: variableslib.VariableState) -> meta.AxisMetadata: if hasattr(linen_type, 'from_nnx_metadata'): return linen_type.from_nnx_metadata({'value': vs.value, **metadata}) return linen_type(vs.value, **metadata) + if is_vanilla_variable(vs): + return vs.value return NNXMeta(vs.type, vs.value, metadata) @@ -128,11 +156,53 @@ def to_nnx_var(col: str, x: meta.AxisMetadata | Any) -> variableslib.Variable: vtype = variable_type(col) if isinstance(x, NNXMeta): assert vtype == x.var_type, f'Type stored in NNXMeta {x.var_type} != type inferred from collection name {vtype}' - return x.var_type(x.value, **x.metadata) + return x.to_nnx_variable() if isinstance(x, meta.AxisMetadata): x_metadata = vars(x) if hasattr(x, 'to_nnx_metadata'): x_metadata = x.to_nnx_metadata() assert hasattr(x, 'value') return vtype(**x_metadata, linen_meta_type=type(x)) - return vtype(x) \ No newline at end of file + return vtype(x) + + +def _recursive_merge(dict1, dict2): + """Recursively merge two dicts.""" + flat_map = traversals.flatten_mapping(dict1) + flat_map |= traversals.flatten_mapping(dict2) + return traversals.unflatten_mapping(flat_map) + + +def linen_vars_to_nnx_attrs(variables: tp.Mapping[str, Any]) -> dict[str, Any]: + nnx_vars = jax.tree_util.tree_map_with_path( + lambda kp, x: to_nnx_var(get_col_name(kp), x), + variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata)) + nnx_attrs: dict[str, Any] = defaultdict(dict) + for _, col_tree in nnx_vars.items(): + assert isinstance(col_tree, dict) + for attr_name, value in col_tree.items(): + assert isinstance(attr_name, str) + if isinstance(value, tp.Mapping): # it's a sublayer + nnx_attrs[attr_name] = _recursive_merge(nnx_attrs[attr_name], value) + else: + nnx_attrs[attr_name] = value # it's a variable on this layer + return nnx_attrs + + +def nnx_attrs_to_linen_vars(nnx_attrs: dict) -> dict: + linen_structured = {} + for kp, v in traversals.flatten_mapping( + nnx_attrs, + is_leaf=lambda _, x: isinstance(x, variableslib.Variable | GraphDef), + ).items(): + if isinstance(v, variableslib.Variable): + col_name = variable_type_name(type(v)) + else: + col_name = 'nnx' # it must be an nnx.GraphDef, for some ToLinen submodule + linen_structured[(col_name, *kp)] = v + variables = traversals.unflatten_mapping(linen_structured) + variables = jax.tree.map(lambda x: to_linen_var(x.to_state()), + variables, + is_leaf=lambda x: isinstance(x, variableslib.Variable)) + return variables + diff --git a/flax/nnx/bridge/wrappers.py b/flax/nnx/bridge/wrappers.py index d209d89819..19c468afd3 100644 --- a/flax/nnx/bridge/wrappers.py +++ b/flax/nnx/bridge/wrappers.py @@ -104,7 +104,7 @@ class ToNNX(Module): >>> model = nnx.bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) >>> # Like Linen apply(), but using NNX's direct call method >>> y = model(x) - >>> nnx.state(model).params.kernel.value.shape + >>> model.kernel.shape (32, 64) Args: @@ -121,7 +121,7 @@ def __init__( ): self.module = module self.rngs = rngs - self.linen_collections: tuple[str, ...] = () + self.linen_attributes: tuple[str, ...] = () def lazy_init(self, *args, **kwargs): """A shortcut of calling `nnx.bridge.lazy_init()` upon this module.""" @@ -146,20 +146,17 @@ def __call__( _rngs['params'] = _rngs.pop('default') out, variables = self.module.init_with_output(_rngs, *args, method=method, **kwargs) - nnx_vars = jtu.tree_map_with_path( - lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x), - variables, is_leaf=lambda x: isinstance(x, meta.AxisMetadata)) - linen_collections = set() - for col, tree in nnx_vars.items(): - setattr(self, col, tree) - linen_collections.add(col) - self.linen_collections = tuple(linen_collections) # make it hashable + nnx_attrs = bv.linen_vars_to_nnx_attrs(variables) + linen_attributes = set() + for attr_name, value in nnx_attrs.items(): + setattr(self, attr_name, value) + linen_attributes.add(attr_name) + self.linen_attributes = tuple(linen_attributes) # make it hashable else: - variables = {col: jax.tree.map(lambda x: bv.to_linen_var(x.to_state()), - getattr(self, col), - is_leaf=lambda x: isinstance(x, nnx.Variable)) - for col in self.linen_collections} + nnx_attrs = {name: getattr(self, name) for name in self.linen_attributes} + variables = bv.nnx_attrs_to_linen_vars(nnx_attrs) + _rngs = ( {name: stream() for name, stream in rngs.items()} if rngs else {} ) @@ -168,11 +165,13 @@ def __call__( # Split out the updates if `mutable` is passed into the Flax module if kwargs.get('mutable', False) != False: out, updates = out - updates = jtu.tree_map_with_path( - lambda kp, x: bv.to_nnx_var(bv.get_col_name(kp), x), - updates, is_leaf=lambda x: isinstance(x, meta.AxisMetadata)) - for collection, value in updates.items(): - setattr(self, collection, value) + nnx_attrs = bv.linen_vars_to_nnx_attrs(updates) + for attr_name, value in nnx_attrs.items(): + if hasattr(self, attr_name) and isinstance(value, dict): + original_tree = getattr(self, attr_name) + setattr(self, attr_name, original_tree | value) + else: + setattr(self, attr_name, value) return out @@ -202,7 +201,7 @@ class ToLinen(linen.Module): >>> y, variables = model.init_with_output(jax.random.key(0), x) >>> y.shape (1, 64) - >>> variables['params']['kernel'].value.shape + >>> variables['params']['kernel'].shape (32, 64) >>> # The static GraphDef of the underlying NNX module >>> variables.keys() diff --git a/tests/nnx/bridge/wrappers_test.py b/tests/nnx/bridge/wrappers_test.py index 27f2927fd9..5b65603a24 100644 --- a/tests/nnx/bridge/wrappers_test.py +++ b/tests/nnx/bridge/wrappers_test.py @@ -50,11 +50,11 @@ def test_linen_to_nnx(self): model = bridge.ToNNX(linen_module, rngs=nnx.Rngs(0)).lazy_init(x) # like linen init y = model(x) # like linen apply assert y.shape == (1, 64) - self.assertIsInstance(model.params['kernel'], nnx.Variable) + self.assertIsInstance(model.kernel, nnx.Variable) # NNX automatically adds metadata box regardless of original Linen module. linen_vars = linen_module.init(jax.random.key(0), x) np.testing.assert_array_equal(linen_vars['params']['kernel'], - model.params['kernel'].value) + model.kernel.value) def test_linen_to_nnx_submodule(self): class NNXOuter(nnx.Module): @@ -74,10 +74,11 @@ def __call__(self, x): bridge.lazy_init(model, x) gdef_full, state = nnx.split(model) assert gdef_before_lazy_init != gdef_full - assert 'params' in state.nn_dense1 - assert 'batch_stats' in state.batchnorm + assert 'nn_dense1' in state + assert 'batchnorm' in state + assert 'kernel' in state.nn_dense1 y = model(x) - k, b = state.nn_dense1.params.kernel.value, state.b.value + k, b = state['nn_dense1']['kernel'].value, state['b'].value np.testing.assert_allclose(y, x @ k + b, rtol=1e-5) assert gdef_full == nnx.graphdef(model) # static data is stable now @@ -97,7 +98,7 @@ def dot(self, x): model = bridge.ToNNX(Foo(), rngs=nnx.Rngs(0)) bridge.lazy_init(model, x, method=model.module.dot) y = model(x, method=model.module.dot) - np.testing.assert_allclose(y, x @ nnx.state(model).params.w.value) + np.testing.assert_allclose(y, x @ nnx.state(model).w.value) # lazy_init only initialized param w inside dot(), so calling __call__ should fail with self.assertRaises(flax.errors.ScopeParamNotFoundError): y = model(x) @@ -114,9 +115,9 @@ def __call__(self, x): x = lambda: jnp.zeros((), jnp.int32) model = bridge.ToNNX(Foo(), rngs=nnx.Rngs(0)).lazy_init(x) - assert nnx.state(model).counter.count.value == 0 + self.assertEqual(nnx.state(model).count.value, 0) y = model(x, mutable=True) - assert nnx.state(model).counter.count.value == 1 + self.assertEqual(nnx.state(model).count.value, 1) def test_linen_to_nnx_transform(self): class NNXOuter(nnx.Module): @@ -137,8 +138,8 @@ def vmap_fn(inner, x): model = NNXOuter(3, rngs=nnx.Rngs(0)) nnx.bridge.lazy_init(model, x) - self.assertEqual(model.inner.params['kernel'].shape, (5, 4, 3)) - self.assertEqual(model.inner.params['bias'].shape, (5, 3)) + self.assertEqual(model.inner.kernel.shape, (5, 4, 3)) + self.assertEqual(model.inner.bias.shape, (5, 3)) def test_linen_to_nnx_metadata(self): linen_module = nn.Dense( @@ -163,17 +164,59 @@ def create_sharded_nnx_module(x): # nn.Partitioned metadata boxes translated into valid nnx.Variable boxes. self.assertIsInstance(linen_vars['params']['kernel'], nn.Partitioned) self.assertIsInstance(linen_vars['params']['bias'], nn.LogicallyPartitioned) - self.assertIsInstance(nnx_model.params['kernel'], nnx.Variable) - assert nnx_model.params['kernel'].sharding == ('in', 'out') - assert nnx_model.params['kernel'].value.sharding.is_equivalent_to( + self.assertIsInstance(nnx_model.kernel, nnx.Variable) + assert nnx_model.kernel.sharding == ('in', 'out') + assert nnx_model.kernel.value.sharding.is_equivalent_to( jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('in', 'out')), ndim=2) - assert nnx_model.params['bias'].sharding == ('out-alias',) - assert nnx_model.params['bias'].sharding_rules == (('out-alias', 'out'),) - assert nnx_model.params['bias'].value.sharding.is_equivalent_to( + assert nnx_model.bias.sharding == ('out-alias',) + assert nnx_model.bias.sharding_rules == (('out-alias', 'out'),) + assert nnx_model.bias.value.sharding.is_equivalent_to( jax.sharding.NamedSharding(self.mesh, jax.sharding.PartitionSpec('out',)), ndim=1) + def test_linen_to_nnx_state_structure_consistency(self): + class LinenInner(nn.Module): + dout: int + @nn.compact + def __call__(self, x): + w = self.param('w', nn.initializers.lecun_normal(), (x.shape[-1], self.dout)) + return nn.Dropout(rate=0.5, deterministic=False)(x @ w) + + class LinenMiddle(nn.Module): + dout: int + @nn.compact + def __call__(self, x): + dot = LinenInner(self.dout, name='dot') + b = self.variable('bias', 'b', nn.initializers.zeros_init(), None, (1, self.dout)) + return dot(x) + b.value + + class Bias(nnx.Variable): pass + nnx.register_variable_name_type_pair('bias', Bias) + class NNXMiddle(nnx.Module): + def __init__(self, dout: int, *, rngs: nnx.Rngs): + self.dot = bridge.ToNNX(LinenInner(dout), rngs=rngs) + self.b = Bias(nnx.initializers.zeros_init()(rngs.params(), (1, dout))) + def __call__(self, x): + return self.dot(x) + self.b + + x = jax.random.normal(jax.random.key(42), (2, 4)) + from_top = bridge.lazy_init( + bridge.ToNNX(LinenMiddle(dout=3), rngs=nnx.Rngs(0, dropout=1)), x) + from_middle = bridge.lazy_init( + NNXMiddle(dout=3, rngs=nnx.Rngs(0, dropout=1)), x) + + # Remove the NNX-module-local RNG states, which will be different + # because the NNX modules are on different level + def get_weights(model): + return nnx.split(model, nnx.RngCount, nnx.RngKey, ...)[3] + from_top_weights = get_weights(from_top) + from_middle_weights = get_weights(from_middle) + + # Confirm the rest of the state has the same structure. + self.assertEqual(jax.tree.structure(from_top_weights), + jax.tree.structure(from_middle_weights)) + ################## ### NNXToLinen ### ################## @@ -183,7 +226,7 @@ def test_nnx_to_linen(self): x = jax.numpy.ones((1, 32)) y, variables = model.init_with_output(jax.random.key(0), x) assert y.shape == (1, 64) - np.testing.assert_allclose(y, x @ variables['params']['kernel'].value) + np.testing.assert_allclose(y, x @ variables['params']['kernel']) assert 'nnx' in variables assert isinstance(variables['nnx']['graphdef'], nnx.GraphDef) @@ -238,10 +281,10 @@ def __call__(self): model = bridge.ToLinen(Counter, skip_rng=True) variables = model.init(jax.random.key(0)) - assert variables['Count']['count'].value == 0 + assert variables['Count']['count'] == 0 _, updates = model.apply(variables, mutable='Count') - assert updates['Count']['count'].value == 1 + assert updates['Count']['count'] == 1 _ = model.apply(variables | updates) def test_nnx_to_linen_mutated_static_data(self): @@ -257,19 +300,19 @@ def __call__(self): model = bridge.ToLinen(Counter, skip_rng=True) variables = model.init(jax.random.key(0)) - assert variables['Count']['count'].value == 0 + assert variables['Count']['count'] == 0 # This does not work, because the __call__ also changes the static data of the model. _, updates = model.apply(variables, mutable='Count') - assert updates['Count']['count'].value == 1 - assert updates['Count']['count_nonzero'].value == 1 + assert updates['Count']['count'] == 1 + assert updates['Count']['count_nonzero'] == 1 with self.assertRaises(ValueError): _ = model.apply(variables | updates) # This makes sure the static data is updated too. Using mutable=True also works. _, updates = model.apply(variables, mutable=['Count', 'nnx']) - assert updates['Count']['count'].value == 1 - assert updates['Count']['count_nonzero'].value == 1 + assert updates['Count']['count'] == 1 + assert updates['Count']['count_nonzero'] == 1 _ = model.apply(variables | updates) def test_nnx_to_linen_transforms(self): @@ -288,7 +331,7 @@ def __call__(self, x): x = jax.random.normal(xkey, (2, 4)) model = LinenOuter(dout=3) y, var = model.init_with_output(pkey, x) - k = var['params']['VmapToLinen_0']['kernel'].value + k = var['params']['VmapToLinen_0']['kernel'] assert k.shape == (2, 4, 3) np.testing.assert_allclose(y, jnp.einsum('ab,abc->ac', x, k)) assert 'nnx' in var @@ -302,12 +345,60 @@ def test_nnx_to_linen_metadata(self): assert y.shape == (1, 64) self.assertIsInstance(variables['params']['kernel'], nnx.bridge.NNXMeta) assert variables['params']['kernel'].metadata['sharding'] == ('in', 'out') + self.assertEqual(nn.get_partition_spec(variables)['params']['kernel'], + jax.sharding.PartitionSpec('in', 'out')) np.testing.assert_allclose(y, x @ variables['params']['kernel'].value) def test_nnx_to_linen_metadata_transform(self): # TODO: add support and testing after axis add/remove in transform is fixed. pass + def test_nnx_to_linen_pytree_structure_consistency(self): + class NNXInner(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.w = nnx.Param(nnx.initializers.lecun_normal()(rngs.params(), (din, dout))) + self.dropout = nnx.Dropout(rate=0.5, rngs=rngs) + def __call__(self, x): + return self.dropout(x @ self.w) + + class Bias(nnx.Variable): pass + nnx.register_variable_name_type_pair('bias', Bias, overwrite=True) + class NNXMiddle(nnx.Module): + def __init__(self, din: int, dout: int, *, rngs: nnx.Rngs): + self.dot = NNXInner(din, dout, rngs=rngs) + self.b = Bias(nnx.initializers.zeros_init()(rngs.params(), (1, dout))) + def __call__(self, x): + return self.dot(x) + self.b + + class LinenMiddle(nn.Module): + dout: int + @nn.compact + def __call__(self, x): + dot = bridge.to_linen(NNXInner, x.shape[-1], self.dout, name='dot') + b = self.variable('bias', 'b', nn.initializers.zeros_init(), None, (1, self.dout)) + return dot(x) + b.value + + x = jax.random.normal(jax.random.key(42), (2, 4)) + keys = {'params': jax.random.key(0), 'dropout': jax.random.key(1)} + from_top = bridge.to_linen(NNXMiddle, din=4, dout=3).init(keys, x) + from_middle = LinenMiddle(dout=3).init(keys, x) + + # Remove the NNX-module-local RNG states, which will be different + # because the NNX modules are on different level + def get_weights(variables): + non_rngs = {} + for kp, v in flax.traverse_util.flatten_dict(variables).items(): + if 'rngs' not in kp and 'nnx' not in kp: + non_rngs[kp] = v + return flax.traverse_util.unflatten_dict(non_rngs) + from_top_weights = get_weights(from_top) + from_middle_weights = get_weights(from_middle) + + # Confirm the rest of the state has the same structure. + self.assertEqual(jax.tree.structure(from_top_weights), + jax.tree.structure(from_middle_weights)) + + ############################ ### Hybrid mix-and-match ### ############################ @@ -355,7 +446,7 @@ def __call__(self, x): # Test the param value with disabled dropout model = bridge.lazy_init(NNXOuter(dout=3, dropout_rate=0., rngs=nnx.Rngs(default=1, dropout=2)), x) - w, b = model.inner.params['dot']['w'], model.inner.params['b'] + w, b = model.inner.dot['w'], model.inner.b self.assertIsInstance(w, nnx.Param) np.testing.assert_allclose(model(x), x @ w + b) assert hasattr(w, 'sharding') and w.sharding == ('in', 'out') diff --git a/tests/nnx/graph_utils_test.py b/tests/nnx/graph_utils_test.py index 57b0f2e3c1..9fbe7548da 100644 --- a/tests/nnx/graph_utils_test.py +++ b/tests/nnx/graph_utils_test.py @@ -510,8 +510,8 @@ def vmap_fn(inner, x): model = NNXOuter(3, rngs=nnx.Rngs(0)) nnx.bridge.lazy_init(model, x) - self.assertEqual(model.inner.params['kernel'].shape, (5, 4, 3)) - self.assertEqual(model.inner.params['bias'].shape, (5, 3)) + self.assertEqual(model.inner.kernel.shape, (5, 4, 3)) + self.assertEqual(model.inner.bias.shape, (5, 3)) def test_split_merge_context(self): m = nnx.Linear(2, 3, rngs=nnx.Rngs(0))