Skip to content

Commit

Permalink
Extend equalization
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed May 4, 2024
1 parent 1c32492 commit b6f6b06
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 28 deletions.
29 changes: 23 additions & 6 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
nn.Conv3d,
nn.Linear,
nn.LayerNorm,
nn.GroupNorm,
nn.BatchNorm1d,
nn.BatchNorm2d,
nn.BatchNorm3d)
Expand All @@ -65,6 +66,8 @@

_scale_invariant_op = (
torch.mul,
operator.truediv,
operator.__truediv__,
operator.mul,
operator.imul,
operator.__mul__,
Expand All @@ -73,7 +76,15 @@

_select_op = (operator.getitem, operator.__getitem__)

_reshaping_op = ('view', 'reshape', 'flatten', 'contiguous', torch.reshape, torch.flatten)
_reshaping_op = (
'view',
'reshape',
'flatten',
'contiguous',
torch.reshape,
torch.flatten,
torch.permute,
'permute')

_scale_varying_activations = (
torch.nn.Sigmoid, torch.nn.Tanh, torch.nn.ReLU6, torch.nn.GELU, torch.nn.SiLU)
Expand Down Expand Up @@ -269,7 +280,7 @@ def _get_input_axis(module: nn.Module) -> Optional[int]:
return 0
elif module.groups == module.out_channels:
return 1
elif isinstance(module, nn.LayerNorm):
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
# We assume normalization happens only along the channel dimension
if len(module.weight.shape) == 1:
return 0
Expand All @@ -296,7 +307,7 @@ def _get_output_axis(module: nn.Module) -> Optional[int]:
return 0
elif isinstance(module, (nn.ConvTranspose1d, nn.ConvTranspose2d, nn.ConvTranspose3d)):
return 1
elif isinstance(module, nn.LayerNorm):
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
# We assume normalization happens only along the channel dimension
if len(module.weight.shape) == 1:
return 0
Expand Down Expand Up @@ -809,7 +820,10 @@ def find_sinks(graph_model: GraphModule, starting_node: Node,
weight = get_weight_sink(module)
eq_indexes = EqualizationIndexes(0, weight.shape[0], state.offset)
# It is not possible to equalize through LayerNorm as sink
if isinstance(module, (nn.LayerNorm,) + _batch_norm):
if isinstance(module, (
nn.LayerNorm,
nn.GroupNorm,
) + _batch_norm):
state.sinks[_UNSUPPORTED_OP] = _UNSUPPORTED_OP
else:
state.add_sinks(node.target, module, eq_indexes)
Expand Down Expand Up @@ -1017,8 +1031,11 @@ def find_module(self, model, regions: List):
Iterate through the model looking at immediate children of every module to look for supported modules.
This allows us to stop the search when we meet a top-level module that is supported.
"""
if isinstance(model,
_supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)):
if isinstance(model, _supported_layers) and not isinstance(model,
_batch_norm + (
nn.LayerNorm,
nn.GroupNorm,
)):
weight = get_weight_sink(model)
eq_indexes = EqualizationIndexes(0, weight.shape[0], 0)
region = Region(sinks={'sinks0': eq_indexes}, name_to_module={'sinks0': model})
Expand Down
4 changes: 0 additions & 4 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,8 +264,6 @@ def single_layer_update(self):
perm = torch.tensor(range(weight.shape[-1]), device=dev)
permutation_list.append(perm)

self.reactivate_quantization()

for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
Expand Down Expand Up @@ -361,8 +359,6 @@ def single_layer_update(self):
perm = torch.tensor(range(weight.shape[-1]), device=dev)
permutation_list.append(perm)

self.reactivate_quantization()

for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
Expand Down
3 changes: 0 additions & 3 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def __init__(
device='cpu',
dtype=torch.float32)
self.nsamples = 0
self.done = False

assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher"

Expand Down Expand Up @@ -263,8 +262,6 @@ def single_layer_update(self, percdamp=.01):
finally:
del self.H

self.reactivate_quantization()

for i1 in range(0, self.columns, self.blocksize):
i2 = min(i1 + self.blocksize, self.columns)
count = i2 - i1
Expand Down
12 changes: 0 additions & 12 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,6 @@ def __enter__(self):
return self

def __exit__(self, type, value, traceback):
for name, layer in self.gpxq_layers.items():
if not layer.done:
layer.reactivate_quantization()

if isinstance(self.model, (GraphModule, TorchGraphModule)):
self.model.__class__.forward = self.orig_forward
Expand Down Expand Up @@ -223,10 +220,6 @@ def __init__(
self.disable_pre_forward_hook = False
# Some layers require knowledge from quant inputs to compute quant weights
self.quant_metadata = None
self.disable_quant_inference = DisableEnableQuantization()
self.return_quant_tensor_state = disable_return_quant_tensor(self.layer)
self.disable_quant_inference.disable_param_quantization(self.layer, False)
self.done = False

def process_input(self, inp):
# Input is a tuple, so we take first element
Expand Down Expand Up @@ -263,11 +256,6 @@ def update_batch(self):
def single_layer_update(self):
pass

def reactivate_quantization(self):
self.done = True
self.disable_quant_inference.enable_param_quantization(self.layer, False)
restore_return_quant_tensor(self.layer, self.return_quant_tensor_state)

def get_quant_weights(self, i, i1, permutation_list):
# We need to recompute quant weights at runtime since our float weights are being updated
# Add offset in case of blockwise computation
Expand Down
6 changes: 3 additions & 3 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def main(args):
raise RuntimeError("LoRA layers should be fused in before calling into quantization.")

if args.activation_equalization:
with activation_equalization_mode(pipe.unet, alpha=0.5, layerwise=True, add_mul_node=True):
with activation_equalization_mode(pipe.unet, alpha=0.2, layerwise=True, add_mul_node=True):
# Workaround to expose `in_features` attribute from the Hook Wrapper
for m in pipe.unet.modules():
if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'):
Expand All @@ -176,8 +176,8 @@ def main(args):

# Workaround to expose `in_features` attribute from the EqualizedModule Wrapper
for m in pipe.unet.modules():
if isinstance(m, EqualizedModule) and hasattr(m.module, 'in_features'):
m.in_features = m.module.in_features
if isinstance(m, EqualizedModule) and hasattr(m.layer, 'in_features'):
m.in_features = m.layer.in_features

# Quantize model
if args.quantize:
Expand Down

0 comments on commit b6f6b06

Please sign in to comment.