Skip to content

Commit

Permalink
Feat (calibrate/activation_calibration): speed-up by skipping quantiz…
Browse files Browse the repository at this point in the history
…ation
  • Loading branch information
Giuseppe5 committed Sep 23, 2024
1 parent b28ac0f commit d3b4d5f
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 23 deletions.
18 changes: 11 additions & 7 deletions src/brevitas/core/quant/float.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,10 @@ def __init__(
if dtype is None:
dtype = torch.get_default_dtype()
self.eps = torch.finfo(dtype).tiny
self.observer_only = brevitas.jit.Attribute(False, bool)

@brevitas.jit.script_method
def quantize(self, x: torch.Tensor):
scale = self.scaling_impl(x)

def quantize(self, x: torch.Tensor, scale: torch.Tensor) -> Tuple[torch.Tensor]:
if self.float_scaling_impl is not None:
float_scaling_impl_value = self.float_scaling_impl(
self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
Expand All @@ -86,10 +85,15 @@ def dequantize(self, y, scale):

@brevitas.jit.script_method
def forward(self, x):
y, scale = self.quantize(x)
# after quantizing, clamp to special cases like NaN/inf if they are set
y, saturating, inf_values, nan_values = self.float_clamp_impl(
y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
scale = self.scaling_impl(x)
if self.observer_only:
y = x
saturating, inf_values, nan_values = self.float_clamp_impl.saturating, self.float_clamp_impl.inf_values, self.float_clamp_impl.nan_values
else:
y, scale = self.quantize(x, scale)
# after quantizing, clamp to special cases like NaN/inf if they are set
y, saturating, inf_values, nan_values = self.float_clamp_impl(
y, self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias())
y = self.dequantize(y, scale)
# This is to respect the current interface of proxies
return y, scale, self.zero_point_impl(), self.exponent_bit_width(), self.mantissa_bit_width(), self.exponent_bias(), saturating, inf_values, nan_values
17 changes: 14 additions & 3 deletions src/brevitas/core/quant/int.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ def __init__(
self.int_scaling_impl = int_scaling_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
self.observer_only = brevitas.jit.Attribute(False, bool)

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
Expand All @@ -153,7 +154,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor]:
int_threshold = self.int_scaling_impl(bit_width)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.int_quant(scale, zero_point, bit_width, x)
if self.observer_only:
y = x
else:
y = self.int_quant(scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width


Expand All @@ -176,6 +180,7 @@ def __init__(
self.pre_zero_point_impl = pre_zero_point_impl
self.zero_point_impl = zero_point_impl
self.msb_clamp_bit_width_impl = bit_width_impl
self.observer_only = brevitas.jit.Attribute(False, bool)

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Tensor]:
Expand All @@ -187,7 +192,10 @@ def forward(self, x: Tensor) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Te
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
if self.observer_only:
y = x
else:
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width, pre_scale, pre_zero_point


Expand Down Expand Up @@ -253,5 +261,8 @@ def forward(self, x: Tensor, input_bit_width: Tensor,
threshold = self.scaling_impl(x)
scale = threshold / int_threshold
zero_point = self.zero_point_impl(x, scale, bit_width)
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
if self.observer_only:
y = x
else:
y = self.decoupled_int_quant(pre_scale, pre_zero_point, scale, zero_point, bit_width, x)
return y, scale, zero_point, bit_width, pre_scale, pre_zero_point
16 changes: 8 additions & 8 deletions src/brevitas/function/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,16 +189,16 @@ def min_int(signed: bool, narrow_range: bool, bit_width: Tensor) -> Tensor:
return value


@brevitas.jit.ignore
def max_mantissa_func(val):
return torch.sum((2. ** torch.arange(0, -1. * val - 1., -1.)))


MAX_MANTISSA_DICT = {x: max_mantissa_func(x) for x in range(0, 16)}


def max_float(exponent_bit_width: Tensor, mantissa_bit_width: Tensor, exponent_bias: Tensor):
max_exponent = (2. ** exponent_bit_width) - 1. - exponent_bias
max_mantissa = torch.sum((
2. ** torch.arange(
0,
-1. * mantissa_bit_width - 1.,
-1.,
dtype=mantissa_bit_width.dtype,
device=mantissa_bit_width.device)))
max_mantissa = MAX_MANTISSA_DICT[mantissa_bit_width.item()]
max_val = max_mantissa * (2 ** max_exponent)
return max_val

Expand Down
11 changes: 6 additions & 5 deletions src/brevitas/graph/calibrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,8 +201,9 @@ def disable_act_quantization(self, model, is_training):
if isinstance(module, ActQuantProxyFromInjectorBase):
module.train(is_training)
if self.call_act_quantizer_impl:
hook = module.register_forward_hook(self.disable_act_quant_hook)
self.disable_act_quant_hooks.append(hook)
for m in module.modules():
if hasattr(m, 'observer_only'):
m.observer_only = True
else:
module.disable_quant = True
elif isinstance(module, _ACC_PROXIES):
Expand All @@ -229,9 +230,9 @@ def enable_act_quantization(self, model, is_training):
elif isinstance(module, ActQuantProxyFromInjectorBase):
module.disable_quant = False
module.train(is_training)
for hook in self.disable_act_quant_hooks:
hook.remove()
self.disable_act_quant_hooks = []
for m in module.modules():
if hasattr(m, 'observer_only'):
m.observer_only = False

def enable_param_quantization(self, model, is_training):
for module in model.modules():
Expand Down

0 comments on commit d3b4d5f

Please sign in to comment.