diff --git a/src/brevitas/proxy/groupwise_int_runtime_quant.py b/src/brevitas/proxy/groupwise_int_runtime_quant.py index ec9418e19..453cb3f9b 100644 --- a/src/brevitas/proxy/groupwise_int_runtime_quant.py +++ b/src/brevitas/proxy/groupwise_int_runtime_quant.py @@ -31,7 +31,7 @@ def create_quant_tensor( qt_args: Union[torch.Tensor, Tuple[Any]], x: Optional[GroupwiseIntQuantTensor] = None) -> GroupwiseIntQuantTensor: if x is None: - value, scale, zero_point, bit_width, = qt_args + value, scale, zero_point, bit_width = qt_args out = GroupwiseIntQuantTensor( value, scale, diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 03303bcc8..b2ded7239 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -123,11 +123,11 @@ def internal_forward(self, force_eval): return out def retrieve_attribute(self, attribute, force_eval): - if self.is_quant_enabled: + if self._cached_act is not None: + return getattr(self._cached_act, attribute) + elif self.is_quant_enabled: out = self.internal_forward(force_eval) return getattr(out, attribute) - elif self._cached_act is not None: - return getattr(self._cached_act, attribute) elif self._cached_act is None: return None diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 661328d7e..a6cdd2af7 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -178,7 +178,7 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}, 'po2_scale': { 'stats': { - 'per_group': MXInt8Act}}}}, + 'per_group': {'sym':MXInt8Act} }}}}, 'float': { 'static': { 'float_scale': {