Skip to content

Commit

Permalink
Improve typing of jax.jit
Browse files Browse the repository at this point in the history
- Fix for #23719

Signed-off-by: Fabrice Normandin <normandf@mila.quebec>
  • Loading branch information
lebrice committed Sep 21, 2024
1 parent a2b3919 commit 32de804
Show file tree
Hide file tree
Showing 11 changed files with 25 additions and 24 deletions.
5 changes: 2 additions & 3 deletions jax/_src/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,9 +137,8 @@ def _update_debug_special_thread_local(_):

float0 = dtypes.float0


def jit(
fun: Callable,
fun: Callable[stages._P, stages._OutT],
in_shardings=sharding_impls.UNSPECIFIED,
out_shardings=sharding_impls.UNSPECIFIED,
static_argnums: int | Sequence[int] | None = None,
Expand All @@ -151,7 +150,7 @@ def jit(
backend: str | None = None,
inline: bool = False,
abstracted_axes: Any | None = None,
) -> pjit.JitWrapped:
) -> pjit.JitWrapped[stages._P, stages._OutT]:
"""Sets up ``fun`` for just-in-time compilation with XLA.
Args:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/lax/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -1649,7 +1649,7 @@ def _lu_solve(lu: Array, permutation: Array, b: Array, trans: int) -> Array:
def lu_solve(lu: ArrayLike, permutation: ArrayLike, b: ArrayLike,
trans: int = 0) -> Array:
"""LU solve with broadcasting."""
return _lu_solve(lu, permutation, b, trans)
return _lu_solve(lu, permutation, b, trans) # type: ignore[arg-type]


# QR decomposition
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,15 +168,15 @@ def _cumprod(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None
Refer to :func:`jax.numpy.cumprod` for the full documentation.
"""
return reductions.cumprod(self, axis=axis, dtype=dtype, out=out)
return reductions.cumprod(self, axis=axis, dtype=dtype, out=out) # type: ignore[arg-type]

def _cumsum(self: Array, axis: reductions.Axis = None, dtype: DTypeLike | None = None,
out: None = None) -> Array:
"""Return the cumulative sum of the array.
Refer to :func:`jax.numpy.cumsum` for the full documentation.
"""
return reductions.cumsum(self, axis=axis, dtype=dtype, out=out)
return reductions.cumsum(self, axis=axis, dtype=dtype, out=out) # type: ignore[arg-type]

def _diagonal(self: Array, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array:
"""Return the specified diagonal from the array.
Expand Down
10 changes: 5 additions & 5 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1124,7 +1124,7 @@ def flip(m: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
[6, 5]]], dtype=int32)
"""
util.check_arraylike("flip", m)
return _flip(asarray(m), reductions._ensure_optional_axes(axis))
return _flip(asarray(m), reductions._ensure_optional_axes(axis)) # type: ignore[arg-type]

@partial(jit, static_argnames=('axis',))
def _flip(m: Array, axis: int | tuple[int, ...] | None = None) -> Array:
Expand Down Expand Up @@ -1982,7 +1982,7 @@ def squeeze(a: ArrayLike, axis: int | Sequence[int] | None = None) -> Array:
Array([0, 1, 2], dtype=int32)
"""
util.check_arraylike("squeeze", a)
return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None)
return _squeeze(asarray(a), _ensure_index_tuple(axis) if axis is not None else None) # type: ignore[arg-type]

@partial(jit, static_argnames=('axis',), inline=True)
def _squeeze(a: Array, axis: tuple[int, ...]) -> Array:
Expand Down Expand Up @@ -7259,7 +7259,7 @@ def delete(
obj = asarray(obj).ravel()
obj = clip(where(obj < 0, obj + a.shape[axis], obj), 0, a.shape[axis])
obj = sort(obj)
obj -= arange(len(obj)) # type: ignore[arg-type,operator]
obj -= arange(len(obj))
i = arange(a.shape[axis] - obj.size)
i += (i[None, :] >= obj[:, None]).sum(0)
return a[(slice(None),) * axis + (i,)]
Expand Down Expand Up @@ -11034,8 +11034,8 @@ def digitize(x: ArrayLike, bins: ArrayLike, right: bool = False,
kwds: dict[str, str] = {} if method is None else {'method': method}
return where(
bins_arr[-1] >= bins_arr[0],
searchsorted(bins_arr, x, side=side, **kwds),
bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, **kwds)
searchsorted(bins_arr, x, side=side, **kwds), # type: ignore[arg-type]
bins_arr.shape[0] - searchsorted(bins_arr[::-1], x, side=side, **kwds) # type: ignore[arg-type]
)


Expand Down
2 changes: 1 addition & 1 deletion jax/_src/numpy/ufunc_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def __call__(self, *args: ArrayLike, out: None = None, where: None = None) -> An
if where is not None:
raise NotImplementedError(f"where argument of {self}")
call = self.__static_props['call'] or self._call_vectorized
return call(*args)
return call(*args) # type: ignore[arg-type]

@partial(jax.jit, static_argnames=['self'])
def _call_vectorized(self, *args):
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/pjit.py
Original file line number Diff line number Diff line change
Expand Up @@ -808,7 +808,7 @@ def ax_leaf(l):
return broadcast_prefix(abstracted_axes, args, ax_leaf)


class JitWrapped(stages.Wrapped):
class JitWrapped(stages.Wrapped[stages._P, stages._OutT]):

def eval_shape(self, *args, **kwargs):
"""See ``jax.eval_shape``."""
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/scipy/cluster/vq.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,5 +70,5 @@ def vq(obs: ArrayLike, code_book: ArrayLike, check_finite: bool = True) -> tuple
raise ValueError("ndim different than 1 or 2 are not supported")
dist = vmap(lambda ob: jnp.linalg.norm(ob[None] - cb_arr, axis=-1))(obs_arr)
code = jnp.argmin(dist, axis=-1)
dist_min = vmap(operator.getitem)(dist, code)
dist_min = vmap(operator.getitem)(dist, code) # type: ignore[call-overload]
return code, dist_min
4 changes: 2 additions & 2 deletions jax/_src/scipy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: Literal[False]) -> Ar
@overload
def _svd(x: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]: ...

@partial(jit, static_argnames=('full_matrices', 'compute_uv'))
@partial(jit, static_argnames=('full_matrices', 'compute_uv')) # type: ignore[misc]
def _svd(a: ArrayLike, *, full_matrices: bool, compute_uv: bool) -> Array | tuple[Array, Array, Array]:
a, = promote_dtypes_inexact(jnp.asarray(a))
return lax_linalg.svd(a, full_matrices=full_matrices, compute_uv=compute_uv)
Expand Down Expand Up @@ -545,7 +545,7 @@ def schur(a: ArrayLike, output: str = 'real') -> tuple[Array, Array]:
if output not in ('real', 'complex'):
raise ValueError(
f"Expected 'output' to be either 'real' or 'complex', got {output=}.")
return _schur(a, output)
return _schur(a, output) # type: ignore[arg-type]


def inv(a: ArrayLike, overwrite_a: bool = False, check_finite: bool = True) -> Array:
Expand Down
2 changes: 1 addition & 1 deletion jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -1805,7 +1805,7 @@ def sph_harm(m: Array,
int, n_max, 'The `n_max` argument of `jnp.scipy.special.sph_harm` must '
'be statically specified to use `sph_harm` within JAX transformations.')

return _sph_harm(m, n, theta, phi, n_max)
return _sph_harm(m, n, theta, phi, n_max) # type: ignore[arg-type]


# exponential integrals
Expand Down
12 changes: 7 additions & 5 deletions jax/_src/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
import functools
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any, NamedTuple, Protocol, Union, runtime_checkable
from typing import Any, NamedTuple, ParamSpec, Protocol, TypeVar, Union, runtime_checkable

import jax

Expand Down Expand Up @@ -751,9 +751,11 @@ def lower(self, *, lowering_platforms: tuple[str, ...] | None = None,
raise ValueError(msg) from None
return Lowered(lowering, self.args_info, self._out_tree)

_P = ParamSpec("_P")
_OutT = TypeVar("_OutT", covariant=True) # pytype: disable=not-supported-yet

@runtime_checkable
class Wrapped(Protocol):
class Wrapped(Protocol[_P, _OutT]):
"""A function ready to be traced, lowered, and compiled.
This protocol reflects the output of functions such as
Expand All @@ -762,11 +764,11 @@ class Wrapped(Protocol):
to compilation, and the result compiled prior to execution.
"""

def __call__(self, *args, **kwargs):
def __call__(self, *args: _P.args, **kwargs: _P.kwargs) -> _OutT:
"""Executes the wrapped function, lowering and compiling as needed."""
raise NotImplementedError

def trace(self, *args, **kwargs) -> Traced:
def trace(self, *args: _P.args, **kwargs: _P.kwargs) -> Traced:
"""Trace this function explicitly for the given arguments.
A traced function is staged out of Python and translated to a jaxpr. It is
Expand All @@ -777,7 +779,7 @@ def trace(self, *args, **kwargs) -> Traced:
"""
raise NotImplementedError

def lower(self, *args, **kwargs) -> Lowered:
def lower(self, *args: _P.args, **kwargs: _P.kwargs) -> Lowered:
"""Lower this function explicitly for the given arguments.
A lowered function is staged out of Python and translated to a
Expand Down
4 changes: 2 additions & 2 deletions jax/_src/third_party/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,8 +272,8 @@ def fresnel(x: ArrayLike) -> tuple[Array, Array]:
c_large = c_inf
s_large = s_inf
else:
c_large = 0.5 + 1 / (jnp.pi * x) * sinpi
s_large = 0.5 - 1 / (jnp.pi * x) * cospi
c_large = 0.5 + 1 / (jnp.pi * x) * sinpi # type: ignore[assignment]
s_large = 0.5 - 1 / (jnp.pi * x) * cospi # type: ignore[assignment]

# Other x values
t = jnp.pi * x2
Expand Down

0 comments on commit 32de804

Please sign in to comment.