Skip to content

Commit

Permalink
Also improve jax.eval_shape while we're at it
Browse files Browse the repository at this point in the history
Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Sep 18, 2024
1 parent 2035a15 commit ae53c9a
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2518,7 +2518,7 @@ def _sds_aval_mapping(x):


@api_boundary
def eval_shape(fun: Callable, *args, **kwargs):
def eval_shape(fun: Callable[_P, Any], *args: _P.args, **kwargs: _P.kwargs):
"""Compute the shape/dtype of ``fun`` without any FLOPs.
This utility function is useful for performing shape inference. Its
Expand Down

0 comments on commit ae53c9a

Please sign in to comment.