From 9818eec0d66136afef71d25bbb778d61461ddf25 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Thu, 29 Feb 2024 18:27:06 +0000 Subject: [PATCH] Feat (tests): add new tests for proxy --- tests/brevitas/proxy/test_proxy.py | 79 +++++++++++++++++++++ tests/brevitas/proxy/test_weight_scaling.py | 1 - 2 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 tests/brevitas/proxy/test_proxy.py diff --git a/tests/brevitas/proxy/test_proxy.py b/tests/brevitas/proxy/test_proxy.py new file mode 100644 index 000000000..eafdd37c3 --- /dev/null +++ b/tests/brevitas/proxy/test_proxy.py @@ -0,0 +1,79 @@ +import pytest + +from brevitas.nn import QuantLinear +from brevitas.nn.quant_activation import QuantReLU +from brevitas.quant.scaled_int import Int8AccumulatorAwareWeightQuant +from brevitas.quant.scaled_int import Int8BiasPerTensorFloatInternalScaling +from brevitas.quant.scaled_int import Int8WeightPerChannelFloatDecoupled +from brevitas.quant.scaled_int import Int8WeightPerTensorFloat +from brevitas_examples.common.generative.quantizers import Int8DynamicActPerTensorFloat + + +class TestProxy: + + def test_bias_proxy(self): + model = QuantLinear(10, 5, bias_quant=Int8BiasPerTensorFloatInternalScaling) + assert model.weight_quant.scale() is not None + assert model.weight_quant.zero_point() is not None + assert model.weight_quant.bit_width() is not None + + model.weight_quant.disable_quant = True + assert model.weight_quant.scale() is None + assert model.weight_quant.zero_point() is None + assert model.weight_quant.bit_width() is None + + def test_weight_proxy(self): + model = QuantLinear(10, 5, weight_quant=Int8WeightPerTensorFloat) + assert model.weight_quant.scale() is not None + assert model.weight_quant.zero_point() is not None + assert model.weight_quant.bit_width() is not None + + model.weight_quant.disable_quant = True + assert model.weight_quant.scale() is None + assert model.weight_quant.zero_point() is None + assert model.weight_quant.bit_width() is None + + def test_weight_decoupled_proxy(self): + model = QuantLinear(10, 5, weight_quant=Int8WeightPerChannelFloatDecoupled) + assert model.weight_quant.pre_scale() is not None + assert model.weight_quant.pre_zero_point() is not None + + model.weight_quant.disable_quant = True + assert model.weight_quant.pre_scale() is None + assert model.weight_quant.pre_zero_point() is None + + def test_weight_decoupled_with_input_proxy(self): + model = QuantLinear(10, 5, weight_quant=Int8AccumulatorAwareWeightQuant) + with pytest.raises(NotImplementedError): + model.weight_quant.scale() + with pytest.raises(NotImplementedError): + model.weight_quant.zero_point() + + with pytest.raises(NotImplementedError): + model.weight_quant.pre_scale() + with pytest.raises(NotImplementedError): + model.weight_quant.pre_zero_point() + + def test_act_proxy(self): + model = QuantReLU() + assert model.act_quant.scale() is not None + assert model.act_quant.zero_point() is not None + assert model.act_quant.bit_width() is not None + + model.act_quant.disable_quant = True + assert model.act_quant.scale() is None + assert model.act_quant.zero_point() is None + assert model.act_quant.bit_width() is None + + def test_act_proxy(self): + model = QuantReLU(Int8DynamicActPerTensorFloat) + + with pytest.raises(NotImplementedError): + model.act_quant.scale() + with pytest.raises(NotImplementedError): + model.act_quant.zero_point() + + assert model.act_quant.bit_width() is not None + + model.act_quant.disable_quant = True + assert model.act_quant.bit_width() is None diff --git a/tests/brevitas/proxy/test_weight_scaling.py b/tests/brevitas/proxy/test_weight_scaling.py index 074ca7c61..49a7f20fe 100644 --- a/tests/brevitas/proxy/test_weight_scaling.py +++ b/tests/brevitas/proxy/test_weight_scaling.py @@ -1,7 +1,6 @@ # Copyright (C) 2023, Advanced Micro Devices, Inc. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause -import pytest from torch import nn from brevitas import config