From 02f5b6b54992a4b41cfd91eb5dfd675ac300c27d Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 3 Jun 2024 11:40:58 +0200 Subject: [PATCH] Fix (calibrate): fix for minifloat act calibration (#966) --- src/brevitas/graph/calibrate.py | 6 +++--- .../imagenet_classification/ptq/ptq_common.py | 15 ++++++++------- 2 files changed, 11 insertions(+), 10 deletions(-) diff --git a/src/brevitas/graph/calibrate.py b/src/brevitas/graph/calibrate.py index 8ac55caaa..a5d7e08e8 100644 --- a/src/brevitas/graph/calibrate.py +++ b/src/brevitas/graph/calibrate.py @@ -15,7 +15,7 @@ from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer as QuantWBIOL from brevitas.proxy.parameter_quant import BiasQuantProxyFromInjectorBase from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjectorBase -from brevitas.proxy.runtime_quant import ActQuantProxyFromInjector +from brevitas.proxy.runtime_quant import ActQuantProxyFromInjectorBase from brevitas.proxy.runtime_quant import ClampQuantProxyFromInjector from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector from brevitas.quant_tensor import QuantTensor @@ -188,7 +188,7 @@ def disable_act_quantization(self, model, is_training): # will be discarded through the hook. It is useful for collecting activation stats, # for example during activation calibration in PTQ for module in model.modules(): - if isinstance(module, ActQuantProxyFromInjector): + if isinstance(module, ActQuantProxyFromInjectorBase): module.train(is_training) if self.call_act_quantizer_impl: hook = module.register_forward_hook(self.disable_act_quant_hook) @@ -216,7 +216,7 @@ def enable_act_quantization(self, model, is_training): if isinstance(module, _ACC_PROXIES): module.train(is_training) module.disable_quant = False - elif isinstance(module, ActQuantProxyFromInjector): + elif isinstance(module, ActQuantProxyFromInjectorBase): module.disable_quant = False module.train(is_training) for hook in self.disable_act_quant_hooks: diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index fcb0be367..9d94df12f 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -141,13 +141,14 @@ class CNNInt8DynamicActPerTensorFloat(Int8DynamicActPerTensorFloat): 'sym': CNNInt8DynamicActPerTensorFloat, 'asym': CNNShiftedUint8DynamicActPerTensorFloat}}}}}, 'float': { - 'float_scale': { - 'stats': { - 'per_tensor': { - 'sym': Fp8e4m3ActPerTensorFloat}}, - 'mse': { - 'per_tensor': { - 'sym': Fp8e4m3ActPerTensorFloatMSE}}}}} + 'static': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3ActPerTensorFloat}}, + 'mse': { + 'per_tensor': { + 'sym': Fp8e4m3ActPerTensorFloatMSE}}}}}} def quantize_model(