Skip to content

Commit

Permalink
Merge pull request #23094 from jakevdp:basearray-annotations
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 664852746
  • Loading branch information
jax authors committed Aug 19, 2024
2 parents 772e042 + dd697a9 commit 0666ccc
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 59 deletions.
127 changes: 78 additions & 49 deletions jax/_src/basearray.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,29 @@
import abc
from collections.abc import Callable, Sequence
from types import ModuleType
from typing import Any, Union
from typing import Any, Protocol, Union, runtime_checkable
import numpy as np

from jax._src.sharding import Sharding

# TODO(jakevdp) de-duplicate this with the DTypeLike definition in typing.py.
# We redefine these here to prevent circular imports.
@runtime_checkable
class SupportsDType(Protocol):
@property
def dtype(self) -> np.dtype: ...
DTypeLike = Union[str, type[Any], np.dtype, SupportsDType]

Axis = Union[int, Sequence[int], None]
Shard = Any

# TODO: alias this to xla_client.Traceback
Device = Any
Traceback = Any

# TODO(jakevdp): fix import cycles and import this from jax._src.lax.
PrecisionLike = Any


class Array(abc.ABC):
aval: Any
Expand Down Expand Up @@ -117,72 +129,89 @@ class Array(abc.ABC):
def __release_buffer__(self, view: memoryview) -> None: ...

# np.ndarray methods:
def all(self, axis: int | Sequence[int] | None = None, out=None,
keepdims=None, *, where: ArrayLike | None = ...) -> Array: ...
def any(self, axis: int | Sequence[int] | None = None, out=None,
keepdims=None, *, where: ArrayLike | None = ...) -> Array: ...
def argmax(self, axis: int | None = None, out=None, keepdims=None) -> Array: ...
def argmin(self, axis: int | None = None, out=None, keepdims=None) -> Array: ...
def argpartition(self, kth, axis=-1, kind='introselect', order=None) -> Array: ...
def argsort(self, axis: int | None = -1, kind='quicksort', order=None) -> Array: ...
def astype(self, dtype) -> Array: ...
def choose(self, choices, out=None, mode='raise') -> Array: ...
def clip(self, min=None, max=None, out=None) -> Array: ...
def compress(self, condition, axis: int | None = None, out=None) -> Array: ...
def all(self, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ...
def any(self: Array, axis: Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array: ...
def argmax(self: Array, axis: int | None = None, out: None = None,
keepdims: bool | None = None) -> Array: ...
def argmin(self, axis: int | None = None, out: None = None,
keepdims: bool | None = None) -> Array: ...
def argpartition(self, kth, axis=-1, kind='introselect', order: None = None) -> Array: ...
def argsort(self, axis: int | None = -1, kind='quicksort', order: None = None) -> Array: ...
def astype(self, dtype: DTypeLike | None = None, max: ArrayLike | None = None) -> Array: ...
def choose(self, choices: Sequence[ArrayLike], out: None = None, mode: str = 'raise') -> Array: ...
def clip(self, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array: ...
def compress(self, condition: ArrayLike,
axis: int | None = None, *, out: None = None,
size: int | None = None, fill_value: ArrayLike = 0) -> Array: ...
def conj(self) -> Array: ...
def conjugate(self) -> Array: ...
def copy(self) -> Array: ...
def cumprod(self, axis: int | Sequence[int] | None = None,
dtype=None, out=None) -> Array: ...
dtype: DTypeLike | None = None, out: None = None) -> Array: ...
def cumsum(self, axis: int | Sequence[int] | None = None,
dtype=None, out=None) -> Array: ...
def diagonal(self, offset=0, axis1: int = 0, axis2: int = 1) -> Array: ...
def dot(self, b, *, precision=None) -> Array: ...
def flatten(self) -> Array: ...
dtype: DTypeLike | None = None, out: None = None) -> Array: ...
def diagonal(self, offset: int = 0, axis1: int = 0, axis2: int = 1) -> Array: ...
def dot(self, b: ArrayLike, *, precision: PrecisionLike = None,
preferred_element_type: DTypeLike | None = None) -> Array: ...
def flatten(self, order: str = "C") -> Array: ...
@property
def imag(self) -> Array: ...
def item(self, *args) -> Any: ...
def max(self, axis: int | Sequence[int] | None = None, out=None,
keepdims=None, initial=None, where=None) -> Array: ...
def mean(self, axis: int | Sequence[int] | None = None, dtype=None,
out=None, keepdims=False, *, where=None,) -> Array: ...
def min(self, axis: int | Sequence[int] | None = None, out=None,
keepdims=None, initial=None, where=None) -> Array: ...
def item(self, *args: int) -> Any: ...
def max(self, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array: ...
def mean(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, *,
where: ArrayLike | None = None) -> Array: ...
def min(self, axis: Axis = None, out: None = None,
keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None) -> Array: ...
@property
def nbytes(self) -> int: ...
def nonzero(self, *, size=None, fill_value=None) -> Array: ...
def prod(self, axis: int | Sequence[int] | None = None, dtype=None,
out=None, keepdims=None, initial=None, where=None) -> Array: ...
def ptp(self, axis: int | Sequence[int] | None = None, out=None,
keepdims=False,) -> Array: ...
def ravel(self, order='C') -> Array: ...
def nonzero(self, *, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None,
size: int | None = None,) -> tuple[Array, ...]: ...
def prod(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False,
initial: ArrayLike | None = None, where: ArrayLike | None = None,
promote_integers: bool = True) -> Array: ...
def ptp(self, axis: Axis = None, out: None = None,
keepdims: bool = False) -> Array: ...
def ravel(self, order: str = 'C') -> Array: ...
@property
def real(self) -> Array: ...
def repeat(self, repeats, axis: int | None = None, *,
total_repeat_length=None) -> Array: ...
def reshape(self, *args, order='C') -> Array: ...
def round(self, decimals=0, out=None) -> Array: ...
def searchsorted(self, v, side='left', sorter=None) -> Array: ...
def sort(self, axis: int | None = -1, kind='quicksort', order=None) -> Array: ...
def repeat(self, repeats: ArrayLike, axis: int | None = None, *,
total_repeat_length: int | None = None) -> Array: ...
def reshape(self, *args: Any, order: str = "C") -> Array: ...
def round(self, decimals: int = 0, out: None = None) -> Array: ...
def searchsorted(self, v: ArrayLike, side: str = 'left',
sorter: ArrayLike | None = None, *, method: str = 'scan') -> Array: ...
def sort(self, axis: int | None = -1, *, kind: None = None,
order: None = None, stable: bool = True, descending: bool = False) -> Array: ...
def squeeze(self, axis: int | Sequence[int] | None = None) -> Array: ...
def std(self, axis: int | Sequence[int] | None = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
def sum(self, axis: int | Sequence[int] | None = None, dtype=None,
out=None, keepdims=None, initial=None, where=None) -> Array: ...
def std(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ...
def sum(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, keepdims: bool = False, initial: ArrayLike | None = None,
where: ArrayLike | None = None, promote_integers: bool = True) -> Array: ...
def swapaxes(self, axis1: int, axis2: int) -> Array: ...
def take(self, indices, axis: int | None = None, out=None,
mode=None) -> Array: ...
def tobytes(self, order='C') -> bytes: ...
def take(self, indices: ArrayLike, axis: int | None = None, out: None = None,
mode: str | None = None, unique_indices: bool = False, indices_are_sorted: bool = False,
fill_value: StaticScalar | None = None) -> Array: ...
def tobytes(self, order: str = 'C') -> bytes: ...
def tolist(self) -> list[Any]: ...
def trace(self, offset=0, axis1: int = 0, axis2: int = 1, dtype=None,
out=None) -> Array: ...
def transpose(self, *args) -> Array: ...
def trace(self, offset: int | ArrayLike = 0, axis1: int = 0, axis2: int = 1,
dtype: DTypeLike | None = None, out: None = None) -> Array: ...
def transpose(self, *args: Any) -> Array: ...
@property
def T(self) -> Array: ...
@property
def mT(self) -> Array: ...
def var(self, axis: int | Sequence[int] | None = None,
dtype=None, out=None, ddof=0, keepdims=False, *, where=None) -> Array: ...
def var(self, axis: Axis = None, dtype: DTypeLike | None = None,
out: None = None, ddof: int = 0, keepdims: bool = False, *,
where: ArrayLike | None = None, correction: int | float | None = None) -> Array: ...
def view(self, dtype=None, type=None) -> Array: ...

# Even though we don't always support the NumPy array protocol, e.g., for
Expand Down
19 changes: 9 additions & 10 deletions jax/_src/numpy/array_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
# functions, which can themselves handle instances from any of these classes.


def _all(self: ArrayLike, axis: reductions.Axis = None, out: None = None,
def _all(self: Array, axis: reductions.Axis = None, out: None = None,
keepdims: bool = False, *, where: ArrayLike | None = None) -> Array:
"""Test whether all array elements along a given axis evaluate to True.
Expand Down Expand Up @@ -107,7 +107,8 @@ def _argsort(self: Array, axis: int | None = -1, *, kind: None = None, order: No
return lax_numpy.argsort(self, axis=axis, kind=kind, order=order,
stable=stable, descending=descending)

def _astype(self: Array, dtype: DTypeLike, copy: bool = False, device: xc.Device | Sharding | None = None) -> Array:
def _astype(self: Array, dtype: DTypeLike | None, copy: bool = False,
device: xc.Device | Sharding | None = None) -> Array:
"""Copy the array and cast to a specified dtype.
This is implemented via :func:`jax.lax.convert_element_type`, which may
Expand All @@ -124,13 +125,12 @@ def _choose(self: Array, choices: Sequence[ArrayLike], out: None = None, mode: s
"""
return lax_numpy.choose(self, choices=choices)

def _clip(number: ArrayLike,
min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array:
def _clip(self: Array, min: ArrayLike | None = None, max: ArrayLike | None = None) -> Array:
"""Return an array whose values are limited to a specified range.
Refer to :func:`jax.numpy.clip` for full documentation.
"""
return lax_numpy.clip(number, min=min, max=max)
return lax_numpy.clip(self, min=min, max=max)

def _compress(self: Array, condition: ArrayLike,
axis: int | None = None, *, out: None = None,
Expand Down Expand Up @@ -163,15 +163,15 @@ def _copy(self: Array) -> Array:
"""
return lax_numpy.copy(self)

def _cumprod(self: Array, /, axis: int | Sequence[int] | None = None,
def _cumprod(self: Array, axis: int | Sequence[int] | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
"""Return the cumulative product of the array.
Refer to :func:`jax.numpy.cumprod` for the full documentation.
"""
return reductions.cumprod(self, axis=axis, dtype=dtype, out=out)

def _cumsum(self: Array, /, axis: int | Sequence[int] | None = None,
def _cumsum(self: Array, axis: int | Sequence[int] | None = None,
dtype: DTypeLike | None = None, out: None = None) -> Array:
"""Return the cumulative sum of the array.
Expand Down Expand Up @@ -258,9 +258,8 @@ def _nbytes_property(self: Array) -> int:
"""Total bytes consumed by the elements of the array."""
return np.size(self) * dtypes.dtype(self, canonicalize=True).itemsize

def _nonzero(self: Array, *, size: int | None = None,
fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None
) -> tuple[Array, ...]:
def _nonzero(self: Array, *, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None,
size: int | None = None) -> tuple[Array, ...]:
"""Return indices of nonzero elements of an array.
Refer to :func:`jax.numpy.nonzero` for the full documentation.
Expand Down

0 comments on commit 0666ccc

Please sign in to comment.