Skip to content

Commit

Permalink
Feat (examples/stable_diffusion): improvements to SD
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed May 16, 2024
1 parent a1926f0 commit 8dcde7b
Show file tree
Hide file tree
Showing 12 changed files with 1,991 additions and 82 deletions.
5 changes: 5 additions & 0 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,9 @@ def single_layer_update(self):
# No permutation, permutation tensor is a ordered index
perm = torch.tensor(range(weight.shape[-1]), device=dev)
permutation_list.append(perm)

self.reactivate_quantization()

for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
Expand Down Expand Up @@ -358,6 +361,8 @@ def single_layer_update(self):
perm = torch.tensor(range(weight.shape[-1]), device=dev)
permutation_list.append(perm)

self.reactivate_quantization()

for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
Expand Down
5 changes: 5 additions & 0 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,11 @@ def catch_stopfwd(self, *args, **kwargs):
# If we want to return the output of the network, we need to disable all hooks
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = True

out = self.orig_forward(*args, **kwargs)
for name, gpxq_class in self.gpxq_layers.items():
gpxq_class.disable_pre_forward_hook = False

return out

def initialize_module_optimizer(
Expand Down Expand Up @@ -134,6 +136,7 @@ def __init__(
device='cpu',
dtype=torch.float32)
self.nsamples = 0
self.done = False

assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher"

Expand Down Expand Up @@ -257,6 +260,8 @@ def single_layer_update(self, percdamp=.01):
finally:
del self.H

self.reactivate_quantization()

for i1 in range(0, self.columns, self.blocksize):
i2 = min(i1 + self.blocksize, self.columns)
count = i2 - i1
Expand Down
13 changes: 13 additions & 0 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,10 @@ def __enter__(self):
return self

def __exit__(self, type, value, traceback):
for name, layer in self.gpxq_layers.items():
if not layer.done:
layer.reactivate_quantization()

if isinstance(self.model, (GraphModule, TorchGraphModule)):
self.model.__class__.forward = self.orig_forward
else:
Expand Down Expand Up @@ -219,6 +223,10 @@ def __init__(
self.disable_pre_forward_hook = False
# Some layers require knowledge from quant inputs to compute quant weights
self.quant_metadata = None
self.disable_quant_inference = DisableEnableQuantization()
self.return_quant_tensor_state = disable_return_quant_tensor(self.layer)
self.disable_quant_inference.disable_param_quantization(self.layer, False)
self.done = False

def process_input(self, inp):
# Input is a tuple, so we take first element
Expand Down Expand Up @@ -255,6 +263,11 @@ def update_batch(self):
def single_layer_update(self):
pass

def reactivate_quantization(self):
self.done = True
self.disable_quant_inference.enable_param_quantization(self.layer, False)
restore_return_quant_tensor(self.layer, self.return_quant_tensor_state)

def get_quant_weights(self, i, i1, permutation_list):
# We need to recompute quant weights at runtime since our float weights are being updated
# Add offset in case of blockwise computation
Expand Down
30 changes: 28 additions & 2 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""
import re

import torch
from torch import nn

from brevitas import nn as qnn
Expand All @@ -13,6 +14,8 @@
from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat
from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat
from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint
from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE
from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint
Expand Down Expand Up @@ -79,7 +82,14 @@
'per_channel': {
'sym': Fp8e4m3WeightPerChannelFloat},
'per_group': {
'sym': Fp8e4m3WeightSymmetricGroupQuant}},}}}
'sym': Fp8e4m3WeightSymmetricGroupQuant}}}},
'float_ocp': {
'float_scale': {
'stats': {
'per_tensor': {
'sym': Fp8e4m3OCPWeightPerTensorFloat},
'per_channel': {
'sym': Fp8e4m3OCPWeightPerChannelFloat}}}}}

INPUT_QUANT_MAP = {
'int': {
Expand Down Expand Up @@ -142,7 +152,10 @@ def quantize_model(
input_group_size=None,
quantize_input_zero_point=False,
quantize_embedding=False,
device=None):
use_ocp=False,
device=None,
weight_kwargs=None,
input_kwargs=None):
"""
Replace float layers with quant layers in the target model
"""
Expand All @@ -154,13 +167,17 @@ def quantize_model(
'exponent_bit_width': int(weight_quant_format[1]),
'mantissa_bit_width': int(weight_quant_format[3])}
weight_quant_format = 'float'
if use_ocp:
weight_quant_format += '_ocp'
else:
weight_float_format = {}
if re.compile(r'e[1-8]m[1-8]').match(input_quant_format):
input_float_format = {
'exponent_bit_width': int(input_quant_format[1]),
'mantissa_bit_width': int(input_quant_format[3])}
input_quant_format = 'float'
if use_ocp:
input_quant_format += '_ocp'
else:
input_float_format = {}

Expand All @@ -178,6 +195,11 @@ def quantize_model(
linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][
input_scale_precision][input_param_method][input_quant_granularity][input_quant_type]

if input_kwargs is not None:
input_quant = input_quant.let(**input_kwargs)
sym_input_quant = sym_input_quant.let(**input_kwargs)
linear_input_quant = linear_input_quant.let(**input_kwargs)

else:
input_quant = None
sym_input_quant = None
Expand All @@ -190,6 +212,10 @@ def quantize_model(
'narrow_range': False,
'quantize_zero_point': quantize_weight_zero_point},
**weight_float_format)
if dtype == torch.float16:
weight_quant = weight_quant.let(**{'scaling_min_val': 1e-4})
if weight_kwargs is not None:
weight_quant = weight_quant.let(**weight_kwargs)

# Set the group_size is we're doing groupwise quantization
if weight_quant_granularity == 'per_group':
Expand Down
63 changes: 54 additions & 9 deletions src/brevitas_examples/stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,25 @@ We support ONNX integer export, and we are planning to release soon export for f
To export the model with fp16 scale factors, enable `export-cuda-float16`. This will performing the tracing necessary for export on GPU, leaving the model in fp16.
If the flag is not enabled, the model will be moved to CPU and cast to float32 before export because of missing CPU kernels in fp16.

To use MLPerf inference setup, check and install the correct requirements specified in the `requirements.txt` file under mlperf_evaluation.

## Run

```bash
usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--resolution RESOLUTION]
[--calibration-prompt CALIBRATION_PROMPT]
[--calibration-prompt-path CALIBRATION_PROMPT_PATH]
[--checkpoint-name CHECKPOINT_NAME]
[--path-to-latents PATH_TO_LATENTS] [--resolution RESOLUTION]
[--guidance-scale GUIDANCE_SCALE]
[--calibration-steps CALIBRATION_STEPS]
[--output-path OUTPUT_PATH | --no-output-path]
[--quantize | --no-quantize]
[--activation-equalization | --no-activation-equalization]
[--gptq | --no-gptq] [--float16 | --no-float16]
[--gptq | --no-gptq] [--bias-correction | --no-bias-correction]
[--dtype {float32,float16,bfloat16}]
[--attention-slicing | --no-attention-slicing]
[--export-target {,onnx}]
[--export-target {,torch,onnx}]
[--export-weight-q-node | --no-export-weight-q-node]
[--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH]
[--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH]
Expand All @@ -47,6 +54,9 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--weight-group-size WEIGHT_GROUP_SIZE]
[--quantize-weight-zero-point | --no-quantize-weight-zero-point]
[--export-cuda-float16 | --no-export-cuda-float16]
[--use-mlperf-inference | --no-use-mlperf-inference]
[--use-ocp | --no-use-ocp]
[--use-negative-prompts | --no-use-negative-prompts]

Stable Diffusion quantization

Expand All @@ -57,12 +67,27 @@ options:
-d DEVICE, --device DEVICE
Target device for quantized model.
-b BATCH_SIZE, --batch-size BATCH_SIZE
Batch size. Default: 4
--prompt PROMPT Manual prompt for testing. Default: An austronaut
riding a horse on Mars.
How many seeds to use for each image during
validation. Default: 2
--prompt PROMPT Number of prompt to use for testing. Default: 4. Max:
4
--calibration-prompt CALIBRATION_PROMPT
Number of prompt to use for calibration. Default: 2
--calibration-prompt-path CALIBRATION_PROMPT_PATH
Path to calibration prompt
--checkpoint-name CHECKPOINT_NAME
Name to use to store the checkpoint. If not provided,
no checkpoint is saved.
--path-to-latents PATH_TO_LATENTS
Load pre-defined latents. If not provided, they are
generated based on an internal seed.
--resolution RESOLUTION
Resolution along height and width dimension. Default:
512.
--guidance-scale GUIDANCE_SCALE
Guidance scale.
--calibration-steps CALIBRATION_STEPS
Percentage of steps used during calibration
--output-path OUTPUT_PATH
Path where to generate output folder.
--no-output-path Disable Path where to generate output folder.
Expand All @@ -76,12 +101,15 @@ options:
Disabled
--gptq Enable Toggle gptq. Default: Disabled
--no-gptq Disable Toggle gptq. Default: Disabled
--float16 Enable Enable float16 execution. Default: Enabled
--no-float16 Disable Enable float16 execution. Default: Enabled
--bias-correction Enable Toggle bias-correction. Default: Enabled
--no-bias-correction Disable Toggle bias-correction. Default: Enabled
--dtype {float32,float16,bfloat16}
Model Dtype, choices are float32, float16, bfloat16.
Default: float16
--attention-slicing Enable Enable attention slicing. Default: Disabled
--no-attention-slicing
Disable Enable attention slicing. Default: Disabled
--export-target {,onnx}
--export-target {,torch,onnx}
Target export flow.
--export-weight-q-node
Enable Enable export of floating point weights + QDQ
Expand Down Expand Up @@ -137,4 +165,21 @@ options:
Enable Export FP16 on CUDA. Default: Disabled
--no-export-cuda-float16
Disable Export FP16 on CUDA. Default: Disabled
--use-mlperf-inference
Enable Evaluate FID score with MLPerf pipeline.
Default: False
--no-use-mlperf-inference
Disable Evaluate FID score with MLPerf pipeline.
Default: False
--use-ocp Enable Use OCP format for float quantization. Default:
True
--no-use-ocp Disable Use OCP format for float quantization.
Default: True
--use-negative-prompts
Enable Use negative prompts during
generation/calibration. Default: Enabled
--no-use-negative-prompts
Disable Use negative prompts during
generation/calibration. Default: Enabled

```
Loading

0 comments on commit 8dcde7b

Please sign in to comment.