This repository has been archived by the owner on May 16, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
test.py
57 lines (45 loc) · 1.59 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import unittest
import numpy as np
import tensorflow as tf
import jax
import jax.numpy as jnp
from flax.training.train_state import TrainState
import optax
import orbax.checkpoint as ocp
import chex
from parameterized import parameterized
import models
# Ensure TF does not see GPU and grab all GPU memory.
tf.config.set_visible_devices([], device_type='GPU')
def model_outputs(model_name):
flax_model = getattr(models, model_name)()
empty_state = TrainState.create(
apply_fn=flax_model.apply,
params=flax_model.init(jax.random.PRNGKey(0), jnp.zeros((1, 224, 224, 3))),
tx=optax.set_to_zero(),
)
train_state = ocp.PyTreeCheckpointer().restore(f"weights/{model_name}", item=empty_state)
del empty_state
x = np.random.uniform(high=1.0, size=(100, 224, 224, 3))
logits = train_state.apply_fn(train_state.params, x, train=False)
tf_model = getattr(tf.keras.applications, model_name)()
tf_logits = tf_model(x, training=False).numpy()
return jnp.argmax(logits, axis=-1), jnp.argmax(tf_logits, axis=-1)
class TestModels(unittest.TestCase):
@parameterized.expand([
("DenseNet121"),
("DenseNet169"),
("DenseNet201"),
# ("InceptionV3"),
("MobileNetV2"),
# ("ResNetRS50"),
# ("ResNetRS101"),
# ("ResNetRS152"),
# ("ResNetRS270"),
# ("ResNetRS350"),
# ("ResNetRS420"),
])
def test_Model(self, model_name):
chex.assert_trees_all_close(*model_outputs(model_name))
if __name__ == "__main__":
unittest.main()