Skip to content

Commit

Permalink
precommit
Browse files Browse the repository at this point in the history
  • Loading branch information
lyprince committed Nov 8, 2023
1 parent 9611bb3 commit 62b2289
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 2 deletions.
5 changes: 3 additions & 2 deletions jax_scaled_arithmetics/interpreters/autoscale.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

from functools import wraps

import jax
import numpy as np

from jax import core
from jax._src.util import safe_map

from ..lax import scaled_ops_registry
from functools import wraps


def autoscale(fun):
Expand Down
1 change: 1 addition & 0 deletions jax_scaled_arithmetics/lax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (c) 2023 Graphcore Ltd. All rights reserved.

from jax import lax

from .scaled_ops import *

scaled_ops_registry = {lax.mul_p: scaled_mul}

0 comments on commit 62b2289

Please sign in to comment.