From d5845516b79f0ae7254665b4137dcb3b06f1a5fb Mon Sep 17 00:00:00 2001 From: Flax Team Date: Sun, 22 Sep 2024 13:59:04 -0700 Subject: [PATCH] Minor documentation fixes for AxisMetadata. PiperOrigin-RevId: 677531963 --- flax/core/meta.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/flax/core/meta.py b/flax/core/meta.py index eca56ffb7c..278e5b51a0 100644 --- a/flax/core/meta.py +++ b/flax/core/meta.py @@ -59,13 +59,13 @@ class AxisMetadata(Generic[A], metaclass=abc.ABCMeta): def unbox(self) -> A: """Returns the content of the AxisMetadata box. - Note that unlike ``meta.unbox`` the unbox call should recursively unbox + Note that unlike ``meta.unbox`` the unbox call should not recursively unbox metadata. It should simply return value that it wraps directly even if that value itself is an instance of AxisMetadata. In practise, AxisMetadata subclasses should be registered as PyTree nodes to support passing instances to JAX and Flax APIs. The leaves returned for this - note should correspond to the value returned by unbox. + node should correspond to the value returned by unbox. Returns: The unboxed value.