Skip to content

Commit

Permalink
Fix (nn): add missing support for padding_mode (#709)
Browse files Browse the repository at this point in the history
  • Loading branch information
volcacius authored Sep 22, 2023
1 parent c95eafd commit 7196dc6
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 29 deletions.
55 changes: 26 additions & 29 deletions src/brevitas/nn/quant_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from torch.nn import Conv1d
from torch.nn import Conv2d
from torch.nn import functional as F
from torch.nn.functional import conv2d

from brevitas.function.ops import max_int
from brevitas.function.ops_ste import ceil_ste
Expand All @@ -35,8 +34,8 @@ def __init__(
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
padding_mode: str = 'zeros',
bias: bool = True,
padding_type: str = 'standard',
weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,
bias_quant: Optional[BiasQuantType] = None,
input_quant: Optional[ActQuantType] = None,
Expand All @@ -45,6 +44,12 @@ def __init__(
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
**kwargs) -> None:
# avoid an init error in the super class by setting padding to 0
if padding_mode == 'zeros' and padding == 'same' and stride > 1:
padding = 0
is_same_padded_strided = True
else:
is_same_padded_strided = False
Conv1d.__init__(
self,
in_channels=in_channels,
Expand All @@ -54,6 +59,7 @@ def __init__(
padding=padding,
dilation=dilation,
groups=groups,
padding_mode=padding_mode,
bias=bias,
device=device,
dtype=dtype)
Expand All @@ -65,9 +71,7 @@ def __init__(
output_quant=output_quant,
return_quant_tensor=return_quant_tensor,
**kwargs)
assert self.padding_mode == 'zeros'
assert not (padding_type == 'same' and padding != 0)
self.padding_type = padding_type
self.is_same_padded_strided = is_same_padded_strided

@property
def per_elem_ops(self):
Expand All @@ -84,11 +88,7 @@ def output_channel_dim(self):
def channelwise_separable(self) -> bool:
return self.groups == self.in_channels

def conv1d_zeros_pad(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]):
out = F.conv1d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
return out

def conv1d_same_zeros_pad(self, x, weight, bias):
def conv1d_same_zeros_pad_stride(self, x, weight, bias):
ih = x.size()[-1]
kh = weight.size()[-1]
sh = self.stride[0]
Expand All @@ -103,12 +103,10 @@ def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTenso
return self.forward_impl(input)

def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
if self.padding_type == 'standard':
return self.conv1d_zeros_pad(x, quant_weight, quant_bias)
elif self.padding_type == 'same':
return self.conv1d_same_zeros_pad(x, quant_weight, quant_bias)
if self.is_same_padded_strided:
return self.conv1d_same_zeros_pad_stride(x, quant_weight, quant_bias)
else:
raise NotImplementedError(f"Padding type {self.padding_type} not supported.")
return self._conv_forward(x, quant_weight, quant_bias)

def max_acc_bit_width(self, input_bit_width, weight_bit_width):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
Expand All @@ -130,8 +128,8 @@ def __init__(
padding: Union[int, Tuple[int, int]] = 0,
dilation: Union[int, Tuple[int, int]] = 1,
groups: int = 1,
padding_mode: str = 'zeros',
bias: bool = True,
padding_type: str = 'standard',
weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,
bias_quant: Optional[BiasQuantType] = None,
input_quant: Optional[ActQuantType] = None,
Expand All @@ -140,13 +138,20 @@ def __init__(
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
**kwargs) -> None:
# avoid an init error in the super class by setting padding to 0
if padding_mode == 'zeros' and padding == 'same' and stride > 1:
padding = 0
is_same_padded_strided = True
else:
is_same_padded_strided = False
Conv2d.__init__(
self,
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
padding_mode=padding_mode,
dilation=dilation,
groups=groups,
bias=bias,
Expand All @@ -160,9 +165,7 @@ def __init__(
output_quant=output_quant,
return_quant_tensor=return_quant_tensor,
**kwargs)
assert self.padding_mode == 'zeros'
assert not (padding_type == 'same' and padding != 0)
self.padding_type = padding_type
self.is_same_padded_strided = is_same_padded_strided

@property
def per_elem_ops(self):
Expand All @@ -179,11 +182,7 @@ def output_channel_dim(self):
def channelwise_separable(self) -> bool:
return self.groups == self.in_channels

def conv2d_zeros_pad(self, x: Tensor, weight: Tensor, bias: Tensor):
out = conv2d(x, weight, bias, self.stride, self.padding, self.dilation, self.groups)
return out

def conv2d_same_zeros_pad(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]):
def conv2d_same_zeros_pad_stride(self, x: Tensor, weight: Tensor, bias: Optional[Tensor]):
ih, iw = x.size()[-2:]
kh, kw = weight.size()[-2:]
sh, sw = self.stride
Expand All @@ -199,12 +198,10 @@ def forward(self, input: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTenso
return self.forward_impl(input)

def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
if self.padding_type == 'standard':
return self.conv2d_zeros_pad(x, quant_weight, quant_bias)
elif self.padding_type == 'same':
return self.conv2d_same_zeros_pad(x, quant_weight, quant_bias)
if self.is_same_padded_strided:
return self.conv2d_same_zeros_pad_stride(x, quant_weight, quant_bias)
else:
raise RuntimeError(f"Padding type {self.padding_type} not supported.")
return self._conv_forward(x, quant_weight, quant_bias)

def max_acc_bit_width(self, input_bit_width: Tensor, weight_bit_width: Tensor):
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas/nn/quant_convtranspose.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def __init__(
output_padding: int = 0,
dilation: int = 1,
groups: int = 1,
padding_mode: str = 'zeros',
bias: bool = True,
weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,
bias_quant: Optional[BiasQuantType] = None,
Expand All @@ -56,6 +57,7 @@ def __init__(
output_padding=output_padding,
dilation=dilation,
groups=groups,
padding_mode=padding_mode,
bias=bias,
device=device,
dtype=dtype)
Expand Down Expand Up @@ -132,6 +134,7 @@ def __init__(
output_padding: Union[int, Tuple[int]] = 0,
dilation: Union[int, Tuple[int]] = 1,
groups: int = 1,
padding_mode: str = 'zeros',
bias: bool = True,
weight_quant: Optional[WeightQuantType] = Int8WeightPerTensorFloat,
bias_quant: Optional[BiasQuantType] = None,
Expand All @@ -151,6 +154,7 @@ def __init__(
output_padding=output_padding,
dilation=dilation,
groups=groups,
padding_mode=padding_mode,
bias=bias,
device=device,
dtype=dtype)
Expand Down

0 comments on commit 7196dc6

Please sign in to comment.