Skip to content

Commit

Permalink
Fix (gpxq): adding input quant to process input (#943)
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert authored Apr 26, 2024
1 parent 670420f commit bae1e26
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 28 deletions.
31 changes: 4 additions & 27 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from brevitas.graph.gpxq import StopFwdException
from brevitas.graph.gpxq import SUPPORTED_CONV_OP
import brevitas.nn as qnn
from brevitas.quant_tensor import QuantTensor


class gpfq_mode(gpxq_mode):
Expand Down Expand Up @@ -313,32 +312,10 @@ def __init__(
self.accumulator_bit_width = accumulator_bit_width
assert self.accumulator_bit_width is not None

def process_input(self, inp):
inp = super().process_input(inp)
inp = self.layer.input_quant(inp)

is_quant_enabled = self.layer.weight_quant.is_quant_enabled

# If using quantized activations, inp could be QuantTensor. In
# this case, we overwrite the metadata.
if isinstance(inp, QuantTensor):
if is_quant_enabled and self.quant_input is None:
self.quant_input = QuantTensor(
value=torch.empty(
1, dtype=self.layer.weight.dtype, device=self.layer.weight.device),
scale=inp.scale,
zero_point=inp.zero_point,
bit_width=inp.bit_width,
signed=inp.signed,
training=inp.training)
inp = inp.value

return inp

def single_layer_update(self):
# raise error in case no quant-input is here
if self.quant_input is None:
raise ValueError('Expected self.quant_input to calculate L1-norm upper bound, but recevied None. ' + \
if self.quant_metadata is None:
raise ValueError('Expected self.quant_metadata to calculate L1-norm upper bound, but recevied None. ' + \
'Make sure that either the input to the model is a QuantTensor or the layer has an input quant enabled. ' \
'Also, check if `use_quant_activations=True` in `gpfq_mode` when `accumulator_bit_width` is specified. ' + \
'Alternatively, provide a custom `a2q_layer_filter_fnc` to `gpfq_mode` to filter layers without a quant_tensor input.')
Expand All @@ -356,8 +333,8 @@ def single_layer_update(self):
self.quantized_input = self.quantized_input.to(dev)

# get upper bound
input_bit_width = self.quant_input.bit_width
input_is_signed = self.quant_input.signed
input_bit_width = self.quant_metadata.bit_width
input_is_signed = self.quant_metadata.signed
T = get_upper_bound_on_l1_norm(
torch.tensor(self.accumulator_bit_width), input_bit_width, input_is_signed)
s = self.layer.weight_quant.scale()
Expand Down
15 changes: 14 additions & 1 deletion src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from brevitas.graph.calibrate import DisableEnableQuantization
from brevitas.graph.calibrate import restore_return_quant_tensor
import brevitas.nn as qnn
from brevitas.quant_tensor import QuantTensor
from brevitas.utils.quant_utils import _CachedIO

SUPPORTED_CONV_OP = (
qnn.QuantConv1d,
Expand Down Expand Up @@ -216,11 +218,21 @@ def __init__(

self.disable_pre_forward_hook = False
# Some layers require knowledge from quant inputs to compute quant weights
self.quant_input = None
self.quant_metadata = None

def process_input(self, inp):
# Input is a tuple, so we take first element
inp = inp[0]
inp = self.layer.input_quant(inp)

is_quant_enabled = self.layer.weight_quant.is_quant_enabled

# If using quantized activations, inp could be QuantTensor. In
# this case, we overwrite the metadata.
if isinstance(inp, QuantTensor):
if is_quant_enabled and self.quant_metadata is None:
self.quant_metadata = _CachedIO(inp, metadata_only=True)
inp = inp.value

# If input is unbatched, add batch_size = 1
if len(inp.shape) == 1:
Expand All @@ -232,6 +244,7 @@ def process_input(self, inp):
batch_dim = inp.names.index('N')
inp.rename_(None)
inp = inp.transpose(0, batch_dim)

return inp

@abstractmethod
Expand Down

0 comments on commit bae1e26

Please sign in to comment.