Skip to content

Commit

Permalink
removed more attributes
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 4, 2024
1 parent 9e72635 commit aca8019
Show file tree
Hide file tree
Showing 11 changed files with 61 additions and 67 deletions.
2 changes: 1 addition & 1 deletion docs/tutorials/quant_tensor_quant_conv2d_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@
"print(f'Is weight quant enabled: {default_quant_conv.is_weight_quant_enabled}')\n",
"print(f'Is bias quant enabled: {default_quant_conv.is_bias_quant_enabled}')\n",
"print(f'Is input quant enabled: {default_quant_conv.is_input_quant_enabled}')\n",
"print(f'Is output quant enabled: {default_quant_conv.is_output_quant_enabled}')"
"print(f'Is output quant enabled: {default_quant_conv.output_quant.is_quant_enabled}')"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion notebooks/01_quant_tensor_quant_conv2d_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@
"print(f'Is weight quant enabled: {default_quant_conv.is_weight_quant_enabled}')\n",
"print(f'Is bias quant enabled: {default_quant_conv.is_bias_quant_enabled}')\n",
"print(f'Is input quant enabled: {default_quant_conv.is_input_quant_enabled}')\n",
"print(f'Is output quant enabled: {default_quant_conv.is_output_quant_enabled}')"
"print(f'Is output quant enabled: {default_quant_conv.output_quant.is_quant_enabled}')"
]
},
{
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/export/onnx/standard/qoperator/handler/act.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@ class StdQOpONNXQuantNLALHandler(StdQOpONNXQuantLayerHandler, ABC):
@classmethod
def validate(cls, module: QuantNLAL):
if cls.input_quant_supported and module.is_input_quant_enabled:
assert not module.is_quant_input_narrow_range, "Narrow range quant not supported."
assert not module.input_quant.is_quant_enabled, "Narrow range quant not supported."
elif not cls.input_quant_supported and module.is_input_quant_enabled:
raise RuntimeError("Input quant not supported.")
if module.is_act_quant_enabled:
assert not module.is_quant_act_narrow_range, "Narrow range quant not supported."
if module.act_quant.is_quant_enabled:
assert not module.act_quant.is_narrow_range, "Narrow range quant not supported."
input_bit_width = module.input_quant.bit_width()
act_bit_width = module.act_quant.bit_width()
if input_bit_width is not None:
Expand Down
16 changes: 8 additions & 8 deletions src/brevitas/export/onnx/standard/qoperator/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@ def quant_output_shape(cls, module):

@classmethod
def output_quant_symbolic_kwargs(cls, module):
if module.is_output_quant_enabled:
quant_proxy = module.act_quant if isinstance(
module, QuantNonLinearActLayer) else module.output_quant
quant_proxy = module.act_quant if isinstance(
module, QuantNonLinearActLayer) else module.output_quant
if quant_proxy.is_quant_enabled:
return {
'output_scale': quant_proxy.scale(),
'output_zero_point': cls.quant_output_zero_point(module),
Expand All @@ -69,10 +69,10 @@ def output_quant_symbolic_kwargs(cls, module):

@classmethod
def output_clip_symbolic_kwargs(cls, module):
if module.is_output_quant_enabled:
quant_proxy = module.act_quant if isinstance(
module, QuantNonLinearActLayer) else module.output_quant
narrow = module.is_quant_output_narrow_range
quant_proxy = module.act_quant if isinstance(
module, QuantNonLinearActLayer) else module.output_quant
if quant_proxy.is_quant_enabled:
narrow = quant_proxy.is_narrow_range
signed = quant_proxy.signed()
bit_width = quant_proxy.bit_width()
return cls.int_clip_symbolic_kwargs(narrow, signed, bit_width)
Expand All @@ -82,7 +82,7 @@ def output_clip_symbolic_kwargs(cls, module):
@classmethod
def input_clip_symbolic_kwargs(cls, module):
if module.is_input_quant_enabled:
narrow = module.is_quant_input_narrow_range
narrow = module.input_quant.is_quant_enabled
signed = module.input_quant.signed()
bit_width = module.input_quant.bit_width()
return cls.int_clip_symbolic_kwargs(narrow, signed, bit_width)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,12 @@ def int_bias(module: QuantWBIOL):
@classmethod
def validate(cls, module: QuantWBIOL, requires_quant_bias=True):
assert module.is_weight_quant_enabled, 'Weight quant required'
assert module.is_output_quant_enabled, 'Output quant required'
assert module.output_quant.is_quant_enabled, 'Output quant required'
# Handling narrow_range is across the network is difficult do to the fact that
# it's not part of QuantTensor, and so it can't be cached
assert not module.is_quant_output_narrow_range, 'Narrow output quant not supported'
assert not module.output_quant.is_narrow_range, 'Narrow output quant not supported'
if module.is_input_quant_enabled:
assert not module.is_quant_input_narrow_range, 'Narrow output quant not supported'
assert not module.input_quant.is_quant_enabled, 'Narrow output quant not supported'
cls.validate_8b_bit_width(module.weight_quant.bit_width(), le_then=True)
cls.validate_8b_bit_width(module.input_quant.bit_width(), le_then=True)
cls.validate_8b_bit_width(module.output_quant.bit_width(), le_then=True)
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/export/torch/qoperator/handler/act.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def validate(cls, module: QuantNLAL):
def prepare_for_export(self, module: QuantNLAL):
self.validate(module)
self.qf_impl, self.qf_kwargs = self.prepare_qf(module)
if module.is_act_quant_enabled:
if module.act_quant.is_quant_enabled:
self.output_quant_impl, self.output_quant_kwargs = self.prepare_output_quant(module)
self.return_quant_tensor = module.return_quant_tensor

Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/export/torch/qoperator/handler/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def __init__(self):

@classmethod
def validate(cls, module: QuantWBIOL):
assert module.is_weight_quant_enabled, 'Weight quantization required'
assert module.is_output_quant_enabled, 'Output quantization required'
assert module.quant_weight.is_quant_enabled, 'Weight quantization required'
assert module.quant_output.is_quant_enabled, 'Output quantization required'

@classmethod
def prepare_bias(cls, module: QuantWBIOL):
Expand Down
36 changes: 18 additions & 18 deletions src/brevitas/nn/mixin/act.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ def __init__(self, input_quant: Optional[ActQuantType], **kwargs):
input_passthrough_act=True,
**kwargs)

@property
def is_input_quant_enabled(self):
return self.input_quant.is_quant_enabled
# @property
# def is_input_quant_enabled(self):
# return self.input_quant.is_quant_enabled

@property
def is_quant_input_narrow_range(self): # TODO make abstract once narrow range can be cached
return self.input_quant.is_narrow_range
# @property
# def is_quant_input_narrow_range(self): # TODO make abstract once narrow range can be cached
# return self.input_quant.is_narrow_range

# @property
# @abstractmethod
Expand All @@ -61,13 +61,13 @@ def __init__(self, output_quant: Optional[ActQuantType], **kwargs):
output_passthrough_act=True,
**kwargs)

@property
def is_output_quant_enabled(self):
return self.output_quant.is_quant_enabled
# @property
# def is_output_quant_enabled(self):
# return self.output_quant.is_quant_enabled

@property
def is_quant_output_narrow_range(self): # TODO make abstract once narrow range can be cached
return self.output_quant.is_narrow_range
# @property
# def is_quant_output_narrow_range(self): # TODO make abstract once narrow range can be cached
# return self.output_quant.is_narrow_range

# @property
# @abstractmethod
Expand Down Expand Up @@ -99,13 +99,13 @@ def __init__(
**prefixed_kwargs,
**kwargs)

@property
def is_act_quant_enabled(self):
return self.act_quant.is_quant_enabled
# @property
# def is_act_quant_enabled(self):
# return self.act_quant.is_quant_enabled

@property
def is_quant_act_narrow_range(self): # TODO make abstract once narrow range can be cached
return self.act_quant.is_narrow_range
# @property
# def is_quant_act_narrow_range(self): # TODO make abstract once narrow range can be cached
# return self.act_quant.is_narrow_range

# @property
# @abstractmethod
Expand Down
30 changes: 15 additions & 15 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ def channelwise_separable(self) -> bool:

@property
def requires_export_handler(self):
return self.is_input_quant_enabled or self.is_act_quant_enabled
return self.quant_input.is_quant_enabled or self.act_quant.is_quant_enabled

@property
def is_output_quant_enabled(self):
return self.is_act_quant_enabled
# @property
# def is_output_quant_enabled(self):
# return self.act_quant.is_quant_enabled

@property
def is_quant_output_narrow_range(self):
return self.is_quant_act_narrow_range
# @property
# def is_quant_output_narrow_range(self):
# return self.act_quant.is_narrow_range

def forward(self, input: Union[Tensor, QuantTensor]):
input = self.unpack_input(input)
Expand Down Expand Up @@ -87,16 +87,16 @@ def __init__(
QuantOutputMixin.__init__(self, output_quant, **kwargs)
# we have to account for quantization being enabled through kwargs
if tie_input_output_quant:
if self.is_input_quant_enabled and self.is_output_quant_enabled:
if self.quant_input.is_quant_enabled and self.act_quant.is_quant_enabled:
raise RuntimeError("Enable only input or output quant with tie_input_output=True")
if self.is_input_quant_enabled:
if self.quant_input.is_quant_enabled:
self.output_quant = self.input_quant
if self.is_output_quant_enabled:
if self.act_quant.is_quant_enabled:
self.input_quant = self.output_quant

@property
def requires_export_handler(self):
return self.is_input_quant_enabled or self.is_output_quant_enabled
return self.quant_input.is_quant_enabled or self.act_quant.is_quant_enabled


class QuantWeightBiasInputOutputLayer(QuantBiasMixin, QuantWeightMixin, QuantInputOutputLayer):
Expand Down Expand Up @@ -138,8 +138,8 @@ def quant_output_scale_impl(
@property
def requires_export_handler(self):
return (
self.is_input_quant_enabled or self.is_weight_quant_enabled or
self.is_bias_quant_enabled or self.is_output_quant_enabled)
self.quant_input.is_quant_enabled or self.quant_weight.is_quant_enabled or
self.quant_bias.is_quant_enabled or self.output_quant.is_quant_enabled)

@property
def per_elem_ops(self): # optional, so concrete impl + error if not overridden
Expand Down Expand Up @@ -169,7 +169,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
compute_output_quant_tensor = isinstance(quant_input, QuantTensor) and isinstance(
quant_weight, QuantTensor)
if not (compute_output_quant_tensor or
self.is_output_quant_enabled) and self.return_quant_tensor:
self.quant_output.is_quant_enabled) and self.return_quant_tensor:
raise RuntimeError("QuantLayer is not correctly configured")

if isinstance(quant_input, QuantTensor) and isinstance(quant_weight, QuantTensor):
Expand Down Expand Up @@ -205,7 +205,7 @@ def forward_impl(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
output_tensor = self.inner_forward_impl(
_unpack_quant_tensor(quant_input), _unpack_quant_tensor(quant_weight), None)

if not self.is_output_quant_enabled and self.return_quant_tensor:
if not self.quant_output.is_quant_enabled and self.return_quant_tensor:
if compute_output_quant_tensor:
if (quant_input.zero_point != 0.0).any() or (quant_weight.zero_point != 0.0).any():
raise RuntimeError(
Expand Down
6 changes: 0 additions & 6 deletions src/brevitas/proxy/runtime_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,12 +181,6 @@ def signed(self, force_eval=True):
return self._cached_act.signed
elif self._cached_act is None:
return None
current_status = self.training
if force_eval:
self.eval()
bit_width = self.__call__(self._zero_hw_sentinel()).bit_width
self.train(current_status)
return bit_width

def forward(self, x: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTensor]:
if self.fused_activation_quant_proxy is not None:
Expand Down
18 changes: 9 additions & 9 deletions tests/brevitas/nn/test_wbiol.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,19 +68,19 @@ def default_weight_tensor_quant(default_wbiol_layer):


def test_default_wbiol_input_quant_enabled(default_wbiol_layer: QuantWBIOL):
assert not default_wbiol_layer.is_input_quant_enabled
assert not default_wbiol_layer.quant_input.is_quant_enabled


def test_default_wbiol_output_quant_enabled(default_wbiol_layer: QuantWBIOL):
assert not default_wbiol_layer.is_output_quant_enabled
assert not default_wbiol_layer.quant_output.is_quant_enabled


def test_default_wbiol_bias_quant_enabled(default_wbiol_layer: QuantWBIOL):
assert not default_wbiol_layer.is_bias_quant_enabled
assert not default_wbiol_layer.quant_bias.is_quant_enabled


def test_default_wbiol_weight_quant_enabled(default_wbiol_layer: QuantWBIOL):
assert default_wbiol_layer.is_weight_quant_enabled
assert default_wbiol_layer.quant_weight.is_quant_enabled


def test_default_wbiol_weight_bit_width_enabled(default_wbiol_layer: QuantWBIOL):
Expand All @@ -92,19 +92,19 @@ def test_default_wbiol_return_quant(default_wbiol_layer: QuantWBIOL):


def test_default_wbiol_quant_bias_signed(default_wbiol_layer: QuantWBIOL):
assert default_wbiol_layer.is_quant_bias_signed is None
assert default_wbiol_layer.bias_quant.is_signed is None


def test_default_wbiol_quant_weight_signed(default_wbiol_layer: QuantWBIOL):
assert default_wbiol_layer.weight_quant.is_signed


def test_default_wbiol_quant_bias_narrow_range(default_wbiol_layer: QuantWBIOL):
assert default_wbiol_layer.is_quant_bias_narrow_range is None
assert default_wbiol_layer.bias_quant.is_narrow_range is None


def test_default_wbiol_quant_weight_narrow_range(default_wbiol_layer: QuantWBIOL):
assert default_wbiol_layer.is_quant_weight_narrow_range
assert default_wbiol_layer.quant_weight.is_narrow_range


def test_default_wbiol_quant_input_signed(default_wbiol_layer: QuantWBIOL):
Expand All @@ -116,11 +116,11 @@ def test_default_wbiol_quant_output_signed(default_wbiol_layer: QuantWBIOL):


def test_default_wbiol_quant_input_narrow_range(default_wbiol_layer: QuantWBIOL):
assert default_wbiol_layer.is_quant_input_narrow_range is None
assert default_wbiol_layer.input_quant.is_quant_enabled is None


def test_default_wbiol_quant_output_narrow_range(default_wbiol_layer: QuantWBIOL):
assert default_wbiol_layer.is_quant_output_narrow_range is None
assert default_wbiol_layer.output_quant.narrow_range is None


def test_default_wbiol_quant_input_zero_point(default_wbiol_layer: QuantWBIOL):
Expand Down

0 comments on commit aca8019

Please sign in to comment.