Skip to content

Commit

Permalink
Fix (examples/stable_diffusion): README formatting and clarification (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Apr 12, 2024
1 parent a106a6d commit 6de7d5a
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 20 deletions.
23 changes: 11 additions & 12 deletions src/brevitas_examples/stable_diffusion/README.md
Original file line number Diff line number Diff line change
@@ -1,31 +1,33 @@
# Stable Diffusion Quantization

It currently supports Stable Diffusion 2.1 and Stable Diffusion XL.
It supports Stable Diffusion 2.1 and Stable Diffusion XL.

The following PTQ techniques are currently supported:
- Activation Equalization (e.g., SmoothQuant), layerwise (with the addition of Mul ops)
- Activation Calibration, in the case of static activation quantization
- GPTQ
- Bias Correction

These techniques can be applied for both integer and floating point quantization
These techniques can be applied for both integer and floating point quantization.
Activation quantization is optional, and disabled by default. To enable, set both `conv-input-bit-width` and `linear-input-bit-width`.

We support ONNX integer export, and we are planning to release soon export for floating point quantization (e.g., FP8).
To export the model in fp16, enable `export-cuda-float16`. This will performing the tracing necessary for export on GPU, leaving the model in fp16.
If the flag is not enabled, the model will be moved to CPU and cast to float32 before export because of missing CPU kernels in fp16.

NB: when exporting Stable Diffusion XL, make sure to enable `is-sd-xl` flag. The flag is not needed when export is not executed.
To export the model with fp16 scale factors, enable `export-cuda-float16`. This will performing the tracing necessary for export on GPU, leaving the model in fp16.
If the flag is not enabled, the model will be moved to CPU and cast to float32 before export because of missing CPU kernels in fp16.


## Run

```bash
usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--resolution RESOLUTION]
[--output-path OUTPUT_PATH | --no-output-path]
[--quantize | --no-quantize]
[--activation-equalization | --no-activation-equalization]
[--gptq | --no-gptq] [--float16 | --no-float16]
[--attention-slicing | --no-attention-slicing]
[--is-sd-xl | --no-is-sd-xl] [--export-target {,onnx}]
[--export-target {,onnx}]
[--export-weight-q-node | --no-export-weight-q-node]
[--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH]
[--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH]
Expand Down Expand Up @@ -79,10 +81,6 @@ options:
--attention-slicing Enable Enable attention slicing. Default: Disabled
--no-attention-slicing
Disable Enable attention slicing. Default: Disabled
--is-sd-xl Enable Enable this flag to correctly export SDXL.
Default: Disabled
--no-is-sd-xl Disable Enable this flag to correctly export SDXL.
Default: Disabled
--export-target {,onnx}
Target export flow.
--export-weight-q-node
Expand Down Expand Up @@ -117,8 +115,8 @@ options:
Weight quantization type. Either int or eXmY, with
X+Y==weight_bit_width-1. Default: int.
--input-quant-format INPUT_QUANT_FORMAT
Weight quantization type. Either int or eXmY, with
X+Y==weight_bit_width-1. Default: int.
Input quantization type. Either int or eXmY, with
X+Y==input_bit_width-1. Default: int.
--weight-quant-granularity {per_channel,per_tensor,per_group}
Granularity for scales/zero-point of weights. Default:
per_channel.
Expand All @@ -139,3 +137,4 @@ options:
Enable Export FP16 on CUDA. Default: Disabled
--no-export-cuda-float16
Disable Export FP16 on CUDA. Default: Disabled
```
14 changes: 6 additions & 8 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,12 @@

from dependencies import value
from diffusers import DiffusionPipeline
from diffusers import StableDiffusionXLPipeline
import torch
from torch import nn
from tqdm import tqdm

from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager
from brevitas.export.torch.qcdq.manager import TorchQCDQManager
from brevitas.graph.calibrate import bias_correction_mode
from brevitas.graph.calibrate import calibration_mode
from brevitas.graph.equalize import activation_equalization_mode
Expand Down Expand Up @@ -97,6 +97,9 @@ def main(args):
pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype)
print(f"Model loaded from {args.model}.")

# Detect Stable Diffusion XL pipeline
is_sd_xl = isinstance(pipe, StableDiffusionXLPipeline)

# Enable attention slicing
if args.attention_slicing:
pipe.enable_attention_slicing()
Expand Down Expand Up @@ -226,7 +229,7 @@ def input_bit_width(module):
dtype = next(iter(pipe.unet.parameters())).dtype

# Define tracing input
if args.is_sd_xl:
if is_sd_xl:
generate_fn = generate_unet_xl_rand_inputs
shape = SD_XL_EMBEDDINGS_SHAPE
else:
Expand Down Expand Up @@ -287,11 +290,6 @@ def input_bit_width(module):
'attention-slicing',
default=False,
help='Enable attention slicing. Default: Disabled')
add_bool_arg(
parser,
'is-sd-xl',
default=False,
help='Enable this flag to correctly export SDXL. Default: Disabled')
parser.add_argument(
'--export-target', type=str, default='', choices=['', 'onnx'], help='Target export flow.')
add_bool_arg(
Expand Down Expand Up @@ -362,7 +360,7 @@ def input_bit_width(module):
type=quant_format_validator,
default='int',
help=
'Weight quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. Default: int.')
'Input quantization type. Either int or eXmY, with X+Y==input_bit_width-1. Default: int.')
parser.add_argument(
'--weight-quant-granularity',
type=str,
Expand Down

0 comments on commit 6de7d5a

Please sign in to comment.