Skip to content

Commit

Permalink
Add jit unit test to autoscale decorator. (#12)
Browse files Browse the repository at this point in the history
Making sure we cover eager + jit modes in our testing.
  • Loading branch information
balancap authored Nov 10, 2023
1 parent 16fc32e commit c17877b
Showing 1 changed file with 8 additions and 4 deletions.
12 changes: 8 additions & 4 deletions tests/core/test_interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,13 @@


class AutoScaleInterpreterTests(chex.TestCase):
def test__identity(self):
@chex.variants(with_jit=True, without_jit=True)
def test__scaled_identity_function(self):
def func(x):
return x

asfunc = autoscale(func)
# Autoscale + (optional) jitting.
asfunc = self.variant(autoscale(func))

scaled_inputs = scaled_array([1.0, 2.0], 1, dtype=np.float32)
scaled_outputs = asfunc(scaled_inputs)
Expand All @@ -31,11 +33,13 @@ def func(x):
assert jaxpr.outvars[0].aval.shape == expected.shape
assert jaxpr.outvars[1].aval.shape == ()

def test__mul(self):
@chex.variants(with_jit=True, without_jit=True)
def test__scaled_mul_function(self):
def func(x, y):
return x * y

asfunc = autoscale(func)
# Autoscale + (optional) jitting.
asfunc = self.variant(autoscale(func))

x = scaled_array([-2.0, 2.0], 0.5, dtype=np.float32)
y = scaled_array([1.5, 1.5], 2, dtype=np.float32)
Expand Down

0 comments on commit c17877b

Please sign in to comment.