From 38fb5cfe80e3e408439ef68460cb14dc9579c3da Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 31 May 2024 14:19:25 +0100 Subject: [PATCH] update --- src/brevitas_examples/stable_diffusion/README.md | 3 +++ src/brevitas_examples/stable_diffusion/main.py | 13 +++++++++---- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index 01c05b2e8..ad2dc57a8 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -75,6 +75,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH] [--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH] [--conv-input-bit-width CONV_INPUT_BIT_WIDTH] + [--act-eq-alpha ACT_EQ_ALPHA] [--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH] [--weight-param-method {stats,mse}] [--input-param-method {stats,mse}] @@ -173,6 +174,8 @@ options: Weight bit width. Default: 8. --conv-input-bit-width CONV_INPUT_BIT_WIDTH Input bit width. Default: None (not quantized) + --act-eq-alpha ACT_EQ_ALPHA + Alpha for activation equalization. Default: 0.9 --linear-input-bit-width LINEAR_INPUT_BIT_WIDTH Input bit width. Default: None (not quantized). --weight-param-method {stats,mse} diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index cf6889dfd..8aa960d0f 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -136,9 +136,6 @@ def run_val_inference( def main(args): - if args.export_target: - assert args.weight_quant_format == 'int', "Currently only integer quantization supported for export." - dtype = getattr(torch, args.dtype) calibration_prompts = CALIBRATION_PROMPTS @@ -219,7 +216,10 @@ def main(args): if args.activation_equalization: pipe.set_progress_bar_config(disable=True) - with activation_equalization_mode(pipe.unet, alpha=0.9, layerwise=True, add_mul_node=True): + with activation_equalization_mode(pipe.unet, + alpha=args.act_eq_alpha, + layerwise=True, + add_mul_node=True): # Workaround to expose `in_features` attribute from the Hook Wrapper for m in pipe.unet.modules(): if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'): @@ -604,6 +604,11 @@ def input_zp_stats_type(): type=int, default=None, help='Input bit width. Default: None (not quantized)') + parser.add_argument( + '--act-eq-alpha', + type=float, + default=0.9, + help='Alpha for activation equalization. Default: 0.9') parser.add_argument( '--linear-input-bit-width', type=int,