From 77e0fa092dfb579fba7c9e299436ed474ad1ce2a Mon Sep 17 00:00:00 2001 From: hglee98 Date: Thu, 14 Nov 2024 09:10:17 +0000 Subject: [PATCH] [fix] Add groups attribute --- src/netspresso_trainer/models/op/custom.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/src/netspresso_trainer/models/op/custom.py b/src/netspresso_trainer/models/op/custom.py index bc01d686..930dc7fb 100644 --- a/src/netspresso_trainer/models/op/custom.py +++ b/src/netspresso_trainer/models/op/custom.py @@ -181,6 +181,7 @@ def __init__(self, in_channels: int, out_channels: int, kernel_size: Union[int, Tuple[int, int]] = 3, + groups: int = 1, act_type: Optional[str] = None,): if act_type is None: act_type = 'silu' @@ -192,8 +193,9 @@ def __init__(self, self.in_channels = in_channels self.out_channels = out_channels - self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, use_act=False) - self.conv2 = ConvLayer(in_channels, out_channels, 1, use_act=False) + self.groups = groups + self.conv1 = ConvLayer(in_channels, out_channels, kernel_size, groups=groups, use_act=False) + self.conv2 = ConvLayer(in_channels, out_channels, 1, groups=groups, use_act=False) self.rbr_identity = nn.BatchNorm2d(num_features=in_channels) if out_channels == in_channels else None assert act_type in ACTIVATION_REGISTRY