Skip to content

Commit

Permalink
Move DarknetBlock
Browse files Browse the repository at this point in the history
  • Loading branch information
illian01 committed Oct 17, 2023
1 parent 0c3dfb7 commit 1f55168
Showing 1 changed file with 27 additions and 26 deletions.
53 changes: 27 additions & 26 deletions src/netspresso_trainer/models/op/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -555,32 +555,6 @@ def __init__(
kernel_size=1,
stride=1, act_type=act_type)

# Newly define because of slight difference with Bottleneck of custom.py
class DarknetBlock(nn.Module):
# Standard bottleneck
def __init__(
self,
in_channels,
out_channels,
shortcut=True,
expansion=0.5,
#depthwise=False,
act_type="silu",
):
super().__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = ConvLayer(in_channels=in_channels, out_channels=hidden_channels,
kernel_size=1, stride=1, act_type=act_type)
self.conv2 = ConvLayer(in_channels=hidden_channels, out_channels=out_channels,
kernel_size=3, stride=1, act_type=act_type)
self.use_add = shortcut and in_channels == out_channels

def forward(self, x):
y = self.conv2(self.conv1(x))
if self.use_add:
y = y + x
return y

block = DarknetBlock

module_list = [
Expand Down Expand Up @@ -628,3 +602,30 @@ def forward(self, x):
x = torch.cat([x] + [m(x) for m in self.m], dim=1)
x = self.conv2(x)
return x


# Newly defined because of slight difference with Bottleneck of custom.py
class DarknetBlock(nn.Module):
# Standard bottleneck
def __init__(
self,
in_channels,
out_channels,
shortcut=True,
expansion=0.5,
#depthwise=False,
act_type="silu",
):
super().__init__()
hidden_channels = int(out_channels * expansion)
self.conv1 = ConvLayer(in_channels=in_channels, out_channels=hidden_channels,
kernel_size=1, stride=1, act_type=act_type)
self.conv2 = ConvLayer(in_channels=hidden_channels, out_channels=out_channels,
kernel_size=3, stride=1, act_type=act_type)
self.use_add = shortcut and in_channels == out_channels

def forward(self, x):
y = self.conv2(self.conv1(x))
if self.use_add:
y = y + x
return y

0 comments on commit 1f55168

Please sign in to comment.