Skip to content

Commit

Permalink
Merge pull request #22664 from jakevdp:astype-device
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 656016734
  • Loading branch information
jax authors committed Jul 25, 2024
2 parents 593afa6 + 81b9db6 commit f17d0f3
Showing 1 changed file with 38 additions and 17 deletions.
55 changes: 38 additions & 17 deletions jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3479,15 +3479,44 @@ def _convert_to_array_if_dtype_fails(x: ArrayLike) -> ArrayLike:

deprecations.register("jax-numpy-astype-complex-to-real")

@util.implements(getattr(np, "astype", None), lax_description="""
This is implemented via :func:`jax.lax.convert_element_type`, which may
have slightly different behavior than :func:`numpy.astype` in some cases.
In particular, the details of float-to-int and int-to-float casts are
implementation dependent.
""")
def astype(x: ArrayLike, dtype: DTypeLike | None,
/, *, copy: bool = False,
device: xc.Device | Sharding | None = None) -> Array:
"""Convert an array to a specified dtype.
JAX imlementation of :func:`numpy.astype`.
This is implemented via :func:`jax.lax.convert_element_type`, which may
have slightly different behavior than :func:`numpy.astype` in some cases.
In particular, the details of float-to-int and int-to-float casts are
implementation dependent.
Args:
x: input array to convert
dtype: output dtype
copy: if True, then always return a copy. If False (default) then only
return a copy if necessary.
device: optionally specify the device to which the output will be committed.
Returns:
An array with the same shape as ``x``, containing values of the specified
dtype.
See Also:
- :func:`jax.lax.convert_element_type`: lower-level function for XLA-style
dtype conversions.
Examples:
>>> x = jnp.array([0, 1, 2, 3])
>>> x
Array([0, 1, 2, 3], dtype=int32)
>>> x.astype('float32')
Array([0.0, 1.0, 2.0, 3.0], dtype=float32)
>>> y = jnp.array([0.0, 0.5, 1.0])
>>> y.astype(int) # truncates fractional values
Array([0, 0, 1], dtype=int32)
"""
util.check_arraylike("astype", x)
x_arr = asarray(x)

Expand All @@ -3510,17 +3539,9 @@ def astype(x: ArrayLike, dtype: DTypeLike | None,
# to issue our warning.
with warnings.catch_warnings():
warnings.simplefilter("ignore", ComplexWarning)
return _place_array(
lax.convert_element_type(x_arr, dtype),
device=device, copy=copy,
)

def _place_array(x, device=None, copy=None):
# TODO(micky774): Implement in future PRs as we formalize device placement
# semantics
if copy:
return _array_copy(x)
return x
result = lax_internal._convert_element_type(
x_arr, dtype, sharding=_normalize_to_sharding(device))
return _array_copy(result) if copy else result


@util.implements(np.asarray, lax_description=_ARRAY_DOC)
Expand Down

0 comments on commit f17d0f3

Please sign in to comment.