Skip to content

Commit

Permalink
Feat (export/onnx): OCP FP8 export (#907)
Browse files Browse the repository at this point in the history
  • Loading branch information
costigt-dev authored May 29, 2024
1 parent 7a716f7 commit fc4162e
Show file tree
Hide file tree
Showing 20 changed files with 622 additions and 124 deletions.
13 changes: 3 additions & 10 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
from brevitas.core.function_wrapper import RoundSte
from brevitas.core.scaling import ConstScaling
from brevitas.core.utils import StatelessBuffer
from brevitas.function.ops import max_float
from brevitas.function.ops_ste import floor_ste
from brevitas.utils.torch_utils import float_internal_scale


class FloatQuant(brevitas.jit.ScriptModule):
Expand Down Expand Up @@ -64,21 +63,15 @@ def __init__(
dtype = torch.get_default_dtype()
self.eps = torch.finfo(dtype).tiny

@brevitas.jit.script_method
def internal_scale(self, x):
internal_scale = floor_ste(torch.log2(torch.abs(x) + self.eps)) - self.mantissa_bit_width()
internal_scale = torch.clamp_min(internal_scale, self.fp_internal_scale_min())
internal_scale = torch.exp2(internal_scale)
return internal_scale

@brevitas.jit.script_method
def quantize(self, x: torch.Tensor):
scaling_impl_value = self.scaling_impl(x)
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
scale = scaling_impl_value / float_scaling_impl_value
scaled_x = x / scale
internal_scale = self.internal_scale(scaled_x)
internal_scale = float_internal_scale(
scaled_x, self.mantissa_bit_width(), self.fp_internal_scale_min())
val_fp_quant = internal_scale * self.float_to_int_impl(scaled_x / internal_scale)
return val_fp_quant, scale

Expand Down
23 changes: 22 additions & 1 deletion src/brevitas/export/common/handler/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from abc import ABC
from abc import abstractmethod
import math
from warnings import warn

import torch
from torch import Tensor
Expand All @@ -12,7 +13,8 @@
from brevitas.function.ops import max_int
from brevitas.function.ops import min_int

__all__ = ['BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin']
__all__ = [
'BaseHandler', 'BitWidthHandlerMixin', 'ZeroPointHandlerMixin', 'FloatZeroPointHandlerMixin']


class BaseHandler(Module, ABC):
Expand All @@ -38,6 +40,13 @@ def quant_axis(cls, scale):
return None


class FloatClipMixin(ABC):

@classmethod
def clip_symbolic_kwargs(cls, narrow, signed, exponent_bit_width, mantissa_bit_width):
return None


class ClipMixin(ABC):

@classmethod
Expand Down Expand Up @@ -112,6 +121,18 @@ def validate_neg_scalar_int_exponent(cls, scale: Tensor):
return -cls.validate_scalar_int_exponent(scale)


class FloatZeroPointHandlerMixin(ABC):

@classmethod
def zero_point_with_dtype(cls, exponent_bit_width, mantissa_bit_width, zero_point):
if exponent_bit_width == 4 and mantissa_bit_width == 3:
return zero_point.type(torch.float8_e4m3fn)
elif exponent_bit_width == 5 and mantissa_bit_width == 2:
return zero_point.type(torch.float8_e5m2)
else:
return zero_point.type(torch.float32)


class ZeroPointHandlerMixin(ABC):

@classmethod
Expand Down
Loading

0 comments on commit fc4162e

Please sign in to comment.