Skip to content

Commit

Permalink
first working flow end to end
Browse files Browse the repository at this point in the history
  • Loading branch information
costigt-dev committed Apr 12, 2024
1 parent 3f14f53 commit 98bc90f
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 36 deletions.
6 changes: 0 additions & 6 deletions src/brevitas/export/common/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,12 +39,6 @@ def quant_axis(cls, scale):
return i
return None

class FloatClipMixin(ABC):
@classmethod
def clip_symbolic_kwargs(cls, narrow, signed, exponent_bit_width, mantissa_bit_width):
warn("Not implemented for floating point")
return None

class ClipMixin(ABC):

@classmethod
Expand Down
50 changes: 41 additions & 9 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

from .base import BitWidthHandlerMixin
from .base import ClipMixin
from .base import FloatClipMixin
from .base import QuantAxisMixin
from .base import ZeroPointHandlerMixin
from .base import FloatZeroPointHandlerMixin
Expand Down Expand Up @@ -129,7 +128,7 @@ def quantize_fn(self, x, dtype):
pass


class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, FloatClipMixin, FloatZeroPointHandlerMixin, CDQCastMixin, ABC):
class FloatCDQCastProxyHandlerMixin(QuantAxisMixin, FloatZeroPointHandlerMixin, CDQCastMixin, ABC):

def dequantize_symbolic_kwargs(cls, scale, zero_point, exponent_bit_width, mantisssa_bit_width, is_signed):
scale_orig_shape = scale.shape
Expand Down Expand Up @@ -221,22 +220,55 @@ def prepare_for_export(self, module):
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)

self.symbolic_kwargs['bit_width'] = quant_weight.bit_width
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.clip_symbolic_kwargs(
module.is_narrow_range, module.is_signed, quant_weight.exponent_bit_width, quant_weight.mantissa_bit_width)
self.symbolic_kwargs['exponent_bit_width'] = quant_weight.exponent_bit_width
self.symbolic_kwargs['mantissa_bit_width'] = quant_weight.mantissa_bit_width
self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
scale, quant_weight.zero_point, quant_weight.exponent_bit_width, quant_weight.mantissa_bit_width, module.is_signed)
else:
self.symbolic_kwargs = None

def quantize_from_floating_point(self, x: Tensor):
raise NotImplementedError()
def quantize_from_floating_point(self, x: Tensor, zp):
quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs']
quantize_symbolic_kwargs['zero_point'] = zp
# Before quantization, cast input to float32
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, torch.float32)
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
return x

def quantize_from_integer(self, x: Tensor):
raise NotImplementedError()
int_weights = {
tm.weight.data_ptr(): tm.quant_weight().int(float_datatype=False)
for tm in module.tracked_module_list}
self.symbolic_kwargs['int_weights'] = int_weights

def symbolic_execution(self, x: Tensor):
raise NotImplementedError()
assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled'

# Copy dict to allow for popping kwargs even on shared quantizers
dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs'])
scale = dequantize_symbolic_kwargs['scale']
zero_point = dequantize_symbolic_kwargs['zero_point']

if self._export_q_node:
x = self.quantize_from_floating_point(x, zero_point)
else:
x = self.quantize_from_integer(x)

exponent_bit_width = self.symbolic_kwargs['exponent_bit_width']
mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width']
scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape')
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width)
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
# After dequantization, cast both input and scale to the correct dtype
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, self.scale_dtype)
scale = self.cast_fn(scale, self.scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
return x, scale, zero_point, exponent_bit_width, mantissa_bit_width

class QCDQCastWeightQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin):
handled_layer = WeightQuantProxyFromInjector
Expand Down
25 changes: 4 additions & 21 deletions src/brevitas/export/onnx/standard/qcdq/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,25 +30,6 @@

from warnings import warn


class StdFloatDQCastONNXMixin(DQCastMixin, ABC):

def dequantize_fn(self, x, scale, zero_point, axis):
raise NotImplementedError()

def cast_fn(self, x, dtype):
raise NotImplementedError()

@property
def flatten_dequantize_params(self):
raise NotImplementedError()

@property
def itemize_quantize_scalar_params(self):
return False

def validate(self, module):
raise NotImplementedError()

class StdDQCastONNXMixin(DQCastMixin, ABC):

Expand All @@ -69,6 +50,9 @@ def itemize_quantize_scalar_params(self):
def validate(self, module):
assert module.bit_width() > 1., 'Binary quant not supported'

class StdFloatDQCastONNXMixin(StdDQCastONNXMixin, ABC):
def validate(self, module):
pass

class StdFloatCDQCastONNXMixin(CDQCastMixin, StdFloatDQCastONNXMixin, ABC):

Expand All @@ -82,11 +66,10 @@ def clip_fn(self, x, min_val, max_val):

class StdFloatQCDQCastONNXMixin(FloatQMixin, StdFloatCDQCastONNXMixin, ABC):
def validate(self, module):
warn("Needs to be implemented")
pass

def quantize_fn(self, x, scale, zero_point, dtype, axis):
raise NotImplementedError()
return QuantizeLinearFn.apply(x, scale, zero_point, dtype, axis)

class StdQCDQCastONNXMixin(QMixin, StdCDQCastONNXMixin, ABC):

Expand Down

0 comments on commit 98bc90f

Please sign in to comment.