Utilities for neural signed distance fields in JAX.
sdf_jax
├── discretize.py # utils for dense 2D and 3D grid evaluation of a field
├── examples.py # for debugging: simple analytical SDFs like the sphere
├── hash_encoding.py # Multiresolution Hash Encoding
└── util.py # plotting utils for level-sets from marching cubes
The Multiresolution Hash Encoding in sdf_jax/hash_encoding.py implements the method described in
Instant Neural Graphics Primitives with a Multiresolution Hash Encoding
Thomas Müller, Alex Evans, Christoph Schied, Alexander Keller
ACM Transactions on Graphics (SIGGRAPH), July 2022
Website / Paper / Code / Video / BibTeX
Below is an example of how to wrap the Hash Encoding inside a treex layer:
from sdf_jax import hash_encoding
import jax.numpy as jnp
import jax.random as jrandom
import treex as tx
class HashEmbedding(tx.Module):
theta: jnp.ndarray = tx.Parameter.node()
def __init__(
self,
levels: int=16,
hashmap_size_log2: int=14,
features_per_entry: int=2,
nmin: int=16,
nmax: int=512,
):
self.levels = levels
self.hashmap_size_log2 = hashmap_size_log2
self.features_per_entry = features_per_entry
self.nmin = nmin
self.nmax = nmax
def __call__(self, x):
assert x.ndim == 1
if self.initializing():
hashmap_size = 1 << self.hashmap_size_log2
key = tx.next_key()
self.theta = jrandom.uniform(
key,
(self.levels, hashmap_size, self.features_per_entry),
minval=-0.0001,
maxval=0.0001
)
y = hash_encoding.encode(x, self.theta, self.nmin, self.nmax)
return y.reshape(-1)
x = jnp.ones(3)
emb = HashEmbedding().init(key=42, inputs=x)
print(emb(x).shape) # (32,) which is (levels * features_per_entry,)
To ensure reproducibility, to install this repo and its dev dependencies:
-
Use Poetry. Make sure you have a local installation of Python
>=3.8
(e.g. by runningpyenv local 3.X.X
) and run:poetry install
-
Alternatively, I've also included a
requirements.txt
that was generated from thepyproject.toml
andpoetry.lock
files.