Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model: CSPDarknet #195

Merged
merged 5 commits into from
Oct 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## New Features:

- No changes to highlight.
- Add CSPDarknet by `@illian01` in [PR 195](https://github.com/Nota-NetsPresso/netspresso-trainer/pull/195)

## Bug Fixes:

Expand Down
1 change: 1 addition & 0 deletions src/netspresso_trainer/models/backbones/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# from .core import *
from .experimental.darknet import cspdarknet
from .experimental.efficientformer import efficientformer
from .experimental.mobilenetv3 import mobilenetv3_small
from .experimental.mobilevit import mobilevit
Expand Down
150 changes: 150 additions & 0 deletions src/netspresso_trainer/models/backbones/experimental/darknet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,150 @@
"""
Based on the Darknet implementation of Megvii.
https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/models/darknet.py
"""

import torch
from torch import nn

from ...op.custom import ConvLayer, CSPLayer, Focus, SPPBottleneck
from ...utils import BackboneOutput

__all__ = ['cspdarknet']
SUPPORTING_TASK = ['detection']


class CSPDarknet(nn.Module):
def __init__(
self,
task,
dep_mul,
wid_mul,
out_features=("dark3", "dark4", "dark5"),
#depthwise=False,
act_type="silu",
**kwargs
):
super().__init__()
assert out_features, "please provide output features of Darknet"

self.task = task.lower()
self.use_intermediate_features = self.task in ['segmentation', 'detection']

self.out_features = out_features
Conv = ConvLayer

base_channels = int(wid_mul * 64) # 64
base_depth = max(round(dep_mul * 3), 1) # 3

# stem
self.stem = Focus(3, base_channels, ksize=3, act_type=act_type)

# dark2
self.dark2 = nn.Sequential(
Conv(in_channels=base_channels,
out_channels=base_channels * 2,
kernel_size=3,
stride=2,
act_type=act_type),
CSPLayer(
base_channels * 2,
base_channels * 2,
n=base_depth,
#depthwise=depthwise,
act_type=act_type,
),
)

# dark3
self.dark3 = nn.Sequential(
Conv(in_channels=base_channels * 2,
out_channels=base_channels * 4,
kernel_size=3,
stride=2,
act_type=act_type),
CSPLayer(
base_channels * 4,
base_channels * 4,
n=base_depth * 3,
#depthwise=depthwise,
act_type=act_type,
),
)

# dark4
self.dark4 = nn.Sequential(
Conv(in_channels=base_channels * 4,
out_channels=base_channels * 8,
kernel_size=3,
stride=2,
act_type=act_type),
CSPLayer(
base_channels * 8,
base_channels * 8,
n=base_depth * 3,
#depthwise=depthwise,
act_type=act_type,
),
)

# dark5
self.dark5 = nn.Sequential(
Conv(in_channels=base_channels * 8,
out_channels=base_channels * 16,
kernel_size=3,
stride=2,
act_type=act_type),
SPPBottleneck(base_channels * 16, base_channels * 16, act_type=act_type),
CSPLayer(
base_channels * 16,
base_channels * 16,
n=base_depth,
shortcut=False,
#depthwise=depthwise,
act_type=act_type,
),
)

self.avgpool = nn.AdaptiveAvgPool2d(1)

predefined_out_features = {'dark2': base_channels * 2, 'dark3': base_channels * 4,
'dark4': base_channels * 8, 'dark5': base_channels * 16}
self._feature_dim = predefined_out_features['dark5']
self._intermediate_features_dim = [predefined_out_features[out_feature] for out_feature in out_features]

def forward(self, x):
outputs_dict = {}
x = self.stem(x)
outputs_dict["stem"] = x
x = self.dark2(x)
outputs_dict["dark2"] = x
x = self.dark3(x)
outputs_dict["dark3"] = x
x = self.dark4(x)
outputs_dict["dark4"] = x
x = self.dark5(x)
outputs_dict["dark5"] = x

if self.use_intermediate_features:
all_hidden_states = [outputs_dict[out_name] for out_name in self.out_features]
return BackboneOutput(intermediate_features=all_hidden_states)

x = self.avgpool(x)
x = torch.flatten(x, 1)

return BackboneOutput(last_feature=x)

@property
def feature_dim(self):
return self._feature_dim

@property
def intermediate_features_dim(self):
return self._intermediate_features_dim

def task_support(self, task):
return task.lower() in SUPPORTING_TASK


def cspdarknet(task, conf_model_backbone) -> CSPDarknet:
return CSPDarknet(task, **conf_model_backbone)
140 changes: 140 additions & 0 deletions src/netspresso_trainer/models/op/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,3 +489,143 @@ def forward(self, x: Tensor) -> Tensor:

# def __repr__(self):
# return "{}(type={})".format(self.__class__.__name__, self.pool_type)


class Focus(nn.Module):
deepkyu marked this conversation as resolved.
Show resolved Hide resolved
"""Focus width and height information into channel space."""

def __init__(self, in_channels, out_channels, ksize=1, stride=1, act_type="silu"):
super().__init__()
self.conv = ConvLayer(in_channels=in_channels * 4,
out_channels=out_channels,
kernel_size=ksize,
stride=stride,
act_type=act_type)

def forward(self, x):
# shape of x (b,c,w,h) -> y(b,4c,w/2,h/2)
patch_top_left = x[..., ::2, ::2]
patch_top_right = x[..., ::2, 1::2]
patch_bot_left = x[..., 1::2, ::2]
patch_bot_right = x[..., 1::2, 1::2]
x = torch.cat(
(
patch_top_left,
patch_bot_left,
patch_top_right,
patch_bot_right,
),
dim=1,
)
return self.conv(x)


class CSPLayer(nn.Module):
"""C3 in yolov5, CSP Bottleneck with 3 convolutions"""

def __init__(
self,
in_channels,
out_channels,
n=1,
shortcut=True,
expansion=0.5,
#depthwise=False,
act_type="silu",
):
"""
Args:
in_channels (int): input channels.
out_channels (int): output channels.
n (int): number of Bottlenecks. Default value: 1.
"""
# ch_in, ch_out, number, shortcut, groups, expansion
super().__init__()
hidden_channels = int(out_channels * expansion) # hidden channels
self.conv1 = ConvLayer(in_channels=in_channels,
out_channels=hidden_channels,
kernel_size=1,
stride=1, act_type=act_type)
self.conv2 = ConvLayer(in_channels=in_channels,
out_channels=hidden_channels,
kernel_size=1,
stride=1, act_type=act_type)
self.conv3 = ConvLayer(in_channels=2 * hidden_channels,
out_channels=out_channels,
kernel_size=1,
stride=1, act_type=act_type)

block = DarknetBlock

module_list = [
block(
in_channels=hidden_channels,
out_channels=hidden_channels,
shortcut=shortcut,
expansion=1.0,
act_type=act_type
)
for _ in range(n)
]
self.m = nn.Sequential(*module_list)

def forward(self, x):
x_1 = self.conv1(x)
x_2 = self.conv2(x)
x_1 = self.m(x_1)
x = torch.cat((x_1, x_2), dim=1)
return self.conv3(x)


class SPPBottleneck(nn.Module):
"""Spatial pyramid pooling layer used in YOLOv3-SPP"""

def __init__(
self, in_channels, out_channels, kernel_sizes=(5, 9, 13), act_type="silu"
):
super().__init__()
hidden_channels = in_channels // 2
self.conv1 = ConvLayer(in_channels=in_channels, out_channels=hidden_channels,
kernel_size=1, stride=1, act_type=act_type)
self.m = nn.ModuleList(
[
nn.MaxPool2d(kernel_size=ks, stride=1, padding=ks // 2)
for ks in kernel_sizes
]
)
conv2_channels = hidden_channels * (len(kernel_sizes) + 1)
self.conv2 = ConvLayer(in_channels=conv2_channels, out_channels=out_channels,
kernel_size=1, stride=1, act_type=act_type)

def forward(self, x):
x = self.conv1(x)
x = torch.cat([x] + [m(x) for m in self.m], dim=1)
x = self.conv2(x)
return x


# Newly defined because of slight difference with Bottleneck of custom.py
class DarknetBlock(nn.Module):
# Standard bottleneck
def __init__(
self,
in_channels,
out_channels,
shortcut=True,
expansion=0.5,
#depthwise=False,
act_type="silu",
):
super().__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = ConvLayer(in_channels=in_channels, out_channels=hidden_channels,
kernel_size=1, stride=1, act_type=act_type)
self.conv2 = ConvLayer(in_channels=hidden_channels, out_channels=out_channels,
kernel_size=3, stride=1, act_type=act_type)
self.use_add = shortcut and in_channels == out_channels

def forward(self, x):
y = self.conv2(self.conv1(x))
if self.use_add:
y = y + x
return y
3 changes: 2 additions & 1 deletion src/netspresso_trainer/models/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import torch.nn as nn

from .backbones import efficientformer, mobilenetv3_small, mobilevit, resnet50, segformer, vit
from .backbones import cspdarknet, efficientformer, mobilenetv3_small, mobilevit, resnet50, segformer, vit
from .full import pidnet
from .heads.classification import fc
from .heads.detection import faster_rcnn
Expand All @@ -16,6 +16,7 @@
'mobilevit': mobilevit,
'vit': vit,
'efficientformer': efficientformer,
'cspdarknet': cspdarknet
}

MODEL_HEAD_DICT: Dict[str, Callable[..., nn.Module]] = {
Expand Down