Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jun 20, 2024
1 parent 8323acd commit 88184f2
Show file tree
Hide file tree
Showing 4 changed files with 227 additions and 18 deletions.
22 changes: 16 additions & 6 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ def __init__(
add_mul_node=True,
layerwise=True,
enabled=True,
blacklist_layers=None,
co_optimize_act_weights=False) -> None:
self.model = model
self.alpha = alpha
Expand All @@ -210,7 +211,8 @@ def __init__(
if layerwise:
if not self.add_mul_node:
raise ValueError("Layerwise activation equalization requires add_mul_node")
self.graph_act_eq = LayerwiseActivationEqualization(self.model)
self.graph_act_eq = LayerwiseActivationEqualization(
self.model, blacklist_layers=blacklist_layers)
else:
if not isinstance(self.model, (TorchGraphModule, GraphModule)):
raise TypeError(
Expand Down Expand Up @@ -996,36 +998,44 @@ def remove_hooks(self):

class LayerwiseActivationEqualization(ActivationEqualization):

def __init__(self, model, scale_computation_type: str = 'maxabs'):
def __init__(
self,
model,
scale_computation_type: str = 'maxabs',
blacklist_layers: Optional[List[str]] = None):
super(LayerwiseActivationEqualization, self).__init__(model, scale_computation_type)
self.float_act_map = {}
self.batch_dim_act_map = {}
self.hooks = []
self.add_mul_node = True
self.blacklist_layers = blacklist_layers

regions: List[Region] = []
self.find_module(model, regions)
name = ''
self.find_module(model, name, regions)
self.regions = regions

if self.scale_computation_type == 'maxabs':
self.scale_fn = _channel_maxabs
elif self.scale_computation_type == 'range':
self.scale_fn = _channel_range

def find_module(self, model, regions: List):
def find_module(self, model, name, regions: List):
"""
Iterate through the model looking at immediate children of every module to look for supported modules.
This allows us to stop the search when we meet a top-level module that is supported.
"""
if isinstance(model,
_supported_layers) and not isinstance(model, _batch_norm + (nn.LayerNorm,)):
if self.blacklist_layers is not None and name in self.blacklist_layers:
return
weight = get_weight_sink(model)
eq_indexes = EqualizationIndexes(0, weight.shape[0], 0)
region = Region(sinks={'sinks0': eq_indexes}, name_to_module={'sinks0': model})
regions.append(region)
else:
for module in model.children():
self.find_module(module, regions)
for name, module in model.named_children():
self.find_module(module, name, regions)

def setup(self):
for region in self.regions:
Expand Down
33 changes: 31 additions & 2 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat
from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZActPerTensorFloat
from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZWeightPerChannelFloat
from brevitas.quant.experimental.float_quant_fnuz import Fp8e4m3FNUZWeightPerTensorFloat
from brevitas.quant.experimental.float_quant_fnuz import Fp8e5m2FNUZActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPActPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
Expand Down Expand Up @@ -104,7 +108,15 @@
'per_tensor': {
'sym': Fp8e5m2OCPWeightPerTensorFloat},
'per_channel': {
'sym': Fp8e5m2OCPWeightPerChannelFloat}}}}}}
'sym': Fp8e5m2OCPWeightPerChannelFloat}}}}},
'float_fnuz': {
'e4m3': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e4m3FNUZWeightPerTensorFloat},
'per_channel': {
'sym': Fp8e4m3FNUZWeightPerChannelFloat}}}}}}

INPUT_QUANT_MAP = {
'int': {
Expand Down Expand Up @@ -154,7 +166,19 @@
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e5m2OCPActPerTensorFloat}}}}},
'sym': Fp8e5m2OCPActPerTensorFloat}}}}}},
'float_fnuz': {
'static': {
'e4m3': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e4m3FNUZActPerTensorFloat}}}},
'e5m2': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e5m2FNUZActPerTensorFloat}}}}},
'dynamic': {
'e4m3': {
'float_scale': {
Expand Down Expand Up @@ -183,6 +207,7 @@ def generate_quantizers(
input_group_size=None,
quantize_input_zero_point=False,
use_ocp=False,
use_fnuz=False,
device=None,
weight_kwargs=None,
input_kwargs=None):
Expand All @@ -201,6 +226,8 @@ def generate_quantizers(
weight_quant_format = 'float'
if use_ocp:
weight_quant_format += '_ocp'
elif use_fnuz:
weight_quant_format += '_fnuz'
else:
weight_float_format = {}
if re.compile(r'e[1-8]m[1-8]').match(input_quant_format):
Expand All @@ -211,6 +238,8 @@ def generate_quantizers(
input_quant_format = 'float'
if use_ocp:
input_quant_format += '_ocp'
elif use_fnuz:
input_quant_format += '_fnuz'
else:
input_float_format = {}

Expand Down
93 changes: 83 additions & 10 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from dependencies import value
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionXLPipeline
from diffusers.models.attention_processor import Attention
from diffusers.models.attention_processor import AttnProcessor
import numpy as np
import pandas as pd
import torch
Expand All @@ -22,6 +24,7 @@

from brevitas.core.stats.stats_op import NegativeMinOrZero
from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas.graph.base import ModuleToModuleByClass
from brevitas.graph.calibrate import bias_correction_mode
from brevitas.graph.calibrate import calibration_mode
from brevitas.graph.calibrate import load_quant_model_mode
Expand All @@ -42,12 +45,15 @@
from brevitas_examples.stable_diffusion.sd_quant.constants import SD_2_1_EMBEDDINGS_SHAPE
from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE
from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx
from brevitas_examples.stable_diffusion.sd_quant.export import export_quant_params
from brevitas_examples.stable_diffusion.sd_quant.nn import QuantAttention
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs
from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs
from brevitas_examples.stable_diffusion.sd_quant.utils import unet_input_shape

TEST_SEED = 123456
# TODO: add deterministc flags

NEGATIVE_PROMPTS = ["normal quality, low quality, worst quality, low res, blurry, nsfw, nude"]

Expand Down Expand Up @@ -206,10 +212,12 @@ def main(args):

if args.activation_equalization:
pipe.set_progress_bar_config(disable=True)
with activation_equalization_mode(pipe.unet,
alpha=args.act_eq_alpha,
layerwise=True,
add_mul_node=True):
with activation_equalization_mode(
pipe.unet,
alpha=args.act_eq_alpha,
layerwise=True,
blacklist_layers=blacklist if args.exclude_blacklist_act_eq else None,
add_mul_node=True):
# Workaround to expose `in_features` attribute from the Hook Wrapper
for m in pipe.unet.modules():
if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'):
Expand Down Expand Up @@ -302,6 +310,7 @@ def input_zp_stats_type():
input_quant_type=args.input_quant_type,
input_quant_granularity=args.input_quant_granularity,
use_ocp=args.use_ocp,
use_fnuz=args.use_fnuz,
input_kwargs=input_kwargs)

layer_map = generate_quant_maps(
Expand All @@ -323,6 +332,55 @@ def input_zp_stats_type():
'weight_quant']
layer_map[torch.nn.Conv2d] = (layer_map[torch.nn.Conv2d][0], conv_qkwargs)

if args.quantize_sdp_1 or args.quantize_sdp_2:
float_sdpa_quantizers = generate_quantizers(
dtype=dtype,
device=args.device,
weight_bit_width=weight_bit_width,
weight_quant_format='e4m3',
weight_quant_type='sym',
weight_param_method=args.weight_param_method,
weight_scale_precision=args.weight_scale_precision,
weight_quant_granularity=args.weight_quant_granularity,
weight_group_size=args.weight_group_size,
quantize_weight_zero_point=args.quantize_weight_zero_point,
quantize_input_zero_point=args.quantize_input_zero_point,
input_bit_width=input_bit_width,
input_quant_format='e4m3',
input_scale_type=args.input_scale_type,
input_scale_precision=args.input_scale_precision,
input_param_method=args.input_param_method,
input_quant_type='sym',
input_quant_granularity=args.input_quant_granularity,
use_ocp=args.use_ocp,
use_fnuz=args.use_fnuz,
input_kwargs=input_kwargs)
input_quant = float_sdpa_quantizers[0]
input_quant = input_quant.let(**{'bit_width': args.linear_output_bit_width})
if args.quantize_sdp_2:
rewriter = ModuleToModuleByClass(
Attention,
QuantAttention,
softmax_output_quant=input_quant,
query_dim=lambda module: module.to_q.in_features,
dim_head=lambda module: int(1 / (module.scale ** 2)),
processor=AttnProcessor(),
is_equalized=args.activation_equalization)
import brevitas.config as config
config.IGNORE_MISSING_KEYS = True
pipe.unet = rewriter.apply(pipe.unet)
config.IGNORE_MISSING_KEYS = False
pipe.unet = pipe.unet.to(args.device)
pipe.unet = pipe.unet.to(dtype)
quant_kwargs = layer_map[torch.nn.Linear][1]
what_to_quantize = []
if args.quantize_sdp_1:
what_to_quantize.extend(['to_q', 'to_k'])
if args.quantize_sdp_2:
what_to_quantize.extend(['to_v'])
quant_kwargs['output_quant'] = lambda module, name: input_quant if any(ending in name for ending in what_to_quantize) else None
layer_map[torch.nn.Linear] = (layer_map[torch.nn.Linear][0], quant_kwargs)

pipe.unet = layerwise_quantize(
model=pipe.unet, compute_layer_map=layer_map, name_blacklist=blacklist)
print("Model quantization applied.")
Expand Down Expand Up @@ -469,6 +527,8 @@ def input_zp_stats_type():
export_manager = StdQCDQONNXManager
export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node)
export_onnx(pipe, trace_inputs, output_dir, export_manager)
if args.export_target == 'params_only':
export_quant_params(pipe, output_dir)


if __name__ == "__main__":
Expand All @@ -488,10 +548,7 @@ def input_zp_stats_type():
default=2,
help='How many seeds to use for each image during validation. Default: 2')
parser.add_argument(
'--prompt',
type=int,
default=4,
help='Number of prompt to use for testing. Default: 4. Max: 4')
'--prompt', type=int, default=4, help='Number of prompt to use for testing. Default: 4')
parser.add_argument(
'--calibration-prompt',
type=int,
Expand Down Expand Up @@ -558,7 +615,11 @@ def input_zp_stats_type():
default=False,
help='Enable attention slicing. Default: Disabled')
parser.add_argument(
'--export-target', type=str, default='', choices=['', 'onnx'], help='Target export flow.')
'--export-target',
type=str,
default='',
choices=['', 'onnx', 'params_only'],
help='Target export flow.')
add_bool_arg(
parser,
'export-weight-q-node',
Expand Down Expand Up @@ -673,6 +734,11 @@ def input_zp_stats_type():
'quantize-weight-zero-point',
default=True,
help='Quantize weight zero-point. Default: Enabled')
add_bool_arg(
parser,
'exclude-blacklist-act-eq',
default=False,
help='Exclude unquantized layers from activation equalization. Default: Disabled')
add_bool_arg(
parser,
'quantize-input-zero-point',
Expand All @@ -688,8 +754,13 @@ def input_zp_stats_type():
add_bool_arg(
parser,
'use-ocp',
default=True,
default=False,
help='Use OCP format for float quantization. Default: True')
add_bool_arg(
parser,
'use-nfuz',
default=True,
help='Use NFUZ format for float quantization. Default: True')
add_bool_arg(
parser,
'use-negative-prompts',
Expand All @@ -700,6 +771,8 @@ def input_zp_stats_type():
'dry-run',
default=False,
help='Generate a quantized model without any calibration. Default: Disabled')
add_bool_arg(parser, 'quantize-sdp-1', default=False, help='Quantize SDP. Default: Disabled')
add_bool_arg(parser, 'quantize-sdp-2', default=False, help='Quantize SDP. Default: Disabled')
args = parser.parse_args()
print("Args: " + str(vars(args)))
main(args)
Loading

0 comments on commit 88184f2

Please sign in to comment.