Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Sep 23, 2024
1 parent d3b4d5f commit 5e5d9e7
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 3 deletions.
20 changes: 20 additions & 0 deletions src/brevitas/core/stats/stats_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,19 @@ def _set_local_loss_mode(module, enabled):
m.local_loss_mode = enabled


def _set_observer_mode(module, enabled, previous_observer_mode):
for m in module.modules():
if hasattr(m, 'observer_only'):
previous_observer_mode[m] = m.observer_only
m.observer_only = enabled


def _restore_observer_mode(module, previous_observer_mode):
for m in module.modules():
if hasattr(m, 'observer_only'):
m.observer_only = previous_observer_mode[m]


class MSE(torch.nn.Module):
# References:
# https://github.com/cornell-zhang/dnn-quant-ocs/blob/master/distiller/quantization/clip.py
Expand All @@ -459,7 +472,12 @@ def __init__(
self.mse_init_op = mse_init_op
self.input_view_shape_impl = inner_stats_input_view_shape_impl
self.proxy_forward = proxy_module.forward
self.previous_observer_mode = dict()
self.set_local_loss_mode = lambda enabled: _set_local_loss_mode(proxy_module, enabled)
self.set_observer_mode = lambda enabled: _set_observer_mode(
proxy_module, enabled, self.previous_observer_mode)
self.restore_observer_mode = lambda: _restore_observer_mode(
proxy_module, self.previous_observer_mode)
self.internal_candidate = None
self.num = mse_iters
self.search_method = mse_search_method
Expand All @@ -480,10 +498,12 @@ def evaluate_loss(self, x, candidate):
self.internal_candidate = candidate
# Set to local_loss_mode before calling the proxy
self.set_local_loss_mode(True)
self.set_observer_mode(False)
quant_value = self.proxy_forward(x)
quant_value = _unpack_quant_tensor(quant_value)
loss = self.mse_loss_fn(x, quant_value)
self.set_local_loss_mode(False)
self.restore_observer_mode()
return loss

def mse_grid_search(self, xl, x):
Expand Down
7 changes: 4 additions & 3 deletions tests/brevitas/core/test_float_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ def test_float_to_quant_float(inp, minifloat_format):
signed=signed,
float_clamp_impl=float_clamp)
expected_out, *_ = float_quant(inp)

out_quant, scale = float_quant.quantize(inp)
scale = float_quant.scaling_impl(inp)
out_quant = float_quant.quantize(inp, scale)
exponent_bit_width, mantissa_bit_width, exponent_bias = torch.tensor(exponent_bit_width, dtype=torch.float), torch.tensor(mantissa_bit_width, dtype=torch.float), torch.tensor(exponent_bias, dtype=torch.float)
out_quant, *_ = float_quant.float_clamp_impl(
out_quant, exponent_bit_width, mantissa_bit_width, exponent_bias)
Expand Down Expand Up @@ -142,7 +142,8 @@ def test_scaling_impls_called_once(inp, minifloat_format):
scaling_impl=scaling_impl,
float_scaling_impl=float_scaling_impl,
float_clamp_impl=float_clamp)
_ = float_quant.quantize(inp)
scale = float_quant.scaling_impl(inp)
_ = float_quant.quantize(inp, scale)
# scaling implementations should be called exaclty once on the input
float_scaling_impl.assert_called_once_with(
torch.tensor(exponent_bit_width),
Expand Down

0 comments on commit 5e5d9e7

Please sign in to comment.