-
Notifications
You must be signed in to change notification settings - Fork 645
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
14 changed files
with
2,505 additions
and
0 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
|
||
from functools import partial | ||
from typing import Tuple | ||
|
||
|
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import jax | ||
import numpy as np | ||
import pytest | ||
|
||
from flax.experimental import nnx | ||
|
||
|
||
class TestContainers: | ||
|
||
def test_node_idenpotence(self): | ||
x = nnx.Node(1) | ||
x = nnx.Node(x) | ||
|
||
assert isinstance(x, nnx.Node) | ||
|
||
def test_variable_idenpotence(self): | ||
x = nnx.Variable(1) | ||
x = nnx.Variable(x) | ||
|
||
assert isinstance(x, nnx.Variable) | ||
assert x.value == 1 | ||
|
||
def test_variable_cannot_change_collection(self): | ||
x = nnx.Param(1) | ||
|
||
with pytest.raises(ValueError, match="is not compatible with return type"): | ||
x = nnx.BatchStat(x) | ||
|
||
def test_container_cannot_change_type(self): | ||
x = nnx.Variable(1) | ||
|
||
with pytest.raises(ValueError, match="is not compatible with return type"): | ||
x = nnx.Node(x) | ||
|
||
x = nnx.Node(2) | ||
|
||
with pytest.raises(ValueError, match="is not compatible with return type"): | ||
x = nnx.Variable(x) | ||
|
||
def test_static_is_empty(self): | ||
leaves = jax.tree_util.tree_leaves(nnx.Static(1)) | ||
|
||
assert len(leaves) == 0 | ||
|
||
def test_static_empty_pytree(self): | ||
static = nnx.Static(2) | ||
|
||
static = jax.tree_map(lambda x: x + 1, static) | ||
|
||
assert static.value == 2 | ||
|
||
def test_static_array_not_jitable(self): | ||
@jax.jit | ||
def f(x): | ||
return x | ||
|
||
# first time you don't get an error due to a bug in jax | ||
f(nnx.Static(np.random.uniform(size=(10, 10)))) | ||
|
||
with pytest.raises(ValueError): | ||
f(nnx.Static(np.random.uniform(size=(10, 10)))) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
from typing import Any | ||
|
||
import jax | ||
import numpy as np | ||
import pytest | ||
|
||
from flax.experimental import nnx | ||
from flax.experimental.nnx.nnx.contextlib import _stable_hash | ||
|
||
|
||
class TestContext: | ||
|
||
def test_hash(self): | ||
_hash = _stable_hash("hi") | ||
assert isinstance(_hash, int) | ||
|
||
def test_rng_stream(self): | ||
key0 = jax.random.PRNGKey(0) | ||
ctx = nnx.context(key0) | ||
assert ctx._rngs["params"].count == 0 | ||
|
||
key1 = ctx.make_rng("params") | ||
assert ctx._rngs["params"].count == 1 | ||
assert ctx._rngs["params"].key is key0 | ||
assert not np.equal(key0, key1).all() | ||
|
||
key2 = ctx.make_rng("params") | ||
assert ctx._rngs["params"].count == 2 | ||
assert ctx._rngs["params"].key is key0 | ||
assert not np.equal(key1, key2).all() | ||
|
||
def test_rng_fork(self): | ||
key0 = jax.random.PRNGKey(0) | ||
ctx1 = nnx.context(key0) | ||
ctx2 = ctx1.partition().merge() | ||
|
||
assert ctx2._rngs["params"].count == 0 | ||
assert ctx2._rngs["params"].count_path == (0,) | ||
|
||
key1 = ctx1.make_rng("params") | ||
key2 = ctx2.make_rng("params") | ||
|
||
assert not np.equal(key1, key2).all() | ||
|
||
def test_rng_trace_level_constraints(self): | ||
ctx = nnx.context(0) | ||
|
||
@jax.jit | ||
def f(): | ||
with pytest.raises( | ||
nnx.TraceContextError, | ||
match="Cannot use Context from a different trace level", | ||
): | ||
ctx.make_rng("params") | ||
|
||
f() | ||
|
||
@jax.jit | ||
def f(): | ||
with pytest.raises( | ||
nnx.TraceContextError, | ||
match="Cannot use Context from a different trace level", | ||
): | ||
ctx.partition() | ||
|
||
f() | ||
|
||
ctx1: Any = None | ||
|
||
@jax.jit | ||
def g(): | ||
nonlocal ctx1 | ||
ctx1 = nnx.context(1) | ||
|
||
g() | ||
|
||
assert isinstance(ctx1, nnx.Context) | ||
with pytest.raises( | ||
nnx.TraceContextError, | ||
match="Cannot use Context from a different trace level", | ||
): | ||
ctx1.make_rng("params") | ||
|
||
def test_partition_merge(self): | ||
ctx = nnx.context(dropout=0) | ||
|
||
keys, ctxdef = ctx.partition() | ||
|
||
assert "dropout" in keys | ||
assert ctxdef._rng_counts == (("dropout", (0,)),) | ||
|
||
ctx2 = ctxdef.merge(keys) | ||
|
||
key1 = ctx.make_rng("dropout") | ||
key2 = ctx2.make_rng("dropout") | ||
assert not np.equal(key1, key2).all() | ||
|
||
ctx3 = ctxdef.merge(keys) | ||
key3 = ctx3.make_rng("dropout") | ||
assert np.equal(key2, key3).all() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import jax | ||
import jax.numpy as jnp | ||
import optax | ||
|
||
from flax.experimental import nnx | ||
|
||
|
||
class TestHelpers: | ||
|
||
def test_train_state(self): | ||
m = nnx.Dict(a=nnx.Param(1), b=nnx.BatchStat(2)) | ||
|
||
(params, batch_stats), moduledef = m.partition(nnx.Param, nnx.BatchStat) | ||
|
||
state = nnx.TrainState( | ||
moduledef, | ||
params=params, | ||
tx=optax.sgd(1.0), | ||
batch_stats=batch_stats, | ||
other=nnx.Node(100), | ||
int=200, | ||
static=nnx.Static(300), | ||
) | ||
|
||
leaves = jax.tree_util.tree_leaves(state) | ||
|
||
assert 1 in leaves | ||
assert 2 in leaves | ||
assert 100 in leaves | ||
assert 200 not in leaves | ||
assert 300 not in leaves | ||
|
||
def test_train_state_methods(self): | ||
class Foo(nnx.Module): | ||
|
||
def __init__(self, *, ctx: nnx.Context): | ||
self.linear = nnx.Linear(2, 4, ctx=ctx) | ||
self.batch_norm = nnx.BatchNorm(4, ctx=ctx) | ||
|
||
def __call__(self, x: jax.Array, train: bool) -> jax.Array: | ||
x = self.linear(x) | ||
x = self.batch_norm(x, use_running_average=not train) | ||
return x | ||
|
||
module = Foo(ctx=nnx.context(0)) | ||
(params, batch_stats), moduledef = module.partition(nnx.Param, nnx.BatchStat) | ||
|
||
state = nnx.TrainState( | ||
moduledef, | ||
params=params, | ||
tx=optax.sgd(1.0), | ||
batch_stats=batch_stats, | ||
) | ||
|
||
x = jax.numpy.ones((1, 2)) | ||
y, _updates = state.apply("params", "batch_stats")(x, train=True) | ||
|
||
assert y.shape == (1, 4) | ||
|
||
# fake gradient | ||
grads = jax.tree_map(jnp.ones_like, state.params) | ||
# test apply_gradients | ||
state = state.apply_gradients(grads) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
import copy | ||
|
||
from flax.experimental.nnx.nnx import ids | ||
|
||
|
||
class TestIds: | ||
|
||
def test_hashable(self): | ||
id1 = ids.uuid() | ||
id2 = ids.uuid() | ||
assert id1 == id1 | ||
assert id1 != id2 | ||
assert hash(id1) != hash(id2) | ||
id1c = copy.copy(id1) | ||
id1dc = copy.deepcopy(id1) | ||
assert hash(id1) != hash(id1c) | ||
assert hash(id1) != hash(id1dc) |
Oops, something went wrong.