diff --git a/tests/core/test_interpreter.py b/tests/core/test_interpreter.py index 7249e3f..b906020 100644 --- a/tests/core/test_interpreter.py +++ b/tests/core/test_interpreter.py @@ -10,11 +10,13 @@ class AutoScaleInterpreterTests(chex.TestCase): - def test__identity(self): + @chex.variants(with_jit=True, without_jit=True) + def test__scaled_identity_function(self): def func(x): return x - asfunc = autoscale(func) + # Autoscale + (optional) jitting. + asfunc = self.variant(autoscale(func)) scaled_inputs = scaled_array([1.0, 2.0], 1, dtype=np.float32) scaled_outputs = asfunc(scaled_inputs) @@ -31,11 +33,13 @@ def func(x): assert jaxpr.outvars[0].aval.shape == expected.shape assert jaxpr.outvars[1].aval.shape == () - def test__mul(self): + @chex.variants(with_jit=True, without_jit=True) + def test__scaled_mul_function(self): def func(x, y): return x * y - asfunc = autoscale(func) + # Autoscale + (optional) jitting. + asfunc = self.variant(autoscale(func)) x = scaled_array([-2.0, 2.0], 0.5, dtype=np.float32) y = scaled_array([1.5, 1.5], 2, dtype=np.float32)