Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] RT-DETR hybrid encoder #581

Merged
merged 14 commits into from
Nov 26, 2024
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -28,106 +28,7 @@

from ...utils import BackboneOutput
from ...op.registry import ACTIVATION_REGISTRY


# TODO: Replace with custom implementation
class ConvNormLayer(nn.Module):
def __init__(self, ch_in, ch_out, kernel_size, stride, padding=None, bias=False, act=None):
super().__init__()
self.conv = nn.Conv2d(
ch_in,
ch_out,
kernel_size,
stride,
padding=(kernel_size-1)//2 if padding is None else padding,
bias=bias)
self.norm = nn.BatchNorm2d(ch_out)
self.act = nn.Identity() if act is None else ACTIVATION_REGISTRY[act]()

def forward(self, x):
return self.act(self.norm(self.conv(x)))


# TODO: Replace with custom implementation
class RepVggBlock(nn.Module):
def __init__(self, ch_in, ch_out, act='relu'):
super().__init__()
self.ch_in = ch_in
self.ch_out = ch_out
self.conv1 = ConvNormLayer(ch_in, ch_out, 3, 1, padding=1, act=None)
self.conv2 = ConvNormLayer(ch_in, ch_out, 1, 1, padding=0, act=None)
self.act = nn.Identity() if act is None else ACTIVATION_REGISTRY[act]()

def forward(self, x):
if hasattr(self, 'conv'):
y = self.conv(x)
else:
y = self.conv1(x) + self.conv2(x)

return self.act(y)

def convert_to_deploy(self):
if not hasattr(self, 'conv'):
self.conv = nn.Conv2d(self.ch_in, self.ch_out, 3, 1, padding=1)

kernel, bias = self.get_equivalent_kernel_bias()
self.conv.weight.data = kernel
self.conv.bias.data = bias
# self.__delattr__('conv1')
# self.__delattr__('conv2')

def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)

return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1), bias3x3 + bias1x1

def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
else:
return F.pad(kernel1x1, [1, 1, 1, 1])

def _fuse_bn_tensor(self, branch: ConvNormLayer):
if branch is None:
return 0, 0
kernel = branch.conv.weight
running_mean = branch.norm.running_mean
running_var = branch.norm.running_var
gamma = branch.norm.weight
beta = branch.norm.bias
eps = branch.norm.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std


# TODO: Replace with custom implementation
class CSPRepLayer(nn.Module):
def __init__(self,
in_channels,
out_channels,
num_blocks=3,
expansion=1.0,
bias=None,
act="silu"):
super(CSPRepLayer, self).__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
self.conv2 = ConvNormLayer(in_channels, hidden_channels, 1, 1, bias=bias, act=act)
self.bottlenecks = nn.Sequential(*[
RepVggBlock(hidden_channels, hidden_channels, act=act) for _ in range(num_blocks)
])
if hidden_channels != out_channels:
self.conv3 = ConvNormLayer(hidden_channels, out_channels, 1, 1, bias=bias, act=act)
else:
self.conv3 = nn.Identity()

def forward(self, x):
x_1 = self.conv1(x)
x_1 = self.bottlenecks(x_1)
x_2 = self.conv2(x)
return self.conv3(x_1 + x_2)
from ...op.custom import RepVGGBlock, ConvLayer, CSPRepLayer


# transformer
Expand Down Expand Up @@ -251,7 +152,7 @@ def __init__(
self.lateral_convs = nn.ModuleList()
self.fpn_blocks = nn.ModuleList()
for _ in range(len(self.in_channels) - 1, 0, -1):
self.lateral_convs.append(ConvNormLayer(self.hidden_dim, self.hidden_dim, 1, 1, act=act))
self.lateral_convs.append(ConvLayer(self.hidden_dim, self.hidden_dim, kernel_size=1, stride=1, act_type=act))
self.fpn_blocks.append(
CSPRepLayer(self.hidden_dim * 2, self.hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
)
Expand All @@ -261,12 +162,11 @@ def __init__(
self.pan_blocks = nn.ModuleList()
for _ in range(len(self.in_channels) - 1):
self.downsample_convs.append(
ConvNormLayer(self.hidden_dim, self.hidden_dim, 3, 2, act=act)
ConvLayer(self.hidden_dim, self.hidden_dim, kernel_size=3, stride=2, act_type=act)
)
self.pan_blocks.append(
CSPRepLayer(self.hidden_dim * 2, self.hidden_dim, round(3 * depth_mult), act=act, expansion=expansion)
)

self._reset_parameters()

@property
Expand Down
117 changes: 117 additions & 0 deletions src/netspresso_trainer/models/op/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from functools import partial
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import torch
import torch.nn as nn
from torch import Tensor
Expand Down Expand Up @@ -171,6 +172,96 @@ def forward(self, x: Union[Tensor, Proxy]) -> Union[Tensor, Proxy]:
x = self.final_act(x)
return x

class RepVGGBlock(nn.Module):
"""
A convolutional block that combines two convolution layers (kernel and point-wise conv).
This implementation is based on https://github.com/lyuwenyu/RT-DETR/blob/b444daf79cf25f95b740ae71e80fd165e892739a/rtdetr_pytorch/src/zoo/rtdetr/hybrid_encoder.py#L35.
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: Union[int, Tuple[int, int]] = 3,
groups: int = 1,
act_type: Optional[str] = None,):
if act_type is None:
act_type = 'silu'
super().__init__()
assert isinstance(in_channels, int)
assert isinstance(out_channels, int)
assert isinstance(act_type, str)
assert kernel_size == 3

self.in_channels = in_channels
self.out_channels = out_channels
self.groups = groups
self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, groups=groups, use_act=False)
self.conv2 = ConvLayer(in_channels, out_channels, 1, groups=groups, use_act=False)
self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels else nn.Identity()

assert act_type in ACTIVATION_REGISTRY
self.act = ACTIVATION_REGISTRY[act_type]()

def forward(self, x: Union[Tensor, Proxy]) -> Union[Tensor, Proxy]:
if hasattr(self, 'conv'):
y = self.conv(x)
return y
y = self.conv1(x) + self.conv2(x) + self.rbr_identity(x)

return self.act(y)

def convert_to_deploy(self):
if not hasattr(self, 'conv'):
self.conv = nn.Conv2d(self.in_channels, self.out_channels, kernel_size=3, stride=1, padding=1)

kernel, bias = self.get_equivalent_kernel_bias()
self.conv.weight.data = kernel
self.conv.bias.data = bias
# self.__delattr__('conv1')
# self.__delattr__('conv2')

def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.conv1)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.conv2)
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)

return kernel3x3 + self._pad_1x1_to_3x3_tensor(kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid

def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
else:
return nn.functional.pad(kernel1x1, [1, 1, 1, 1])

def _fuse_bn_tensor(self, branch: Union[ConvLayer, nn.BatchNorm2d, nn.Identity]):
if isinstance(branch, nn.Identity):
return 0, 0
if isinstance(branch, nn.BatchNorm2d):
if not hasattr(self, 'id_tensor'):
input_dim = self.in_channels // self.groups
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3), dtype=np.float32)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, 1, 1] = 1
self.id_tensor = torch.from_numpy(kernel_value).to(branch.weight.device)
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
else:
assert isinstance(branch, ConvLayer)
kernel = branch.block.conv.weight
running_mean = branch.block.norm.running_mean
running_var = branch.block.norm.running_var
gamma = branch.block.norm.weight
beta = branch.block.norm.bias
eps = branch.block.norm.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std


class BasicBlock(nn.Module):
expansion: int = 1
Expand Down Expand Up @@ -725,6 +816,32 @@ def forward(self, x):
return self.conv3(x)


class CSPRepLayer(nn.Module):
def __init__(self,
in_channels: int,
out_channels: int,
num_blocks: int=3,
expansion: float=1.0,
bias: bool= False,
act: str="silu"):
super(CSPRepLayer, self).__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = ConvLayer(in_channels, hidden_channels, kernel_size=1, stride=1, bias=bias, act_type=act)
self.conv2 = ConvLayer(in_channels, hidden_channels, kernel_size=1, stride=1, bias=bias, act_type=act)
self.bottlenecks = nn.Sequential(*[
RepVGGBlock(hidden_channels, hidden_channels, act_type=act) for _ in range(num_blocks)
])
if hidden_channels != out_channels:
self.conv3 = ConvLayer(hidden_channels, out_channels, kernel_size=1, stride=1, bias=bias, act_type=act)
else:
self.conv3 = nn.Identity()

def forward(self, x):
x_1 = self.conv1(x)
x_1 = self.bottlenecks(x_1)
x_2 = self.conv2(x)
return self.conv3(x_1 + x_2)

class SPPBottleneck(nn.Module):
"""Spatial pyramid pooling layer used in YOLOv3-SPP"""

Expand Down
Loading