diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 8810e2af2..a70394e07 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -147,14 +147,9 @@ def state_dict(self, destination=None, prefix='', keep_vars=False): @staticmethod def gate_params_fwd(gate, quant_input): - acc_scale = None quant_weight_ih = gate.input_weight() quant_weight_hh = gate.hidden_weight() - if isinstance(quant_input, QuantTensor) and isinstance(quant_weight_ih, QuantTensor): - acc_scale_shape = compute_channel_view_shape(quant_input.value, channel_dim=1) - acc_scale = quant_weight_ih.scale.view(acc_scale_shape) - acc_scale = acc_scale * quant_input.scale.view(acc_scale_shape) - quant_bias = gate.bias_quant(gate.bias, acc_scale) + quant_bias = gate.bias_quant(gate.bias, quant_input, quant_weight_ih) return quant_weight_ih, quant_weight_hh, quant_bias def reset_parameters(self) -> None: diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 0a82afb9b..cd5f48418 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -118,14 +118,6 @@ def inner_forward_impl(self, x: Tensor, quant_weight: Tensor, quant_bias: Option def max_acc_bit_width(self, input_bit_width: Tensor, quant_weight_bit_width: Tensor): pass - def quant_output_scale_impl( - self, inp: Tensor, quant_input_scale: Tensor, quant_weight_scale: Tensor): - channel_dim = -1 if isinstance(self, torch.nn.Linear) else 1 - output_scale_shape = compute_channel_view_shape(inp, channel_dim=channel_dim) - output_scale = quant_weight_scale.view(output_scale_shape) - output_scale = output_scale * quant_input_scale.view(output_scale_shape) - return output_scale - @property def requires_export_handler(self): return ( @@ -150,7 +142,6 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe return out quant_input = self.input_quant(inp) - quant_weight = self.quant_weight(quant_input) compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance( @@ -159,12 +150,8 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe self.output_quant.is_quant_enabled) and self.return_quant_tensor: raise RuntimeError("QuantLayer is not correctly configured") - output_scale = None - if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): - output_scale = self.quant_output_scale_impl(inp, quant_input.scale, quant_weight.scale) - if self.bias is not None: - quant_bias = self.bias_quant(self.bias, output_scale) + quant_bias = self.bias_quant(self.bias, quant_input, quant_weight) else: quant_bias = None output_tensor = self.inner_forward_impl(quant_input, quant_weight, quant_bias) diff --git a/src/brevitas/proxy/parameter_quant.py b/src/brevitas/proxy/parameter_quant.py index 893ff5e30..4a95b5e0f 100644 --- a/src/brevitas/proxy/parameter_quant.py +++ b/src/brevitas/proxy/parameter_quant.py @@ -17,6 +17,7 @@ from brevitas.inject import BaseInjector as Injector from brevitas.quant_tensor import QuantTensor from brevitas.utils.quant_utils import _CachedIO +from brevitas.utils.torch_utils import compute_channel_view_shape from .quant_proxy import QuantProxyFromInjector from .quant_proxy import QuantProxyProtocol @@ -234,10 +235,38 @@ def bit_width(self): bit_width = self.__call__(self.tracked_parameter_list[0], zhs).bit_width return bit_width - def forward(self, - x: Tensor, - input_scale: Optional[Tensor] = None) -> Union[Tensor, QuantTensor]: + def quant_output_scale_impl( + self, input: QuantTensor, weight: QuantTensor, module: torch.nn.Module) -> Tensor: + channel_dim = -1 if isinstance(module, torch.nn.Linear) else 1 + output_scale_shape = compute_channel_view_shape(input, channel_dim=channel_dim) + output_scale = weight.scale.view(output_scale_shape) + output_scale = output_scale * input.scale.view(output_scale_shape) + return output_scale + + def compute_bias_scale( + self, + input: Optional[Union[Tensor, QuantTensor]], + weight: Optional[Union[Tensor, QuantTensor]]) -> Optional[Tensor]: + if not self.requires_input_scale: + return None + if not isinstance(input, QuantTensor) or not isinstance(weight, QuantTensor): + return None + if len(self.tracked_module_list) > 1: + if not all( + [type[self.tracked_module_list[0]] == type[x] for x in self.tracked_module_list]): + raise RuntimeError( + "Bias quantizer shared across different type of layers with external scale is not supported." + ) + scale = self.quant_output_scale_impl(input, weight, self.tracked_module_list[0]) + return scale + + def forward( + self, + x: Tensor, + input: Optional[Union[Tensor, QuantTensor]] = None, + weight: Optional[Union[Tensor, QuantTensor]] = None) -> Union[Tensor, QuantTensor]: out = x + input_scale = self.compute_bias_scale(input, weight) if self.is_quant_enabled: impl = self.export_handler if self.export_mode else self.tensor_quant if self.requires_input_scale and input_scale is None: