From 1e706c24a14a87ed9971473721459fbf025d4ad3 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 18 Sep 2024 09:47:27 -0400 Subject: [PATCH] Improve typing of `jax.jit` - Fix for #23719 Signed-off-by: Fabrice Normandin --- jax/_src/api.py | 4 ++-- jax/_src/lax/linalg.py | 2 +- jax/_src/numpy/array_methods.py | 4 ++-- jax/_src/numpy/lax_numpy.py | 10 +++++----- jax/_src/numpy/ufunc_api.py | 2 +- jax/_src/pjit.py | 2 +- jax/_src/scipy/cluster/vq.py | 2 +- jax/_src/scipy/linalg.py | 4 ++-- jax/_src/scipy/special.py | 2 +- jax/_src/stages.py | 12 +++++++----- jax/_src/third_party/scipy/special.py | 4 ++-- 11 files changed, 25 insertions(+), 23 deletions(-) diff --git a/jax/_src/api.py b/jax/_src/api.py index bd8a951954ac..053ff26d933e 100644 --- a/jax/_src/api.py +++ b/jax/_src/api.py @@ -139,7 +139,7 @@ def _update_debug_special_thread_local(_): def jit( - fun: Callable, + fun: Callable[stages._P, stages._OutT], in_shardings=sharding_impls.UNSPECIFIED, out_shardings=sharding_impls.UNSPECIFIED, static_argnums: int | Sequence[int] | None = None, @@ -151,7 +151,7 @@ def jit( backend: str | None = None, inline: bool = False, abstracted_axes: Any | None = None, -) -> pjit.JitWrapped: +) -> pjit.JitWrapped[stages._P, stages._OutT]: """Sets up ``fun`` for just-in-time compilation with XLA. Args: diff --git a/jax/_src/lax/linalg.py b/jax/_src/lax/linalg.py index ef6a5a11a56e..bbef2c870c7c 100644 --- a/jax/_src/lax/linalg.py +++ b/jax/_src/lax/linalg.py @@ -1649,7 +1649,7 @@ def _lu_solve(lu: Array, permutation: Array, b: Array, trans: int) -> Array: def lu_solve(lu: ArrayLike, permutation: ArrayLike, b: ArrayLike, trans: int = 0) -> Array: """LU solve with broadcasting.""" - return _lu_solve(lu, permutation, b, trans) + return _lu_solve(lu, permutation, b, trans) # type: ignore[arg-type] # QR decomposition diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 547fe1247459..bb198bdbc677 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -168,7 +168,7 @@ def _cumprod(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None Refer to :func:`jax.numpy.cumprod` for the full documentation. """ - return reductions.cumprod(self, axis=axis, dtype=dtype, out=out) + return reductions.cumprod(self, axis=axis, dtype=dtype, out=out) # type: ignore[arg-type] def _cumsum(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None, out: None = None) -> Array: @@ -176,7 +176,7 @@ def _cumsum(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = Refer to :func:`jax.numpy.cumsum` for the full documentation. """ - return reductions.cumsum(self, axis=axis, dtype=dtype, out=out) + return reductions.cumsum(self, axis=axis, dtype=dtype, out=out) # type: ignore[arg-type] def _diagonal(self: Array, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: """Return the specified diagonal from the array. diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index b71412e586bd..6ba4e41245f6 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -1124,7 +1124,7 @@ def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: [6, 5]]], dtype=int32) """ util.check_arraylike("flip", m) - return _flip(asarray(m), reductions._ensure_optional_axes(axis)) + return _flip(asarray(m), reductions._ensure_optional_axes(axis)) # type: ignore[arg-type] @partial(jit, static_argnames=('axis',)) def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array: @@ -1982,7 +1982,7 @@ def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array: Array([0, 1, 2], dtype=int32) """ util.check_arraylike("squeeze", a) - return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None) + return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None) # type: ignore[arg-type] @partial(jit, static_argnames=('axis',), inline=True) def _squeeze(a: Array, axis: tuple[int, ...]) -> Array: @@ -7259,7 +7259,7 @@ def delete( obj = asarray(obj).ravel() obj = clip(where(obj < 0, obj + a.shape[axis], obj), 0, a.shape[axis]) obj = sort(obj) - obj -= arange(len(obj)) # type: ignore[arg-type,operator] + obj -= arange(len(obj)) i = arange(a.shape[axis] - obj.size) i += (i[None, :] >= obj[:, None]).sum(0) return a[(slice(None),) * axis + (i,)] @@ -11034,8 +11034,8 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False, kwds: dict[str, str] = {} if method is None else {'method': method} return where( bins_arr[-1] >= bins_arr[0], - searchsorted(bins_arr, x, side=side, **kwds), - bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, **kwds) + searchsorted(bins_arr, x, side=side, **kwds), # type: ignore[arg-type] + bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, **kwds) # type: ignore[arg-type] ) diff --git a/jax/_src/numpy/ufunc_api.py b/jax/_src/numpy/ufunc_api.py index 3473e8a7468a..96ced5abb45d 100644 --- a/jax/_src/numpy/ufunc_api.py +++ b/jax/_src/numpy/ufunc_api.py @@ -174,7 +174,7 @@ def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> An if where is not None: raise NotImplementedError(f"where argument of {self}") call = self.__static_props['call'] or self._call_vectorized - return call(*args) + return call(*args) # type: ignore[arg-type] @partial(jax.jit, static_argnames=['self']) def _call_vectorized(self, *args): diff --git a/jax/_src/pjit.py b/jax/_src/pjit.py index ac1318ed7810..f940dce14139 100644 --- a/jax/_src/pjit.py +++ b/jax/_src/pjit.py @@ -808,7 +808,7 @@ def ax_leaf(l): return broadcast_prefix(abstracted_axes, args, ax_leaf) -class JitWrapped(stages.Wrapped): +class JitWrapped(stages.Wrapped[stages._P, stages._OutT]): def eval_shape(self, *args, **kwargs): """See ``jax.eval_shape``.""" diff --git a/jax/_src/scipy/cluster/vq.py b/jax/_src/scipy/cluster/vq.py index a82c8928644d..05cad3504b98 100644 --- a/jax/_src/scipy/cluster/vq.py +++ b/jax/_src/scipy/cluster/vq.py @@ -70,5 +70,5 @@ def vq(obs: ArrayLike, code_book: ArrayLike, check_finite: bool = True) -> tuple raise ValueError("ndim different than 1 or 2 are not supported") dist = vmap(lambda ob: jnp.linalg.norm(ob[None] - cb_arr, axis=-1))(obs_arr) code = jnp.argmin(dist, axis=-1) - dist_min = vmap(operator.getitem)(dist, code) + dist_min = vmap(operator.getitem)(dist, code) # type: ignore[call-overload] return code, dist_min diff --git a/jax/_src/scipy/linalg.py b/jax/_src/scipy/linalg.py index d014e5ceb24e..9066076c5fbf 100644 --- a/jax/_src/scipy/linalg.py +++ b/jax/_src/scipy/linalg.py @@ -219,7 +219,7 @@ def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[False]) -> Ar @overload def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]: ... -@partial(jit, static_argnames=('full_matrices', 'compute_uv')) +@partial(jit, static_argnames=('full_matrices', 'compute_uv')) # type: ignore[misc] def _svd(a: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]: a, = promote_dtypes_inexact(jnp.asarray(a)) return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv) @@ -545,7 +545,7 @@ def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]: if output not in ('real', 'complex'): raise ValueError( f"Expected 'output' to be either 'real' or 'complex', got {output=}.") - return _schur(a, output) + return _schur(a, output) # type: ignore[arg-type] def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array: diff --git a/jax/_src/scipy/special.py b/jax/_src/scipy/special.py index 837aa011f165..77d7db65c737 100644 --- a/jax/_src/scipy/special.py +++ b/jax/_src/scipy/special.py @@ -1805,7 +1805,7 @@ def sph_harm(m: Array, int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must ' 'be statically specified to use `sph_harm` within JAX transformations.') - return _sph_harm(m, n, theta, phi, n_max) + return _sph_harm(m, n, theta, phi, n_max) # type: ignore[arg-type] # exponential integrals diff --git a/jax/_src/stages.py b/jax/_src/stages.py index 3a2c375b64db..41d5cdc6c785 100644 --- a/jax/_src/stages.py +++ b/jax/_src/stages.py @@ -33,7 +33,7 @@ import functools from collections.abc import Sequence from dataclasses import dataclass -from typing import Any, NamedTuple, Protocol, Union, runtime_checkable +from typing import Any, NamedTuple, ParamSpec, Protocol, TypeVar, Union, runtime_checkable import jax @@ -751,9 +751,11 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None, raise ValueError(msg) from None return Lowered(lowering, self.args_info, self._out_tree) +_P = ParamSpec("_P") +_OutT = TypeVar("_OutT", covariant=True) # pytype: disable=not-supported-yet @runtime_checkable -class Wrapped(Protocol): +class Wrapped(Protocol[_P, _OutT]): """A function ready to be traced, lowered, and compiled. This protocol reflects the output of functions such as @@ -762,11 +764,11 @@ class Wrapped(Protocol): to compilation, and the result compiled prior to execution. """ - def __call__(self, *args, **kwargs): + def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _OutT: """Executes the wrapped function, lowering and compiling as needed.""" raise NotImplementedError - def trace(self, *args, **kwargs) -> Traced: + def trace(self, *args: _P.args, **kwargs: _P.kwargs) -> Traced: """Trace this function explicitly for the given arguments. A traced function is staged out of Python and translated to a jaxpr. It is @@ -777,7 +779,7 @@ def trace(self, *args, **kwargs) -> Traced: """ raise NotImplementedError - def lower(self, *args, **kwargs) -> Lowered: + def lower(self, *args: _P.args, **kwargs: _P.kwargs) -> Lowered: """Lower this function explicitly for the given arguments. A lowered function is staged out of Python and translated to a diff --git a/jax/_src/third_party/scipy/special.py b/jax/_src/third_party/scipy/special.py index 67ef09f6de37..3c565f04623f 100644 --- a/jax/_src/third_party/scipy/special.py +++ b/jax/_src/third_party/scipy/special.py @@ -272,8 +272,8 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]: c_large = c_inf s_large = s_inf else: - c_large = 0.5 + 1 / (jnp.pi * x) * sinpi - s_large = 0.5 - 1 / (jnp.pi * x) * cospi + c_large = 0.5 + 1 / (jnp.pi * x) * sinpi # type: ignore[assignment] + s_large = 0.5 - 1 / (jnp.pi * x) * cospi # type: ignore[assignment] # Other x values t = jnp.pi * x2