diff --git a/docs/tutorials/quant_tensor_quant_conv2d_overview.ipynb b/docs/tutorials/quant_tensor_quant_conv2d_overview.ipynb index 2e9ef9179..6ed3cadf3 100644 --- a/docs/tutorials/quant_tensor_quant_conv2d_overview.ipynb +++ b/docs/tutorials/quant_tensor_quant_conv2d_overview.ipynb @@ -132,7 +132,7 @@ "print(f'Is weight quant enabled: {default_quant_conv.is_weight_quant_enabled}')\n", "print(f'Is bias quant enabled: {default_quant_conv.is_bias_quant_enabled}')\n", "print(f'Is input quant enabled: {default_quant_conv.is_input_quant_enabled}')\n", - "print(f'Is output quant enabled: {default_quant_conv.is_output_quant_enabled}')" + "print(f'Is output quant enabled: {default_quant_conv.output_quant.is_quant_enabled}')" ] }, { diff --git a/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb b/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb index ef04daf60..0b482c614 100644 --- a/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb +++ b/notebooks/01_quant_tensor_quant_conv2d_overview.ipynb @@ -145,7 +145,7 @@ "print(f'Is weight quant enabled: {default_quant_conv.is_weight_quant_enabled}')\n", "print(f'Is bias quant enabled: {default_quant_conv.is_bias_quant_enabled}')\n", "print(f'Is input quant enabled: {default_quant_conv.is_input_quant_enabled}')\n", - "print(f'Is output quant enabled: {default_quant_conv.is_output_quant_enabled}')" + "print(f'Is output quant enabled: {default_quant_conv.output_quant.is_quant_enabled}')" ] }, { diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/act.py b/src/brevitas/export/onnx/standard/qoperator/handler/act.py index 91f4b18db..418c96b0a 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/act.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/act.py @@ -21,11 +21,11 @@ class StdQOpONNXQuantNLALHandler(StdQOpONNXQuantLayerHandler, ABC): @classmethod def validate(cls, module: QuantNLAL): if cls.input_quant_supported and module.is_input_quant_enabled: - assert not module.is_quant_input_narrow_range, "Narrow range quant not supported." + assert not module.input_quant.is_quant_enabled, "Narrow range quant not supported." elif not cls.input_quant_supported and module.is_input_quant_enabled: raise RuntimeError("Input quant not supported.") - if module.is_act_quant_enabled: - assert not module.is_quant_act_narrow_range, "Narrow range quant not supported." + if module.act_quant.is_quant_enabled: + assert not module.act_quant.is_narrow_range, "Narrow range quant not supported." input_bit_width = module.input_quant.bit_width() act_bit_width = module.act_quant.bit_width() if input_bit_width is not None: diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/base.py b/src/brevitas/export/onnx/standard/qoperator/handler/base.py index 449e059ef..fde6a4b75 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/base.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/base.py @@ -56,9 +56,9 @@ def quant_output_shape(cls, module): @classmethod def output_quant_symbolic_kwargs(cls, module): - if module.is_output_quant_enabled: - quant_proxy = module.act_quant if isinstance( - module, QuantNonLinearActLayer) else module.output_quant + quant_proxy = module.act_quant if isinstance( + module, QuantNonLinearActLayer) else module.output_quant + if quant_proxy.is_quant_enabled: return { 'output_scale': quant_proxy.scale(), 'output_zero_point': cls.quant_output_zero_point(module), @@ -69,10 +69,10 @@ def output_quant_symbolic_kwargs(cls, module): @classmethod def output_clip_symbolic_kwargs(cls, module): - if module.is_output_quant_enabled: - quant_proxy = module.act_quant if isinstance( - module, QuantNonLinearActLayer) else module.output_quant - narrow = module.is_quant_output_narrow_range + quant_proxy = module.act_quant if isinstance( + module, QuantNonLinearActLayer) else module.output_quant + if quant_proxy.is_quant_enabled: + narrow = quant_proxy.is_narrow_range signed = quant_proxy.signed() bit_width = quant_proxy.bit_width() return cls.int_clip_symbolic_kwargs(narrow, signed, bit_width) @@ -82,7 +82,7 @@ def output_clip_symbolic_kwargs(cls, module): @classmethod def input_clip_symbolic_kwargs(cls, module): if module.is_input_quant_enabled: - narrow = module.is_quant_input_narrow_range + narrow = module.input_quant.is_quant_enabled signed = module.input_quant.signed() bit_width = module.input_quant.bit_width() return cls.int_clip_symbolic_kwargs(narrow, signed, bit_width) diff --git a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py index 376820102..2784af56c 100644 --- a/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py +++ b/src/brevitas/export/onnx/standard/qoperator/handler/parameter.py @@ -42,12 +42,12 @@ def int_bias(module: QuantWBIOL): @classmethod def validate(cls, module: QuantWBIOL, requires_quant_bias=True): assert module.is_weight_quant_enabled, 'Weight quant required' - assert module.is_output_quant_enabled, 'Output quant required' + assert module.output_quant.is_quant_enabled, 'Output quant required' # Handling narrow_range is across the network is difficult do to the fact that # it's not part of QuantTensor, and so it can't be cached - assert not module.is_quant_output_narrow_range, 'Narrow output quant not supported' + assert not module.output_quant.is_narrow_range, 'Narrow output quant not supported' if module.is_input_quant_enabled: - assert not module.is_quant_input_narrow_range, 'Narrow output quant not supported' + assert not module.input_quant.is_quant_enabled, 'Narrow output quant not supported' cls.validate_8b_bit_width(module.weight_quant.bit_width(), le_then=True) cls.validate_8b_bit_width(module.input_quant.bit_width(), le_then=True) cls.validate_8b_bit_width(module.output_quant.bit_width(), le_then=True) diff --git a/src/brevitas/export/torch/qoperator/handler/act.py b/src/brevitas/export/torch/qoperator/handler/act.py index d2b7f8061..502e8eee6 100644 --- a/src/brevitas/export/torch/qoperator/handler/act.py +++ b/src/brevitas/export/torch/qoperator/handler/act.py @@ -29,7 +29,7 @@ def validate(cls, module: QuantNLAL): def prepare_for_export(self, module: QuantNLAL): self.validate(module) self.qf_impl, self.qf_kwargs = self.prepare_qf(module) - if module.is_act_quant_enabled: + if module.act_quant.is_quant_enabled: self.output_quant_impl, self.output_quant_kwargs = self.prepare_output_quant(module) self.return_quant_tensor = module.return_quant_tensor diff --git a/src/brevitas/export/torch/qoperator/handler/parameter.py b/src/brevitas/export/torch/qoperator/handler/parameter.py index 6314320cf..772e5a90e 100644 --- a/src/brevitas/export/torch/qoperator/handler/parameter.py +++ b/src/brevitas/export/torch/qoperator/handler/parameter.py @@ -32,8 +32,8 @@ def __init__(self): @classmethod def validate(cls, module: QuantWBIOL): - assert module.is_weight_quant_enabled, 'Weight quantization required' - assert module.is_output_quant_enabled, 'Output quantization required' + assert module.quant_weight.is_quant_enabled, 'Weight quantization required' + assert module.quant_output.is_quant_enabled, 'Output quantization required' @classmethod def prepare_bias(cls, module: QuantWBIOL): diff --git a/src/brevitas/nn/mixin/act.py b/src/brevitas/nn/mixin/act.py index e1c5f9393..6b1492de8 100644 --- a/src/brevitas/nn/mixin/act.py +++ b/src/brevitas/nn/mixin/act.py @@ -32,13 +32,13 @@ def __init__(self, input_quant: Optional[ActQuantType], **kwargs): input_passthrough_act=True, **kwargs) - @property - def is_input_quant_enabled(self): - return self.input_quant.is_quant_enabled + # @property + # def is_input_quant_enabled(self): + # return self.input_quant.is_quant_enabled - @property - def is_quant_input_narrow_range(self): # TODO make abstract once narrow range can be cached - return self.input_quant.is_narrow_range + # @property + # def is_quant_input_narrow_range(self): # TODO make abstract once narrow range can be cached + # return self.input_quant.is_narrow_range # @property # @abstractmethod @@ -61,13 +61,13 @@ def __init__(self, output_quant: Optional[ActQuantType], **kwargs): output_passthrough_act=True, **kwargs) - @property - def is_output_quant_enabled(self): - return self.output_quant.is_quant_enabled + # @property + # def is_output_quant_enabled(self): + # return self.output_quant.is_quant_enabled - @property - def is_quant_output_narrow_range(self): # TODO make abstract once narrow range can be cached - return self.output_quant.is_narrow_range + # @property + # def is_quant_output_narrow_range(self): # TODO make abstract once narrow range can be cached + # return self.output_quant.is_narrow_range # @property # @abstractmethod @@ -99,13 +99,13 @@ def __init__( **prefixed_kwargs, **kwargs) - @property - def is_act_quant_enabled(self): - return self.act_quant.is_quant_enabled + # @property + # def is_act_quant_enabled(self): + # return self.act_quant.is_quant_enabled - @property - def is_quant_act_narrow_range(self): # TODO make abstract once narrow range can be cached - return self.act_quant.is_narrow_range + # @property + # def is_quant_act_narrow_range(self): # TODO make abstract once narrow range can be cached + # return self.act_quant.is_narrow_range # @property # @abstractmethod diff --git a/src/brevitas/nn/quant_layer.py b/src/brevitas/nn/quant_layer.py index 110c1a394..e0a78c6bd 100644 --- a/src/brevitas/nn/quant_layer.py +++ b/src/brevitas/nn/quant_layer.py @@ -41,15 +41,15 @@ def channelwise_separable(self) -> bool: @property def requires_export_handler(self): - return self.is_input_quant_enabled or self.is_act_quant_enabled + return self.quant_input.is_quant_enabled or self.act_quant.is_quant_enabled - @property - def is_output_quant_enabled(self): - return self.is_act_quant_enabled + # @property + # def is_output_quant_enabled(self): + # return self.act_quant.is_quant_enabled - @property - def is_quant_output_narrow_range(self): - return self.is_quant_act_narrow_range + # @property + # def is_quant_output_narrow_range(self): + # return self.act_quant.is_narrow_range def forward(self, input: Union[Tensor, QuantTensor]): input = self.unpack_input(input) @@ -87,16 +87,16 @@ def __init__( QuantOutputMixin.__init__(self, output_quant, **kwargs) # we have to account for quantization being enabled through kwargs if tie_input_output_quant: - if self.is_input_quant_enabled and self.is_output_quant_enabled: + if self.quant_input.is_quant_enabled and self.act_quant.is_quant_enabled: raise RuntimeError("Enable only input or output quant with tie_input_output=True") - if self.is_input_quant_enabled: + if self.quant_input.is_quant_enabled: self.output_quant = self.input_quant - if self.is_output_quant_enabled: + if self.act_quant.is_quant_enabled: self.input_quant = self.output_quant @property def requires_export_handler(self): - return self.is_input_quant_enabled or self.is_output_quant_enabled + return self.quant_input.is_quant_enabled or self.act_quant.is_quant_enabled class QuantWeightBiasInputOutputLayer(QuantBiasMixin, QuantWeightMixin, QuantInputOutputLayer): @@ -138,8 +138,8 @@ def quant_output_scale_impl( @property def requires_export_handler(self): return ( - self.is_input_quant_enabled or self.is_weight_quant_enabled or - self.is_bias_quant_enabled or self.is_output_quant_enabled) + self.quant_input.is_quant_enabled or self.quant_weight.is_quant_enabled or + self.quant_bias.is_quant_enabled or self.output_quant.is_quant_enabled) @property def per_elem_ops(self): # optional, so concrete impl + error if not overridden @@ -169,7 +169,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance( quant_weight, QuantTensor) if not (compute_output_quant_tensor or - self.is_output_quant_enabled) and self.return_quant_tensor: + self.quant_output.is_quant_enabled) and self.return_quant_tensor: raise RuntimeError("QuantLayer is not correctly configured") if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor): @@ -205,7 +205,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe output_tensor = self.inner_forward_impl( _unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None) - if not self.is_output_quant_enabled and self.return_quant_tensor: + if not self.quant_output.is_quant_enabled and self.return_quant_tensor: if compute_output_quant_tensor: if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any(): raise RuntimeError( diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 654b130f6..de71403f1 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -181,12 +181,6 @@ def signed(self, force_eval=True): return self._cached_act.signed elif self._cached_act is None: return None - current_status = self.training - if force_eval: - self.eval() - bit_width = self.__call__(self._zero_hw_sentinel()).bit_width - self.train(current_status) - return bit_width def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]: if self.fused_activation_quant_proxy is not None: diff --git a/tests/brevitas/nn/test_wbiol.py b/tests/brevitas/nn/test_wbiol.py index d3fc4bdce..beb95c1b2 100644 --- a/tests/brevitas/nn/test_wbiol.py +++ b/tests/brevitas/nn/test_wbiol.py @@ -68,19 +68,19 @@ def default_weight_tensor_quant(default_wbiol_layer): def test_default_wbiol_input_quant_enabled(default_wbiol_layer: QuantWBIOL): - assert not default_wbiol_layer.is_input_quant_enabled + assert not default_wbiol_layer.quant_input.is_quant_enabled def test_default_wbiol_output_quant_enabled(default_wbiol_layer: QuantWBIOL): - assert not default_wbiol_layer.is_output_quant_enabled + assert not default_wbiol_layer.quant_output.is_quant_enabled def test_default_wbiol_bias_quant_enabled(default_wbiol_layer: QuantWBIOL): - assert not default_wbiol_layer.is_bias_quant_enabled + assert not default_wbiol_layer.quant_bias.is_quant_enabled def test_default_wbiol_weight_quant_enabled(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_weight_quant_enabled + assert default_wbiol_layer.quant_weight.is_quant_enabled def test_default_wbiol_weight_bit_width_enabled(default_wbiol_layer: QuantWBIOL): @@ -92,7 +92,7 @@ def test_default_wbiol_return_quant(default_wbiol_layer: QuantWBIOL): def test_default_wbiol_quant_bias_signed(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_bias_signed is None + assert default_wbiol_layer.bias_quant.is_signed is None def test_default_wbiol_quant_weight_signed(default_wbiol_layer: QuantWBIOL): @@ -100,11 +100,11 @@ def test_default_wbiol_quant_weight_signed(default_wbiol_layer: QuantWBIOL): def test_default_wbiol_quant_bias_narrow_range(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_bias_narrow_range is None + assert default_wbiol_layer.bias_quant.is_narrow_range is None def test_default_wbiol_quant_weight_narrow_range(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_weight_narrow_range + assert default_wbiol_layer.quant_weight.is_narrow_range def test_default_wbiol_quant_input_signed(default_wbiol_layer: QuantWBIOL): @@ -116,11 +116,11 @@ def test_default_wbiol_quant_output_signed(default_wbiol_layer: QuantWBIOL): def test_default_wbiol_quant_input_narrow_range(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_input_narrow_range is None + assert default_wbiol_layer.input_quant.is_quant_enabled is None def test_default_wbiol_quant_output_narrow_range(default_wbiol_layer: QuantWBIOL): - assert default_wbiol_layer.is_quant_output_narrow_range is None + assert default_wbiol_layer.output_quant.narrow_range is None def test_default_wbiol_quant_input_zero_point(default_wbiol_layer: QuantWBIOL):