diff --git a/src/brevitas/quant_tensor/int_torch_handler.py b/src/brevitas/quant_tensor/int_torch_handler.py index fb5db7ca1..3258b8914 100644 --- a/src/brevitas/quant_tensor/int_torch_handler.py +++ b/src/brevitas/quant_tensor/int_torch_handler.py @@ -234,8 +234,6 @@ def quant_output_scale_impl( output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim) quant_weight_scale = quant_weight_scale.view(output_scale_shape) - if len(quant_input_scale.shape) == 0: - quant_input_scale = quant_input_scale.view(output_scale_shape) quant_input_scale = quant_input_scale.view(output_scale_shape) if not is_broadcastable(quant_weight_scale.shape, quant_input_scale.shape): return None