From c17877b4f2b1a0736928b620c943eccd1629931e Mon Sep 17 00:00:00 2001 From: Paul Balanca Date: Fri, 10 Nov 2023 11:08:21 +0000 Subject: [PATCH] Add jit unit test to autoscale decorator. (#12) Making sure we cover eager + jit modes in our testing. --- tests/core/test_interpreter.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index 7249e3f..b906020 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -10,11 +10,13 @@ class AutoScaleInterpreterTests(chex.TestCase): - def test__identity(self): + @chex.variants(with_jit=True, without_jit=True) + def test__scaled_identity_function(self): def func(x): return x - asfunc = autoscale(func) + # Autoscale + (optional) jitting. + asfunc = self.variant(autoscale(func)) scaled_inputs = scaled_array([1.0, 2.0], 1, dtype=np.float32) scaled_outputs = asfunc(scaled_inputs) @@ -31,11 +33,13 @@ def func(x): assert jaxpr.outvars[0].aval.shape == expected.shape assert jaxpr.outvars[1].aval.shape == () - def test__mul(self): + @chex.variants(with_jit=True, without_jit=True) + def test__scaled_mul_function(self): def func(x, y): return x * y - asfunc = autoscale(func) + # Autoscale + (optional) jitting. + asfunc = self.variant(autoscale(func)) x = scaled_array([-2.0, 2.0], 0.5, dtype=np.float32) y = scaled_array([1.5, 1.5], 2, dtype=np.float32)