Skip to content

Commit

Permalink
Merge pull request #3344 from chiamp:resnet
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 566395726
  • Loading branch information
Flax Authors committed Sep 18, 2023
2 parents 830d335 + 69b3732 commit ed2b1cf
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions examples/imagenet/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

"""Flax implementation of ResNet V1."""
"""Flax implementation of ResNet V1.5."""

# See issue #620.
# pytype: disable=wrong-arg-count
Expand Down Expand Up @@ -87,7 +87,7 @@ def __call__(self, x):


class ResNet(nn.Module):
"""ResNetV1."""
"""ResNetV1.5."""

stage_sizes: Sequence[int]
block_cls: ModuleDef
Expand Down
12 changes: 6 additions & 6 deletions examples/imagenet/models_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
jax.config.update('jax_disable_most_optimizations', True)


class ResNetV1Test(parameterized.TestCase):
"""Test cases for ResNet v1 model definition."""
class ResNetTest(parameterized.TestCase):
"""Test cases for ResNet v1.5 model definition."""

def test_resnet_v1_model(self):
"""Tests ResNet V1 model definition and output (variables)."""
def test_resnet_model(self):
"""Tests ResNet V1.5 model definition and output (variables)."""
rng = jax.random.key(0)
model_def = models.ResNet50(num_classes=10, dtype=jnp.float32)
variables = model_def.init(rng, jnp.ones((8, 224, 224, 3), jnp.float32))
Expand All @@ -43,8 +43,8 @@ def test_resnet_v1_model(self):
self.assertLen(variables['params'], 19)

@parameterized.product(model=(models.ResNet18, models.ResNet18Local))
def test_resnet_18_v1_model(self, model):
"""Tests ResNet18 V1 model definition and output (variables)."""
def test_resnet_18_model(self, model):
"""Tests ResNet18 V1.5 model definition and output (variables)."""
rng = jax.random.key(0)
model_def = model(num_classes=2, dtype=jnp.float32)
variables = model_def.init(rng, jnp.ones((1, 64, 64, 3), jnp.float32))
Expand Down

0 comments on commit ed2b1cf

Please sign in to comment.