Skip to content

Commit

Permalink
Feat (proxy/parameter_quant): cache quant weights (#990)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jul 18, 2024
1 parent 7c7d825 commit 77466ad
Showing 1 changed file with 18 additions and 30 deletions.
48 changes: 18 additions & 30 deletions src/brevitas/proxy/parameter_quant.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,11 @@ def get_cached(self, attr):

class WeightQuantProxyFromInjector(WeightQuantProxyFromInjectorBase):

def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
super().__init__(quant_layer, quant_injector)
self._cached_weight = None
self.cache_inference_quant_weight = False

@property
def tracked_parameter_list(self):
return [m.weight for m in self.tracked_module_list if m.weight is not None]
Expand Down Expand Up @@ -152,11 +157,20 @@ def bit_width(self):

def forward(self, x: torch.Tensor) -> Union[Tensor, IntQuantTensor]:
if self.is_quant_enabled:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width = impl(x)
return IntQuantTensor(out, scale, zero_point, bit_width, self.is_signed, self.training)
if self._cached_weight is not None:
out = self._cached_weight.quant_tensor
else:
impl = self.export_handler if self.export_mode else self.tensor_quant
out, scale, zero_point, bit_width = impl(x)
out = IntQuantTensor(
out, scale, zero_point, bit_width, self.is_signed, self.training)
else: # quantization disabled
return x
out = x
if isinstance(
out, IntQuantTensor
) and not self.training and self.cache_inference_quant_weight and self._cached_weight is None:
self._cached_weight = _CachedIO(out.detach(), metadata_only=False)
return out


class DecoupledWeightQuantProxyFromInjector(WeightQuantProxyFromInjector):
Expand Down Expand Up @@ -232,32 +246,6 @@ def forward(

class BiasQuantProxyFromInjector(BiasQuantProxyFromInjectorBase):

def __init__(self, quant_layer: nn.Module, quant_injector: Injector) -> None:
super().__init__(quant_layer, quant_injector)
self._cached_bias = None
self.cache_inference_quant_bias = False

@property
def tracked_parameter_list(self):
return [m.bias for m in self.tracked_module_list if m.bias is not None]

@property
def requires_input_scale(self) -> bool:
if self.is_quant_enabled:
return self.quant_injector.requires_input_scale
else:
return False

def get_cached(self, attr):
if self._cached_bias is None:
warn(
"No quant bias cache found, set cache_inference_quant_bias=True and run an "
"inference pass first")
return None
if self.training:
warn("Cached quant bias scale is being used in training mode.")
return getattr(self._cached_bias, attr)

def scale(self):
if not self.is_quant_enabled:
return None
Expand Down

0 comments on commit 77466ad

Please sign in to comment.