diff --git a/flax/linen/module.py b/flax/linen/module.py index a9d802d7d0..f4b6a3604e 100644 --- a/flax/linen/module.py +++ b/flax/linen/module.py @@ -1666,51 +1666,6 @@ def __call__(self, x): module = self.clone() return module, variables - @overload - def apply( - self, - variables: VariableDict, - *args, - rngs: Optional[RNGSequences] = None, - method: Union[Callable[..., Any], str, None] = None, - mutable: Literal[False], - capture_intermediates: Union[ - bool, Callable[['Module', str], bool] - ] = False, - **kwargs, - ) -> Any: - ... - - @overload - def apply( - self, - variables: VariableDict, - *args, - rngs: Optional[RNGSequences] = None, - method: Union[Callable[..., Any], str, None] = None, - mutable: CollectionFilter, - capture_intermediates: Union[ - bool, Callable[['Module', str], bool] - ] = False, - **kwargs, - ) -> Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]: - ... - - @overload - def apply( - self, - variables: VariableDict, - *args, - rngs: Optional[RNGSequences] = None, - method: Union[Callable[..., Any], str, None] = None, - mutable: CollectionFilter = False, - capture_intermediates: Union[ - bool, Callable[['Module', str], bool] - ] = False, - **kwargs, - ) -> Union[Any, Tuple[Any, Union[FrozenVariableDict, Dict[str, Any]]]]: - ... - @traceback_util.api_boundary def apply( self,