Skip to content

Commit

Permalink
add typing overloads for variable
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Aug 1, 2023
1 parent 1e2d1c7 commit 8525a9b
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 3 deletions.
45 changes: 44 additions & 1 deletion flax/core/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,14 +818,57 @@ def put(target, key, val):

put(variables, name, value)

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
) -> Variable[T]:
...

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: Literal[True],
) -> Variable[T]:
...

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: Literal[False],
) -> Variable[meta.AxisMetadata[T]]:
...

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: bool = True,
) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]:
...

def variable(
self,
col: str,
name: str, # pylint: disable=keyword-arg-before-vararg
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: bool = True,
) -> Variable[T]:
) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]:
"""Creates a variable if it doesn't exist yet in this scope and returns it.
Args:
Expand Down
47 changes: 45 additions & 2 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1330,14 +1330,57 @@ def clone_fn(m: Module) -> Module:

return module

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
) -> Variable[T]:
...

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: Literal[True],
) -> Variable[T]:
...

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: Literal[False],
) -> Variable[meta.AxisMetadata[T]]:
...

@overload
def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: bool = True,
) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]:
...

def variable(
self,
col: str,
name: str,
init_fn: Optional[Callable[..., Any]] = None,
init_fn: Optional[Callable[..., T]] = None,
*init_args,
unbox: bool = True,
) -> Variable:
) -> Union[Variable[T], Variable[meta.AxisMetadata[T]]]:
"""Declares and returns a variable in this Module.
See :mod:`flax.core.variables` for more information. See also :meth:`param`
Expand Down

0 comments on commit 8525a9b

Please sign in to comment.