diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index 8ba6d2912..ef720d092 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -17,7 +17,6 @@ from brevitas.graph.gpxq import StopFwdException from brevitas.graph.gpxq import SUPPORTED_CONV_OP import brevitas.nn as qnn -from brevitas.quant_tensor import QuantTensor class gpfq_mode(gpxq_mode): @@ -313,32 +312,10 @@ def __init__( self.accumulator_bit_width = accumulator_bit_width assert self.accumulator_bit_width is not None - def process_input(self, inp): - inp = super().process_input(inp) - inp = self.layer.input_quant(inp) - - is_quant_enabled = self.layer.weight_quant.is_quant_enabled - - # If using quantized activations, inp could be QuantTensor. In - # this case, we overwrite the metadata. - if isinstance(inp, QuantTensor): - if is_quant_enabled and self.quant_input is None: - self.quant_input = QuantTensor( - value=torch.empty( - 1, dtype=self.layer.weight.dtype, device=self.layer.weight.device), - scale=inp.scale, - zero_point=inp.zero_point, - bit_width=inp.bit_width, - signed=inp.signed, - training=inp.training) - inp = inp.value - - return inp - def single_layer_update(self): # raise error in case no quant-input is here - if self.quant_input is None: - raise ValueError('Expected self.quant_input to calculate L1-norm upper bound, but recevied None. ' + \ + if self.quant_metadata is None: + raise ValueError('Expected self.quant_metadata to calculate L1-norm upper bound, but recevied None. ' + \ 'Make sure that either the input to the model is a QuantTensor or the layer has an input quant enabled. ' \ 'Also, check if `use_quant_activations=True` in `gpfq_mode` when `accumulator_bit_width` is specified. ' + \ 'Alternatively, provide a custom `a2q_layer_filter_fnc` to `gpfq_mode` to filter layers without a quant_tensor input.') @@ -356,8 +333,8 @@ def single_layer_update(self): self.quantized_input = self.quantized_input.to(dev) # get upper bound - input_bit_width = self.quant_input.bit_width - input_is_signed = self.quant_input.signed + input_bit_width = self.quant_metadata.bit_width + input_is_signed = self.quant_metadata.signed T = get_upper_bound_on_l1_norm( torch.tensor(self.accumulator_bit_width), input_bit_width, input_is_signed) s = self.layer.weight_quant.scale() diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index aa174f5a3..c38a15712 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -18,6 +18,8 @@ from brevitas.graph.calibrate import DisableEnableQuantization from brevitas.graph.calibrate import restore_return_quant_tensor import brevitas.nn as qnn +from brevitas.quant_tensor import QuantTensor +from brevitas.utils.quant_utils import _CachedIO SUPPORTED_CONV_OP = ( qnn.QuantConv1d, @@ -216,11 +218,21 @@ def __init__( self.disable_pre_forward_hook = False # Some layers require knowledge from quant inputs to compute quant weights - self.quant_input = None + self.quant_metadata = None def process_input(self, inp): # Input is a tuple, so we take first element inp = inp[0] + inp = self.layer.input_quant(inp) + + is_quant_enabled = self.layer.weight_quant.is_quant_enabled + + # If using quantized activations, inp could be QuantTensor. In + # this case, we overwrite the metadata. + if isinstance(inp, QuantTensor): + if is_quant_enabled and self.quant_metadata is None: + self.quant_metadata = _CachedIO(inp, metadata_only=True) + inp = inp.value # If input is unbatched, add batch_size = 1 if len(inp.shape) == 1: @@ -232,6 +244,7 @@ def process_input(self, inp): batch_dim = inp.names.index('N') inp.rename_(None) inp = inp.transpose(0, batch_dim) + return inp @abstractmethod