diff --git a/flax/core/scope.py b/flax/core/scope.py index a1844f4d77..910fd42c32 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -818,6 +818,49 @@ def put(target, key, val): put(variables, name, value) + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + ) -> Variable[T]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: Literal[True], + ) -> Variable[T]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: Literal[False], + ) -> Variable[meta.AxisMetadata[T]]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: bool = True, + ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: + ... + def variable( self, col: str, @@ -825,7 +868,7 @@ def variable( init_fn: Optional[Callable[..., T]] = None, *init_args, unbox: bool = True, - ) -> Variable[T]: + ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: """Creates a variable if it doesn't exist yet in this scope and returns it. Args: diff --git a/flax/linen/module.py b/flax/linen/module.py index 6f4554d482..a9d802d7d0 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -1330,14 +1330,57 @@ def clone_fn(m: Module) -> Module: return module + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + ) -> Variable[T]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: Literal[True], + ) -> Variable[T]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: Literal[False], + ) -> Variable[meta.AxisMetadata[T]]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: bool = True, + ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: + ... + def variable( self, col: str, name: str, - init_fn: Optional[Callable[..., Any]] = None, + init_fn: Optional[Callable[..., T]] = None, *init_args, unbox: bool = True, - ) -> Variable: + ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: """Declares and returns a variable in this Module. See :mod:`flax.core.variables` for more information. See also :meth:`param`