Skip to content

Commit

Permalink
W&B properties removed
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 5, 2024
1 parent 2e1836e commit 57f66b1
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 93 deletions.
4 changes: 2 additions & 2 deletions docs/tutorials/quant_tensor_quant_conv2d_overview.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,8 @@
}
],
"source": [
"print(f'Is weight quant enabled: {default_quant_conv.is_weight_quant_enabled}')\n",
"print(f'Is bias quant enabled: {default_quant_conv.is_bias_quant_enabled}')\n",
"print(f'Is weight quant enabled: {default_quant_conv.weight_quant.is_quant_enabled}')\n",
"print(f'Is bias quant enabled: {default_quant_conv.bias_quant.is_quant_enabled}')\n",
"print(f'Is input quant enabled: {default_quant_conv.is_input_quant_enabled}')\n",
"print(f'Is output quant enabled: {default_quant_conv.output_quant.is_quant_enabled}')"
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def int_bias(module: QuantWBIOL):

@classmethod
def validate(cls, module: QuantWBIOL, requires_quant_bias=True):
assert module.is_weight_quant_enabled, 'Weight quant required'
assert module.weight_quant.is_quant_enabled, 'Weight quant required'
assert module.output_quant.is_quant_enabled, 'Output quant required'
# Handling narrow_range is across the network is difficult do to the fact that
# it's not part of QuantTensor, and so it can't be cached
Expand All @@ -52,8 +52,8 @@ def validate(cls, module: QuantWBIOL, requires_quant_bias=True):
cls.validate_8b_bit_width(module.input_quant.bit_width(), le_then=True)
cls.validate_8b_bit_width(module.output_quant.bit_width(), le_then=True)
if module.bias is not None and requires_quant_bias:
assert module.is_bias_quant_enabled
assert module.is_quant_bias_signed
assert module.bias_quant.is_quant_enabled
assert module.bias_quant.is_signed
cls.validate_32b_bit_width(module.bias_quant.bit_width(), le_then=True)

def prepare_for_export(self, module: Union[QuantConv1d, QuantConv2d]):
Expand Down
2 changes: 1 addition & 1 deletion src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def __init__(
@property
def layer_requires_input_quant(self):
# some weight quantizers require a quant input (e.g., A2Q)
check_1 = self.layer.weight_quant_requires_quant_input
check_1 = self.layer.weight_quant.requires_quant_input
# if input_quant is enabled, then we will store its information
check_2 = self.layer.input_quant.is_quant_enabled
# GPFA2Q requires the quantized input to be stored
Expand Down
51 changes: 0 additions & 51 deletions src/brevitas/nn/mixin/act.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,19 +32,6 @@ def __init__(self, input_quant: Optional[ActQuantType], **kwargs):
input_passthrough_act=True,
**kwargs)

# @property
# def is_input_quant_enabled(self):
# return self.input_quant.is_quant_enabled

# @property
# def is_quant_input_narrow_range(self): # TODO make abstract once narrow range can be cached
# return self.input_quant.is_narrow_range

# @property
# @abstractmethod
# def is_quant_input_signed(self):
# pass


class QuantOutputMixin(QuantProxyMixin):
__metaclass__ = ABCMeta
Expand All @@ -61,19 +48,6 @@ def __init__(self, output_quant: Optional[ActQuantType], **kwargs):
output_passthrough_act=True,
**kwargs)

# @property
# def is_output_quant_enabled(self):
# return self.output_quant.is_quant_enabled

# @property
# def is_quant_output_narrow_range(self): # TODO make abstract once narrow range can be cached
# return self.output_quant.is_narrow_range

# @property
# @abstractmethod
# def is_quant_output_signed(self):
# pass


class QuantNonLinearActMixin(QuantProxyMixin):
__metaclass__ = ABCMeta
Expand All @@ -98,28 +72,3 @@ def __init__(
none_quant_injector=NoneActQuant,
**prefixed_kwargs,
**kwargs)

# @property
# def is_act_quant_enabled(self):
# return self.act_quant.is_quant_enabled

# @property
# def is_quant_act_narrow_range(self): # TODO make abstract once narrow range can be cached
# return self.act_quant.is_narrow_range

# @property
# @abstractmethod
# def is_quant_act_signed(self):
# pass

# @abstractmethod
# def quant_act_scale(self):
# pass

# @abstractmethod
# def quant_act_zero_point(self):
# pass

# @abstractmethod
# def quant_act_bit_width(self):
# pass
2 changes: 1 addition & 1 deletion src/brevitas/nn/mixin/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe
cached_inp = _CachedIO(inp.detach(), self.cache_quant_io_metadata_only)
if hasattr(self, 'input_quant'):
self.input_quant._cached_act = cached_inp
if hasattr(self, 'weight_quant') and self.weight_quant_requires_quant_input:
if hasattr(self, 'weight_quant') and self.weight_quant.requires_quant_input:
self.weight_quant._cached_act = cached_inp
if not torch._C._get_tracing_state():
if isinstance(inp, QuantTensor):
Expand Down
38 changes: 3 additions & 35 deletions src/brevitas/nn/mixin/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,22 +43,6 @@ def __init__(self, weight_quant: Optional[WeightQuantType], **kwargs):
def output_channel_dim(self) -> int:
pass

@property
def is_weight_quant_enabled(self):
return self.weight_quant.is_quant_enabled

@property
def is_quant_weight_narrow_range(self):
return self.weight_quant.is_narrow_range

# @property
# def is_quant_weight_signed(self):
# return self.weight_quant.is_signed

@property
def weight_quant_requires_quant_input(self):
return self.weight_quant.requires_quant_input

def quant_weight(
self,
quant_input: Optional[QuantTensor] = None,
Expand Down Expand Up @@ -86,10 +70,10 @@ def quant_weight(
slice(*s) if s is not None else slice(s) for s in subtensor_slice_list)
else:
weight_slice_tuple = slice(None)
if self.weight_quant_requires_quant_input:
if self.weight_quant.requires_quant_input:
input_bit_width = None
input_is_signed = None
if self.is_weight_quant_enabled:
if self.weight_quant.is_quant_enabled:
if quant_input is None:
input_bit_width = self.input_quant.bit_width()
input_is_signed = self.input_quant.signed()
Expand Down Expand Up @@ -138,24 +122,8 @@ def __init__(
**kwargs)
self.cache_inference_quant_bias = cache_inference_bias

@property
def is_bias_quant_enabled(self):
return self.bias_quant.is_quant_enabled

@property
def is_quant_bias_narrow_range(self):
if self.bias is None:
return None
return self.bias_quant.is_narrow_range

@property
def is_quant_bias_signed(self):
if self.bias is None or not self.is_bias_quant_enabled:
return None
return self.bias_quant.is_signed

def int_bias(self, float_datatype=False):
if self.bias is None or not self.is_bias_quant_enabled:
if self.bias is None or not self.bias_quant.is_quant_enabled:
return None
quant_bias = self.quant_bias()
return quant_bias.int(float_datatype=float_datatype)
Expand Down

0 comments on commit 57f66b1

Please sign in to comment.