diff --git a/tests/brevitas/core/test_stats.py b/tests/brevitas/core/test_stats.py index 24586131e..69a206dd8 100644 --- a/tests/brevitas/core/test_stats.py +++ b/tests/brevitas/core/test_stats.py @@ -68,5 +68,7 @@ def test_interval_percentile(self): out = interval_percentile(values) range = self.compute_percentile(values, low_q=0.01, high_q=99.9) - expected_out = torch.abs(range[1] - range[0]) + # Clamp is to make sure the lower bound is not positive to align with zero-point statistics + low_result = torch.clamp(range[0], max=torch.tensor(0.0)) + expected_out = torch.abs(range[1] - low_result) assert torch.allclose(out, expected_out) diff --git a/tests/brevitas/fx/test_tracer.py b/tests/brevitas/fx/test_tracer.py index 23a8efc95..be5698d2c 100644 --- a/tests/brevitas/fx/test_tracer.py +++ b/tests/brevitas/fx/test_tracer.py @@ -238,4 +238,4 @@ def test_quant_module(module): out = mod(x) graph_model = value_trace(mod, value_args={'x': x_trace}) graph_out = graph_model(x) - assert graph_out.value.isclose(out.value).all().item() + assert graph_out.isclose(out).all().item()