forked from johannesu/NASNet-keras
-
Notifications
You must be signed in to change notification settings - Fork 0
/
nasnet_test.py
62 lines (45 loc) · 1.94 KB
/
nasnet_test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from __future__ import division
import logging
logging.getLogger('tensorflow').setLevel(logging.WARNING)
import unittest
import nasnet
import keras.backend as K
class ModelTest(unittest.TestCase):
def check_parameter_count(self, model, target_in_m):
count = model.count_params() / 10 ** 6
msg = '{} params #{}M suppose to be #{}M.'.format(model.name, count, target_in_m)
self.assertAlmostEqual(target_in_m, count, msg=msg, delta=0.1)
def check_penultimate_shape(self, model, target_shape):
layer = model.get_layer('last_relu')
if K.image_data_format() == 'channels_first':
shape = layer.input_shape[2:]
else:
shape = layer.input_shape[1:3]
self.assertEqual(shape, target_shape)
def test_cifar_10(self):
model = nasnet.cifar10()
self.check_parameter_count(model, 3.3)
self.check_penultimate_shape(model, (8, 8))
aux_model = nasnet.cifar10(aux_output=True)
self.check_parameter_count(aux_model, 4.9)
self.assertEqual(len(aux_model.output), 2)
def test_mobile(self):
model = nasnet.mobile()
self.check_parameter_count(model, 5.3)
self.check_penultimate_shape(model, (7, 7))
aux_model = nasnet.mobile(aux_output=True)
self.check_parameter_count(aux_model, 7.7)
self.assertEqual(len(aux_model.output), 2)
def test_large(self):
model = nasnet.large()
self.check_parameter_count(model, 88.9)
self.check_penultimate_shape(model, (11, 11))
aux_model = nasnet.large(aux_output=True)
self.check_parameter_count(aux_model, 93.5)
self.assertEqual(len(aux_model.output), 2)
def test_channel_first(self):
K.set_image_data_format('channels_first')
model = nasnet.cifar10()
self.check_parameter_count(model, 3.3)
self.check_penultimate_shape(model, (8, 8))
K.set_image_data_format('channels_last')