diff --git a/docs/guides/haiku_migration_guide.rst b/docs/guides/haiku_migration_guide.rst index 58f14380ea..0affc4be92 100644 --- a/docs/guides/haiku_migration_guide.rst +++ b/docs/guides/haiku_migration_guide.rst @@ -891,7 +891,7 @@ The Flax team recommends a more Module-centric approach that uses `__call__` to model = hk.transform_with_state(forward) - params, state = model.init(PRNGKey(0), jax.numpy.ones((1, 64))) + params, state = model.init(random.key(0), jax.numpy.ones((1, 64))) --- @@ -906,5 +906,5 @@ The Flax team recommends a more Module-centric approach that uses `__call__` to return output model = FooModule() - variables = model.init(PRNGKey(0), jax.numpy.ones((1, 64))) + variables = model.init(random.key(0), jax.numpy.ones((1, 64))) params, counter = variables['params'], variables['counter'] \ No newline at end of file