Fast lookup table within JIT #10475
-
I have a large lookup table containing allowed values. Within a JIT compiled function, I would like to replace all values that are not present in the lookup table with 0. Something like this: >>> arr = jnp.arange(1000, dtype="int32")
# allowed values
>>> lut = jnp.array([37, 127, 666], dtype="int32")
# only pass through values that are in `lut`
>>> arr = arr.at[~jnp.isin(arr, lut)].set(0)
>>> arr
DeviceArray([0, 0, 0, ..., 0, 37, 0, ..., 127, ..., 666, ...]) This works fine, but has complexity def filter_allowed(arr, lut):
idx = jnp.searchsorted(lut, arr)
return jnp.where(lut[idx] == arr, arr, 0) Ideally, I would like to do it in from jax.experimental.sparse import BCOO
maxval = 10_000
lut_sparse = BCOO((jnp.ones(len(lut), dtype="bool"), lut), shape=(maxval, 1))
arr = arr * lut_sparse[arr] But it fails with: NotImplementedError: sparse rule for gather Does anyone have an idea what else I could try? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 12 replies
-
Unfortunately, there's not any hashtable-like data structure available in XLA, so I think your best bet at the moment will be some kind of Regarding the sparse option, even were we to implement the sparse gather rule (something on my TODO list), under the hood it would be using something like |
Beta Was this translation helpful? Give feedback.
-
I think the cuckoo hashing can be implemented in JAX, if we can reserve one value as sentinel. import operator
import jax
import jax.numpy as jnp
import numpy as np
from jax import lax
def rotl(x, n):
return (x << n) | (x >> (32 - n))
def xxhash(x, seed):
x = lax.bitcast_convert_type(x, jnp.uint32)
prime_1 = np.uint32(0x9E3779B1)
prime_2 = np.uint32(0x85EBCA77)
prime_3 = np.uint32(0xC2B2AE3D)
prime_5 = np.uint32(0x165667B1)
acc = np.uint32(seed) + prime_5
for _ in range(4):
lane = x & 255
acc = acc + lane * prime_5
acc = rotl(acc, 11) * prime_1
x = x >> 8
acc = acc ^ (acc >> 15)
acc = acc * prime_2
acc = acc ^ (acc >> 13)
acc = acc * prime_3
acc = acc ^ (acc >> 16)
return acc
capacity = int(1024 * 1024 * 2)
max_try = 128
def hashing(x, seed): # max capacity is 2 ** 24
z = xxhash(x, seed) >> 8
b1 = z & 255
z = z >> 8
b2 = z & 255
b3 = z >> 8
c1 = (b1 * capacity) >> 24
c2 = (b2 * capacity) >> 16
c3 = (b3 * capacity) >> 8
return c1 + c2 + c3
@jax.jit
def batchIndex(arr, idx):
return jax.vmap(operator.getitem)(arr, idx)
def cuckooHash(xs, seed=1):
hash_stack = jnp.stack([hashing(xs, 0), hashing(xs, seed)], -1)
assignment = jnp.zeros(len(xs), jnp.uint32)
def cond_fun(val):
_, collided, i = val
return collided & (i < max_try)
def body_fun(val):
assignment, _, i = val
hash_assigned = batchIndex(hash_stack, assignment)
count = jnp.bincount(hash_assigned, length=capacity)
collided = count[hash_assigned] > 1
return jnp.where(collided, 1 - assignment, assignment), jnp.any(collided), i + 1
assignment, collided, _ = lax.while_loop(cond_fun, body_fun, (assignment, True, 0))
if collided:
return cuckooHash(xs, seed + 1)
return batchIndex(hash_stack, assignment), seed
vals = jax.random.randint(jax.random.PRNGKey(0), (1024 * 1024,), jnp.uint32(0), jnp.uint32(16777216), jnp.uint32)
vals = jnp.unique(vals)
print(len(vals))
hash_assigned, seed = cuckooHash(vals)
print(seed)
lut = jnp.full(capacity, -1, jnp.uint32).at[hash_assigned].set(vals)
x = jnp.arange(1000, dtype=jnp.uint32)
y = jnp.where((lut[hashing(x, 0)] == x) | (lut[hashing(x, seed)] == x), x, 0)
print(y) |
Beta Was this translation helpful? Give feedback.
-
Is this implementation intended to be merged into lax? |
Beta Was this translation helpful? Give feedback.
I think the cuckoo hashing can be implemented in JAX, if we can reserve one value as sentinel.
If load factor is not important, a naive hash table can be used as well.EDIT: due to the birthday paradox, a naive hash table is unusable: it need Ω(n^2) space to contain n elements.
Note that it is very hard to implement a hash table with high update performance in XLA, but it is much easy to implement a static LUT.
Demo: