From 62b2289bc629cc324cb1d82d103ea772372f62ef Mon Sep 17 00:00:00 2001 From: Luke Prince Date: Wed, 8 Nov 2023 12:18:52 +0000 Subject: [PATCH] precommit --- jax_scaled_arithmetics/interpreters/autoscale.py | 5 +++-- jax_scaled_arithmetics/lax/__init__.py | 1 + 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/jax_scaled_arithmetics/interpreters/autoscale.py b/jax_scaled_arithmetics/interpreters/autoscale.py index 6d003d9..65b2dc9 100644 --- a/jax_scaled_arithmetics/interpreters/autoscale.py +++ b/jax_scaled_arithmetics/interpreters/autoscale.py @@ -1,12 +1,13 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. +from functools import wraps + import jax import numpy as np - from jax import core from jax._src.util import safe_map + from ..lax import scaled_ops_registry -from functools import wraps def autoscale(fun): diff --git a/jax_scaled_arithmetics/lax/__init__.py b/jax_scaled_arithmetics/lax/__init__.py index 1fdc265..5d97aa8 100644 --- a/jax_scaled_arithmetics/lax/__init__.py +++ b/jax_scaled_arithmetics/lax/__init__.py @@ -1,6 +1,7 @@ # Copyright (c) 2023 Graphcore Ltd. All rights reserved. from jax import lax + from .scaled_ops import * scaled_ops_registry = {lax.mul_p: scaled_mul}