From ed42d2e999e9208e7bbf7d39c23028b1f1130d8c Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Wed, 17 Apr 2024 15:50:48 +0100 Subject: [PATCH] Rename --- src/brevitas/quant_tensor/torch_handler.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/brevitas/quant_tensor/torch_handler.py b/src/brevitas/quant_tensor/torch_handler.py index f0d9e95f8..670439136 100644 --- a/src/brevitas/quant_tensor/torch_handler.py +++ b/src/brevitas/quant_tensor/torch_handler.py @@ -219,7 +219,7 @@ def avg_pool2d_handler( count_include_pad, divisor_override) - max_acc_bit_width = IMPLS[F.avg_pool2d] + max_acc_bit_width = FN_ACC_BITWIDTH_MAPPING[F.avg_pool2d] # remove avg scaling if isinstance(kernel_size, tuple): avg_scaling = kernel_size[0] * kernel_size[1] @@ -242,7 +242,7 @@ def adaptive_avg_pool2d_handler(quant_input, output_shape): x = F.adaptive_avg_pool2d(_unpack_quant_tensor(quant_input), output_shape) k_size, stride = TruncAdaptiveAvgPool2d.compute_kernel_size_stride(quant_input.value.shape[2:], x.shape[2:]) - max_acc_bit_width = IMPLS[F.avg_pool2d] + max_acc_bit_width = FN_ACC_BITWIDTH_MAPPING[F.avg_pool2d] reduce_size = reduce(mul, k_size, 1) rescaled_value = x * reduce_size # remove avg scaling @@ -251,7 +251,7 @@ def adaptive_avg_pool2d_handler(quant_input, output_shape): return quant_input -def quant_layer(cls, quant_input, quant_weight, bias, *args, **kwargs): +def quant_layer(fn, quant_input, quant_weight, bias, *args, **kwargs): from brevitas.quant_tensor import _unpack_quant_tensor from brevitas.quant_tensor import QuantTensor @@ -259,20 +259,20 @@ def quant_layer(cls, quant_input, quant_weight, bias, *args, **kwargs): output_bit_width = None output_zero_point = None output_signed = None - max_acc_bit_width = IMPLS[cls] + max_acc_bit_width = FN_ACC_BITWIDTH_MAPPING[fn] compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance( quant_weight, QuantTensor) if bias is None: - output = cls( + output = fn( _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None, *args, **kwargs) else: - output = cls( + output = fn( _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), _unpack_quant_tensor(bias), @@ -287,7 +287,7 @@ def quant_layer(cls, quant_input, quant_weight, bias, *args, **kwargs): *args, **kwargs) output_scale = quant_output_scale_impl( - cls, quant_input.value, quant_input.scale, quant_weight.scale) + fn, quant_input.value, quant_input.scale, quant_weight.scale) output_signed = quant_input.signed or quant_weight.signed output_training = quant_input.training or quant_weight.training @@ -296,7 +296,7 @@ def quant_layer(cls, quant_input, quant_weight, bias, *args, **kwargs): if (isinstance(bias, QuantTensor) and not torch.allclose(bias.scale, output_scale)) or not isinstance(bias, QuantTensor): - channel_dim = -1 if isinstance(cls, torch.nn.Linear) else 1 + channel_dim = -1 if isinstance(fn, torch.nn.Linear) else 1 output_scale_broadcast_shape = compute_channel_view_shape( quant_input, channel_dim=channel_dim) output_zero_point = -_unpack_quant_tensor(bias).view( @@ -339,8 +339,8 @@ def create_quant_tensor(tensor, scale, bit_width, zero_point, signed, training): training=training) -def quant_output_scale_impl(cls, inp, quant_input_scale, quant_weight_scale): - channel_dim = -1 if cls == F.linear else 1 +def quant_output_scale_impl(fn, inp, quant_input_scale, quant_weight_scale): + channel_dim = -1 if fn == F.linear else 1 output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim) output_scale = quant_weight_scale.view(output_scale_shape) output_scale = output_scale * quant_input_scale.view(output_scale_shape) @@ -390,7 +390,7 @@ def max_acc_bit_width_avg_pool2d(input_bit_width, avg_scaling): return max_output_bit_width -IMPLS = { +FN_ACC_BITWIDTH_MAPPING = { F.linear: max_acc_bit_width_linear, F.conv1d: max_acc_bit_width_convNd, F.conv2d: max_acc_bit_width_convNd,