Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding FP8 weight export #907

Merged
merged 46 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
5b168c5
placeholder version
costigt-dev Apr 11, 2024
d2b7d2d
checkpoint commit
costigt-dev Apr 12, 2024
e10e630
first working flow end to end
costigt-dev Apr 12, 2024
84e70f7
formatting
costigt-dev Apr 12, 2024
ef4c737
changes to tests
costigt-dev Apr 12, 2024
4aa4b21
added version check for test
costigt-dev Apr 12, 2024
3b05883
using existing functionality over homespun
costigt-dev Apr 12, 2024
cad5802
corrected mistake in copying and restored FloatClipMixin
costigt-dev Apr 12, 2024
4848248
fixed mistake
costigt-dev Apr 12, 2024
5188aa6
first pass activation fp8 export
costigt-dev Apr 16, 2024
29cb952
beginnings of activation fp8 export and change name of QCDQCastFloatW…
costigt-dev Apr 16, 2024
9bf9240
more changes to make naming scheme more consistent
costigt-dev Apr 16, 2024
f9406f1
added FloatFusedActivationQuantProxy
costigt-dev Apr 16, 2024
991ddb7
replaced zero_point workaround with placeholder implementation of fp8…
costigt-dev Apr 17, 2024
520db85
removed verbose flag
costigt-dev Apr 17, 2024
2bb2895
created context manager for fp8 workaround
costigt-dev Apr 17, 2024
8ffce48
added check that objects being compared are tensors in the fp8 workar…
costigt-dev Apr 17, 2024
7edf5bd
General equal implementation
Giuseppe5 May 14, 2024
bbd5362
fallback to fp32 if fp8
Giuseppe5 May 14, 2024
4bc126d
Fix for PT < 2.1
Giuseppe5 May 14, 2024
a55dcd0
Remove non existent destroy
Giuseppe5 May 14, 2024
cd6cad6
Merge branch 'dev' into feat/export_fp8
Giuseppe5 May 23, 2024
fabc8ae
Remove import
Giuseppe5 May 23, 2024
74b65a9
Fixed imports
Giuseppe5 May 23, 2024
cf1ea02
Fixed imports
Giuseppe5 May 23, 2024
cda7f1f
Fix export
Giuseppe5 May 23, 2024
8349391
more testing
Giuseppe5 May 23, 2024
11387d3
Fix
Giuseppe5 May 24, 2024
592ccd3
Fix
Giuseppe5 May 24, 2024
1fc5642
fix
Giuseppe5 May 25, 2024
58f46bc
Fix minifloat check
Giuseppe5 May 25, 2024
bd657b8
Last fix
Giuseppe5 May 25, 2024
630a3e3
Fix minifloat
Giuseppe5 May 27, 2024
38a37fb
Review
Giuseppe5 May 28, 2024
76b3193
Review 2
Giuseppe5 May 28, 2024
529470f
Merge branch 'dev' into feat/export_fp8
Giuseppe5 May 28, 2024
f2f8969
fix
Giuseppe5 May 28, 2024
44579f8
Typo
Giuseppe5 May 28, 2024
038cba9
fix tests
Giuseppe5 May 28, 2024
198c5af
Typo
Giuseppe5 May 28, 2024
c3d7d3c
fix
Giuseppe5 May 28, 2024
fef531d
last fix
Giuseppe5 May 28, 2024
6431882
Fix JIT
Giuseppe5 May 29, 2024
4b78543
Fix import
Giuseppe5 May 29, 2024
d762c99
Last fix
Giuseppe5 May 29, 2024
ac5e58c
correct skip
Giuseppe5 May 29, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion src/brevitas/export/common/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC
from abc import abstractmethod
import math
from warnings import warn

import torch
from torch import Tensor
Expand All @@ -12,7 +13,8 @@
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int

__all__ = ['BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin']
__all__ = [
'BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin', 'FloatZeroPointHandlerMixin']


class BaseHandler(Module, ABC):
Expand All @@ -38,6 +40,13 @@ def quant_axis(cls, scale):
return None


class FloatClipMixin(ABC):

@classmethod
def clip_symbolic_kwargs(cls, narrow, signed, exponent_bit_width, mantissa_bit_width):
return None


class ClipMixin(ABC):

@classmethod
Expand Down Expand Up @@ -112,6 +121,18 @@ def validate_neg_scalar_int_exponent(cls, scale: Tensor):
return -cls.validate_scalar_int_exponent(scale)


class FloatZeroPointHandlerMixin(ABC):

@classmethod
def zero_point_with_dtype(cls, signed, exponent_bit_width, mantissa_bit_width, zero_point):
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
if exponent_bit_width == 4 and mantissa_bit_width == 3:
return zero_point.type(torch.float8_e4m3fn)
elif exponent_bit_width == 5 and mantissa_bit_width == 2:
return zero_point.type(torch.float8_e5m2)
else:
return zero_point.type(torch.float32)


class ZeroPointHandlerMixin(ABC):

@classmethod
Expand Down
266 changes: 266 additions & 0 deletions src/brevitas/export/common/handler/qcdq.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,20 @@

from brevitas.export.common import to_0dim_if_scalar
from brevitas.export.common import to_item_if_0dim
from brevitas.proxy import ActFloatQuantProxyFromInjector
from brevitas.proxy import ActQuantProxyFromInjector
from brevitas.proxy import BiasQuantProxyFromInjector
from brevitas.proxy import DecoupledWeightQuantProxyFromInjector
from brevitas.proxy import DecoupledWeightQuantWithInputProxyFromInjector
from brevitas.proxy import WeightFloatQuantProxyFromInjector
from brevitas.proxy import WeightQuantProxyFromInjector
from brevitas.proxy.runtime_quant import DynamicActQuantProxyFromInjector
from brevitas.proxy.runtime_quant import TruncQuantProxyFromInjector

from .base import BitWidthHandlerMixin
from .base import ClipMixin
from .base import FloatClipMixin
from .base import FloatZeroPointHandlerMixin
from .base import QuantAxisMixin
from .base import ZeroPointHandlerMixin

Expand Down Expand Up @@ -66,6 +70,25 @@ def clip_fn(self, x, min_val, max_val):
pass


class FloatQMixin(ABC):

@abstractmethod
def quantize_fn(self, x, scale, zero_point, dtype, axis):
pass

@classmethod
def signed_dtype(cls, exponent_bit_width, mantissa_bit_width, is_signed):
if exponent_bit_width is None or mantissa_bit_width is None:
return None
if exponent_bit_width == 4 and mantissa_bit_width == 3:
dtype = torch.float8_e4m3fn
elif exponent_bit_width == 5 and mantissa_bit_width == 2:
dtype = torch.float8_e5m2
else:
dtype = torch.float32
return dtype


class QMixin(BitWidthHandlerMixin, ABC):

@classmethod
Expand Down Expand Up @@ -110,6 +133,34 @@ def quantize_fn(self, x, dtype):
pass


class FloatCDQCastProxyHandlerMixin(QuantAxisMixin,
FloatClipMixin,
FloatZeroPointHandlerMixin,
CDQCastMixin,
ABC):

def dequantize_symbolic_kwargs(
cls, scale, zero_point, exponent_bit_width, mantisssa_bit_width, is_signed):
scale_orig_shape = scale.shape
axis = cls.quant_axis(scale)
if cls.flatten_dequantize_params:
scale = scale.flatten()
scale = to_0dim_if_scalar(scale)
if cls.flatten_dequantize_params:
zero_point = zero_point.flatten()
zp = to_0dim_if_scalar(zero_point)
zp = zp.expand_as(scale)
zp = cls.zero_point_with_dtype(is_signed, exponent_bit_width, mantisssa_bit_width, zp)
return {
'scale': scale,
'zero_point': zp,
'axis': axis,
# We save only the scale original shape
# as zero-point is being expanded to the same
# size as the scale
'scale_orig_shape': scale_orig_shape}


class CDQCastProxyHandlerMixin(QuantAxisMixin, ClipMixin, ZeroPointHandlerMixin, CDQCastMixin, ABC):

def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
Expand All @@ -133,6 +184,128 @@ def dequantize_symbolic_kwargs(cls, scale, zero_point, bit_width, is_signed):
'scale_orig_shape': scale_orig_shape}


class FloatQCDQCastWeightQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin):
handled_layer = WeightFloatQuantProxyFromInjector

def quantize_symbolic_kwargs(
cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_signed):
# compute axis before redefining scale
axis = cls.quant_axis(scale)
scale = to_0dim_if_scalar(scale.flatten())
zp = to_0dim_if_scalar(zero_point.flatten())
# expand_as must go after 0-dim check
zp = zp.expand_as(scale)
zp = cls.zero_point_with_dtype(is_signed, exponent_bit_width, mantissa_bit_width, zp)
if cls.itemize_quantize_scalar_params:
scale = to_item_if_0dim(scale)
zp = to_item_if_0dim(zp)
dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width, is_signed)
return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis}

def prepare_quantize_from_floating_point(self, module):
quant_weight = module.tracked_module_list[0].quant_weight()
scale = quant_weight.scale
self.scale_dtype = scale.dtype
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)
self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs(
scale,
quant_weight.zero_point,
quant_weight.exponent_bit_width,
quant_weight.mantissa_bit_width,
module.is_signed)

def prepare_quantize_from_integer(self, module):
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
int_weights = {
tm.weight.data_ptr(): tm.quant_weight().int(float_datatype=False)
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
for tm in module.tracked_module_list}
self.symbolic_kwargs['int_weights'] = int_weights

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
if self._export_q_node:
self.prepare_quantize_from_floating_point(module)
else:
self.prepare_quantize_from_integer(module)
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
# Get the first quant weight as representative
quant_weight = module.tracked_module_list[0].quant_weight()

# (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype
# of the scale and we cast it to float32.
# The original dtype is then restored during the forward pass
scale = quant_weight.scale
self.scale_dtype = scale.dtype
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)

self.symbolic_kwargs['exponent_bit_width'] = quant_weight.exponent_bit_width
self.symbolic_kwargs['mantissa_bit_width'] = quant_weight.mantissa_bit_width
self.symbolic_kwargs['exponent_bias'] = quant_weight.exponent_bias
self.symbolic_kwargs['saturating'] = quant_weight.saturating
self.symbolic_kwargs['inf_values'] = quant_weight.inf_values
self.symbolic_kwargs['nan_values'] = quant_weight.nan_values
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.clip_symbolic_kwargs(
module.is_narrow_range,
module.is_signed,
quant_weight.exponent_bit_width,
quant_weight.mantissa_bit_width)
self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
scale,
quant_weight.zero_point,
quant_weight.exponent_bit_width,
quant_weight.mantissa_bit_width,
module.is_signed)
else:
self.symbolic_kwargs = None

def quantize_from_floating_point(self, x: Tensor):
# Workaround for equal_cpu RuntimeError
quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs']
# Before quantization, cast input to float32
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, torch.float32)
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
return x

def quantize_from_integer(self, x: Tensor):
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
return self.symbolic_kwargs['int_weights'][x.data_ptr()]

def symbolic_execution(self, x: Tensor):
assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled'

# Copy dict to allow for popping kwargs even on shared quantizers
dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs'])
scale = dequantize_symbolic_kwargs['scale']
zero_point = dequantize_symbolic_kwargs['zero_point']

if self._export_q_node:
x = self.quantize_from_floating_point(x)
else:
x = self.quantize_from_integer(x)
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs']
exponent_bit_width = self.symbolic_kwargs['exponent_bit_width']
mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width']
exponent_bias = self.symbolic_kwargs['exponent_bias']
saturating = self.symbolic_kwargs['saturating']
inf_values = self.symbolic_kwargs['inf_values']
nan_values = self.symbolic_kwargs['nan_values']
scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape')
# Workaround to trick the tracer into believing all return values are used
self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width)
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
# After dequantization, cast both input and scale to the correct dtype
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, self.scale_dtype)
scale = self.cast_fn(scale, self.scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
return x, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values


class QCDQCastWeightQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin):
handled_layer = WeightQuantProxyFromInjector

Expand Down Expand Up @@ -251,6 +424,99 @@ def symbolic_execution(self, x: Tensor, input_bit_width: torch.Tensor, input_is_
return super().symbolic_execution(x)


class FloatQCDQCastActQuantProxyHandlerMixin(FloatQMixin, FloatCDQCastProxyHandlerMixin, ABC):
handled_layer = ActFloatQuantProxyFromInjector

def quantize_symbolic_kwargs(
cls, scale, zero_point, exponent_bit_width, mantissa_bit_width, is_signed):
# compute axis before redefining scale
axis = cls.quant_axis(scale)
scale = to_0dim_if_scalar(scale.flatten())
zp = to_0dim_if_scalar(zero_point.flatten())
# expand_as must go after 0-dim check
zp = zp.expand_as(scale)
zp = cls.zero_point_with_dtype(is_signed, exponent_bit_width, mantissa_bit_width, zp)
if cls.itemize_quantize_scalar_params:
scale = to_item_if_0dim(scale)
zp = to_item_if_0dim(zp)
dtype = cls.signed_dtype(exponent_bit_width, mantissa_bit_width, is_signed)
return {'scale': scale, 'zero_point': zp, 'dtype': dtype, 'axis': axis}

def prepare_for_export(self, module):
if module.is_quant_enabled:
self.validate(module)
self.symbolic_kwargs['exponent_bit_width'] = module.exponent_bit_width()
self.symbolic_kwargs['mantissa_bit_width'] = module.mantissa_bit_width()
self.symbolic_kwargs['exponent_bias'] = module.exponent_bias()
self.symbolic_kwargs['saturating'] = module.saturating()
self.symbolic_kwargs['inf_values'] = module.inf_values()
self.symbolic_kwargs['nan_values'] = module.nan_values()

# (B)float16 is not supported with standard Q/DQ ops, thus we store the original dtype
# of the scale and we cast it to float32.
# The original dtype is then restored during the forward pass
scale = module.scale()
self.scale_dtype = scale.dtype
if self.scale_dtype == torch.bfloat16 or self.scale_dtype == torch.float16:
scale = self.cast_fn(scale, torch.float32)

self.symbolic_kwargs['quantize_symbolic_kwargs'] = self.quantize_symbolic_kwargs(
scale,
module.zero_point(),
module.exponent_bit_width(),
module.mantissa_bit_width(),
module.is_signed)
self.symbolic_kwargs['dequantize_symbolic_kwargs'] = self.dequantize_symbolic_kwargs(
scale,
module.zero_point(),
module.exponent_bit_width(),
module.mantissa_bit_width(),
module.is_signed)
self.symbolic_kwargs['clip_symbolic_kwargs'] = self.clip_symbolic_kwargs(
module.is_narrow_range,
module.is_signed,
module.exponent_bit_width(),
module.mantissa_bit_width())

else:
self.symbolic_kwargs = None

def symbolic_execution(self, x: Tensor):
assert self.symbolic_kwargs is not None, 'Symbolic execution requires quant to be enabled'

# Copy dict to allow for popping kwargs even on shared quantizers
dequantize_symbolic_kwargs = copy(self.symbolic_kwargs['dequantize_symbolic_kwargs'])
scale = dequantize_symbolic_kwargs['scale']
zero_point = dequantize_symbolic_kwargs['zero_point']
scale_orig_shape = dequantize_symbolic_kwargs.pop('scale_orig_shape')

quantize_symbolic_kwargs = self.symbolic_kwargs['quantize_symbolic_kwargs']
clip_symbolic_kwargs = self.symbolic_kwargs['clip_symbolic_kwargs']
exponent_bit_width = self.symbolic_kwargs['exponent_bit_width']
mantissa_bit_width = self.symbolic_kwargs['mantissa_bit_width']
exponent_bias = self.symbolic_kwargs['exponent_bias']
saturating = self.symbolic_kwargs['saturating']
inf_values = self.symbolic_kwargs['inf_values']
nan_values = self.symbolic_kwargs['nan_values']

self.assert_ge_zero(scale, exponent_bit_width, mantissa_bit_width)
# If original dtype of the input is (b)float16, cast the input to float32
if x.dtype == torch.float16 or x.dtype == torch.bfloat16:
x = self.cast_fn(x, torch.float32)
x = self.quantize_fn(x, *quantize_symbolic_kwargs.values())
if clip_symbolic_kwargs is not None:
x = self.clip_fn(x, *clip_symbolic_kwargs.values())
x = self.dequantize_fn(x, *dequantize_symbolic_kwargs.values())
# After dequantization, cast both output and scale to the correct dtype
if self.scale_dtype == torch.float16 or self.scale_dtype == torch.bfloat16:
x = self.cast_fn(x, self.scale_dtype)
scale = self.cast_fn(scale, self.scale_dtype)
# Restore the original shapes to guarantee correct shape propagation downstream
scale = scale.view(scale_orig_shape)
zero_point = zero_point.view_as(scale)
return x, scale, zero_point, exponent_bit_width, mantissa_bit_width, exponent_bias, saturating, inf_values, nan_values


class QCDQCastActQuantProxyHandlerMixin(QMixin, CDQCastProxyHandlerMixin, ABC):
handled_layer = ActQuantProxyFromInjector

Expand Down
Loading
Loading