Skip to content

Fast lookup table within JIT #10475

Answered by YouJiacheng
dionhaefner asked this question in Q&A
Discussion options

You must be logged in to vote

I think the cuckoo hashing can be implemented in JAX, if we can reserve one value as sentinel.
If load factor is not important, a naive hash table can be used as well.
EDIT: due to the birthday paradox, a naive hash table is unusable: it need Ω(n^2) space to contain n elements.
Note that it is very hard to implement a hash table with high update performance in XLA, but it is much easy to implement a static LUT.
Demo:

import operator

import jax
import jax.numpy as jnp
import numpy as np
from jax import lax

def rotl(x, n):
    return (x << n) | (x >> (32 - n))

def xxhash(x, seed):
    x = lax.bitcast_convert_type(x, jnp.uint32)
    prime_1 = np.uint32(0x9E3779B1)
    prime_2 = np.uint32(0…

Replies: 3 comments 12 replies

Comment options

You must be logged in to vote
1 reply
@YouJiacheng
Comment options

Comment options

You must be logged in to vote
10 replies
@dionhaefner
Comment options

@YouJiacheng
Comment options

@YouJiacheng
Comment options

@dionhaefner
Comment options

@YouJiacheng
Comment options

Answer selected by dionhaefner
Comment options

You must be logged in to vote
1 reply
@jakevdp
Comment options

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Category
Q&A
Labels
None yet
4 participants