From 4b1377a119585e4275a32afe623d9448b5eb9260 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sun, 1 Sep 2024 23:25:23 +0100 Subject: [PATCH] Fix tests + JIT --- src/brevitas/core/function_wrapper/shape.py | 4 ++-- src/brevitas/utils/torch_utils.py | 4 ++-- tests/brevitas/nn/nn_quantizers_fixture.py | 12 ++++++------ tests/brevitas/nn/test_nn_quantizers.py | 5 ----- 4 files changed, 10 insertions(+), 15 deletions(-) diff --git a/src/brevitas/core/function_wrapper/shape.py b/src/brevitas/core/function_wrapper/shape.py index 67795485a..f1dfc7796 100644 --- a/src/brevitas/core/function_wrapper/shape.py +++ b/src/brevitas/core/function_wrapper/shape.py @@ -166,7 +166,7 @@ def __init__(self, expanded_groupwise_shape, group_size, group_dim) -> None: @brevitas.jit.script_method def forward(self, x: torch.Tensor): y = torch.nn.functional.pad( - x, padding(x, self.group_size, self.group_dim), mode='constant', value=0) + x, padding(x, self.group_size, self.group_dim), mode='constant', value=0.) y = y.view(self.expanded_groupwise_shape) return y @@ -186,7 +186,7 @@ def forward(self, x): tensor_shape_list = list(tensor_shape) pad = padding(x, self.group_size, self.group_dim) - x = torch.nn.functional.pad(x, pad, mode='constant', value=0) + x = torch.nn.functional.pad(x, pad, mode='constant', value=0.) tensor_shape = x.shape tensor_shape_list = list(tensor_shape) diff --git a/src/brevitas/utils/torch_utils.py b/src/brevitas/utils/torch_utils.py index 225cacaac..2f0d34fba 100644 --- a/src/brevitas/utils/torch_utils.py +++ b/src/brevitas/utils/torch_utils.py @@ -2,7 +2,7 @@ # SPDX-License-Identifier: BSD-3-Clause import copy -from typing import Optional, Tuple +from typing import List, Optional, Tuple import torch from torch.nn import Sequential @@ -105,7 +105,7 @@ def float_internal_scale( @brevitas.jit.ignore -def padding(x, group_size, group_dim): +def padding(x: torch.Tensor, group_size: int, group_dim: int) -> List[int]: # Given a tensor X, compute the padding aloing group_dim so that groupwise shaping is possible padding = [0, 0] * len(x.shape) size = x.shape diff --git a/tests/brevitas/nn/nn_quantizers_fixture.py b/tests/brevitas/nn/nn_quantizers_fixture.py index 1789fc3fe..a6b1c05af 100644 --- a/tests/brevitas/nn/nn_quantizers_fixture.py +++ b/tests/brevitas/nn/nn_quantizers_fixture.py @@ -124,7 +124,7 @@ def build_case_model( weight_quant_name, weight_quantizer = weight_quantizer bias_quant_name, bias_quantizer = bias_quantizer io_quant_name, io_quantizer = io_quantizer - print(io_quant_name) + if ((io_quantizer is None and not input_quantized) or 'float' in io_quant_name) and weight_quant_name in A2Q_WBIOL_WEIGHT_QUANTIZER: pytest.skip( @@ -134,8 +134,6 @@ def build_case_model( 'mx' not in io_quant_name) or ('mx' not in weight_quant_name and 'mx' in io_quant_name): pytest.skip("MX requires input and weights quantization to be aligned") elif weight_quantizer == MXInt8Weight: - if config.JIT_ENABLED: - pytest.skip("Dynamic act quant is not compatible with JIT") if bias_quant_name != 'quant_internal': pytest.skip("MX quant does not support external scaled bias") elif weight_quantizer == Fp8e4m3WeightPerTensorFloat or io_quantizer == Fp8e4m3ActPerTensorFloat: @@ -640,16 +638,18 @@ def case_mha( # Change the case_id based on current value of Parameters set_case_id(request.node.callspec.id, case_mha) - k, weight_quantizer = weight_quantizer + weight_quant_name, weight_quantizer = weight_quantizer _, bias_quantizer = bias_quantizer _, io_quantizer = io_quantizer - if io_quantizer is None and k in A2Q_WBIOL_WEIGHT_QUANTIZER: + if io_quantizer is None and weight_quant_name in A2Q_WBIOL_WEIGHT_QUANTIZER: # Can't rely on a QuantTensor input for quant_mha at this point pytest.skip( "A2Q uses an input-aware decoupled weight proxy that requires a quantized input tensor." ) - + # TODO: restore compatibility + if ('mx' in weight_quant_name or 'float' in weight_quant_name): + pytest.skip("MX/Float quant not supported for MHA") # BatchQuant1d works over 3d input but not 2d, so we have a separate quantizer for out_proj if isinstance(io_quantizer, tuple): io_quantizer, out_proj_io_quantizer = io_quantizer diff --git a/tests/brevitas/nn/test_nn_quantizers.py b/tests/brevitas/nn/test_nn_quantizers.py index 194fe7aae..db4f21e02 100644 --- a/tests/brevitas/nn/test_nn_quantizers.py +++ b/tests/brevitas/nn/test_nn_quantizers.py @@ -173,11 +173,6 @@ def test_quant_mha(model_input, current_cases): args = case_id.split('-')[1:] # Exclude first argument kwargs = parse_args(args) - # TODO: restore compatibility - skipped_quant = ['quant_mx', 'quant_float'] - if kwargs['io_quant'] in skipped_quant or kwargs['weight_quant'] in skipped_quant: - pytest.skip("MX and Float quant not supported for MHA") - is_input_quanttensor = kwargs['io_quant'] is not None or kwargs['input_quantized'] if (not is_input_quanttensor or kwargs['weight_quant'] is None) and kwargs['bias_quant'] == 'quant_external':