Skip to content

Commit

Permalink
Add type hints to methods of JitWrapped
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Sep 18, 2024
1 parent ae53c9a commit a58adb8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,11 +812,11 @@ def ax_leaf(l):

class JitWrapped(stages.Wrapped[_P, _OutT]):

def eval_shape(self, *args, **kwargs):
def eval_shape(self, *args: _P.args, **kwargs: _P.kwargs):
"""See ``jax.eval_shape``."""
raise NotImplementedError

def trace(self, *args, **kwargs) -> stages.Traced:
def trace(self, *args: _P.args, **kwargs: _P.kwargs) -> stages.Traced:
raise NotImplementedError


Expand Down

0 comments on commit a58adb8

Please sign in to comment.