Skip to content

Commit

Permalink
move tests
Browse files Browse the repository at this point in the history
  • Loading branch information
cgarciae committed Jul 24, 2023
1 parent b152c94 commit 27334f4
Show file tree
Hide file tree
Showing 14 changed files with 2,505 additions and 0 deletions.
561 changes: 561 additions & 0 deletions flax/experimental/nnx/docs/quick_start.ipynb

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions flax/experimental/nnx/ideas/nnx_example.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@

from functools import partial
from typing import Tuple

Expand Down
Empty file.
61 changes: 61 additions & 0 deletions flax/experimental/nnx/tests/test_containers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import jax
import numpy as np
import pytest

from flax.experimental import nnx


class TestContainers:

def test_node_idenpotence(self):
x = nnx.Node(1)
x = nnx.Node(x)

assert isinstance(x, nnx.Node)

def test_variable_idenpotence(self):
x = nnx.Variable(1)
x = nnx.Variable(x)

assert isinstance(x, nnx.Variable)
assert x.value == 1

def test_variable_cannot_change_collection(self):
x = nnx.Param(1)

with pytest.raises(ValueError, match="is not compatible with return type"):
x = nnx.BatchStat(x)

def test_container_cannot_change_type(self):
x = nnx.Variable(1)

with pytest.raises(ValueError, match="is not compatible with return type"):
x = nnx.Node(x)

x = nnx.Node(2)

with pytest.raises(ValueError, match="is not compatible with return type"):
x = nnx.Variable(x)

def test_static_is_empty(self):
leaves = jax.tree_util.tree_leaves(nnx.Static(1))

assert len(leaves) == 0

def test_static_empty_pytree(self):
static = nnx.Static(2)

static = jax.tree_map(lambda x: x + 1, static)

assert static.value == 2

def test_static_array_not_jitable(self):
@jax.jit
def f(x):
return x

# first time you don't get an error due to a bug in jax
f(nnx.Static(np.random.uniform(size=(10, 10))))

with pytest.raises(ValueError):
f(nnx.Static(np.random.uniform(size=(10, 10))))
100 changes: 100 additions & 0 deletions flax/experimental/nnx/tests/test_context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
from typing import Any

import jax
import numpy as np
import pytest

from flax.experimental import nnx
from flax.experimental.nnx.nnx.contextlib import _stable_hash


class TestContext:

def test_hash(self):
_hash = _stable_hash("hi")
assert isinstance(_hash, int)

def test_rng_stream(self):
key0 = jax.random.PRNGKey(0)
ctx = nnx.context(key0)
assert ctx._rngs["params"].count == 0

key1 = ctx.make_rng("params")
assert ctx._rngs["params"].count == 1
assert ctx._rngs["params"].key is key0
assert not np.equal(key0, key1).all()

key2 = ctx.make_rng("params")
assert ctx._rngs["params"].count == 2
assert ctx._rngs["params"].key is key0
assert not np.equal(key1, key2).all()

def test_rng_fork(self):
key0 = jax.random.PRNGKey(0)
ctx1 = nnx.context(key0)
ctx2 = ctx1.partition().merge()

assert ctx2._rngs["params"].count == 0
assert ctx2._rngs["params"].count_path == (0,)

key1 = ctx1.make_rng("params")
key2 = ctx2.make_rng("params")

assert not np.equal(key1, key2).all()

def test_rng_trace_level_constraints(self):
ctx = nnx.context(0)

@jax.jit
def f():
with pytest.raises(
nnx.TraceContextError,
match="Cannot use Context from a different trace level",
):
ctx.make_rng("params")

f()

@jax.jit
def f():
with pytest.raises(
nnx.TraceContextError,
match="Cannot use Context from a different trace level",
):
ctx.partition()

f()

ctx1: Any = None

@jax.jit
def g():
nonlocal ctx1
ctx1 = nnx.context(1)

g()

assert isinstance(ctx1, nnx.Context)
with pytest.raises(
nnx.TraceContextError,
match="Cannot use Context from a different trace level",
):
ctx1.make_rng("params")

def test_partition_merge(self):
ctx = nnx.context(dropout=0)

keys, ctxdef = ctx.partition()

assert "dropout" in keys
assert ctxdef._rng_counts == (("dropout", (0,)),)

ctx2 = ctxdef.merge(keys)

key1 = ctx.make_rng("dropout")
key2 = ctx2.make_rng("dropout")
assert not np.equal(key1, key2).all()

ctx3 = ctxdef.merge(keys)
key3 = ctx3.make_rng("dropout")
assert np.equal(key2, key3).all()
63 changes: 63 additions & 0 deletions flax/experimental/nnx/tests/test_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import jax
import jax.numpy as jnp
import optax

from flax.experimental import nnx


class TestHelpers:

def test_train_state(self):
m = nnx.Dict(a=nnx.Param(1), b=nnx.BatchStat(2))

(params, batch_stats), moduledef = m.partition(nnx.Param, nnx.BatchStat)

state = nnx.TrainState(
moduledef,
params=params,
tx=optax.sgd(1.0),
batch_stats=batch_stats,
other=nnx.Node(100),
int=200,
static=nnx.Static(300),
)

leaves = jax.tree_util.tree_leaves(state)

assert 1 in leaves
assert 2 in leaves
assert 100 in leaves
assert 200 not in leaves
assert 300 not in leaves

def test_train_state_methods(self):
class Foo(nnx.Module):

def __init__(self, *, ctx: nnx.Context):
self.linear = nnx.Linear(2, 4, ctx=ctx)
self.batch_norm = nnx.BatchNorm(4, ctx=ctx)

def __call__(self, x: jax.Array, train: bool) -> jax.Array:
x = self.linear(x)
x = self.batch_norm(x, use_running_average=not train)
return x

module = Foo(ctx=nnx.context(0))
(params, batch_stats), moduledef = module.partition(nnx.Param, nnx.BatchStat)

state = nnx.TrainState(
moduledef,
params=params,
tx=optax.sgd(1.0),
batch_stats=batch_stats,
)

x = jax.numpy.ones((1, 2))
y, _updates = state.apply("params", "batch_stats")(x, train=True)

assert y.shape == (1, 4)

# fake gradient
grads = jax.tree_map(jnp.ones_like, state.params)
# test apply_gradients
state = state.apply_gradients(grads)
17 changes: 17 additions & 0 deletions flax/experimental/nnx/tests/test_ids.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import copy

from flax.experimental.nnx.nnx import ids


class TestIds:

def test_hashable(self):
id1 = ids.uuid()
id2 = ids.uuid()
assert id1 == id1
assert id1 != id2
assert hash(id1) != hash(id2)
id1c = copy.copy(id1)
id1dc = copy.deepcopy(id1)
assert hash(id1) != hash(id1c)
assert hash(id1) != hash(id1dc)
Loading

0 comments on commit 27334f4

Please sign in to comment.