From 5e5d9e78121166ceded75ccc3cd8ad237c975cdb Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Mon, 23 Sep 2024 17:49:28 +0100 Subject: [PATCH] fix tests --- src/brevitas/core/stats/stats_op.py | 20 ++++++++++++++++++++ tests/brevitas/core/test_float_quant.py | 7 ++++--- 2 files changed, 24 insertions(+), 3 deletions(-) diff --git a/src/brevitas/core/stats/stats_op.py b/src/brevitas/core/stats/stats_op.py index 461aeb3e6..3cd6172d7 100644 --- a/src/brevitas/core/stats/stats_op.py +++ b/src/brevitas/core/stats/stats_op.py @@ -442,6 +442,19 @@ def _set_local_loss_mode(module, enabled): m.local_loss_mode = enabled +def _set_observer_mode(module, enabled, previous_observer_mode): + for m in module.modules(): + if hasattr(m, 'observer_only'): + previous_observer_mode[m] = m.observer_only + m.observer_only = enabled + + +def _restore_observer_mode(module, previous_observer_mode): + for m in module.modules(): + if hasattr(m, 'observer_only'): + m.observer_only = previous_observer_mode[m] + + class MSE(torch.nn.Module): # References: # https://github.com/cornell-zhang/dnn-quant-ocs/blob/master/distiller/quantization/clip.py @@ -459,7 +472,12 @@ def __init__( self.mse_init_op = mse_init_op self.input_view_shape_impl = inner_stats_input_view_shape_impl self.proxy_forward = proxy_module.forward + self.previous_observer_mode = dict() self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled) + self.set_observer_mode = lambda enabled: _set_observer_mode( + proxy_module, enabled, self.previous_observer_mode) + self.restore_observer_mode = lambda: _restore_observer_mode( + proxy_module, self.previous_observer_mode) self.internal_candidate = None self.num = mse_iters self.search_method = mse_search_method @@ -480,10 +498,12 @@ def evaluate_loss(self, x, candidate): self.internal_candidate = candidate # Set to local_loss_mode before calling the proxy self.set_local_loss_mode(True) + self.set_observer_mode(False) quant_value = self.proxy_forward(x) quant_value = _unpack_quant_tensor(quant_value) loss = self.mse_loss_fn(x, quant_value) self.set_local_loss_mode(False) + self.restore_observer_mode() return loss def mse_grid_search(self, xl, x): diff --git a/tests/brevitas/core/test_float_quant.py b/tests/brevitas/core/test_float_quant.py index 16b8a4b5f..52352c38b 100644 --- a/tests/brevitas/core/test_float_quant.py +++ b/tests/brevitas/core/test_float_quant.py @@ -98,8 +98,8 @@ def test_float_to_quant_float(inp, minifloat_format): signed=signed, float_clamp_impl=float_clamp) expected_out, *_ = float_quant(inp) - - out_quant, scale = float_quant.quantize(inp) + scale = float_quant.scaling_impl(inp) + out_quant = float_quant.quantize(inp, scale) exponent_bit_width, mantissa_bit_width, exponent_bias = torch.tensor(exponent_bit_width, dtype=torch.float), torch.tensor(mantissa_bit_width, dtype=torch.float), torch.tensor(exponent_bias, dtype=torch.float) out_quant, *_ = float_quant.float_clamp_impl( out_quant, exponent_bit_width, mantissa_bit_width, exponent_bias) @@ -142,7 +142,8 @@ def test_scaling_impls_called_once(inp, minifloat_format): scaling_impl=scaling_impl, float_scaling_impl=float_scaling_impl, float_clamp_impl=float_clamp) - _ = float_quant.quantize(inp) + scale = float_quant.scaling_impl(inp) + _ = float_quant.quantize(inp, scale) # scaling implementations should be called exaclty once on the input float_scaling_impl.assert_called_once_with( torch.tensor(exponent_bit_width),