Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate passing NdArrays with ndim != 1 and non-arraylike inputs to jnp.trim_zeros #23563

Merged
merged 1 commit into from
Sep 18, 2024

Conversation

rajasekharporeddy
Copy link
Contributor

@rajasekharporeddy rajasekharporeddy commented Sep 11, 2024

JAX APIs doesn't accept non-arraylike inputs, but jnp.trim_zeros works with Python list inputs.

>>> jnp.trim_zeros([0, 2, 3, 1, 0, 0])
Array([2, 3, 1], dtype=int32)

Additionally, the behavior of jnp.trim_zeros is inconsistent with np.trim_zeros for NdArray inputs with ndim != 1.

>>> x = jnp.array([[0, 1, 2],
...                [3, 0, 0]])
>>> jnp.trim_zeros(x)
Array([[3, 0, 0]], dtype=int32)
>>> np.trim_zeros(x)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
>>> x1 = jnp.array([[0, 0, 0, 1, 2, 0],
...                 [0, 3, 4, 5, 0, 0]])
>>> jnp.trim_zeros(x1)
Array([], shape=(0, 6), dtype=int32)
>>> np.trim_zeros(x1)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

@rajasekharporeddy rajasekharporeddy changed the title Deprecate passing NdArrays and Python lists to jnp.trim_zeros Deprecate passing NdArrays with ndim != 1 and Python lists to jnp.trim_zeros Sep 11, 2024
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall! A few comments below.

jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
@rajasekharporeddy rajasekharporeddy changed the title Deprecate passing NdArrays with ndim != 1 and Python lists to jnp.trim_zeros Deprecate passing NdArrays with ndim != 1 and non-arraylike inputs to jnp.trim_zeros Sep 11, 2024
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Outdated Show resolved Hide resolved
jax/_src/numpy/lax_numpy.py Show resolved Hide resolved
CHANGELOG.md Outdated Show resolved Hide resolved
Copy link
Collaborator

@jakevdp jakevdp left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@google-ml-butler google-ml-butler bot added kokoro:force-run pull ready Ready for copybara import and testing labels Sep 11, 2024
@rajasekharporeddy
Copy link
Contributor Author

Hi @jakevdp, Should I need to rebase onto the latest main to fix these lint_and_typecheck errors?

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 16, 2024

Yeah, try rebasing. I'm having trouble telling where the error is coming from, because line 6750 in your PR looks like it's the middle of a docstring

@jakevdp jakevdp self-assigned this Sep 16, 2024
@rajasekharporeddy
Copy link
Contributor Author

I have rebased onto latest upstream main. Could you please trigger the tests now?

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 16, 2024

Looks good, thanks! Could you please squash the changes?

@rajasekharporeddy
Copy link
Contributor Author

Thanks! Squashed into a single commit.

CHANGELOG.md Show resolved Hide resolved
@jakevdp jakevdp added pull ready Ready for copybara import and testing and removed pull ready Ready for copybara import and testing labels Sep 17, 2024
@jakevdp
Copy link
Collaborator

jakevdp commented Sep 17, 2024

There's a merge conflict that blocked submitting this. Could you please rebase on an updated main branch?

@jakevdp
Copy link
Collaborator

jakevdp commented Sep 18, 2024

Thanks - can you please squash your changes into a single commit?

@rajasekharporeddy
Copy link
Contributor Author

Squashed into single commit

@copybara-service copybara-service bot merged commit 48d8fce into jax-ml:main Sep 18, 2024
15 checks passed
@rajasekharporeddy rajasekharporeddy deleted the testbranch1 branch September 18, 2024 15:40
keshavb96 added a commit to keshavb96/jax that referenced this pull request Sep 18, 2024
commit 093c6e9
Merge: e1a77ee d0cb318
Author: Keshav <keshavb@nvidia.com>
Date:   Wed Sep 18 13:37:44 2024 -0700

    Merge remote-tracking branch 'upstream/main' into disable_remat_pass

commit e1a77ee
Author: Keshav <keshavb@nvidia.com>
Date:   Wed Sep 18 13:35:37 2024 -0700

    minor changes

commit d0cb318
Merge: b51c653 bef36c4
Author: jax authors <google-ml-automation@google.com>
Date:   Wed Sep 18 13:34:11 2024 -0700

    Merge pull request jax-ml#23736 from hawkinsp:changelog

    PiperOrigin-RevId: 676111400

commit b51c653
Merge: dbc03cf 57a4b76
Author: jax authors <google-ml-automation@google.com>
Date:   Wed Sep 18 13:33:05 2024 -0700

    Merge pull request jax-ml#23737 from jakevdp:digitize-doc

    PiperOrigin-RevId: 676111220

commit dbc03cf
Author: Dan Foreman-Mackey <danfm@google.com>
Date:   Wed Sep 18 12:39:58 2024 -0700

    Re-land jax-ml#23261 with appropriate compatibility checks.

    PiperOrigin-RevId: 676092618

commit b164d67
Merge: cd04d0f 541b3a3
Author: jax authors <google-ml-automation@google.com>
Date:   Wed Sep 18 12:05:03 2024 -0700

    Merge pull request jax-ml#23247 from kaixih:sliding_window_attn

    PiperOrigin-RevId: 676079831

commit 57a4b76
Author: Jake VanderPlas <jakevdp@google.com>
Date:   Wed Sep 18 11:59:00 2024 -0700

    Improve documentation for jnp.digitize

commit bef36c4
Author: Peter Hawkins <phawkins@google.com>
Date:   Wed Sep 18 18:57:03 2024 +0000

    Add Python 3.13 wheels to changelog.

commit cd04d0f
Merge: 016c499 c756d9b
Author: jax authors <google-ml-automation@google.com>
Date:   Wed Sep 18 10:00:03 2024 -0700

    Merge pull request jax-ml#23726 from hawkinsp:debug

    PiperOrigin-RevId: 676030839

commit 016c499
Author: Sergei Lebedev <slebedev@google.com>
Date:   Wed Sep 18 09:56:44 2024 -0700

    Removed leftover usages of GPUGridSpec from Pallas Mosaic GPU tests

    PiperOrigin-RevId: 676029854

commit 9dd363d
Author: Luke Baumann <lukebaumann@google.com>
Date:   Wed Sep 18 09:28:25 2024 -0700

    Export `jax.lib.xla_extension.ifrt_programs`.

    PiperOrigin-RevId: 676020419

commit e27f1e9
Author: jax authors <google-ml-automation@google.com>
Date:   Wed Sep 18 09:03:55 2024 -0700

    Change Python version 3.13.0rc2 to 3.13.0-rc.2.

    The value is taken from [the versions manifest](https://raw.githubusercontent.com/actions/python-versions/main/versions-manifest.json).

    PiperOrigin-RevId: 676012255

commit 442e863
Author: Sergei Lebedev <slebedev@google.com>
Date:   Wed Sep 18 08:56:49 2024 -0700

    Added a missing branch to `mgpu.FragmentedArray.astype`

    Previously, an unsupported cast produced a `NameError` instead.

    PiperOrigin-RevId: 676010161

commit 6236b8f
Merge: 826843a 1cc9661
Author: jax authors <google-ml-automation@google.com>
Date:   Wed Sep 18 08:57:38 2024 -0700

    Merge pull request jax-ml#23667 from dfm:always-lower-jnp-dot-to-dot-general

    PiperOrigin-RevId: 676010154

commit c756d9b
Author: Peter Hawkins <phawkins@google.com>
Date:   Wed Sep 18 15:44:45 2024 +0000

    Fix error in debugger tests that is showing up in CI.

    I'm unsure why this started happening now, but sometimes we get an
    invalid offset for a frame. Be tolerant of that case.

commit 826843a
Merge: c191bbc 922e652
Author: jax authors <google-ml-automation@google.com>
Date:   Wed Sep 18 08:42:39 2024 -0700

    Merge pull request jax-ml#23723 from hawkinsp:setuptools

    PiperOrigin-RevId: 676005613

commit c191bbc
Author: Yash Katariya <yashkatariya@google.com>
Date:   Wed Sep 18 08:40:30 2024 -0700

    Make `debug.print` work with static args. Fixes: jax-ml#23600

    PiperOrigin-RevId: 676005582

commit 1cc9661
Author: Dan Foreman-Mackey <danfm@google.com>
Date:   Mon Sep 16 14:18:29 2024 -0400

    Unconditionally lower jnp.dot to lax.dot_general.

    jax-ml#16721 added a condition to lower
    calls to `jnp.dot` with scalar inputs to `lax.mul` instead of
    `lax.dot_general`. AFAICT, jax-ml#16826
    fixed the issue that this was solving, so this condition should no
    longer be necessary. Removing this condition simplifies the addition of
    new arguments to `dot` and `dot_general`, including the `algorithm`
    parameter that I am currently working on in
    jax-ml#23574, so now seemed like a good time
    to remove it!

commit 922e652
Author: Peter Hawkins <phawkins@google.com>
Date:   Wed Sep 18 15:17:49 2024 +0000

    Replace plat-name with plat_name.

    The former seems to elicit a deprecation warning from setuptools
    recently.

commit 69ba060
Author: Dan Foreman-Mackey <danfm@google.com>
Date:   Wed Sep 18 07:40:58 2024 -0700

    Reverts e15ec1e

    PiperOrigin-RevId: 675987338

commit 44a7f04
Merge: 0a29696 2834c13
Author: jax authors <google-ml-automation@google.com>
Date:   Wed Sep 18 07:31:00 2024 -0700

    Merge pull request jax-ml#23708 from jakevdp:sort-complex

    PiperOrigin-RevId: 675983957

commit 0a29696
Merge: e15ec1e 73c38cb
Author: jax authors <google-ml-automation@google.com>
Date:   Wed Sep 18 07:08:24 2024 -0700

    Merge pull request jax-ml#23698 from dfm:dev-clang-warning

    PiperOrigin-RevId: 675977448

commit 2834c13
Author: Jake VanderPlas <jakevdp@google.com>
Date:   Tue Sep 17 15:32:25 2024 -0700

    jnp.sort_complex: fix output for N-dimensional inputs

commit 73c38cb
Author: Dan Foreman-Mackey <danfm@google.com>
Date:   Tue Sep 17 14:00:21 2024 -0400

    Add a note to the developer docs making it clear that clang is the only
    toolchain that is actively supported for source compilation.

    As discussed in jax-ml#23687

commit e15ec1e
Merge: 48d8fce 3f2bc9b
Author: jax authors <google-ml-automation@google.com>
Date:   Wed Sep 18 06:56:28 2024 -0700

    Merge pull request jax-ml#23261 from joaospinto:stablehlo.tan

    PiperOrigin-RevId: 675973798

commit 48d8fce
Merge: 4e6f690 2714469
Author: jax authors <google-ml-automation@google.com>
Date:   Wed Sep 18 06:54:28 2024 -0700

    Merge pull request jax-ml#23563 from rajasekharporeddy:testbranch1

    PiperOrigin-RevId: 675973225

commit 4e6f690
Merge: b7c91e9 611ad63
Author: jax authors <google-ml-automation@google.com>
Date:   Wed Sep 18 06:35:15 2024 -0700

    Merge pull request jax-ml#23653 from apaszke:torchsaic

    PiperOrigin-RevId: 675967844

commit b7c91e9
Author: Sergei Lebedev <slebedev@google.com>
Date:   Wed Sep 18 06:22:14 2024 -0700

    Lookup `shape` and `dtype` directly on `state.AbstractRef` instead of going through `inner_aval`

    This is just a cleanup. No behavior changes are expected.

    PiperOrigin-RevId: 675964703

commit 611ad63
Author: Adam Paszke <adam.paszke@gmail.com>
Date:   Fri Sep 6 16:09:58 2024 +0000

    Add basic PyTorch integration for Mosaic GPU

    We have already had most of the relevant pieces and we only needed
    to connect them together. The most sensitive change is perhaps that
    I needed to expose one more symbol from the XLA GPU plugin, but I don't
    think it should be a problem.

commit e903369
Author: Sergei Lebedev <slebedev@google.com>
Date:   Wed Sep 18 05:25:37 2024 -0700

    Pulled `scratch_shapes` into `GridSpec`

    It is supported by Mosaic TPU and Mosaic GPU and unsupported by Triton.

    PiperOrigin-RevId: 675950199

commit 2714469
Author: rajasekharporeddy <rajasekharp@google.com>
Date:   Wed Sep 18 17:06:28 2024 +0530

    Deprecate passing NdArrays with ndim != 1 and non-arraylike inputs to jnp.trim_zeros

commit b904599
Author: Sergei Lebedev <slebedev@google.com>
Date:   Wed Sep 18 04:23:25 2024 -0700

    `pl.debug_print` no longer restricts values to be scalars

    This allows printing arrays on Triton and soon on Mosaic GPU.

    PiperOrigin-RevId: 675935666

commit 988ed2b
Author: jax authors <google-ml-automation@google.com>
Date:   Tue Sep 17 21:09:26 2024 -0700

    Add support for SMEM windows in Pallas custom pipeline.

    PiperOrigin-RevId: 675822640

commit f79d85b
Merge: 1b74cfd cc28d63
Author: Keshav <keshavb@nvidia.com>
Date:   Tue Sep 17 18:58:33 2024 -0700

    Merge remote-tracking branch 'upstream/main' into disable_remat_pass

commit cc28d63
Merge: 8bcdb12 9d3762b
Author: jax authors <google-ml-automation@google.com>
Date:   Tue Sep 17 17:36:36 2024 -0700

    Merge pull request jax-ml#23682 from sharadmv:pallas-async-docs

    PiperOrigin-RevId: 675770723

commit 1b74cfd
Author: Keshav <keshavb@nvidia.com>
Date:   Tue Sep 17 17:23:30 2024 -0700

    disable remat hlo pass by default

commit 8bcdb12
Author: jax authors <google-ml-automation@google.com>
Date:   Tue Sep 17 16:50:55 2024 -0700

    Add CI jobs for python 3.13.0rc2.

    PiperOrigin-RevId: 675758096

commit 8b5b717
Author: Yash Katariya <yashkatariya@google.com>
Date:   Tue Sep 17 16:39:55 2024 -0700

    Fix jaxpr equation context propagation in jaxpr equations when `inline=True`.

    PiperOrigin-RevId: 675754808

commit 86fe463
Author: Parker Schuh <parkers@google.com>
Date:   Tue Sep 17 16:10:41 2024 -0700

    [Take 2] Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums.

    This allows us to get more cache hits globally. For example:

    Before:

    jax.jit(f, out_shardings=s)(arr)
    jax.jit(f, out_shardings=s)(arr)  # cpp cache miss
    After:

    jax.jit(f, out_shardings=s)(arr)
    jax.jit(f, out_shardings=s)(arr)  # cpp cache hit

    Reverts b615266

    PiperOrigin-RevId: 675746175

commit e92a599
Author: Christos Perivolaropoulos <cperivol@google.com>
Date:   Tue Sep 17 15:26:42 2024 -0700

    [mosaic_gpu] Better error message for misaligned tma_transpose with dtype.

    PiperOrigin-RevId: 675731295

commit 7864648
Merge: 3f2c58b 83a7555
Author: jax authors <google-ml-automation@google.com>
Date:   Tue Sep 17 15:12:50 2024 -0700

    Merge pull request jax-ml#23679 from selamw1:docstring_sort_complex

    PiperOrigin-RevId: 675726527

commit 83a7555
Author: selamw1 <selamw@google.com>
Date:   Mon Sep 16 16:47:52 2024 -0700

    docstring_sort_complex_added

    input_array_modified

commit 9d3762b
Author: Sharad Vikram <sharad.vikram@gmail.com>
Date:   Mon Sep 16 19:18:22 2024 -0700

    [Pallas] Add design note for async ops on TPU

commit 3f2bc9b
Author: Joao Sousa-Pinto <joaospinto@gmail.com>
Date:   Mon Aug 26 17:25:16 2024 -0700

    Lower tan to StableHLO instead of CHLO.

    Fixes jax-ml#23259

commit 541b3a3
Author: kaixih <kaixih@nvidia.com>
Date:   Mon Aug 26 17:32:38 2024 +0000

    New feature
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants