Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timm backbones to mmdet models #848

Open
wants to merge 100 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
100 commits
Select commit Hold shift + click to select a range
79b8634
adding timm support
ai-fast-track May 13, 2021
9e19149
fixed argument names
ai-fast-track May 13, 2021
69bbc66
added imports
ai-fast-track May 13, 2021
fc1eb42
use ice_mobilenetv3_large_100 method name
ai-fast-track May 13, 2021
81efa63
add BACKBONES import
ai-fast-track May 13, 2021
4741230
add model attribute
ai-fast-track May 13, 2021
da65157
add torch.nn import
ai-fast-track May 13, 2021
ed29f12
fix path
ai-fast-track May 13, 2021
471ac1a
fixed file name
ai-fast-track May 13, 2021
7b76456
setweights_path to None for timm backbones
ai-fast-track May 13, 2021
f4a75f3
applied black
ai-fast-track May 13, 2021
7704a4f
added __call__
ai-fast-track May 13, 2021
c2ac87d
added a super class: MMDetTimmBackbone
ai-fast-track May 14, 2021
d27353b
added resnet backbones
ai-fast-track May 14, 2021
f00c7c3
added all the backbones
ai-fast-track May 14, 2021
47f8def
added support for all timm mobilenetnet backbones
ai-fast-track May 14, 2021
91b584d
move common code in a separate module
ai-fast-track May 14, 2021
34a4e9a
rename file
ai-fast-track May 14, 2021
207e544
added resne(s)t support
ai-fast-track May 14, 2021
706b169
added missing import
ai-fast-track May 14, 2021
bd06064
applied black
ai-fast-track May 14, 2021
ce8019a
fixed names and added imports
ai-fast-track May 14, 2021
9c8cb43
added missing imports
ai-fast-track May 14, 2021
3ec614f
remove unsupported models
ai-fast-track May 14, 2021
8f4375b
fixed typo
ai-fast-track May 14, 2021
0e54cce
remove import
ai-fast-track May 14, 2021
a3e7950
added resnest import
ai-fast-track May 14, 2021
3ad7ac3
added misssing backbone
ai-fast-track May 14, 2021
990d8eb
store backbones in folders by model
ai-fast-track May 15, 2021
925f916
expose separate model families
ai-fast-track May 15, 2021
71917fb
added mmdet folder
ai-fast-track May 15, 2021
42db30c
wrap backbone dict in ConfigDict
ai-fast-track May 17, 2021
28a1ed7
applied black
ai-fast-track May 17, 2021
c56558a
added param_groups support
ai-fast-track May 17, 2021
493bbff
set different param_groups for mmdet and timm
ai-fast-track May 18, 2021
6df61bf
fixed name
ai-fast-track May 18, 2021
e0d3967
add ssd stuff
ai-fast-track May 18, 2021
1620e1a
applied black
ai-fast-track May 18, 2021
406eca6
add soft_dependencies
ai-fast-track May 18, 2021
ca8a758
added pip install timm
ai-fast-track May 18, 2021
a895502
remove commented code
ai-fast-track May 18, 2021
1d42b9e
fix mmdet tests
ai-fast-track May 18, 2021
5219106
import SoftDependencies
ai-fast-track May 18, 2021
a334e87
removed unnecessary import
ai-fast-track May 18, 2021
2315a1e
added imports
ai-fast-track May 18, 2021
4504d62
move timm installation up
ai-fast-track May 18, 2021
057c621
set pretrained attribute in cfg.model
ai-fast-track May 26, 2021
52877df
added weight_path for mmdet backbones
ai-fast-track May 26, 2021
9b39f3d
removed pretrained attribute
ai-fast-track May 26, 2021
7093a13
added init_cfg
ai-fast-track May 26, 2021
750a34f
applied black
ai-fast-track May 26, 2021
0333783
added weights_url
ai-fast-track May 26, 2021
99f0a1d
include downloading timm backbone case
ai-fast-track May 27, 2021
6c458ef
create timm model from model_name
ai-fast-track May 27, 2021
f43ab5d
added BaseMobileNetV3 class
ai-fast-track May 27, 2021
84458d3
missing imports
ai-fast-track May 27, 2021
e7a9df7
only test v3_large_100
ai-fast-track May 27, 2021
e91e8b0
only test v3_large_100
ai-fast-track May 27, 2021
6d993f1
only get the 3 standard layers
ai-fast-track May 28, 2021
7f6260d
added get_feature_channels() method
ai-fast-track May 28, 2021
4f37bc9
return only feature_channels for out_indices
ai-fast-track May 28, 2021
2850e62
added resnet support
ai-fast-track May 28, 2021
17e6654
remove unused methods
ai-fast-track May 28, 2021
c499442
added resnest support
ai-fast-track May 28, 2021
c8afce6
initialize attributes in the base class
ai-fast-track May 28, 2021
9a0b7aa
added mobilenetv3_rw
ai-fast-track May 28, 2021
2149328
applied black
ai-fast-track May 31, 2021
b4c27de
added resnetrs50
ai-fast-track May 31, 2021
659f227
added extra keys
ai-fast-track May 31, 2021
168423d
added extra keys
ai-fast-track May 31, 2021
f83e112
added config_path
ai-fast-track Jun 8, 2021
b6cf0b0
added weights_url attribute
ai-fast-track Jun 8, 2021
88fa3ec
pass config_path argument to parent class
ai-fast-track Jun 9, 2021
f9acc9c
added download_weights()
ai-fast-track Jun 9, 2021
b076aad
set out_indices to (2, 3, 4) like mmdet and torchvision
ai-fast-track Jun 9, 2021
e4b8bfa
remove comma
ai-fast-track Jun 9, 2021
4720564
added default values
ai-fast-track Jun 9, 2021
27dcd78
added weights_url
ai-fast-track Jun 9, 2021
cb27c85
added log message
ai-fast-track Jun 9, 2021
ff51931
remove duplicated keys
ai-fast-track Jun 9, 2021
9ac8585
populare config_path if a user pass None
ai-fast-track Jun 14, 2021
2dd3080
add mmdet base path to config_path and raise errors
ai-fast-track Jun 14, 2021
d7bb82d
added resnest101e
ai-fast-track Jun 14, 2021
f64ab17
added weights_url
ai-fast-track Jun 14, 2021
ae34d3c
rename TIMM to Timm. raise ValueError for invalid values
ai-fast-track Jun 14, 2021
6ee7960
fix upper and lower bands
ai-fast-track Jun 14, 2021
88ee5a0
applied black
ai-fast-track Jun 18, 2021
562e783
handle Timm pretrained backbones
ai-fast-track Jun 18, 2021
ee3c8e1
added fcos config file to samples (for tests)
ai-fast-track Jul 16, 2021
20f7fd5
added tests for mobilnetv3 timm backbone
ai-fast-track Jul 16, 2021
59392fd
applied black
ai-fast-track Jul 16, 2021
65d9c50
added pretrained
ai-fast-track Jul 19, 2021
22c2276
removed comment
ai-fast-track Jul 19, 2021
94be54b
add utils module
ai-fast-track Jul 21, 2021
74c769b
added more tests
ai-fast-track Jul 21, 2021
ae13e67
added resnetrs50 to tests
ai-fast-track Jul 23, 2021
638ec6b
Created using Colab
ai-fast-track Jul 24, 2021
2917eb1
removed duplicate
ai-fast-track Jul 24, 2021
1332af4
added timm_mmdet notebook + its doc
ai-fast-track Jul 24, 2021
579c3fe
added efficientnet support
ai-fast-track Jul 27, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/ci-all-testing.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ jobs:
pip install torch=="1.8.1+cpu" torchvision=="0.9.1+cpu" -f https://download.pytorch.org/whl/torch_stable.html
pip install mmcv-full=="1.3.2+torch.1.8.0+cpu" -f https://download.openmmlab.com/mmcv/dist/index.html --use-deprecated=legacy-resolver
pip install mmdet==2.12.0 --upgrade
pip install timm
pip install -e ".[all,dev]"
pip install yolov5-icevision --upgrade

Expand Down
1 change: 1 addition & 0 deletions docs/mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ nav:
- Custom Parser: custom_parser.md
- Inference: inference.md
- Other Tutorials:
- MMDet + Timm Models: timm_mmdet_integration.md
- Model Tracking Using Wandb: wandb_efficientdet.md
- How to use negative samples: negative_samples.md
- Fixed Splitter: voc_predefined_splits.md
Expand Down
1 change: 1 addition & 0 deletions icevision/models/mmdet/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from icevision.models.mmdet import common
from icevision.models.mmdet.models import *
from icevision.models.mmdet.utils import *
5 changes: 5 additions & 0 deletions icevision/models/mmdet/backbones/timm/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from icevision.models.mmdet.backbones.timm.common import *
from icevision.models.mmdet.backbones.timm.mobilenet import *
from icevision.models.mmdet.backbones.timm.resnest import *
from icevision.models.mmdet.backbones.timm.resnet import *
from icevision.models.mmdet.backbones.timm.efficientnet import *
32 changes: 32 additions & 0 deletions icevision/models/mmdet/backbones/timm/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
__all__ = ["MMDetTimmBase"]

import torch.nn as nn
from timm.models.registry import *
from typing import Tuple, Collection, List
from torch import Tensor


class MMDetTimmBase(nn.Module):
def __init__(
self,
model_name: str = None,
pretrained: bool = True,
out_indices: Collection[int] = (2, 3, 4),
norm_eval: bool = True,
):

super().__init__()
self.model_name = model_name
self.pretrained = pretrained
self.out_indices = out_indices
self.norm_eval = norm_eval
model_fn = model_entrypoint(self.model_name)
self.model = model_fn(
pretrained=self.pretrained, features_only=True, out_indices=out_indices
)

def init_weights(self, pretrained=None):
pass

def forward(self, x) -> Tuple[Tensor]: # should return a tuple
return tuple(self.model(x))
239 changes: 239 additions & 0 deletions icevision/models/mmdet/backbones/timm/efficientnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
__all__ = [
"EfficientNet_B1",
"EfficientNet_B2",
"EfficientNet_B3",
"EfficientNet_B4",
"EfficientNet_Lite0",
]

from icevision.models.mmdet.backbones.timm.common import *
from mmdet.models.builder import BACKBONES

from typing import Optional, Collection
from torch.nn.modules.batchnorm import _BatchNorm
from typing import Tuple, Collection, List

import timm
from timm.models.mobilenetv3 import *
from timm.models.registry import *

import torch.nn as nn
import torch


class BaseEfficientNet(MMDetTimmBase):
"""
Base class that implements model freezing and forward methods
"""

def __init__(
self,
model_name: str = None,
pretrained: bool = True,
out_indices: Collection[int] = (2, 3, 4),
norm_eval: bool = True,
frozen_stages: int = 1,
frozen_stem: bool = True,
):

super().__init__(
model_name=model_name,
pretrained=pretrained,
out_indices=out_indices,
norm_eval=norm_eval,
)
self.frozen_stages = frozen_stages
self.frozen_stem = frozen_stem

def test_pretrained_weights(self):
# Get model method from the timm registry by model name
model_fn = model_entrypoint(self.model_name)
model = model_fn(pretrained=self.pretrained)
assert torch.equal(self.model.conv_stem.weight, model.conv_stem.weight)

def post_init_setup(self):
self.freeze(
freeze_stem=self.frozen_stem,
freeze_blocks=self.frozen_stages,
)
# self.test_pretrained_weights()

def freeze(self, freeze_stem: bool = True, freeze_blocks: int = 1):
"Optionally freeze the stem and/or Inverted Residual blocks of the model"
if 0 > freeze_blocks > 8:
raise ValueError("freeze_blocks values must between 0 and 7 included")

m = self.model

# Stem freezing logic
if freeze_stem:
for l in [m.conv_stem, m.bn1]:
l.eval()
for param in l.parameters():
param.requires_grad = False

# `freeze_blocks=1` freezes the first block, and so on
for i, block in enumerate(m.blocks, start=1):
if i > freeze_blocks:
break
else:
block.eval()
for param in block.parameters():
param.requires_grad = False

def train(self, mode=True):
"Convert the model to training mode while optionally freezing BatchNorm"
super(BaseEfficientNet, self).train(mode)
self.freeze(
freeze_stem=self.frozen_stem,
freeze_blocks=self.frozen_stages,
)
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()


@BACKBONES.register_module(force=True)
class EfficientNet_B0(BaseEfficientNet):
def __init__(
self,
pretrained: bool = True,
out_indices: Collection[int] = (2, 3, 4),
norm_eval: bool = True,
frozen_stages: int = 1,
frozen_stem: bool = True,
):
"EfficientNet B0"
super().__init__(
model_name="efficientnet_b0",
pretrained=pretrained,
out_indices=out_indices,
norm_eval=norm_eval,
frozen_stages=frozen_stages,
frozen_stem=frozen_stem,
)

self.weights_url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b0_ra-3dd342df.pth"

self.post_init_setup()


@BACKBONES.register_module(force=True)
class EfficientNet_B1(BaseEfficientNet):
def __init__(
self,
pretrained: bool = True,
out_indices: Collection[int] = (2, 3, 4),
norm_eval: bool = True,
frozen_stages: int = 1,
frozen_stem: bool = True,
):
"EfficientNet B1"
super().__init__(
model_name="efficientnet_b1",
pretrained=pretrained,
out_indices=out_indices,
norm_eval=norm_eval,
frozen_stages=frozen_stages,
frozen_stem=frozen_stem,
)
self.weights_url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b1-533bc792.pth"

self.post_init_setup()


@BACKBONES.register_module(force=True)
class EfficientNet_B2(BaseEfficientNet):
def __init__(
self,
pretrained: bool = True,
out_indices: Collection[int] = (2, 3, 4),
norm_eval: bool = True,
frozen_stages: int = 1,
frozen_stem: bool = True,
):
"EfficientNet B2"
super().__init__(
model_name="efficientnet_b2",
pretrained=pretrained,
out_indices=out_indices,
norm_eval=norm_eval,
frozen_stages=frozen_stages,
frozen_stem=frozen_stem,
)
self.weights_url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b2_ra-bcdf34b7.pth"

self.post_init_setup()


@BACKBONES.register_module(force=True)
class EfficientNet_B3(BaseEfficientNet):
def __init__(
self,
pretrained: bool = True,
out_indices: Collection[int] = (2, 3, 4),
norm_eval: bool = True,
frozen_stages: int = 1,
frozen_stem: bool = True,
):
"EfficientNet B3"
super().__init__(
model_name="efficientnet_b3",
pretrained=pretrained,
out_indices=out_indices,
norm_eval=norm_eval,
frozen_stages=frozen_stages,
frozen_stem=frozen_stem,
)
self.weights_url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b3_ra2-cf984f9c.pth"

self.post_init_setup()


@BACKBONES.register_module(force=True)
class EfficientNet_B4(BaseEfficientNet):
def __init__(
self,
pretrained: bool = True,
out_indices: Collection[int] = (2, 3, 4),
norm_eval: bool = True,
frozen_stages: int = 1,
frozen_stem: bool = True,
):
"EfficientNet B4"
super().__init__(
model_name="efficientnet_b2",
pretrained=pretrained,
out_indices=out_indices,
norm_eval=norm_eval,
frozen_stages=frozen_stages,
frozen_stem=frozen_stem,
)
self.weights_url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_b4_ra2_320-7eb33cd5.pth"

self.post_init_setup()


@BACKBONES.register_module(force=True)
class EfficientNet_Lite0(BaseEfficientNet):
def __init__(
self,
pretrained: bool = True,
out_indices: Collection[int] = (2, 3, 4),
norm_eval: bool = True,
frozen_stages: int = 1,
frozen_stem: bool = True,
):
"EfficientNet Lite0"
super().__init__(
model_name="efficientnet_lite0",
pretrained=pretrained,
out_indices=out_indices,
norm_eval=norm_eval,
frozen_stages=frozen_stages,
frozen_stem=frozen_stem,
)
self.weights_url = "https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/efficientnet_lite0_ra-37913777.pth"

self.post_init_setup()
Loading