diff --git a/docs/guides/flax_basics.ipynb b/docs/guides/flax_basics.ipynb index b76289be82..73a57785a8 100644 --- a/docs/guides/flax_basics.ipynb +++ b/docs/guides/flax_basics.ipynb @@ -68,7 +68,7 @@ "source": [ "import jax\n", "from typing import Any, Callable, Sequence\n", - "from jax import lax, random, numpy as jnp\n", + "from jax import random, numpy as jnp\n", "import flax\n", "from flax import linen as nn" ] @@ -775,8 +775,7 @@ " kernel = self.param('kernel',\n", " self.kernel_init, # Initialization function\n", " (inputs.shape[-1], self.features)) # shape info.\n", - " y = lax.dot_general(inputs, kernel,\n", - " (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot?\n", + " y = jnp.dot(inputs, kernel)\n", " bias = self.param('bias', self.bias_init, (self.features,))\n", " y = y + bias\n", " return y\n", @@ -866,7 +865,6 @@ " ra_mean = self.variable('batch_stats', 'mean',\n", " lambda s: jnp.zeros(s),\n", " x.shape[1:])\n", - " mean = ra_mean.value # This will either get the value or trigger init\n", " bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:])\n", " if is_initialized:\n", " ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)\n", diff --git a/docs/guides/flax_basics.md b/docs/guides/flax_basics.md index 1df7944414..fed965fed4 100644 --- a/docs/guides/flax_basics.md +++ b/docs/guides/flax_basics.md @@ -45,7 +45,7 @@ Here we provide the code needed to set up the environment for our notebook. import jax from typing import Any, Callable, Sequence -from jax import lax, random, numpy as jnp +from jax import random, numpy as jnp import flax from flax import linen as nn ``` @@ -394,8 +394,7 @@ class SimpleDense(nn.Module): kernel = self.param('kernel', self.kernel_init, # Initialization function (inputs.shape[-1], self.features)) # shape info. - y = lax.dot_general(inputs, kernel, - (((inputs.ndim - 1,), (0,)), ((), ())),) # TODO Why not jnp.dot? + y = jnp.dot(inputs, kernel) bias = self.param('bias', self.bias_init, (self.features,)) y = y + bias return y @@ -448,7 +447,6 @@ class BiasAdderWithRunningMean(nn.Module): ra_mean = self.variable('batch_stats', 'mean', lambda s: jnp.zeros(s), x.shape[1:]) - mean = ra_mean.value # This will either get the value or trigger init bias = self.param('bias', lambda rng, shape: jnp.zeros(shape), x.shape[1:]) if is_initialized: ra_mean.value = self.decay * ra_mean.value + (1.0 - self.decay) * jnp.mean(x, axis=0, keepdims=True)