Replies: 1 comment 1 reply
-
XLA doesn't have any support for 1-bit variables, but recent versions of JAX do have some experimental suppot for 4-bit integers via the In [1]: import jax.numpy as jnp
In [2]: x = jnp.arange(4, dtype='int4')
In [3]: print(x)
[0 1 2 3] The only way I know of to use bit arrays is via the In [4]: x = jnp.array([1, 0, 1, 1, 0, 0, 1, 0], dtype='uint8')
In [5]: bits = jnp.packbits(x)
In [6]: print(bits)
[178]
In [7]: print(jnp.unpackbits(bits))
[1 0 1 1 0 0 1 0] |
Beta Was this translation helpful? Give feedback.
1 reply
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
-
Can I use JAX to compress 32 boolean variables into one uint32 variable, like C++'s bitset functionality? I want to implement architectures like binary neural networks in JAX.
For example, I have a boolean array of length 256 that I want to compress into a uint32 array of length 8.
Beta Was this translation helpful? Give feedback.
All reactions