Skip to content

Commit

Permalink
fix (nn/conv): Fix regression introduced in #1017 (#1019)
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser authored Sep 4, 2024
1 parent 0dfed16 commit 9693ff5
Show file tree
Hide file tree
Showing 3 changed files with 86 additions and 8 deletions.
9 changes: 6 additions & 3 deletions src/brevitas/nn/quant_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,8 @@ def __init__(
dtype: Optional[torch.dtype] = None,
**kwargs) -> None:
# avoid an init error in the super class by setting padding to 0
if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)):
if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance(
stride, int) else any(map(lambda x: x > 1, stride))):
padding = 0
is_same_padded_strided = True
else:
Expand Down Expand Up @@ -132,7 +133,8 @@ def __init__(
dtype: Optional[torch.dtype] = None,
**kwargs) -> None:
# avoid an init error in the super class by setting padding to 0
if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)):
if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance(
stride, int) else any(map(lambda x: x > 1, stride))):
padding = 0
is_same_padded_strided = True
else:
Expand Down Expand Up @@ -220,7 +222,8 @@ def __init__(
dtype: Optional[torch.dtype] = None,
**kwargs) -> None:
# avoid an init error in the super class by setting padding to 0
if padding_mode == 'zeros' and padding == 'same' and any(map(lambda x: x > 1, stride)):
if padding_mode == 'zeros' and padding == 'same' and (stride > 1 if isinstance(
stride, int) else any(map(lambda x: x > 1, stride))):
padding = 0
is_same_padded_strided = True
else:
Expand Down
43 changes: 40 additions & 3 deletions tests/brevitas/nn/test_conv2d.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import pytest_cases
import torch
from torch.nn import BatchNorm2d
from torch.nn import Conv2d
from torch.nn import Module

from brevitas.inject.defaults import Int8BiasPerTensorFloatInternalScaling
from brevitas.nn import QuantConv2d
Expand All @@ -18,12 +18,24 @@

class TestQuantConv2d:

def test_module_init(self):
@pytest_cases.parametrize(
'kwargs',
[{}, {
'padding': 'same', 'stride': 1}, {
'padding': 'same', 'stride': (1, 1)}, {
'padding': 'same', 'stride': (2, 1)}],
ids=[
'defaults',
'padding="same",stride=1',
'padding="same",stride=(1,1)',
'padding="same",stride=(2,1)'])
def test_module_init(self, kwargs):
mod = QuantConv2d(
out_channels=OUTPUT_CHANNELS,
in_channels=INPUT_CHANNELS,
kernel_size=KERNEL_SIZE,
bias=False)
bias=False,
**kwargs)

def test_fp_quant_module(self):
float_mod = Conv2d(
Expand Down Expand Up @@ -102,3 +114,28 @@ def test_internally_scaled_int_bias_after_bn_merge(self):
merge_bn(mod, bn)
inp = torch.randn(1, INPUT_CHANNELS, 20, 20)
mod(inp)

@pytest_cases.parametrize(
'kwargs',
[{
'is_same_padded_strided': False}, {
'padding': 'same', 'stride': 1, 'is_same_padded_strided': False}, {
'padding': 'same', 'stride': (1, 1), 'is_same_padded_strided': False}, {
'padding': 'same', 'stride': (2, 1), 'is_same_padded_strided': True}, {
'padding': 0, 'stride': (2, 1), 'is_same_padded_strided': False}],
ids=[
'defaults',
'padding="same",stride=1',
'padding="same",stride=(1,1)',
'padding="same",stride=(2,1)',
'padding=0,stride=(2,1)'])
def test_is_same_padded_strided(self, kwargs):
is_same_padded_strided = kwargs['is_same_padded_strided']
del kwargs['is_same_padded_strided']
mod = QuantConv2d(
out_channels=OUTPUT_CHANNELS,
in_channels=INPUT_CHANNELS,
kernel_size=KERNEL_SIZE,
bias=False,
**kwargs)
assert is_same_padded_strided == mod.is_same_padded_strided, f"Expected is_same_padded_strided={is_same_padded_strided}, found is_same_padded_strided={mod.is_same_padded_strided}"
42 changes: 40 additions & 2 deletions tests/brevitas/nn/test_conv3d.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import pytest_cases
import torch
from torch.nn import BatchNorm3d
from torch.nn import Conv3d
Expand All @@ -17,12 +18,24 @@

class TestQuantConv3d:

def test_module_init(self):
@pytest_cases.parametrize(
'kwargs',
[{}, {
'padding': 'same', 'stride': 1}, {
'padding': 'same', 'stride': (1, 1, 1)}, {
'padding': 'same', 'stride': (2, 1, 1)}],
ids=[
'defaults',
'padding="same",stride=1',
'padding="same",stride=(1,1,1)',
'padding="same",stride=(2,1,1)'])
def test_module_init(self, kwargs):
mod = QuantConv3d(
out_channels=OUTPUT_CHANNELS,
in_channels=INPUT_CHANNELS,
kernel_size=KERNEL_SIZE,
bias=False)
bias=False,
**kwargs)

def test_fp_quant_module(self):
float_mod = Conv3d(
Expand Down Expand Up @@ -101,3 +114,28 @@ def test_internally_scaled_int_bias_after_bn_merge(self):
merge_bn(mod, bn)
inp = torch.randn(1, INPUT_CHANNELS, 20, 20, 20)
mod(inp)

@pytest_cases.parametrize(
'kwargs',
[{
'is_same_padded_strided': False}, {
'padding': 'same', 'stride': 1, 'is_same_padded_strided': False}, {
'padding': 'same', 'stride': (1, 1, 1), 'is_same_padded_strided': False}, {
'padding': 'same', 'stride': (2, 1, 1), 'is_same_padded_strided': True}, {
'padding': 0, 'stride': (2, 1, 1), 'is_same_padded_strided': False}],
ids=[
'defaults',
'padding="same",stride=1',
'padding="same",stride=(1,1,1)',
'padding="same",stride=(2,1,1)',
'padding=0,stride=(2,1,1)'])
def test_is_same_padded_strided(self, kwargs):
is_same_padded_strided = kwargs['is_same_padded_strided']
del kwargs['is_same_padded_strided']
mod = QuantConv3d(
out_channels=OUTPUT_CHANNELS,
in_channels=INPUT_CHANNELS,
kernel_size=KERNEL_SIZE,
bias=False,
**kwargs)
assert is_same_padded_strided == mod.is_same_padded_strided, f"Expected is_same_padded_strided={is_same_padded_strided}, found is_same_padded_strided={mod.is_same_padded_strided}"

0 comments on commit 9693ff5

Please sign in to comment.