Skip to content

Commit

Permalink
Merge pull request #23556 from selamw1:docstring_frombuffer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 675240113
  • Loading branch information
Google-ML-Automation committed Sep 16, 2024
2 parents 8ab66c8 + 7dde9b2 commit 90f532a
Showing 1 changed file with 42 additions and 1 deletion.
43 changes: 42 additions & 1 deletion jax/_src/numpy/lax_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5312,9 +5312,50 @@ def array_equiv(a1: ArrayLike, a2: ArrayLike) -> Array:

# General np.from* style functions mostly delegate to numpy.

@util.implements(np.frombuffer)
def frombuffer(buffer: bytes | Any, dtype: DTypeLike = float,
count: int = -1, offset: int = 0) -> Array:
r"""Convert a buffer into a 1-D JAX array.
JAX implementation of :func:`numpy.frombuffer`.
Args:
buffer: an object containing the data. It must be either a bytes object with
a length that is an integer multiple of the dtype element size, or
it must be an object exporting the `Python buffer interface`_.
dtype: optional. Desired data type for the array. Default is ``float64``.
This specifes the dtype used to parse the buffer, but note that after parsing,
64-bit values will be cast to 32-bit JAX arrays if the ``jax_enable_x64``
flag is set to ``False``.
count: optional integer specifying the number of items to read from the buffer.
If -1 (default), all items from the buffer are read.
offset: optional integer specifying the number of bytes to skip at the beginning
of the buffer. Default is 0.
Returns:
A 1-D JAX array representing the interpreted data from the buffer.
See also:
- :func:`jax.numpy.fromstring`: convert a string of text into 1-D JAX array.
Examples:
Using a bytes buffer:
>>> buf = b"\x00\x01\x02\x03\x04"
>>> jnp.frombuffer(buf, dtype=jnp.uint8)
Array([0, 1, 2, 3, 4], dtype=uint8)
>>> jnp.frombuffer(buf, dtype=jnp.uint8, offset=1)
Array([1, 2, 3, 4], dtype=uint8)
Constructing a JAX array via the Python buffer interface, using Python's
built-in :mod:`array` module.
>>> from array import array
>>> pybuffer = array('i', [0, 1, 2, 3, 4])
>>> jnp.frombuffer(pybuffer, dtype=jnp.int32)
Array([0, 1, 2, 3, 4], dtype=int32)
.. _Python buffer interface: https://docs.python.org/3/c-api/buffer.html
"""
return asarray(np.frombuffer(buffer=buffer, dtype=dtype, count=count, offset=offset))


Expand Down

0 comments on commit 90f532a

Please sign in to comment.