diff --git a/examples/imagenet/models.py b/examples/imagenet/models.py index fe423eaa96..b942ea4600 100644 --- a/examples/imagenet/models.py +++ b/examples/imagenet/models.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Flax implementation of ResNet V1.""" +"""Flax implementation of ResNet V1.5.""" # See issue #620. # pytype: disable=wrong-arg-count @@ -87,7 +87,7 @@ def __call__(self, x): class ResNet(nn.Module): - """ResNetV1.""" + """ResNetV1.5.""" stage_sizes: Sequence[int] block_cls: ModuleDef diff --git a/examples/imagenet/models_test.py b/examples/imagenet/models_test.py index 338df46b2b..d135c3e9a6 100644 --- a/examples/imagenet/models_test.py +++ b/examples/imagenet/models_test.py @@ -26,11 +26,11 @@ jax.config.update('jax_disable_most_optimizations', True) -class ResNetV1Test(parameterized.TestCase): - """Test cases for ResNet v1 model definition.""" +class ResNetTest(parameterized.TestCase): + """Test cases for ResNet v1.5 model definition.""" - def test_resnet_v1_model(self): - """Tests ResNet V1 model definition and output (variables).""" + def test_resnet_model(self): + """Tests ResNet V1.5 model definition and output (variables).""" rng = jax.random.key(0) model_def = models.ResNet50(num_classes=10, dtype=jnp.float32) variables = model_def.init(rng, jnp.ones((8, 224, 224, 3), jnp.float32)) @@ -43,8 +43,8 @@ def test_resnet_v1_model(self): self.assertLen(variables['params'], 19) @parameterized.product(model=(models.ResNet18, models.ResNet18Local)) - def test_resnet_18_v1_model(self, model): - """Tests ResNet18 V1 model definition and output (variables).""" + def test_resnet_18_model(self, model): + """Tests ResNet18 V1.5 model definition and output (variables).""" rng = jax.random.key(0) model_def = model(num_classes=2, dtype=jnp.float32) variables = model_def.init(rng, jnp.ones((1, 64, 64, 3), jnp.float32))