Skip to content

Commit

Permalink
add overload signatures for Module.apply
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jul 31, 2023
1 parent d69d106 commit 1e2d1c7
Showing 1 changed file with 45 additions and 0 deletions.
45 changes: 45 additions & 0 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -1623,6 +1623,51 @@ def __call__(self, x):
module = self.clone()
return module, variables

@overload
def apply(
self,
variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
method: Union[Callable[..., Any], str, None] = None,
mutable: Literal[False],
capture_intermediates: Union[
bool, Callable[['Module', str], bool]
] = False,
**kwargs,
) -> Any:
...

@overload
def apply(
self,
variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
method: Union[Callable[..., Any], str, None] = None,
mutable: CollectionFilter,
capture_intermediates: Union[
bool, Callable[['Module', str], bool]
] = False,
**kwargs,
) -> Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]:
...

@overload
def apply(
self,
variables: VariableDict,
*args,
rngs: Optional[RNGSequences] = None,
method: Union[Callable[..., Any], str, None] = None,
mutable: CollectionFilter = False,
capture_intermediates: Union[
bool, Callable[['Module', str], bool]
] = False,
**kwargs,
) -> Union[Any, Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]]:
...

@traceback_util.api_boundary
def apply(
self,
Expand Down

0 comments on commit 1e2d1c7

Please sign in to comment.