Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed May 31, 2024
1 parent 724842c commit 38fb5cf
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 4 deletions.
3 changes: 3 additions & 0 deletions src/brevitas_examples/stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH]
[--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH]
[--conv-input-bit-width CONV_INPUT_BIT_WIDTH]
[--act-eq-alpha ACT_EQ_ALPHA]
[--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH]
[--weight-param-method {stats,mse}]
[--input-param-method {stats,mse}]
Expand Down Expand Up @@ -173,6 +174,8 @@ options:
Weight bit width. Default: 8.
--conv-input-bit-width CONV_INPUT_BIT_WIDTH
Input bit width. Default: None (not quantized)
--act-eq-alpha ACT_EQ_ALPHA
Alpha for activation equalization. Default: 0.9
--linear-input-bit-width LINEAR_INPUT_BIT_WIDTH
Input bit width. Default: None (not quantized).
--weight-param-method {stats,mse}
Expand Down
13 changes: 9 additions & 4 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,9 +136,6 @@ def run_val_inference(

def main(args):

if args.export_target:
assert args.weight_quant_format == 'int', "Currently only integer quantization supported for export."

dtype = getattr(torch, args.dtype)

calibration_prompts = CALIBRATION_PROMPTS
Expand Down Expand Up @@ -219,7 +216,10 @@ def main(args):

if args.activation_equalization:
pipe.set_progress_bar_config(disable=True)
with activation_equalization_mode(pipe.unet, alpha=0.9, layerwise=True, add_mul_node=True):
with activation_equalization_mode(pipe.unet,
alpha=args.act_eq_alpha,
layerwise=True,
add_mul_node=True):
# Workaround to expose `in_features` attribute from the Hook Wrapper
for m in pipe.unet.modules():
if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'):
Expand Down Expand Up @@ -604,6 +604,11 @@ def input_zp_stats_type():
type=int,
default=None,
help='Input bit width. Default: None (not quantized)')
parser.add_argument(
'--act-eq-alpha',
type=float,
default=0.9,
help='Alpha for activation equalization. Default: 0.9')
parser.add_argument(
'--linear-input-bit-width',
type=int,
Expand Down

0 comments on commit 38fb5cf

Please sign in to comment.