From a79de59f379e027fc1f8ef66bff630ca79f5e3c3 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 18 Sep 2024 09:56:15 -0400 Subject: [PATCH] Add type hints to methods of JitWrapped Signed-off-by: Fabrice Normandin --- jax/_src/pjit.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index fe40462db5d3..eccd94a864d3 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -812,11 +812,11 @@ def ax_leaf(l): class JitWrapped(stages.Wrapped[_P, _OutT]): - def eval_shape(self, *args, **kwargs): + def eval_shape(self, *args: _P.args, **kwargs: _P.kwargs): """See ``jax.eval_shape``.""" raise NotImplementedError - def trace(self, *args, **kwargs) -> stages.Traced: + def trace(self, *args: _P.args, **kwargs: _P.kwargs) -> stages.Traced: raise NotImplementedError