Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement JAX pow2_decompose primitive. #100

Merged
merged 1 commit into from
Feb 9, 2024
Merged

Commits on Feb 9, 2024

  1. Implement JAX pow2_decompose primitive.

    The primitive `pow2_decompose` is the core decomposition kernel used everywhere in AutoScale/Scalify,
    meaning it is worth properly formalizing it as a JAX primitive, simplifying the Jaxpr level graph
    and allowing proper custom kernel optimization on different HW platforms (GPU, IPU, TPU, ...).
    
    NOTE: this PR is fixing additional subnormal related bugs, due to inconsistency of jnp.frexp vs Numpy.
    See: jax-ml/jax#19689
    balancap committed Feb 9, 2024
    Configuration menu
    Copy the full SHA
    61d62ab View commit details
    Browse the repository at this point in the history