Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Jun 14, 2024
1 parent 38fb5cf commit 8471bbd
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 250 deletions.
104 changes: 90 additions & 14 deletions src/brevitas_examples/common/generative/quantize.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,8 +163,7 @@
'sym': Fp8e4m3DynamicOCPActPerTensorFloat}}}}}}}


def quantize_model(
model,
def generate_quantizers(
dtype,
weight_bit_width,
weight_param_method,
Expand All @@ -174,7 +173,6 @@ def quantize_model(
weight_group_size,
quantize_weight_zero_point,
weight_quant_format='int',
name_blacklist=None,
input_bit_width=None,
input_quant_format='',
input_scale_precision=None,
Expand All @@ -184,7 +182,6 @@ def quantize_model(
input_quant_granularity=None,
input_group_size=None,
quantize_input_zero_point=False,
quantize_embedding=False,
use_ocp=False,
device=None,
weight_kwargs=None,
Expand All @@ -200,20 +197,20 @@ def quantize_model(
weight_float_format = {
'exponent_bit_width': int(weight_quant_format[1]),
'mantissa_bit_width': int(weight_quant_format[3])}
ocp_weight_format = weight_quant_format
weight_quant_format = 'float'
if use_ocp:
weight_quant_format += '_ocp'
ocp_weight_format = weight_quant_format
weight_quant_format = 'float'
else:
weight_float_format = {}
if re.compile(r'e[1-8]m[1-8]').match(input_quant_format):
input_float_format = {
'exponent_bit_width': int(input_quant_format[1]),
'mantissa_bit_width': int(input_quant_format[3])}
ocp_input_format = input_quant_format
input_quant_format = 'float'
if use_ocp:
input_quant_format += '_ocp'
ocp_input_format = input_quant_format
input_quant_format = 'float'
else:
input_float_format = {}

Expand All @@ -230,15 +227,15 @@ def quantize_model(
input_scale_type][input_quant_type]
elif input_bit_width is not None:
if ocp_input_format:
input_quant = INPUT_QUANT_MAP[input_quant_format][ocp_input_format][input_scale_type][
input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ocp_input_format][
input_scale_precision][input_param_method][input_quant_granularity][
input_quant_type]
# Some activations in MHA should always be symmetric
sym_input_quant = INPUT_QUANT_MAP[input_quant_format][ocp_input_format][
input_scale_type][input_scale_precision][input_param_method][
sym_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][
ocp_input_format][input_scale_precision][input_param_method][
input_quant_granularity]['sym']
linear_input_quant = INPUT_QUANT_MAP[input_quant_format][ocp_input_format][
input_scale_type][input_scale_precision][input_param_method][
linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][
ocp_input_format][input_scale_precision][input_param_method][
input_quant_granularity][input_quant_type]
else:
input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][
Expand Down Expand Up @@ -365,6 +362,21 @@ def quantize_model(
linear_input_quant = linear_input_quant.let(
**{
'group_dim': -1, 'group_size': input_group_size})
return linear_input_quant, weight_quant, input_quant, q_scaled_quant, k_transposed_quant, v_quant, attn_output_weights_quant


def generate_quant_maps(
linear_input_quant,
weight_quant,
input_quant,
q_scaled_quant,
k_transposed_quant,
v_quant,
attn_output_weights_quant,
dtype,
device,
input_quant_format,
quantize_embedding):

quant_linear_kwargs = {
'input_quant': linear_input_quant,
Expand All @@ -380,7 +392,7 @@ def quantize_model(
'in_proj_bias_quant': None,
'softmax_input_quant': None,
'attn_output_weights_quant': attn_output_weights_quant,
'attn_output_weights_signed': input_quant_format == 'float',
'attn_output_weights_signed': 'float' in input_quant_format,
'q_scaled_quant': q_scaled_quant,
'k_transposed_quant': k_transposed_quant,
'v_quant': v_quant,
Expand All @@ -406,7 +418,71 @@ def quantize_model(
if quantize_embedding:
quant_embedding_kwargs = {'weight_quant': weight_quant, 'dtype': dtype, 'device': device}
layer_map[nn.Embedding] = (qnn.QuantEmbedding, quant_embedding_kwargs)
return layer_map


def quantize_model(
model,
dtype,
weight_bit_width,
weight_param_method,
weight_scale_precision,
weight_quant_type,
weight_quant_granularity,
weight_group_size,
quantize_weight_zero_point,
weight_quant_format='int',
name_blacklist=None,
input_bit_width=None,
input_quant_format='',
input_scale_precision=None,
input_scale_type=None,
input_param_method=None,
input_quant_type=None,
input_quant_granularity=None,
input_group_size=None,
quantize_input_zero_point=False,
quantize_embedding=False,
use_ocp=False,
device=None,
weight_kwargs=None,
input_kwargs=None):

linear_input_quant, weight_quant, input_quant, q_scaled_quant, k_transposed_quant, v_quant, attn_output_weights_quant = generate_quantizers(
dtype,
weight_bit_width,
weight_param_method,
weight_scale_precision,
weight_quant_type,
weight_quant_granularity,
weight_group_size,
quantize_weight_zero_point,
weight_quant_format,
input_bit_width,
input_quant_format,
input_scale_precision,
input_scale_type,
input_param_method,
input_quant_type,
input_quant_granularity,
input_group_size,
quantize_input_zero_point,
use_ocp,
device,
weight_kwargs,
input_kwargs)
layer_map = generate_quant_maps(
linear_input_quant,
weight_quant,
input_quant,
q_scaled_quant,
k_transposed_quant,
v_quant,
attn_output_weights_quant,
dtype,
device,
input_quant_format,
quantize_embedding)
model = layerwise_quantize(
model=model, compute_layer_map=layer_map, name_blacklist=name_blacklist)
return model
37 changes: 7 additions & 30 deletions src/brevitas_examples/stable_diffusion/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ Activation quantization is optional, and disabled by default. To enable, set bot

We support ONNX integer export, and we are planning to release soon export for floating point quantization (e.g., FP8).

To export the model with fp16 scale factors, enable `export-cuda-float16`. This will performing the tracing necessary for export on GPU, leaving the model in fp16.
To export the model with fp16 scale factors, disable `export-cpu-float32`. This will performing the tracing necessary for export on GPU, leaving the model in fp16.
If the flag is not enabled, the model will be moved to CPU and cast to float32 before export because of missing CPU kernels in fp16.

To use MLPerf inference setup, check and install the correct requirements specified in the `requirements.txt` file under mlperf_evaluation.
Expand Down Expand Up @@ -70,7 +70,7 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--gptq | --no-gptq] [--bias-correction | --no-bias-correction]
[--dtype {float32,float16,bfloat16}]
[--attention-slicing | --no-attention-slicing]
[--export-target {,torch,onnx}]
[--export-target {,onnx}]
[--export-weight-q-node | --no-export-weight-q-node]
[--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH]
[--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH]
Expand All @@ -93,15 +93,11 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT]
[--weight-group-size WEIGHT_GROUP_SIZE]
[--quantize-weight-zero-point | --no-quantize-weight-zero-point]
[--quantize-input-zero-point | --no-quantize-input-zero-point]
[--export-cuda-float16 | --no-export-cuda-float16]
[--export-cpu-float32 | --no-export-cpu-float32]
[--use-mlperf-inference | --no-use-mlperf-inference]
[--use-ocp | --no-use-ocp]
[--use-negative-prompts | --no-use-negative-prompts]
[--dry-run | --no-dry-run]
[--quantize-time-emb | --no-quantize-time-emb]
[--quantize-conv-in | --no-quantize-conv-in]
[--quantize-input-time-emb | --no-quantize-input-time-emb]
[--quantize-input-conv-in | --no-quantize-input-conv-in]

Stable Diffusion quantization

Expand Down Expand Up @@ -160,7 +156,7 @@ options:
--attention-slicing Enable Enable attention slicing. Default: Disabled
--no-attention-slicing
Disable Enable attention slicing. Default: Disabled
--export-target {,torch,onnx}
--export-target {,onnx}
Target export flow.
--export-weight-q-node
Enable Enable export of floating point weights + QDQ
Expand Down Expand Up @@ -224,10 +220,9 @@ options:
Enable Quantize input zero-point. Default: Enabled
--no-quantize-input-zero-point
Disable Quantize input zero-point. Default: Enabled
--export-cuda-float16
Enable Export FP16 on CUDA. Default: Disabled
--no-export-cuda-float16
Disable Export FP16 on CUDA. Default: Disabled
--export-cpu-float32 Enable Export FP32 on CPU. Default: Disabled
--no-export-cpu-float32
Disable Export FP32 on CPU. Default: Disabled
--use-mlperf-inference
Enable Evaluate FID score with MLPerf pipeline.
Default: False
Expand All @@ -248,23 +243,5 @@ options:
calibration. Default: Disabled
--no-dry-run Disable Generate a quantized model without any
calibration. Default: Disabled
--quantize-time-emb Enable Quantize time embedding layers. Default: True
--no-quantize-time-emb
Disable Quantize time embedding layers. Default: True
--quantize-conv-in Enable Quantize first conv layer. Default: True
--no-quantize-conv-in
Disable Quantize first conv layer. Default: True
--quantize-input-time-emb
Enable Quantize input to time embedding layers.
Default: Disabled
--no-quantize-input-time-emb
Disable Quantize input to time embedding layers.
Default: Disabled
--quantize-input-conv-in
Enable Quantize input to first conv layer. Default:
Enabled
--no-quantize-input-conv-in
Disable Quantize input to first conv layer. Default:
Enabled

```
Loading

0 comments on commit 8471bbd

Please sign in to comment.