-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Deprecate passing NdArrays with ndim != 1
and non-arraylike
inputs to jnp.trim_zeros
#23563
Conversation
NdArrays
and Python lists
to jnp.trim_zeros
NdArrays with ndim != 1
and Python lists
to jnp.trim_zeros
3279fad
to
b6fbbe4
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall! A few comments below.
NdArrays with ndim != 1
and Python lists
to jnp.trim_zeros
NdArrays with ndim != 1
and non-arraylike
inputs to jnp.trim_zeros
d961e13
to
8714f40
Compare
d74ccfd
to
0047b06
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
Hi @jakevdp, Should I need to rebase onto the latest main to fix these |
Yeah, try rebasing. I'm having trouble telling where the error is coming from, because line 6750 in your PR looks like it's the middle of a docstring |
8fefaef
to
fc96d03
Compare
I have rebased onto latest upstream main. Could you please trigger the tests now? |
Looks good, thanks! Could you please squash the changes? |
4497be8
to
32a971f
Compare
Thanks! Squashed into a single commit. |
d7589da
to
e02f580
Compare
There's a merge conflict that blocked submitting this. Could you please rebase on an updated main branch? |
Thanks - can you please squash your changes into a single commit? |
bbd2b56
to
2714469
Compare
Squashed into single commit |
commit 093c6e9 Merge: e1a77ee d0cb318 Author: Keshav <keshavb@nvidia.com> Date: Wed Sep 18 13:37:44 2024 -0700 Merge remote-tracking branch 'upstream/main' into disable_remat_pass commit e1a77ee Author: Keshav <keshavb@nvidia.com> Date: Wed Sep 18 13:35:37 2024 -0700 minor changes commit d0cb318 Merge: b51c653 bef36c4 Author: jax authors <google-ml-automation@google.com> Date: Wed Sep 18 13:34:11 2024 -0700 Merge pull request jax-ml#23736 from hawkinsp:changelog PiperOrigin-RevId: 676111400 commit b51c653 Merge: dbc03cf 57a4b76 Author: jax authors <google-ml-automation@google.com> Date: Wed Sep 18 13:33:05 2024 -0700 Merge pull request jax-ml#23737 from jakevdp:digitize-doc PiperOrigin-RevId: 676111220 commit dbc03cf Author: Dan Foreman-Mackey <danfm@google.com> Date: Wed Sep 18 12:39:58 2024 -0700 Re-land jax-ml#23261 with appropriate compatibility checks. PiperOrigin-RevId: 676092618 commit b164d67 Merge: cd04d0f 541b3a3 Author: jax authors <google-ml-automation@google.com> Date: Wed Sep 18 12:05:03 2024 -0700 Merge pull request jax-ml#23247 from kaixih:sliding_window_attn PiperOrigin-RevId: 676079831 commit 57a4b76 Author: Jake VanderPlas <jakevdp@google.com> Date: Wed Sep 18 11:59:00 2024 -0700 Improve documentation for jnp.digitize commit bef36c4 Author: Peter Hawkins <phawkins@google.com> Date: Wed Sep 18 18:57:03 2024 +0000 Add Python 3.13 wheels to changelog. commit cd04d0f Merge: 016c499 c756d9b Author: jax authors <google-ml-automation@google.com> Date: Wed Sep 18 10:00:03 2024 -0700 Merge pull request jax-ml#23726 from hawkinsp:debug PiperOrigin-RevId: 676030839 commit 016c499 Author: Sergei Lebedev <slebedev@google.com> Date: Wed Sep 18 09:56:44 2024 -0700 Removed leftover usages of GPUGridSpec from Pallas Mosaic GPU tests PiperOrigin-RevId: 676029854 commit 9dd363d Author: Luke Baumann <lukebaumann@google.com> Date: Wed Sep 18 09:28:25 2024 -0700 Export `jax.lib.xla_extension.ifrt_programs`. PiperOrigin-RevId: 676020419 commit e27f1e9 Author: jax authors <google-ml-automation@google.com> Date: Wed Sep 18 09:03:55 2024 -0700 Change Python version 3.13.0rc2 to 3.13.0-rc.2. The value is taken from [the versions manifest](https://raw.githubusercontent.com/actions/python-versions/main/versions-manifest.json). PiperOrigin-RevId: 676012255 commit 442e863 Author: Sergei Lebedev <slebedev@google.com> Date: Wed Sep 18 08:56:49 2024 -0700 Added a missing branch to `mgpu.FragmentedArray.astype` Previously, an unsupported cast produced a `NameError` instead. PiperOrigin-RevId: 676010161 commit 6236b8f Merge: 826843a 1cc9661 Author: jax authors <google-ml-automation@google.com> Date: Wed Sep 18 08:57:38 2024 -0700 Merge pull request jax-ml#23667 from dfm:always-lower-jnp-dot-to-dot-general PiperOrigin-RevId: 676010154 commit c756d9b Author: Peter Hawkins <phawkins@google.com> Date: Wed Sep 18 15:44:45 2024 +0000 Fix error in debugger tests that is showing up in CI. I'm unsure why this started happening now, but sometimes we get an invalid offset for a frame. Be tolerant of that case. commit 826843a Merge: c191bbc 922e652 Author: jax authors <google-ml-automation@google.com> Date: Wed Sep 18 08:42:39 2024 -0700 Merge pull request jax-ml#23723 from hawkinsp:setuptools PiperOrigin-RevId: 676005613 commit c191bbc Author: Yash Katariya <yashkatariya@google.com> Date: Wed Sep 18 08:40:30 2024 -0700 Make `debug.print` work with static args. Fixes: jax-ml#23600 PiperOrigin-RevId: 676005582 commit 1cc9661 Author: Dan Foreman-Mackey <danfm@google.com> Date: Mon Sep 16 14:18:29 2024 -0400 Unconditionally lower jnp.dot to lax.dot_general. jax-ml#16721 added a condition to lower calls to `jnp.dot` with scalar inputs to `lax.mul` instead of `lax.dot_general`. AFAICT, jax-ml#16826 fixed the issue that this was solving, so this condition should no longer be necessary. Removing this condition simplifies the addition of new arguments to `dot` and `dot_general`, including the `algorithm` parameter that I am currently working on in jax-ml#23574, so now seemed like a good time to remove it! commit 922e652 Author: Peter Hawkins <phawkins@google.com> Date: Wed Sep 18 15:17:49 2024 +0000 Replace plat-name with plat_name. The former seems to elicit a deprecation warning from setuptools recently. commit 69ba060 Author: Dan Foreman-Mackey <danfm@google.com> Date: Wed Sep 18 07:40:58 2024 -0700 Reverts e15ec1e PiperOrigin-RevId: 675987338 commit 44a7f04 Merge: 0a29696 2834c13 Author: jax authors <google-ml-automation@google.com> Date: Wed Sep 18 07:31:00 2024 -0700 Merge pull request jax-ml#23708 from jakevdp:sort-complex PiperOrigin-RevId: 675983957 commit 0a29696 Merge: e15ec1e 73c38cb Author: jax authors <google-ml-automation@google.com> Date: Wed Sep 18 07:08:24 2024 -0700 Merge pull request jax-ml#23698 from dfm:dev-clang-warning PiperOrigin-RevId: 675977448 commit 2834c13 Author: Jake VanderPlas <jakevdp@google.com> Date: Tue Sep 17 15:32:25 2024 -0700 jnp.sort_complex: fix output for N-dimensional inputs commit 73c38cb Author: Dan Foreman-Mackey <danfm@google.com> Date: Tue Sep 17 14:00:21 2024 -0400 Add a note to the developer docs making it clear that clang is the only toolchain that is actively supported for source compilation. As discussed in jax-ml#23687 commit e15ec1e Merge: 48d8fce 3f2bc9b Author: jax authors <google-ml-automation@google.com> Date: Wed Sep 18 06:56:28 2024 -0700 Merge pull request jax-ml#23261 from joaospinto:stablehlo.tan PiperOrigin-RevId: 675973798 commit 48d8fce Merge: 4e6f690 2714469 Author: jax authors <google-ml-automation@google.com> Date: Wed Sep 18 06:54:28 2024 -0700 Merge pull request jax-ml#23563 from rajasekharporeddy:testbranch1 PiperOrigin-RevId: 675973225 commit 4e6f690 Merge: b7c91e9 611ad63 Author: jax authors <google-ml-automation@google.com> Date: Wed Sep 18 06:35:15 2024 -0700 Merge pull request jax-ml#23653 from apaszke:torchsaic PiperOrigin-RevId: 675967844 commit b7c91e9 Author: Sergei Lebedev <slebedev@google.com> Date: Wed Sep 18 06:22:14 2024 -0700 Lookup `shape` and `dtype` directly on `state.AbstractRef` instead of going through `inner_aval` This is just a cleanup. No behavior changes are expected. PiperOrigin-RevId: 675964703 commit 611ad63 Author: Adam Paszke <adam.paszke@gmail.com> Date: Fri Sep 6 16:09:58 2024 +0000 Add basic PyTorch integration for Mosaic GPU We have already had most of the relevant pieces and we only needed to connect them together. The most sensitive change is perhaps that I needed to expose one more symbol from the XLA GPU plugin, but I don't think it should be a problem. commit e903369 Author: Sergei Lebedev <slebedev@google.com> Date: Wed Sep 18 05:25:37 2024 -0700 Pulled `scratch_shapes` into `GridSpec` It is supported by Mosaic TPU and Mosaic GPU and unsupported by Triton. PiperOrigin-RevId: 675950199 commit 2714469 Author: rajasekharporeddy <rajasekharp@google.com> Date: Wed Sep 18 17:06:28 2024 +0530 Deprecate passing NdArrays with ndim != 1 and non-arraylike inputs to jnp.trim_zeros commit b904599 Author: Sergei Lebedev <slebedev@google.com> Date: Wed Sep 18 04:23:25 2024 -0700 `pl.debug_print` no longer restricts values to be scalars This allows printing arrays on Triton and soon on Mosaic GPU. PiperOrigin-RevId: 675935666 commit 988ed2b Author: jax authors <google-ml-automation@google.com> Date: Tue Sep 17 21:09:26 2024 -0700 Add support for SMEM windows in Pallas custom pipeline. PiperOrigin-RevId: 675822640 commit f79d85b Merge: 1b74cfd cc28d63 Author: Keshav <keshavb@nvidia.com> Date: Tue Sep 17 18:58:33 2024 -0700 Merge remote-tracking branch 'upstream/main' into disable_remat_pass commit cc28d63 Merge: 8bcdb12 9d3762b Author: jax authors <google-ml-automation@google.com> Date: Tue Sep 17 17:36:36 2024 -0700 Merge pull request jax-ml#23682 from sharadmv:pallas-async-docs PiperOrigin-RevId: 675770723 commit 1b74cfd Author: Keshav <keshavb@nvidia.com> Date: Tue Sep 17 17:23:30 2024 -0700 disable remat hlo pass by default commit 8bcdb12 Author: jax authors <google-ml-automation@google.com> Date: Tue Sep 17 16:50:55 2024 -0700 Add CI jobs for python 3.13.0rc2. PiperOrigin-RevId: 675758096 commit 8b5b717 Author: Yash Katariya <yashkatariya@google.com> Date: Tue Sep 17 16:39:55 2024 -0700 Fix jaxpr equation context propagation in jaxpr equations when `inline=True`. PiperOrigin-RevId: 675754808 commit 86fe463 Author: Parker Schuh <parkers@google.com> Date: Tue Sep 17 16:10:41 2024 -0700 [Take 2] Generalize global jit cpp cache keys so we can add more keys than the current donate_argnums. This allows us to get more cache hits globally. For example: Before: jax.jit(f, out_shardings=s)(arr) jax.jit(f, out_shardings=s)(arr) # cpp cache miss After: jax.jit(f, out_shardings=s)(arr) jax.jit(f, out_shardings=s)(arr) # cpp cache hit Reverts b615266 PiperOrigin-RevId: 675746175 commit e92a599 Author: Christos Perivolaropoulos <cperivol@google.com> Date: Tue Sep 17 15:26:42 2024 -0700 [mosaic_gpu] Better error message for misaligned tma_transpose with dtype. PiperOrigin-RevId: 675731295 commit 7864648 Merge: 3f2c58b 83a7555 Author: jax authors <google-ml-automation@google.com> Date: Tue Sep 17 15:12:50 2024 -0700 Merge pull request jax-ml#23679 from selamw1:docstring_sort_complex PiperOrigin-RevId: 675726527 commit 83a7555 Author: selamw1 <selamw@google.com> Date: Mon Sep 16 16:47:52 2024 -0700 docstring_sort_complex_added input_array_modified commit 9d3762b Author: Sharad Vikram <sharad.vikram@gmail.com> Date: Mon Sep 16 19:18:22 2024 -0700 [Pallas] Add design note for async ops on TPU commit 3f2bc9b Author: Joao Sousa-Pinto <joaospinto@gmail.com> Date: Mon Aug 26 17:25:16 2024 -0700 Lower tan to StableHLO instead of CHLO. Fixes jax-ml#23259 commit 541b3a3 Author: kaixih <kaixih@nvidia.com> Date: Mon Aug 26 17:32:38 2024 +0000 New feature
JAX APIs doesn't accept non-arraylike inputs, but
jnp.trim_zeros
works with Python list inputs.Additionally, the behavior of
jnp.trim_zeros
is inconsistent withnp.trim_zeros
forNdArray
inputs withndim != 1
.