diff --git a/jax/_src/basearray.pyi b/jax/_src/basearray.pyi index e4c7fdc4a9ab..bd546169d341 100644 --- a/jax/_src/basearray.pyi +++ b/jax/_src/basearray.pyi @@ -171,7 +171,7 @@ class Array(abc.ABC): @property def nbytes(self) -> int: ... def nonzero(self, *, fill_value: None | ArrayLike | tuple[ArrayLike, ...] = None, - size: int | None = None,) -> tuple[Array, ...]: ... + size: int | None = None) -> tuple[Array, ...]: ... def prod(self, axis: Axis = None, dtype: DTypeLike | None = None, out: None = None, keepdims: bool = False, initial: ArrayLike | None = None, where: ArrayLike | None = None,