diff --git a/CHANGELOG.md b/CHANGELOG.md index e9f6acf70a68..55e3f882fcbb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ Remember to align the itemized text with the first line of an item within a list ## jax 0.4.27 +* New Features + * {func}`jax.numpy.astype` supports new `device` keyword argument. + * Changes * {func}`jax.pure_callback` and {func}`jax.experimental.io_callback` now use {class}`jax.Array` instead of {class}`np.ndarray`. You can recover @@ -26,6 +29,12 @@ Remember to align the itemized text with the first line of an item within a list * The `device()` method of JAX arrays has been removed, after being deprecated since JAX v0.4.21. Use `arr.devices()` instead. +* Bug fixes + * {func}`jax.numpy.astype` will now always return a copy when `copy=True`. + Previously, no copy would be made when the output array would have the same + dtype as the input array. This may result in some increased memory usage. + To prevent copying when possible, set `copy=None`. To error when a copy is + required, set `copy=False`. ## jaxlib 0.4.27 diff --git a/jax/_src/numpy/array_methods.py b/jax/_src/numpy/array_methods.py index 98eea8887198..831fef33627c 100644 --- a/jax/_src/numpy/array_methods.py +++ b/jax/_src/numpy/array_methods.py @@ -31,11 +31,13 @@ import numpy as np import jax from jax import lax +from jax.sharding import Sharding from jax._src import core from jax._src import dtypes from jax._src.api_util import _ensure_index_tuple from jax._src.array import ArrayImpl from jax._src.lax import lax as lax_internal +from jax._src.lib import xla_client as xc from jax._src.numpy import lax_numpy from jax._src.numpy import reductions from jax._src.numpy import ufuncs @@ -55,7 +57,7 @@ # functions, which can themselves handle instances from any of these classes. -def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: +def _astype(arr: ArrayLike, dtype: DTypeLike, copy: bool = True, device: xc.Device | Sharding | None = None) -> Array: """Copy the array and cast to a specified dtype. This is implemented via :func:`jax.lax.convert_element_type`, which may @@ -63,7 +65,7 @@ def _astype(arr: ArrayLike, dtype: DTypeLike) -> Array: some cases. In particular, the details of float-to-int and int-to-float casts are implementation dependent. """ - return lax_numpy.astype(arr, dtype) + return lax_numpy.astype(arr, dtype, copy=copy, device=device) def _nbytes(arr: ArrayLike) -> int: diff --git a/jax/_src/numpy/lax_numpy.py b/jax/_src/numpy/lax_numpy.py index 4b23ca210b88..0158fb63eeb3 100644 --- a/jax/_src/numpy/lax_numpy.py +++ b/jax/_src/numpy/lax_numpy.py @@ -41,7 +41,7 @@ import opt_einsum import jax -from jax import jit +from jax import jit, device_put from jax import errors from jax import lax from jax.sharding import Sharding, SingleDeviceSharding @@ -2262,17 +2262,65 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike: In particular, the details of float-to-int and int-to-float casts are implementation dependent. """) -def astype(x: ArrayLike, dtype: DTypeLike | None, /, *, copy: bool = True) -> Array: +def astype(x: ArrayLike, dtype: DTypeLike | None, + /, *, copy: bool = True, + device: xc.Device | Sharding | None = None) -> Array: util.check_arraylike("astype", x) x_arr = asarray(x) - del copy # unused in JAX if dtype is None: dtype = dtypes.canonicalize_dtype(float_) dtypes.check_user_dtype_supported(dtype, "astype") + if ( + issubdtype(x_arr.dtype, complexfloating) + and dtypes.isdtype(dtype, ("integral", "real floating")) + ): + warnings.warn( + "Casting from complex to real dtypes will soon raise a ValueError. " + "Please first use jnp.real or jnp.imag to take the real/imaginary " + "component of your input.", + DeprecationWarning, stacklevel=2 + ) + + out = x_arr # convert_element_type(complex, bool) has the wrong semantics. if np.dtype(dtype) == bool and issubdtype(x_arr.dtype, complexfloating): - return (x_arr != _lax_const(x_arr, 0)) - return lax.convert_element_type(x_arr, dtype) + out = (x_arr != _lax_const(x_arr, 0)) + out = _place_array(out, device=device, copy=copy) + + # We offer a more specific warning than the usual ComplexWarning so we prefer + # to issue our warning. + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ComplexWarning) + return lax.convert_element_type(out, dtype) + +def _get_device_set(x: ArrayLike | xc.Device | Sharding | None): + if x is None: + return None + elif isinstance(x, Sharding): + return x.device_set + elif isinstance(x, xc.Device): + return {x} + elif hasattr(x, "devices") and not isinstance(x, core.Tracer): + return x.devices() + +def _place_array(x: jax.Array, device: xc.Device | Sharding | None = None, copy=None): + # TODO(micky774): Fine tune mechanics in future PRs as we formalize device + # placement semantics + devices = _get_device_set(device) + src_devices = _get_device_set(x) + if devices is not None and src_devices != devices: + if copy is not None and not copy: + raise ValueError( + f"Specified {device=} which requires a copy since the source devices " + f"are {src_devices}, however copy=False. Set copy=True or " + "copy=None to perform the requested operation." + ) + out = device_put(x, device) + else: + out = x + if copy: + return _array_copy(out) + return out @util.implements(np.asarray, lax_description=_ARRAY_DOC) diff --git a/jax/experimental/array_api/_data_type_functions.py b/jax/experimental/array_api/_data_type_functions.py index 4f72fcba29d0..770d264c1c07 100644 --- a/jax/experimental/array_api/_data_type_functions.py +++ b/jax/experimental/array_api/_data_type_functions.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import builtins import functools from typing import NamedTuple @@ -19,6 +21,9 @@ import jax.numpy as jnp +from jax._src.lib import xla_client as xc +from jax._src.sharding import Sharding +from jax._src import dtypes as _dtypes from jax.experimental.array_api._dtypes import ( bool, int8, int16, int32, int64, uint8, uint16, uint32, uint64, float32, float64, complex64, complex128 @@ -124,8 +129,19 @@ def _promote_types(t1, t2): raise ValueError("No promotion path for {t1} & {t2}") -def astype(x, dtype, /, *, copy=True): - return jnp.array(x, dtype=dtype, copy=copy) +def astype(x, dtype, /, *, copy: builtins.bool = True, device: xc.Device | Sharding | None = None): + src_dtype = x.dtype if hasattr(x, "dtype") else _dtypes.dtype(x) + if ( + src_dtype is not None + and _dtypes.isdtype(src_dtype, "complex floating") + and _dtypes.isdtype(dtype, ("integral", "real floating")) + ): + raise ValueError( + "Casting from complex to non-complex dtypes is not permitted. Please " + "first use jnp.real or jnp.imag to take the real/imaginary component of " + "your input." + ) + return jnp.astype(x, dtype, copy=copy, device=device) def can_cast(from_, to, /): diff --git a/jax/numpy/__init__.pyi b/jax/numpy/__init__.pyi index 5a2046f6f7a1..e930e1e43837 100644 --- a/jax/numpy/__init__.pyi +++ b/jax/numpy/__init__.pyi @@ -13,6 +13,8 @@ from jax._src.typing import ( Array, ArrayLike, DType, DTypeLike, DimSize, DuckTypedArray, Shape, DeprecatedArg ) +from jax._src.sharding import Sharding +from jax._src.lib import xla_client as xc from jax.numpy import fft as fft, linalg as linalg from jax.sharding import Sharding as _Sharding import numpy as _np @@ -115,7 +117,7 @@ def asarray( ) -> Array: ... def asin(x: ArrayLike, /) -> Array: ... def asinh(x: ArrayLike, /) -> Array: ... -def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ...) -> Array: ... +def astype(a: ArrayLike, dtype: Optional[DTypeLike], /, *, copy: builtins.bool = ..., device: xc.Device | Sharding | None = ...) -> Array: ... def atan(x: ArrayLike, /) -> Array: ... def atan2(x: ArrayLike, y: ArrayLike, /) -> Array: ... def atanh(x: ArrayLike, /) -> Array: ... diff --git a/tests/lax_numpy_test.py b/tests/lax_numpy_test.py index c957ed669b3a..4b56fa4f618a 100644 --- a/tests/lax_numpy_test.py +++ b/tests/lax_numpy_test.py @@ -3840,6 +3840,36 @@ def testAstypeBool(self, from_dtype, use_method, to_dtype='bool'): self._CheckAgainstNumpy(np_op, jnp_op, args_maker) self._CompileAndCheck(jnp_op, args_maker) + @jtu.sample_product( + change_dtype=[True, False], + copy=[True, False], + change_device=[True, False], + ) + def testAstypeCopy(self, change_dtype, copy, change_device): + if change_device and not jtu.test_device_matches(["gpu"]): + raise unittest.SkipTest( + "Testing device transfer requires at least two available devices." + ) + + dtype = 'float32' if change_dtype else 'int32' + device = jax.devices("cpu")[0] if change_device else None + expect_copy = change_dtype or copy or change_device + x = jnp.arange(5, dtype='int32') + y = x.astype(dtype, copy=copy, device=device) + + assert y.dtype == dtype + if change_device: + assert y.devices() == {device} + else: + y.delete() + assert x.is_deleted() != expect_copy + + def testAstypeComplexDowncast(self): + x = jnp.array(2.0+1.5j, dtype='complex64') + msg = "Casting from complex to non-complex dtypes will soon raise " + with self.assertWarns(DeprecationWarning, msg=msg): + x.astype('float32') + def testAstypeInt4(self): # Test converting from int4 to int8 x = np.array([1, -2, -3, 4, -8, 7], dtype=jnp.int4)