Skip to content

Commit

Permalink
Feat (examples/stable_diffusion): improvements to SD quantization (#965)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jun 21, 2024
1 parent b4e9287 commit 0c12bbc
Show file tree
Hide file tree
Showing 14 changed files with 3,213 additions and 151 deletions.
2 changes: 1 addition & 1 deletion src/brevitas/core/function_wrapper/clamp.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class FloatClamp(brevitas.jit.ScriptModule):
I.e. setting inf to 1101.111 (E4M3) is not a valid code.
"""

__constants__ = ['saturating', 'inf_values', 'nan_values', 'signed', 'max_available_float']
__constants__ = ['saturating', 'inf_values', 'nan_values', 'signed']

def __init__(
self,
Expand Down
22 changes: 16 additions & 6 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(
add_mul_node=True,
layerwise=True,
enabled=True,
blacklist_layers=None,
co_optimize_act_weights=False) -> None:
self.model = model
self.alpha = alpha
Expand All @@ -210,7 +211,8 @@ def __init__(
if layerwise:
if not self.add_mul_node:
raise ValueError("Layerwise activation equalization requires add_mul_node")
self.graph_act_eq = LayerwiseActivationEqualization(self.model)
self.graph_act_eq = LayerwiseActivationEqualization(
self.model, blacklist_layers=blacklist_layers)
else:
if not isinstance(self.model, (TorchGraphModule, GraphModule)):
raise TypeError(
Expand Down Expand Up @@ -1004,36 +1006,44 @@ def remove_hooks(self):

class LayerwiseActivationEqualization(ActivationEqualization):

def __init__(self, model, scale_computation_type: str = 'maxabs'):
def __init__(
self,
model,
scale_computation_type: str = 'maxabs',
blacklist_layers: Optional[List[str]] = None):
super(LayerwiseActivationEqualization, self).__init__(model, scale_computation_type)
self.float_act_map = {}
self.batch_dim_act_map = {}
self.hooks = []
self.add_mul_node = True
self.blacklist_layers = blacklist_layers

regions: List[Region] = []
self.find_module(model, regions)
name = ''
self.find_module(model, name, regions)
self.regions = regions

if self.scale_computation_type == 'maxabs':
self.scale_fn = _channel_maxabs
elif self.scale_computation_type == 'range':
self.scale_fn = _channel_range

def find_module(self, model, regions: List):
def find_module(self, model, name, regions: List):
"""
Iterate through the model looking at immediate children of every module to look for supported modules.
This allows us to stop the search when we meet a top-level module that is supported.
"""
if isinstance(model,
_supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)):
if self.blacklist_layers is not None and name in self.blacklist_layers:
return
weight = get_weight_sink(model)
eq_indexes = EqualizationIndexes(0, weight.shape[0], 0)
region = Region(sinks={'sinks0': eq_indexes}, name_to_module={'sinks0': model})
regions.append(region)
else:
for module in model.children():
self.find_module(module, regions)
for name, module in model.named_children():
self.find_module(module, name, regions)

def setup(self):
for region in self.regions:
Expand Down
7 changes: 5 additions & 2 deletions src/brevitas/nn/quant_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,8 +166,11 @@ def _load_from_state_dict(
bias_key = prefix + 'bias'
# If the state dict has a bias and the module does not, bias correction was used
# We add a bias module to prevent failing during the load of the state dict
if bias_key in state_dict and self.bias is None and self._quant_load_model_mode:
if (bias_key in state_dict) and (self.bias is None) and self._quant_load_model_mode:
self.register_parameter(
'bias', torch.nn.Parameter(torch.zeros(self.out_channels)).to(self.weight.device))
'bias',
torch.nn.Parameter(
torch.zeros(
self.out_channels, device=self.weight.device, dtype=self.weight.dtype)))
super(QuantWeightBiasInputOutputLayer, self)._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
Loading

0 comments on commit 0c12bbc

Please sign in to comment.