Skip to content

Commit

Permalink
Fix for conv3d/convtranspose3d
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 14, 2024
1 parent ded1a72 commit 230188a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
9 changes: 8 additions & 1 deletion src/brevitas/nn/quant_convtranspose.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,14 @@ def compute_output_padding(self, inp, output_size):
def conv_transpose3d_zeros_pad(
self, x: Tensor, weight: Tensor, bias: Optional[Tensor], output_padding):
out = conv_transpose3d(
x, weight, bias, self.stride, self.padding, output_padding, self.groups, self.dilation)
x,
weight,
bias,
stride=self.stride,
padding=self.padding,
output_padding=output_padding,
groups=self.groups,
dilation=self.dilation)
return out

def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Optional[Tensor]):
Expand Down
38 changes: 33 additions & 5 deletions src/brevitas/quant_tensor/torch_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,21 +168,33 @@ def conv2d_handler(quant_input, quant_weight, bias=None, *args, **kwargs):
return output


@implements(F.conv3d)
def conv3d_handler(quant_input, quant_weight, bias=None, *args, **kwargs):
output = quant_layer(F.conv3d, quant_input, quant_weight, bias, *args, **kwargs)
return output


@implements(F.conv1d)
def conv1d_handler(quant_input, quant_weight, bias=None, *args, **kwargs):
output = quant_layer(F.conv1d, quant_input, quant_weight, bias, *args, **kwargs)
return output


@implements(F.conv_transpose1d)
def conv_transpose1d_handler(quant_input, quant_weight, bias=None, *args, **kwargs):
output = quant_layer(F.conv_transpose1d, quant_input, quant_weight, bias, *args, **kwargs)
return output


@implements(F.conv_transpose2d)
def conv_transpose2d_handler(quant_input, quant_weight, bias=None, *args, **kwargs):
output = quant_layer(F.conv_transpose2d, quant_input, quant_weight, bias, *args, **kwargs)
return output


@implements(F.conv_transpose1d)
def conv_transpose1d_handler(quant_input, quant_weight, bias=None, *args, **kwargs):
output = quant_layer(F.conv_transpose1d, quant_input, quant_weight, bias, *args, **kwargs)
@implements(F.conv_transpose3d)
def conv_transpose3d_handler(quant_input, quant_weight, bias=None, *args, **kwargs):
output = quant_layer(F.conv_transpose3d, quant_input, quant_weight, bias, *args, **kwargs)
return output


Expand Down Expand Up @@ -326,9 +338,25 @@ def max_acc_bit_width_convtranspose2d(
return max_output_bit_width


def max_acc_bit_width_convtranspose3d(
input_bit_width, weight_bit_width, weight_shape, *args, **kwargs):
stride = kwargs['stride']
max_uint_input = max_int(bit_width=input_bit_width, signed=False, narrow_range=False)
max_kernel_val = max_int(bit_width=weight_bit_width, signed=False, narrow_range=False)
out_channel = weight_shape[1]
kernel_shape = weight_shape[2:]
patch_size = max(math.ceil(kernel_shape[0] / stride[0]), 1) * max(
math.ceil(kernel_shape[1] / stride[1]), 1) * max(math.ceil(kernel_shape[2] / stride[2]), 1)
max_uint_output = max_uint_input * max_kernel_val * patch_size * out_channel
max_output_bit_width = ceil_ste(torch.log2(max_uint_output))
return max_output_bit_width


IMPLS = {
F.conv2d: max_acc_bit_width_convnd,
F.conv1d: max_acc_bit_width_convnd,
F.conv2d: max_acc_bit_width_convnd,
F.conv3d: max_acc_bit_width_convnd,
F.linear: max_acc_bit_width_linear,
F.conv_transpose1d: max_acc_bit_width_convtranspose1d,
F.conv_transpose2d: max_acc_bit_width_convtranspose2d,
F.conv_transpose1d: max_acc_bit_width_convtranspose1d}
F.conv_transpose3d: max_acc_bit_width_convtranspose3d}

0 comments on commit 230188a

Please sign in to comment.