diff --git a/flax/linen/module.py b/flax/linen/module.py index 2d1e7e2a52..6f4554d482 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -1623,6 +1623,51 @@ def __call__(self, x): module = self.clone() return module, variables + @overload + def apply( + self, + variables: VariableDict, + *args, + rngs: Optional[RNGSequences] = None, + method: Union[Callable[..., Any], str, None] = None, + mutable: Literal[False], + capture_intermediates: Union[ + bool, Callable[['Module', str], bool] + ] = False, + **kwargs, + ) -> Any: + ... + + @overload + def apply( + self, + variables: VariableDict, + *args, + rngs: Optional[RNGSequences] = None, + method: Union[Callable[..., Any], str, None] = None, + mutable: CollectionFilter, + capture_intermediates: Union[ + bool, Callable[['Module', str], bool] + ] = False, + **kwargs, + ) -> Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]: + ... + + @overload + def apply( + self, + variables: VariableDict, + *args, + rngs: Optional[RNGSequences] = None, + method: Union[Callable[..., Any], str, None] = None, + mutable: CollectionFilter = False, + capture_intermediates: Union[ + bool, Callable[['Module', str], bool] + ] = False, + **kwargs, + ) -> Union[Any, Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]]: + ... + @traceback_util.api_boundary def apply( self,