Skip to content

Commit

Permalink
register scaled op as decorator
Browse files Browse the repository at this point in the history
  • Loading branch information
lyprince committed Nov 9, 2023
1 parent 8352e80 commit 59c43e6
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 4 deletions.
14 changes: 12 additions & 2 deletions jax_scaled_arithmetics/core/interpreters.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,18 @@
_scaled_ops_registry = {}


def register_scaled_op(lax_func, scaled_func):
_scaled_ops_registry[lax_func] = scaled_func
def get_lax_prim(scaled_func):
try:
op = getattr(jax.lax, scaled_func.__name__.replace("scaled_", ""))
except AttributeError:
print(f"Could not find corresponding jax.lax primitive for {scaled_func.__name__}")
raise
return op


def register_scaled_op(scaled_func):
lax_prim = get_lax_prim(scaled_func)
_scaled_ops_registry[lax_prim] = scaled_func


def autoscale(fun):
Expand Down
3 changes: 1 addition & 2 deletions jax_scaled_arithmetics/lax/scaled_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@
from jax_scaled_arithmetics.core import ScaledArray


@core.register_scaled_op
def scaled_mul_p(A: ScaledArray, B: ScaledArray) -> ScaledArray:
return ScaledArray(A.data * B.data, A.scale * B.scale)


core.register_scaled_op(lax.mul_p, scaled_mul_p)

__all__ = ["scaled_mul_p"]

0 comments on commit 59c43e6

Please sign in to comment.