diff --git a/jax/_src/api.py b/jax/_src/api.py index de819fb50bc1..66095a19f1ac 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -2518,7 +2518,7 @@ def _sds_aval_mapping(x): @api_boundary -def eval_shape(fun: Callable, *args, **kwargs): +def eval_shape(fun: Callable[_P, Any], *args: _P.args, **kwargs: _P.kwargs): """Compute the shape/dtype of ``fun`` without any FLOPs. This utility function is useful for performing shape inference. Its