Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 4, 2024
1 parent aca8019 commit 56c5c7e
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 15 deletions.
16 changes: 8 additions & 8 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def channelwise_separable(self) -> bool:

@property
def requires_export_handler(self):
return self.quant_input.is_quant_enabled or self.act_quant.is_quant_enabled
return self.input_quant.is_quant_enabled or self.act_quant.is_quant_enabled

# @property
# def is_output_quant_enabled(self):
Expand Down Expand Up @@ -87,16 +87,16 @@ def __init__(
QuantOutputMixin.__init__(self, output_quant, **kwargs)
# we have to account for quantization being enabled through kwargs
if tie_input_output_quant:
if self.quant_input.is_quant_enabled and self.act_quant.is_quant_enabled:
if self.input_quant.is_quant_enabled and self.act_quant.is_quant_enabled:
raise RuntimeError("Enable only input or output quant with tie_input_output=True")
if self.quant_input.is_quant_enabled:
if self.input_quant.is_quant_enabled:
self.output_quant = self.input_quant
if self.act_quant.is_quant_enabled:
self.input_quant = self.output_quant

@property
def requires_export_handler(self):
return self.quant_input.is_quant_enabled or self.act_quant.is_quant_enabled
return self.input_quant.is_quant_enabled or self.act_quant.is_quant_enabled


class QuantWeightBiasInputOutputLayer(QuantBiasMixin, QuantWeightMixin, QuantInputOutputLayer):
Expand Down Expand Up @@ -138,8 +138,8 @@ def quant_output_scale_impl(
@property
def requires_export_handler(self):
return (
self.quant_input.is_quant_enabled or self.quant_weight.is_quant_enabled or
self.quant_bias.is_quant_enabled or self.output_quant.is_quant_enabled)
self.input_quant.is_quant_enabled or self.weight_quant.is_quant_enabled or
self.bias_quant.is_quant_enabled or self.output_quant.is_quant_enabled)

@property
def per_elem_ops(self): # optional, so concrete impl + error if not overridden
Expand Down Expand Up @@ -169,7 +169,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance(
quant_weight, QuantTensor)
if not (compute_output_quant_tensor or
self.quant_output.is_quant_enabled) and self.return_quant_tensor:
self.output_quant.is_quant_enabled) and self.return_quant_tensor:
raise RuntimeError("QuantLayer is not correctly configured")

if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor):
Expand Down Expand Up @@ -205,7 +205,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
output_tensor = self.inner_forward_impl(
_unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None)

if not self.quant_output.is_quant_enabled and self.return_quant_tensor:
if not self.output_quant.is_quant_enabled and self.return_quant_tensor:
if compute_output_quant_tensor:
if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any():
raise RuntimeError(
Expand Down
14 changes: 7 additions & 7 deletions tests/brevitas/nn/test_wbiol.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,19 @@ def default_weight_tensor_quant(default_wbiol_layer):


def test_default_wbiol_input_quant_enabled(default_wbiol_layer: QuantWBIOL):
assert not default_wbiol_layer.quant_input.is_quant_enabled
assert not default_wbiol_layer.input_quant.is_quant_enabled


def test_default_wbiol_output_quant_enabled(default_wbiol_layer: QuantWBIOL):
assert not default_wbiol_layer.quant_output.is_quant_enabled
assert not default_wbiol_layer.output_quant.is_quant_enabled


def test_default_wbiol_bias_quant_enabled(default_wbiol_layer: QuantWBIOL):
assert not default_wbiol_layer.quant_bias.is_quant_enabled
assert not default_wbiol_layer.bias_quant.is_quant_enabled


def test_default_wbiol_weight_quant_enabled(default_wbiol_layer: QuantWBIOL):
assert default_wbiol_layer.quant_weight.is_quant_enabled
assert default_wbiol_layer.weight_quant.is_quant_enabled


def test_default_wbiol_weight_bit_width_enabled(default_wbiol_layer: QuantWBIOL):
Expand All @@ -104,7 +104,7 @@ def test_default_wbiol_quant_bias_narrow_range(default_wbiol_layer: QuantWBIOL):


def test_default_wbiol_quant_weight_narrow_range(default_wbiol_layer: QuantWBIOL):
assert default_wbiol_layer.quant_weight.is_narrow_range
assert default_wbiol_layer.weight_quant.is_narrow_range


def test_default_wbiol_quant_input_signed(default_wbiol_layer: QuantWBIOL):
Expand All @@ -116,11 +116,11 @@ def test_default_wbiol_quant_output_signed(default_wbiol_layer: QuantWBIOL):


def test_default_wbiol_quant_input_narrow_range(default_wbiol_layer: QuantWBIOL):
assert default_wbiol_layer.input_quant.is_quant_enabled is None
assert default_wbiol_layer.input_quant.is_narrow_range is None


def test_default_wbiol_quant_output_narrow_range(default_wbiol_layer: QuantWBIOL):
assert default_wbiol_layer.output_quant.narrow_range is None
assert default_wbiol_layer.output_quant.is_narrow_range is None


def test_default_wbiol_quant_input_zero_point(default_wbiol_layer: QuantWBIOL):
Expand Down

0 comments on commit 56c5c7e

Please sign in to comment.