Skip to content

Commit

Permalink
Fix (GPxQ): unwrap QuantTensor when dealing with QuantLinear
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed May 14, 2024
1 parent a1926f0 commit 6952ea4
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 2 deletions.
3 changes: 3 additions & 0 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,9 @@ def update_batch(self, module, input, current_layer):
if isinstance(self.layer, qnn.QuantLinear):
if len(inp.shape) > 2:
inp = inp.reshape((-1, sum(inp.shape[2:])))
# Unwrap tensor value if quantized input
if isinstance(inp, QuantTensor):
inp = inp.value
# For QuantLinear layer, groups will be 1
inp_processed = inp.unsqueeze(0)

Expand Down
4 changes: 4 additions & 0 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from brevitas.graph.gpxq import StopFwdException
from brevitas.graph.gpxq import SUPPORTED_CONV_OP
import brevitas.nn as qnn
from brevitas.quant_tensor import QuantTensor


class gptq_mode(gpxq_mode):
Expand Down Expand Up @@ -150,6 +151,9 @@ def update_batch(self, module, input, current_layer):
if isinstance(self.layer, qnn.QuantLinear):
if len(inp.shape) > 2:
inp = inp.reshape((-1, sum(inp.shape[2:])))
# Unwrap tensor value if QuantTensor
if isinstance(inp, QuantTensor):
inp = inp.value
inp = inp.t()
# For QuantLinear layer, groups will be 1
inp_processed = inp.unsqueeze(0)
Expand Down
22 changes: 21 additions & 1 deletion tests/brevitas/graph/equalization_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,11 +461,31 @@ def forward(self, x):
return QuantConvTransposeModel


@pytest_cases.fixture()
def quant_linear_model():

class QuantLinearModel(nn.Module):

def __init__(self) -> None:
super().__init__()
self.linear_0 = qnn.QuantLinear(
3, 16, True, input_quant=Int8ActPerTensorFloat, return_quant_tensor=True)
self.linear_1 = qnn.QuantLinear(16, 10, True)

def forward(self, x):
x = self.linear_0(x)
x = self.linear_1(x)
return x

return QuantLinearModel


list_of_quant_fixtures = [
'quant_conv_with_input_quant_model',
'quant_convdepthconv_model',
'quant_residual_model',
'quant_convtranspose_model']
'quant_convtranspose_model',
'quant_linear_model']

toy_quant_model = fixture_union(
'toy_quant_model', list_of_quant_fixtures, ids=list_of_quant_fixtures)
2 changes: 1 addition & 1 deletion tests/brevitas/graph/test_gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_toymodels(

model_class = toy_quant_model
model = model_class()
if 'mha' in test_id:
if 'mha' in test_id or 'linear' in test_id:
inp = torch.randn(32, *IN_SIZE_LINEAR[1:])
else:
inp = torch.randn(32, *IN_SIZE_CONV_SMALL[1:])
Expand Down

0 comments on commit 6952ea4

Please sign in to comment.