From 8525a9b7eabcd818caab2840092f9608236d8f21 Mon Sep 17 00:00:00 2001 From: Cristian Garcia Date: Tue, 1 Aug 2023 20:48:57 +0000 Subject: [PATCH] add typing overloads for variable --- flax/core/scope.py | 45 +++++++++++++++++++++++++++++++++++++++++- flax/linen/module.py | 47 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 89 insertions(+), 3 deletions(-) diff --git a/flax/core/scope.py b/flax/core/scope.py index a1844f4d77..910fd42c32 100644 --- a/flax/core/scope.py +++ b/flax/core/scope.py @@ -818,6 +818,49 @@ def put(target, key, val): put(variables, name, value) + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + ) -> Variable[T]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: Literal[True], + ) -> Variable[T]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: Literal[False], + ) -> Variable[meta.AxisMetadata[T]]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: bool = True, + ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: + ... + def variable( self, col: str, @@ -825,7 +868,7 @@ def variable( init_fn: Optional[Callable[..., T]] = None, *init_args, unbox: bool = True, - ) -> Variable[T]: + ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: """Creates a variable if it doesn't exist yet in this scope and returns it. Args: diff --git a/flax/linen/module.py b/flax/linen/module.py index 6f4554d482..a9d802d7d0 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -1330,14 +1330,57 @@ def clone_fn(m: Module) -> Module: return module + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + ) -> Variable[T]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: Literal[True], + ) -> Variable[T]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: Literal[False], + ) -> Variable[meta.AxisMetadata[T]]: + ... + + @overload + def variable( + self, + col: str, + name: str, + init_fn: Optional[Callable[..., T]] = None, + *init_args, + unbox: bool = True, + ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: + ... + def variable( self, col: str, name: str, - init_fn: Optional[Callable[..., Any]] = None, + init_fn: Optional[Callable[..., T]] = None, *init_args, unbox: bool = True, - ) -> Variable: + ) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]: """Declares and returns a variable in this Module. See :mod:`flax.core.variables` for more information. See also :meth:`param`