From 8dcde7bd7569a8fd07cc5832540dbcbdd93447a0 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 27 Apr 2024 19:22:55 +0100 Subject: [PATCH] Feat (examples/stable_diffusion): improvements to SD --- src/brevitas/graph/gpfq.py | 5 + src/brevitas/graph/gptq.py | 5 + src/brevitas/graph/gpxq.py | 13 + .../common/generative/quantize.py | 30 +- .../stable_diffusion/README.md | 63 +- .../stable_diffusion/main.py | 367 +++++++++-- .../mlperf_evaluation/accuracy.py | 526 +++++++++++++++ .../mlperf_evaluation/backend.py | 610 ++++++++++++++++++ .../mlperf_evaluation/dataset.py | 359 +++++++++++ .../mlperf_evaluation/requirements.txt | 8 + .../stable_diffusion/sd_quant/export.py | 14 +- .../stable_diffusion/sd_quant/utils.py | 73 +++ 12 files changed, 1991 insertions(+), 82 deletions(-) create mode 100644 src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py create mode 100644 src/brevitas_examples/stable_diffusion/mlperf_evaluation/backend.py create mode 100644 src/brevitas_examples/stable_diffusion/mlperf_evaluation/dataset.py create mode 100644 src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index ef720d092..85ecd6ef5 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -263,6 +263,9 @@ def single_layer_update(self): # No permutation, permutation tensor is a ordered index perm = torch.tensor(range(weight.shape[-1]), device=dev) permutation_list.append(perm) + + self.reactivate_quantization() + for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( @@ -358,6 +361,8 @@ def single_layer_update(self): perm = torch.tensor(range(weight.shape[-1]), device=dev) permutation_list.append(perm) + self.reactivate_quantization() + for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 31d31433b..0861fd15c 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -86,9 +86,11 @@ def catch_stopfwd(self, *args, **kwargs): # If we want to return the output of the network, we need to disable all hooks for name, gpxq_class in self.gpxq_layers.items(): gpxq_class.disable_pre_forward_hook = True + out = self.orig_forward(*args, **kwargs) for name, gpxq_class in self.gpxq_layers.items(): gpxq_class.disable_pre_forward_hook = False + return out def initialize_module_optimizer( @@ -134,6 +136,7 @@ def __init__( device='cpu', dtype=torch.float32) self.nsamples = 0 + self.done = False assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher" @@ -257,6 +260,8 @@ def single_layer_update(self, percdamp=.01): finally: del self.H + self.reactivate_quantization() + for i1 in range(0, self.columns, self.blocksize): i2 = min(i1 + self.blocksize, self.columns) count = i2 - i1 diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index b85ac1188..deb613b1a 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -164,6 +164,10 @@ def __enter__(self): return self def __exit__(self, type, value, traceback): + for name, layer in self.gpxq_layers.items(): + if not layer.done: + layer.reactivate_quantization() + if isinstance(self.model, (GraphModule, TorchGraphModule)): self.model.__class__.forward = self.orig_forward else: @@ -219,6 +223,10 @@ def __init__( self.disable_pre_forward_hook = False # Some layers require knowledge from quant inputs to compute quant weights self.quant_metadata = None + self.disable_quant_inference = DisableEnableQuantization() + self.return_quant_tensor_state = disable_return_quant_tensor(self.layer) + self.disable_quant_inference.disable_param_quantization(self.layer, False) + self.done = False def process_input(self, inp): # Input is a tuple, so we take first element @@ -255,6 +263,11 @@ def update_batch(self): def single_layer_update(self): pass + def reactivate_quantization(self): + self.done = True + self.disable_quant_inference.enable_param_quantization(self.layer, False) + restore_return_quant_tensor(self.layer, self.return_quant_tensor_state) + def get_quant_weights(self, i, i1, permutation_list): # We need to recompute quant weights at runtime since our float weights are being updated # Add offset in case of blockwise computation diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 31ab57361..0b1614e74 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -4,6 +4,7 @@ """ import re +import torch from torch import nn from brevitas import nn as qnn @@ -13,6 +14,8 @@ from brevitas.quant.experimental.float import Fp8e4m3ActPerTensorFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerChannelFloat from brevitas.quant.experimental.float import Fp8e4m3WeightPerTensorFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerChannelFloat +from brevitas.quant.experimental.float_quant_ocp import Fp8e4m3OCPWeightPerTensorFloat from brevitas.quant.fixed_point import Int8ActPerTensorFixedPoint from brevitas.quant.fixed_point import Int8ActPerTensorFixedPointMSE from brevitas.quant.fixed_point import Int8WeightPerChannelFixedPoint @@ -79,7 +82,14 @@ 'per_channel': { 'sym': Fp8e4m3WeightPerChannelFloat}, 'per_group': { - 'sym': Fp8e4m3WeightSymmetricGroupQuant}},}}} + 'sym': Fp8e4m3WeightSymmetricGroupQuant}}}}, + 'float_ocp': { + 'float_scale': { + 'stats': { + 'per_tensor': { + 'sym': Fp8e4m3OCPWeightPerTensorFloat}, + 'per_channel': { + 'sym': Fp8e4m3OCPWeightPerChannelFloat}}}}} INPUT_QUANT_MAP = { 'int': { @@ -142,7 +152,10 @@ def quantize_model( input_group_size=None, quantize_input_zero_point=False, quantize_embedding=False, - device=None): + use_ocp=False, + device=None, + weight_kwargs=None, + input_kwargs=None): """ Replace float layers with quant layers in the target model """ @@ -154,6 +167,8 @@ def quantize_model( 'exponent_bit_width': int(weight_quant_format[1]), 'mantissa_bit_width': int(weight_quant_format[3])} weight_quant_format = 'float' + if use_ocp: + weight_quant_format += '_ocp' else: weight_float_format = {} if re.compile(r'e[1-8]m[1-8]').match(input_quant_format): @@ -161,6 +176,8 @@ def quantize_model( 'exponent_bit_width': int(input_quant_format[1]), 'mantissa_bit_width': int(input_quant_format[3])} input_quant_format = 'float' + if use_ocp: + input_quant_format += '_ocp' else: input_float_format = {} @@ -178,6 +195,11 @@ def quantize_model( linear_input_quant = INPUT_QUANT_MAP[input_quant_format][input_scale_type][ input_scale_precision][input_param_method][input_quant_granularity][input_quant_type] + if input_kwargs is not None: + input_quant = input_quant.let(**input_kwargs) + sym_input_quant = sym_input_quant.let(**input_kwargs) + linear_input_quant = linear_input_quant.let(**input_kwargs) + else: input_quant = None sym_input_quant = None @@ -190,6 +212,10 @@ def quantize_model( 'narrow_range': False, 'quantize_zero_point': quantize_weight_zero_point}, **weight_float_format) + if dtype == torch.float16: + weight_quant = weight_quant.let(**{'scaling_min_val': 1e-4}) + if weight_kwargs is not None: + weight_quant = weight_quant.let(**weight_kwargs) # Set the group_size is we're doing groupwise quantization if weight_quant_granularity == 'per_group': diff --git a/src/brevitas_examples/stable_diffusion/README.md b/src/brevitas_examples/stable_diffusion/README.md index 30754b3d8..1c71e4431 100644 --- a/src/brevitas_examples/stable_diffusion/README.md +++ b/src/brevitas_examples/stable_diffusion/README.md @@ -16,18 +16,25 @@ We support ONNX integer export, and we are planning to release soon export for f To export the model with fp16 scale factors, enable `export-cuda-float16`. This will performing the tracing necessary for export on GPU, leaving the model in fp16. If the flag is not enabled, the model will be moved to CPU and cast to float32 before export because of missing CPU kernels in fp16. +To use MLPerf inference setup, check and install the correct requirements specified in the `requirements.txt` file under mlperf_evaluation. ## Run ```bash usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] - [--resolution RESOLUTION] + [--calibration-prompt CALIBRATION_PROMPT] + [--calibration-prompt-path CALIBRATION_PROMPT_PATH] + [--checkpoint-name CHECKPOINT_NAME] + [--path-to-latents PATH_TO_LATENTS] [--resolution RESOLUTION] + [--guidance-scale GUIDANCE_SCALE] + [--calibration-steps CALIBRATION_STEPS] [--output-path OUTPUT_PATH | --no-output-path] [--quantize | --no-quantize] [--activation-equalization | --no-activation-equalization] - [--gptq | --no-gptq] [--float16 | --no-float16] + [--gptq | --no-gptq] [--bias-correction | --no-bias-correction] + [--dtype {float32,float16,bfloat16}] [--attention-slicing | --no-attention-slicing] - [--export-target {,onnx}] + [--export-target {,torch,onnx}] [--export-weight-q-node | --no-export-weight-q-node] [--conv-weight-bit-width CONV_WEIGHT_BIT_WIDTH] [--linear-weight-bit-width LINEAR_WEIGHT_BIT_WIDTH] @@ -47,6 +54,9 @@ usage: main.py [-h] [-m MODEL] [-d DEVICE] [-b BATCH_SIZE] [--prompt PROMPT] [--weight-group-size WEIGHT_GROUP_SIZE] [--quantize-weight-zero-point | --no-quantize-weight-zero-point] [--export-cuda-float16 | --no-export-cuda-float16] + [--use-mlperf-inference | --no-use-mlperf-inference] + [--use-ocp | --no-use-ocp] + [--use-negative-prompts | --no-use-negative-prompts] Stable Diffusion quantization @@ -57,12 +67,27 @@ options: -d DEVICE, --device DEVICE Target device for quantized model. -b BATCH_SIZE, --batch-size BATCH_SIZE - Batch size. Default: 4 - --prompt PROMPT Manual prompt for testing. Default: An austronaut - riding a horse on Mars. + How many seeds to use for each image during + validation. Default: 2 + --prompt PROMPT Number of prompt to use for testing. Default: 4. Max: + 4 + --calibration-prompt CALIBRATION_PROMPT + Number of prompt to use for calibration. Default: 2 + --calibration-prompt-path CALIBRATION_PROMPT_PATH + Path to calibration prompt + --checkpoint-name CHECKPOINT_NAME + Name to use to store the checkpoint. If not provided, + no checkpoint is saved. + --path-to-latents PATH_TO_LATENTS + Load pre-defined latents. If not provided, they are + generated based on an internal seed. --resolution RESOLUTION Resolution along height and width dimension. Default: 512. + --guidance-scale GUIDANCE_SCALE + Guidance scale. + --calibration-steps CALIBRATION_STEPS + Percentage of steps used during calibration --output-path OUTPUT_PATH Path where to generate output folder. --no-output-path Disable Path where to generate output folder. @@ -76,12 +101,15 @@ options: Disabled --gptq Enable Toggle gptq. Default: Disabled --no-gptq Disable Toggle gptq. Default: Disabled - --float16 Enable Enable float16 execution. Default: Enabled - --no-float16 Disable Enable float16 execution. Default: Enabled + --bias-correction Enable Toggle bias-correction. Default: Enabled + --no-bias-correction Disable Toggle bias-correction. Default: Enabled + --dtype {float32,float16,bfloat16} + Model Dtype, choices are float32, float16, bfloat16. + Default: float16 --attention-slicing Enable Enable attention slicing. Default: Disabled --no-attention-slicing Disable Enable attention slicing. Default: Disabled - --export-target {,onnx} + --export-target {,torch,onnx} Target export flow. --export-weight-q-node Enable Enable export of floating point weights + QDQ @@ -137,4 +165,21 @@ options: Enable Export FP16 on CUDA. Default: Disabled --no-export-cuda-float16 Disable Export FP16 on CUDA. Default: Disabled + --use-mlperf-inference + Enable Evaluate FID score with MLPerf pipeline. + Default: False + --no-use-mlperf-inference + Disable Evaluate FID score with MLPerf pipeline. + Default: False + --use-ocp Enable Use OCP format for float quantization. Default: + True + --no-use-ocp Disable Use OCP format for float quantization. + Default: True + --use-negative-prompts + Enable Use negative prompts during + generation/calibration. Default: Enabled + --no-use-negative-prompts + Disable Use negative prompts during + generation/calibration. Default: Enabled + ``` diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 5d626accb..b5c2eee5f 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -5,6 +5,7 @@ import argparse from datetime import datetime +from functools import partial import json import os import time @@ -12,24 +13,32 @@ from dependencies import value from diffusers import DiffusionPipeline from diffusers import StableDiffusionXLPipeline +import numpy as np +import pandas as pd import torch from torch import nn +from torchmetrics.image.fid import FrechetInceptionDistance from tqdm import tqdm from brevitas.export.onnx.standard.qcdq.manager import StdQCDQONNXManager +from brevitas.export.torch.qcdq.manager import TorchQCDQManager from brevitas.graph.calibrate import bias_correction_mode from brevitas.graph.calibrate import calibration_mode from brevitas.graph.equalize import activation_equalization_mode from brevitas.graph.gptq import gptq_mode +from brevitas.inject.enum import QuantType from brevitas.nn.equalized_layer import EqualizedModule +from brevitas.nn.forward_handlers import brevitas_proxy_inference_mode from brevitas.utils.torch_utils import KwargsForwardHook from brevitas_examples.common.generative.quantize import quantize_model from brevitas_examples.common.parse_utils import add_bool_arg from brevitas_examples.common.parse_utils import quant_format_validator from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager +from brevitas_examples.stable_diffusion.mlperf_evaluation.accuracy import compute_mlperf_fid from brevitas_examples.stable_diffusion.sd_quant.constants import SD_2_1_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.constants import SD_XL_EMBEDDINGS_SHAPE from brevitas_examples.stable_diffusion.sd_quant.export import export_onnx +from brevitas_examples.stable_diffusion.sd_quant.export import export_torch_export from brevitas_examples.stable_diffusion.sd_quant.utils import generate_latents from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_21_rand_inputs from brevitas_examples.stable_diffusion.sd_quant.utils import generate_unet_xl_rand_inputs @@ -37,35 +46,88 @@ TEST_SEED = 123456 -VALIDATION_PROMPTS = { - 'validation_prompt_0': 'A cat playing with a ball', - 'validation_prompt_1': 'A dog running on the beach'} +NEGATIVE_PROMPTS = ["normal quality, low quality, worst quality, low res, blurry, nsfw, nude"] + +CALIBRATION_PROMPTS = [ + 'A man in a space suit playing a guitar, inspired by Cyril Rolando, highly detailed illustration, full color illustration, very detailed illustration, dan mumford and alex grey style', + 'a living room, bright modern Scandinavian style house, large windows, magazine photoshoot, 8k, studio lighting', + 'cute rabbit in a spacesuit', + 'minimalistic plolygon geometric car in brutalism warehouse, Rick Owens'] + +TESTING_PROMPTS = [ + 'batman, cute modern disney style, Pixar 3d portrait, ultra detailed, gorgeous, 3d zbrush, trending on dribbble, 8k render', + 'A beautiful stack of rocks sitting on top of a beach, a picture, red black white golden colors, chakras, packshot, stock photo', + 'A painting of a fish on a black background, a digital painting, by Jason Benjamin, colorful vector illustration, mixed media style illustration, epic full color illustration, mascot illustration', + 'close up photo of a rabbit, forest in spring, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot' +] + + +def load_calib_prompts(calib_data_path, sep="\t"): + df = pd.read_csv(calib_data_path, sep=sep) + lst = df["caption"].tolist() + return lst def run_test_inference( - pipe, resolution, prompts, seeds, output_path, device, dtype, name_prefix=''): + pipe, + resolution, + prompts, + seeds, + output_path, + device, + dtype, + use_negative_prompts, + guidance_scale, + name_prefix=''): + images = dict() with torch.no_grad(): if not os.path.exists(output_path): os.mkdir(output_path) test_latents = generate_latents(seeds, device, dtype, unet_input_shape(resolution)) - - for name, prompt in prompts.items(): - print(f"Generating: {name}") - images = pipe([prompt] * len(seeds), latents=test_latents).images - for i, seed in enumerate(seeds): - file_path = os.path.join(output_path, f"{name_prefix}{name}_{seed}.png") + neg_prompts = NEGATIVE_PROMPTS * len(seeds) if use_negative_prompts else [] + for prompt in prompts: + prompt_images = pipe([prompt] * len(seeds), + latents=test_latents, + negative_prompt=neg_prompts, + guidance_scale=guidance_scale).images + images[prompt] = prompt_images + + i = 0 + for prompt, prompt_images in images.items(): + for image in prompt_images: + file_path = os.path.join(output_path, f"{name_prefix}{i}.png") print(f"Saving to {file_path}") - images[i].save(file_path) - - -def run_val_inference(pipe, resolution, prompts, seeds, output_path, device, dtype, name_prefix=''): + image.save(file_path) + i += 1 + return images + + +def run_val_inference( + pipe, + resolution, + prompts, + seeds, + device, + dtype, + use_negative_prompts, + guidance_scale, + total_steps, + test_latents=None): with torch.no_grad(): - test_latents = generate_latents(seeds, device, dtype, unet_input_shape(resolution)) - for name, prompt in prompts.items(): - print(f"Generating: {name}") + if test_latents is None: + test_latents = generate_latents(seeds[0], device, dtype, unet_input_shape(resolution)) + + neg_prompts = NEGATIVE_PROMPTS if use_negative_prompts else [] + for prompt in prompts: # We don't want to generate any image, so we return only the latent encoding pre VAE - pipe([prompt] * len(seeds), latents=test_latents, output_type='latent') + pipe( + prompt, + negative_prompt=neg_prompts[0], + latents=test_latents, + output_type='latent', + guidance_scale=guidance_scale, + num_inference_steps=total_steps) def main(args): @@ -73,11 +135,21 @@ def main(args): if args.export_target: assert args.weight_quant_format == 'int', "Currently only integer quantization supported for export." - # Select dtype - if args.float16: - dtype = torch.float16 - else: - dtype = torch.float32 + dtype = getattr(torch, args.dtype) + + calibration_prompts = CALIBRATION_PROMPTS + if args.calibration_prompt_path is not None: + calibration_prompts = load_calib_prompts(args.calibration_prompt_path) + prompts = list() + for i, v in enumerate(calibration_prompts): + if i == args.calibration_prompt: + break + prompts.append(v) + calibration_prompts = prompts + + latents = None + if args.path_to_latents is not None: + latents = torch.load(args.path_to_latents).to(torch.float16) # Create output dir. Move to tmp if None ts = datetime.fromtimestamp(time.time()) @@ -97,6 +169,29 @@ def main(args): pipe = DiffusionPipeline.from_pretrained(args.model, torch_dtype=dtype) print(f"Model loaded from {args.model}.") + # Move model to target device + print(f"Moving model to {args.device}...") + pipe = pipe.to(args.device) + + if args.prompt > 0 and not args.use_mlperf_inference: + print(f"Running inference with prompt ...") + prompts = [] + for i, v in enumerate(TESTING_PROMPTS): + if i == args.prompt: + break + prompts.append(v) + float_images = run_test_inference( + pipe, + args.resolution, + prompts, + test_seeds, + output_dir, + args.device, + dtype, + guidance_scale=args.guidance_scale, + use_negative_prompts=args.use_negative_prompts, + name_prefix='float_') + # Detect Stable Diffusion XL pipeline is_sd_xl = isinstance(pipe, StableDiffusionXLPipeline) @@ -116,19 +211,23 @@ def main(args): if hasattr(m, 'lora_layer') and m.lora_layer is not None: raise RuntimeError("LoRA layers should be fused in before calling into quantization.") - # Move model to target device - print(f"Moving model to {args.device}...") - pipe = pipe.to(args.device) - if args.activation_equalization: with activation_equalization_mode(pipe.unet, alpha=0.5, layerwise=True, add_mul_node=True): # Workaround to expose `in_features` attribute from the Hook Wrapper for m in pipe.unet.modules(): if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'): m.in_features = m.module.in_features - prompts = VALIDATION_PROMPTS run_val_inference( - pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + pipe, + args.resolution, + calibration_prompts, + test_seeds, + args.device, + dtype, + total_steps=args.calibration_steps, + use_negative_prompts=args.use_negative_prompts, + test_latents=latents, + guidance_scale=args.guidance_scale) # Workaround to expose `in_features` attribute from the EqualizedModule Wrapper for m in pipe.unet.modules(): @@ -147,22 +246,28 @@ def weight_bit_width(module): else: raise RuntimeError(f"Module {module} not supported.") - # XOR between the two input_bit_width. Either they are both None, or neither of them is - assert (args.linear_input_bit_width is None) == (args.conv_input_bit_width is None), 'Both input bit width must be specified or left to None' + @value + def input_bit_width(module): + if isinstance(module, nn.Linear): + return args.linear_input_bit_width + elif isinstance(module, nn.Conv2d): + return args.conv_input_bit_width + else: + raise RuntimeError(f"Module {module} not supported.") - is_input_quantized = args.linear_input_bit_width is not None and args.conv_input_bit_width is not None - if is_input_quantized: + input_kwargs = dict() + if args.linear_input_bit_width is None or args.conv_input_bit_width is None: @value - def input_bit_width(module): - if isinstance(module, nn.Linear): - return args.linear_input_bit_width - elif isinstance(module, nn.Conv2d): - return args.conv_input_bit_width + def input_quant_type(module): + if args.linear_input_bit_width is None and isinstance(module, nn.Linear): + return QuantType.FP + elif args.conv_input_bit_width is None and isinstance(module, nn.Conv2d): + return QuantType.FP else: - raise RuntimeError(f"Module {module} not supported.") - else: - input_bit_width = None + return QuantType.INT + + input_kwargs['quant_type'] = input_quant_type print("Applying model quantization...") quantize_model( @@ -184,46 +289,115 @@ def input_bit_width(module): input_scale_precision=args.input_scale_precision, input_param_method=args.input_param_method, input_quant_type=args.input_quant_type, - input_quant_granularity=args.input_quant_granularity) + input_quant_granularity=args.input_quant_granularity, + use_ocp=args.use_ocp, + input_kwargs=input_kwargs) print("Model quantization applied.") - if is_input_quantized and args.input_scale_type == 'static': + if (args.linear_input_bit_width is not None or + args.conv_input_bit_width is not None) and args.input_scale_type == 'static': print("Applying activation calibration") - with calibration_mode(pipe.unet): - prompts = VALIDATION_PROMPTS + with brevitas_proxy_inference_mode(pipe.unet), torch.no_grad(), calibration_mode(pipe.unet): run_val_inference( - pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + pipe, + args.resolution, + calibration_prompts, + test_seeds, + args.device, + dtype, + total_steps=args.calibration_steps, + use_negative_prompts=args.use_negative_prompts, + test_latents=latents, + guidance_scale=args.guidance_scale) + pipe.set_progress_bar_config(disable=True) if args.gptq: print("Applying GPTQ. It can take several hours") - with gptq_mode(pipe.unet, + with torch.no_grad(), gptq_mode(pipe.unet, create_weight_orig=False, use_quant_activations=False, return_forward_output=True, act_order=True) as gptq: - prompts = VALIDATION_PROMPTS for _ in tqdm(range(gptq.num_layers)): run_val_inference( - pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + pipe, + args.resolution, + calibration_prompts, + test_seeds, + args.device, + dtype, + total_steps=args.calibration_steps, + use_negative_prompts=args.use_negative_prompts, + test_latents=latents, + guidance_scale=args.guidance_scale) gptq.update() + torch.cuda.empty_cache() + pipe.set_progress_bar_config(disable=False) - print("Applying bias correction") - with bias_correction_mode(pipe.unet): - prompts = VALIDATION_PROMPTS - run_val_inference( - pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + if args.bias_correction: + print("Applying bias correction") + with brevitas_proxy_inference_mode(pipe.unet), bias_correction_mode(pipe.unet): + run_val_inference( + pipe, + args.resolution, + calibration_prompts, + test_seeds, + args.device, + dtype, + total_steps=args.calibration_steps, + use_negative_prompts=args.use_negative_prompts, + test_latents=latents, + guidance_scale=args.guidance_scale) + + if args.checkpoint_name is not None: + torch.save(pipe.unet.state_dict(), args.checkpoint_name) # Perform inference - if args.prompt: - print(f"Running inference with prompt '{args.prompt}' ...") - prompts = {'manual_prompt': args.prompt} - run_test_inference( - pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + if args.prompt > 0: + with brevitas_proxy_inference_mode(pipe.unet): + if args.use_mlperf_inference: + print(f"Computing accuracy with MLPerf pipeline") + compute_mlperf_fid(pipe.unet, args.prompt) + else: + print(f"Computing accuracy on default prompt") + prompts = list() + for i, v in enumerate(TESTING_PROMPTS): + if i == args.prompt: + break + prompts.append(v) + quant_images = run_test_inference( + pipe, + args.resolution, + prompts, + test_seeds, + output_dir, + args.device, + dtype, + use_negative_prompts=args.use_negative_prompts, + guidance_scale=args.guidance_scale, + name_prefix='quant_') + + float_images_values = float_images.values() + float_images_values = [x for x_nested in float_images_values for x in x_nested] + float_images_values = torch.tensor([ + np.array(image) for image in float_images_values]) + float_images_values = float_images_values.permute(0, 3, 1, 2) + + quant_images_values = quant_images.values() + quant_images_values = [x for x_nested in quant_images_values for x in x_nested] + quant_images_values = torch.tensor([ + np.array(image) for image in quant_images_values]) + quant_images_values = quant_images_values.permute(0, 3, 1, 2) + + fid = FrechetInceptionDistance(normalize=False) + fid.update(float_images_values, real=True) + fid.update(quant_images_values, real=False) + print(f"FID: {float(fid.compute())}") if args.export_target: # Move to cpu and to float32 to enable CPU export - if not (args.float16 and args.export_cuda_float16): - pipe.unet.to('cpu').to(torch.float32) + if not (dtype == torch.float16 and args.export_cuda_float16): + pipe.unet.to('cpu').to(dtype) pipe.unet.eval() device = next(iter(pipe.unet.parameters())).device dtype = next(iter(pipe.unet.parameters())).dtype @@ -248,6 +422,13 @@ def input_bit_width(module): export_manager = StdQCDQONNXManager export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node) export_onnx(pipe, trace_inputs, output_dir, export_manager) + if args.export_target == 'torch': + if args.weight_quant_granularity == 'per_group': + export_manager = BlockQuantProxyLevelManager + else: + export_manager = TorchQCDQManager + export_manager.change_weight_export(export_weight_q_node=args.export_weight_q_node) + export_torch_export(pipe, trace_inputs, output_dir, export_manager) if __name__ == "__main__": @@ -260,17 +441,46 @@ def input_bit_width(module): help='Path or name of the model.') parser.add_argument( '-d', '--device', type=str, default='cuda:0', help='Target device for quantized model.') - parser.add_argument('-b', '--batch-size', type=int, default=4, help='Batch size. Default: 4') + parser.add_argument( + '-b', + '--batch-size', + type=int, + default=2, + help='How many seeds to use for each image during validation. Default: 2') parser.add_argument( '--prompt', + type=int, + default=4, + help='Number of prompt to use for testing. Default: 4. Max: 4') + parser.add_argument( + '--calibration-prompt', + type=int, + default=2, + help='Number of prompt to use for calibration. Default: 2') + parser.add_argument( + '--calibration-prompt-path', type=str, default=None, help='Path to calibration prompt') + parser.add_argument( + '--checkpoint-name', type=str, - default='An austronaut riding a horse on Mars.', - help='Manual prompt for testing. Default: An austronaut riding a horse on Mars.') + default=None, + help='Name to use to store the checkpoint. If not provided, no checkpoint is saved.') + parser.add_argument( + '--path-to-latents', + type=str, + default=None, + help= + 'Load pre-defined latents. If not provided, they are generated based on an internal seed.') parser.add_argument( '--resolution', type=int, default=512, help='Resolution along height and width dimension. Default: 512.') + parser.add_argument('--guidance-scale', type=float, default=7.5, help='Guidance scale.') + parser.add_argument( + '--calibration-steps', + type=float, + default=8, + help='Percentage of steps used during calibration') add_bool_arg( parser, 'output-path', @@ -284,14 +494,24 @@ def input_bit_width(module): default=False, help='Toggle Activation Equalization. Default: Disabled') add_bool_arg(parser, 'gptq', default=False, help='Toggle gptq. Default: Disabled') - add_bool_arg(parser, 'float16', default=True, help='Enable float16 execution. Default: Enabled') + add_bool_arg( + parser, 'bias-correction', default=True, help='Toggle bias-correction. Default: Enabled') + parser.add_argument( + '--dtype', + default='float16', + choices=['float32', 'float16', 'bfloat16'], + help='Model Dtype, choices are float32, float16, bfloat16. Default: float16') add_bool_arg( parser, 'attention-slicing', default=False, help='Enable attention slicing. Default: Disabled') parser.add_argument( - '--export-target', type=str, default='', choices=['', 'onnx'], help='Target export flow.') + '--export-target', + type=str, + default='', + choices=['', 'torch', 'onnx'], + help='Target export flow.') add_bool_arg( parser, 'export-weight-q-node', @@ -391,6 +611,21 @@ def input_bit_width(module): help='Quantize weight zero-point. Default: Enabled') add_bool_arg( parser, 'export-cuda-float16', default=False, help='Export FP16 on CUDA. Default: Disabled') + add_bool_arg( + parser, + 'use-mlperf-inference', + default=False, + help='Evaluate FID score with MLPerf pipeline. Default: False') + add_bool_arg( + parser, + 'use-ocp', + default=True, + help='Use OCP format for float quantization. Default: True') + add_bool_arg( + parser, + 'use-negative-prompts', + default=True, + help='Use negative prompts during generation/calibration. Default: Enabled') args = parser.parse_args() print("Args: " + str(vars(args))) main(args) diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py new file mode 100644 index 000000000..71353e9cb --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py @@ -0,0 +1,526 @@ +""" +Code is adapted from the MLPerf text-to-image pipeline: https://github.com/mlcommons/inference/tree/master/text_to_image +Available under the following LICENSE: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS +""" + +import logging +import os +import pathlib +import random + +import numpy as np +from PIL import Image +from scipy import linalg +import torch +from torch.nn.functional import adaptive_avg_pool2d +import torchvision.transforms as TF +from tqdm import tqdm + +from brevitas_examples.inception import InceptionV3 +from brevitas_examples.stable_diffusion.mlperf_evaluation.backend import BackendPytorch +from brevitas_examples.stable_diffusion.mlperf_evaluation.backend import Item +from brevitas_examples.stable_diffusion.mlperf_evaluation.backend import RunnerBase +from brevitas_examples.stable_diffusion.mlperf_evaluation.dataset import Coco +from brevitas_examples.stable_diffusion.mlperf_evaluation.dataset import ImagesDataset + +IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm', 'tif', 'tiff', 'webp'} + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger() + + +def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6): + """Numpy implementation of the Frechet Distance. + The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1) + and X_2 ~ N(mu_2, C_2) is + d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)). + + Stable version by Dougal J. Sutherland. + + Params: + -- mu1 : Numpy array containing the activations of a layer of the + inception net (like returned by the function 'get_predictions') + for generated samples. + -- mu2 : The sample mean over activations, precalculated on an + representative data set. + -- sigma1: The covariance matrix over activations for generated samples. + -- sigma2: The covariance matrix over activations, precalculated on an + representative data set. + + Returns: + -- : The Frechet Distance. + """ + + mu1 = np.atleast_1d(mu1) + mu2 = np.atleast_1d(mu2) + + sigma1 = np.atleast_2d(sigma1) + sigma2 = np.atleast_2d(sigma2) + + assert (mu1.shape == mu2.shape), "Training and test mean vectors have different lengths" + assert (sigma1.shape == sigma2.shape), "Training and test covariances have different dimensions" + + diff = mu1 - mu2 + + # Product might be almost singular + covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False) + if not np.isfinite(covmean).all(): + msg = ( + "fid calculation produces singular product; " + "adding %s to diagonal of cov estimates") % eps + print(msg) + offset = np.eye(sigma1.shape[0]) * eps + covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset)) + + # Numerical error might give slight imaginary component + if np.iscomplexobj(covmean): + if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3): + m = np.max(np.abs(covmean.imag)) + raise ValueError("Imaginary component {}".format(m)) + covmean = covmean.real + + tr_covmean = np.trace(covmean) + + return diff.dot(diff) + np.trace(sigma1) + np.trace(sigma2) - 2 * tr_covmean + + +def get_activations(files, model, batch_size=50, dims=2048, device="cpu", num_workers=1): + """Calculates the activations of the pool_3 layer for all images. + + Params: + -- files : List of image files paths + -- model : Instance of inception model + -- batch_size : Batch size of images for the model to process at once. + Make sure that the number of samples is a multiple of + the batch size, otherwise some samples are ignored. This + behavior is retained to match the original FID score + implementation. + -- dims : Dimensionality of features returned by Inception + -- device : Device to run calculations + -- num_workers : Number of parallel dataloader workers + + Returns: + -- A numpy array of dimension (num images, dims) that contains the + activations of the given tensor when feeding inception with the + query tensor. + """ + model.eval() + + if batch_size > len(files): + print(( + "Warning: batch size is bigger than the data size. " + "Setting batch size to data size")) + batch_size = len(files) + + dataset = ImagesDataset(files, transforms=TF.ToTensor()) + dataloader = torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + shuffle=False, + drop_last=False, + num_workers=num_workers, + ) + + pred_arr = np.empty((len(files), dims)) + + start_idx = 0 + + for batch in tqdm(dataloader): + batch = batch.to(device) + + with torch.no_grad(): + pred = model(batch)[0] + + # If model output is not scalar, apply global spatial average pooling. + # This happens if you choose a dimensionality not equal 2048. + if pred.size(2) != 1 or pred.size(3) != 1: + pred = adaptive_avg_pool2d(pred, output_size=(1, 1)) + + pred = pred.squeeze(3).squeeze(2).cpu().numpy() + + pred_arr[start_idx:start_idx + pred.shape[0]] = pred + + start_idx = start_idx + pred.shape[0] + + return pred_arr + + +def calculate_activation_statistics( + files, model, batch_size=50, dims=2048, device="cpu", num_workers=1): + """Calculation of the statistics used by the FID. + Params: + -- files : List of image files paths + -- model : Instance of inception model + -- batch_size : The images numpy array is split into batches with + batch size batch_size. A reasonable batch size + depends on the hardware. + -- dims : Dimensionality of features returned by Inception + -- device : Device to run calculations + -- num_workers : Number of parallel dataloader workers + + Returns: + -- mu : The mean over samples of the activations of the pool_3 layer of + the inception model. + -- sigma : The covariance matrix of the activations of the pool_3 layer of + the inception model. + """ + act = get_activations(files, model, batch_size, dims, device, num_workers) + mu = np.mean(act, axis=0) + sigma = np.cov(act, rowvar=False) + return mu, sigma + + +def compute_statistics_of_path( + path, + model, + batch_size, + dims, + device, + num_workers=1, + subset_size=None, + shuffle_seed=None, + ds=None): + if path.endswith(".npz"): + with np.load(path) as f: + m, s = f["mu"][:], f["sigma"][:] + else: + path = pathlib.Path(path) + files = [file for ext in IMAGE_EXTENSIONS for file in path.glob("*.{}".format(ext))] + + files = ds.get_imgs([i for i in range(10)]) + files = [file.permute(1, 2, 0).numpy() for file in files] + if subset_size is not None: + random.seed(shuffle_seed) + files = random.sample(files, subset_size) + m, s = calculate_activation_statistics(files, model, batch_size, dims, device, num_workers) + + return m, s + + +def compute_fid( + results, + statistics_path, + device, + dims=2048, + num_workers=1, + batch_size=1, + subset_size=None, + shuffle_seed=None, + ds=None, +): + imgs = [Image.fromarray(e).convert("RGB") for e in results] + device = torch.device(device if torch.cuda.is_available() else "cpu") + if num_workers is None: + try: + num_cpus = len(os.sched_getaffinity(0)) + except AttributeError: + # os.sched_getaffinity is not available under Windows, use + # os.cpu_count instead (which may not return the *available* number + # of CPUs). + num_cpus = os.cpu_count() + + num_workers = min(num_cpus, 8) if num_cpus is not None else 0 + else: + num_workers = num_workers + # assert statistics_path.endswith(".npz") + + block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims] + + model = InceptionV3([block_idx]).to(device) + + m1, s1 = compute_statistics_of_path( + statistics_path, + model, + batch_size, + dims, + device, + num_workers, + subset_size, + shuffle_seed, + ds=ds + ) + + m2, s2 = calculate_activation_statistics(imgs, model, batch_size, dims, device, num_workers) + + fid_value = calculate_frechet_distance(m1, s1, m2, s2) + + return fid_value + + +class PostProcessCoco: + + def __init__( + self, + device="cpu", + dtype="uint8", + statistics_path=os.path.join(os.path.dirname(__file__), "tools", "val2014.npz")): + self.results = [] + self.good = 0 + self.total = 0 + self.content_ids = [] + self.clip_scores = [] + self.fid_scores = [] + self.device = device if torch.cuda.is_available() else "cpu" + if dtype == "uint8": + self.dtype = torch.uint8 + self.numpy_dtype = np.uint8 + else: + raise ValueError(f"dtype must be one of: uint8") + self.statistics_path = statistics_path + + def add_results(self, results): + self.results.extend(results) + + def __call__(self, results, ids, expected=None, result_dict=None): + self.content_ids.extend(ids) + return [(t.cpu().permute(1, 2, 0).float().numpy() * 255).round().astype(self.numpy_dtype) + for t in results] + + def save_images(self, ids, ds): + info = [] + idx = {} + for i, id in enumerate(self.content_ids): + if id in ids: + idx[id] = i + if not os.path.exists("images/"): + os.makedirs("images/", exist_ok=True) + for id in ids: + caption = ds.get_caption(id) + generated = Image.fromarray(self.results[idx[id]]) + image_path_tmp = f"images/{self.content_ids[idx[id]]}.png" + generated.save(image_path_tmp) + info.append((self.content_ids[idx[id]], caption)) + with open("images/captions.txt", "w+") as f: + for id, caption in info: + f.write(f"{id} {caption}\n") + + def start(self): + self.results = [] + + def finalize(self, result_dict, ds=None, output_dir=None): + log.info("Accumulating results") + + fid_score = compute_fid(self.results, self.statistics_path, self.device, ds=ds) + result_dict["FID_SCORE"] = fid_score + + return result_dict + + +def compute_mlperf_fid(model_to_replace=None, samples_to_evaluate=500): + + post_proc = PostProcessCoco( + statistics_path='/scratch/users/gfranco/datasets/coco/tools/val2014.npz') + + dtype = next(iter(model_to_replace.parameters())).dtype + res_dict = {} + model = BackendPytorch( + '/scratch/hf_models/stable-diffusion-xl-base-1.0/stable-diffusion-xl-base-1.0/', + 'xl', + steps=20, + batch_size=1, + precision=dtype) + model.load() + + if model_to_replace is not None: + model.pipe.unet = model_to_replace + + ds = Coco( + data_path='/scratch/users/gfranco/datasets/coco', + name="coco-1024", + pre_process=torch.nn.Identity, + count=None, + threads=1, + pipe_tokenizer=model.pipe.tokenizer, + pipe_tokenizer_2=model.pipe.tokenizer_2, + latent_dtype=dtype, + latent_device='cuda', + latent_framework='torch', + **{"image_size": [3, 1024, 1024]}, + ) + model.pipe.set_progress_bar_config(disable=True) + with torch.no_grad(): + runner = RunnerBase(model, ds, 1, post_proc=post_proc, max_batchsize=1) + runner.start_run(res_dict, True) + idx = list(range(0, samples_to_evaluate)) + ds.load_query_samples(idx) + data, label = ds.get_samples(idx) + runner.run_one_item(Item(idx, idx, data, label)) + post_proc.finalize(res_dict, ds=ds) + log.info(res_dict) diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/backend.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/backend.py new file mode 100644 index 000000000..cf79421a0 --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/backend.py @@ -0,0 +1,610 @@ +""" +Code is adapted from the MLPerf text-to-image pipeline: https://github.com/mlcommons/inference/tree/master/text_to_image +Available under the following LICENSE: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS +""" +import array +import logging +import time +from typing import Optional + +from diffusers import EulerDiscreteScheduler +from diffusers import StableDiffusionXLPipeline +import numpy as np +import torch +from tqdm import tqdm + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger() + + +class Item: + """An item that we queue for processing by the thread pool.""" + + def __init__(self, query_id, content_id, inputs, img=None): + self.query_id = query_id + self.content_id = content_id + self.img = img + self.inputs = inputs + self.start = time.time() + + +class RunnerBase: + + def __init__(self, model, ds, threads, post_proc=None, max_batchsize=128): + self.take_accuracy = False + self.ds = ds + self.model = model + self.post_process = post_proc + self.threads = threads + self.take_accuracy = False + self.max_batchsize = max_batchsize + self.result_timing = [] + + def handle_tasks(self, tasks_queue): + pass + + def start_run(self, result_dict, take_accuracy): + self.result_dict = result_dict + self.result_timing = [] + self.take_accuracy = take_accuracy + self.post_process.start() + + def run_one_item(self, qitem: Item): + # run the prediction + processed_results = [] + try: + results = self.model.predict(qitem.inputs) + processed_results = self.post_process( + results, qitem.content_id, qitem.inputs, self.result_dict) + if self.take_accuracy: + self.post_process.add_results(processed_results) + self.result_timing.append(time.time() - qitem.start) + except Exception as ex: # pylint: disable=broad-except + src = [self.ds.get_item_loc(i) for i in qitem.content_id] + log.error("thread: failed on contentid=%s, %s", src, ex) + # since post_process will not run, fake empty responses + processed_results = [[]] * len(qitem.query_id) + finally: + response_array_refs = [] + response = [] + for idx, query_id in enumerate(qitem.query_id): + response_array = array.array( + "B", np.array(processed_results[idx], np.uint8).tobytes()) + response_array_refs.append(response_array) + bi = response_array.buffer_info() + response.append((query_id, bi[0], bi[1])) + # lg.QuerySamplesComplete(response) + + def enqueue(self, query_samples): + idx = [q.index for q in query_samples] + query_id = [q.id for q in query_samples] + if len(query_samples) < self.max_batchsize: + data, label = self.ds.get_samples(idx) + self.run_one_item(Item(query_id, idx, data, label)) + else: + bs = self.max_batchsize + for i in range(0, len(idx), bs): + data, label = self.ds.get_samples(idx[i:i + bs]) + self.run_one_item(Item(query_id[i:i + bs], idx[i:i + bs], data, label)) + + def finish(self): + pass + + +class BackendPytorch: + + def __init__( + self, + model_path=None, + model_id="xl", + guidance=8, + steps=20, + batch_size=1, + device="cuda", + precision=torch.float32, + negative_prompt="normal quality, low quality, worst quality, low res, blurry, nsfw, nude", + ): + self.inputs = [] + self.outputs = [] + + self.model_path = model_path + if model_id == "xl": + self.model_id = "stabilityai/stable-diffusion-xl-base-1.0" + else: + raise ValueError(f"{model_id} is not a valid model id") + + self.device = device if torch.cuda.is_available() else "cpu" + self.dtype = precision + + if torch.cuda.is_available(): + self.local_rank = 0 + self.world_size = 1 + + self.guidance = guidance + self.steps = steps + self.negative_prompt = negative_prompt + self.max_length_neg_prompt = 77 + self.batch_size = batch_size + + def version(self): + return torch.__version__ + + # def name(self): + # return "pytorch-SUT" + + def image_format(self): + return "NCHW" + + def load(self): + if self.model_path is None: + log.warning( + "Model path not provided, running with default hugging face weights\n" + "This may not be valid for official submissions") + self.scheduler = EulerDiscreteScheduler.from_pretrained( + self.model_id, subfolder="scheduler") + self.pipe = StableDiffusionXLPipeline.from_pretrained( + self.model_id, + scheduler=self.scheduler, + safety_checker=None, + add_watermarker=False, + variant="fp16" if (self.dtype == torch.float16) else None, + torch_dtype=self.dtype, + ) + # self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True) + else: + self.scheduler = EulerDiscreteScheduler.from_pretrained( + self.model_id, subfolder="scheduler") + self.pipe = StableDiffusionXLPipeline.from_pretrained( + self.model_path, + scheduler=self.scheduler, + safety_checker=None, + add_watermarker=False, + variant="fp16" if (self.dtype == torch.float16) else None, + torch_dtype=self.dtype, + ) + # self.pipe.unet = torch.compile(self.pipe.unet, mode="reduce-overhead", fullgraph=True) + + self.pipe.to(self.device) + #self.pipe.set_progress_bar_config(disable=True) + + self.negative_prompt_tokens = self.pipe.tokenizer( + self.convert_prompt(self.negative_prompt, self.pipe.tokenizer), + padding="max_length", + max_length=self.max_length_neg_prompt, + truncation=True, + return_tensors="pt", + ) + self.negative_prompt_tokens_2 = self.pipe.tokenizer_2( + self.convert_prompt(self.negative_prompt, self.pipe.tokenizer_2), + padding="max_length", + max_length=self.max_length_neg_prompt, + truncation=True, + return_tensors="pt", + ) + return self + + def convert_prompt(self, prompt, tokenizer): + tokens = tokenizer.tokenize(prompt) + unique_tokens = set(tokens) + for token in unique_tokens: + if token in tokenizer.added_tokens_encoder: + replacement = token + i = 1 + while f"{token}_{i}" in tokenizer.added_tokens_encoder: + replacement += f" {token}_{i}" + i += 1 + + prompt = prompt.replace(token, replacement) + + return prompt + + def encode_tokens( + self, + pipe: StableDiffusionXLPipeline, + text_input: torch.Tensor, + text_input_2: Optional[torch.Tensor] = None, + device: Optional[torch.device] = None, + num_images_per_prompt: int = 1, + do_classifier_free_guidance: bool = True, + negative_prompt: Optional[torch.Tensor] = None, + negative_prompt_2: Optional[torch.Tensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, + lora_scale: Optional[float] = None, + clip_skip: Optional[int] = None, + ): + r""" + Encodes the input tokens into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is + used in both text-encoders + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + negative_prompt_2 (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation to be sent to `tokenizer_2` and + `text_encoder_2`. If not defined, `negative_prompt` is used in both text-encoders + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. + If not provided, pooled text embeddings will be generated from `prompt` input argument. + negative_pooled_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, pooled negative_prompt_embeds will be generated from `negative_prompt` + input argument. + lora_scale (`float`, *optional*): + A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded. + clip_skip (`int`, *optional*): + Number of layers to be skipped from CLIP while computing the prompt embeddings. A value of 1 means that + the output of the pre-final layer will be used for computing the prompt embeddings. + """ + device = device or pipe._execution_device + batch_size = text_input.input_ids.shape[0] + + # Define tokenizers and text encoders + tokenizers = ([pipe.tokenizer, pipe.tokenizer_2] if pipe.tokenizer is not None else [ + pipe.tokenizer_2]) + text_encoders = ([pipe.text_encoder, pipe.text_encoder_2] + if pipe.text_encoder is not None else [pipe.text_encoder_2]) + + if prompt_embeds is None: + text_input_2 = text_input_2 or text_input + + # textual inversion: procecss multi-vector tokens if necessary + prompt_embeds_list = [] + text_inputs_list = [text_input, text_input_2] + for text_inputs, tokenizer, text_encoder in zip( + text_inputs_list, tokenizers, text_encoders + ): + text_input_ids = text_inputs.input_ids + prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True) + + # We are only ALWAYS interested in the pooled output of the final text encoder + pooled_prompt_embeds = prompt_embeds[0] + if clip_skip is None: + prompt_embeds = prompt_embeds.hidden_states[-2] + else: + # "2" because SDXL always indexes from the penultimate layer. + prompt_embeds = prompt_embeds.hidden_states[-(clip_skip + 2)] + + prompt_embeds_list.append(prompt_embeds) + + prompt_embeds = torch.concat(prompt_embeds_list, dim=-1) + + # get unconditional embeddings for classifier free guidance + zero_out_negative_prompt = ( + negative_prompt is None and pipe.config.force_zeros_for_empty_prompt) + if (do_classifier_free_guidance and negative_prompt_embeds is None and + zero_out_negative_prompt): + negative_prompt_embeds = torch.zeros_like(prompt_embeds) + negative_pooled_prompt_embeds = torch.zeros_like(pooled_prompt_embeds) + elif do_classifier_free_guidance and negative_prompt_embeds is None: + negative_prompt = negative_prompt or "" + negative_prompt_2 = negative_prompt_2 or negative_prompt + + # normalize str to list + negative_prompt_inputs = ( + negative_prompt.input_ids.repeat(batch_size, 1) if + (len(negative_prompt.input_ids.shape) == 1) else negative_prompt.input_ids) + negative_prompt_2_inputs = ( + negative_prompt_2.input_ids.repeat(batch_size, 1) if + (len(negative_prompt_2.input_ids.shape) == 1) else negative_prompt_2.input_ids) + + uncond_inputs = [negative_prompt_inputs, negative_prompt_2_inputs] + + negative_prompt_embeds_list = [] + for uncond_input, tokenizer, text_encoder in zip( + uncond_inputs, tokenizers, text_encoders + ): + negative_prompt_embeds = text_encoder( + uncond_input.to(device), + output_hidden_states=True, + ) + # We are only ALWAYS interested in the pooled output of the final text encoder + negative_pooled_prompt_embeds = negative_prompt_embeds[0] + negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2] + + negative_prompt_embeds_list.append(negative_prompt_embeds) + + negative_prompt_embeds = torch.concat(negative_prompt_embeds_list, dim=-1) + + if pipe.text_encoder_2 is not None: + prompt_embeds = prompt_embeds.to(dtype=pipe.text_encoder_2.dtype, device=device) + else: + prompt_embeds = prompt_embeds.to(dtype=pipe.unet.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + if pipe.text_encoder_2 is not None: + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=pipe.text_encoder_2.dtype, device=device) + else: + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=pipe.unet.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1) + + pooled_prompt_embeds = pooled_prompt_embeds.repeat(1, num_images_per_prompt).view( + bs_embed * num_images_per_prompt, -1) + if do_classifier_free_guidance: + negative_pooled_prompt_embeds = negative_pooled_prompt_embeds.repeat( + 1, num_images_per_prompt).view(bs_embed * num_images_per_prompt, -1) + return ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) + + def prepare_inputs(self, inputs, i): + if self.batch_size == 1: + return self.encode_tokens( + self.pipe, + inputs[i]["input_tokens"], + inputs[i]["input_tokens_2"], + negative_prompt=self.negative_prompt_tokens, + negative_prompt_2=self.negative_prompt_tokens_2, + ) + else: + prompt_embeds = [] + negative_prompt_embeds = [] + pooled_prompt_embeds = [] + negative_pooled_prompt_embeds = [] + for prompt in inputs[i:min(i + self.batch_size, len(inputs))]: + assert isinstance(prompt, dict) + text_input = prompt["input_tokens"] + text_input_2 = prompt["input_tokens_2"] + ( + p_e, + n_p_e, + p_p_e, + n_p_p_e, + ) = self.encode_tokens( + self.pipe, + text_input, + text_input_2, + negative_prompt=self.negative_prompt_tokens, + negative_prompt_2=self.negative_prompt_tokens_2, + ) + prompt_embeds.append(p_e) + negative_prompt_embeds.append(n_p_e) + pooled_prompt_embeds.append(p_p_e) + negative_pooled_prompt_embeds.append(n_p_p_e) + + prompt_embeds = torch.cat(prompt_embeds) + negative_prompt_embeds = torch.cat(negative_prompt_embeds) + pooled_prompt_embeds = torch.cat(pooled_prompt_embeds) + negative_pooled_prompt_embeds = torch.cat(negative_pooled_prompt_embeds) + return prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds + + def predict(self, inputs): + images = [0] * len(inputs) + with torch.no_grad(): + for i in tqdm(range(0, len(inputs), self.batch_size)): + max_index = min(i + self.batch_size, len(inputs)) + latents_input = [inputs[idx]["latents"] for idx in range(i, max_index)] + latents_input = torch.cat(latents_input).to(self.device) + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.prepare_inputs(inputs, i) + generated = self.pipe( + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + pooled_prompt_embeds=pooled_prompt_embeds, + negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, + guidance_scale=self.guidance, + num_inference_steps=self.steps, + output_type="pt", + latents=latents_input, + ).images + images[i:i + max_index] = generated.cpu() + # images.extend(generated) + return images diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/dataset.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/dataset.py new file mode 100644 index 000000000..ab7f92ea1 --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/dataset.py @@ -0,0 +1,359 @@ +""" +Code is adapted from the MLPerf text-to-image pipeline: https://github.com/mlcommons/inference/tree/master/text_to_image +Available under the following LICENSE: + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS +""" + +import logging +import os +import time + +import numpy as np +import pandas as pd +from PIL import Image +import torch + +logging.basicConfig(level=logging.INFO) +log = logging.getLogger() + + +class Dataset: + + def __init__(self): + self.arrival = None + self.image_list = [] + self.caption_list = [] + self.items_inmemory = {} + self.last_loaded = -1 + + def preprocess(self, use_cache=True): + raise NotImplementedError("Dataset:preprocess") + + def get_item_count(self): + return len(self.image_list) + + def get_list(self): + raise NotImplementedError("Dataset:get_list") + + def load_query_samples(self, sample_list): + self.items_inmemory = {} + for sample in sample_list: + self.items_inmemory[sample] = self.get_item(sample) + self.last_loaded = time.time() + + def unload_query_samples(self, sample_list): + if sample_list: + for sample in sample_list: + if sample in self.items_inmemory: + del self.items_inmemory[sample] + else: + self.items_inmemory = {} + + def get_samples(self, id_list): + data = [{ + "input_tokens": self.items_inmemory[id]["input_tokens"], + "input_tokens_2": self.items_inmemory[id]["input_tokens_2"], + "latents": self.items_inmemory[id]["latents"],} for id in id_list] + images = [self.items_inmemory[id]["file_name"] for id in id_list] + return data, images + + def get_item(self, id): + raise NotImplementedError("Dataset:get_item") + + +class ImagesDataset(torch.utils.data.Dataset): + + def __init__(self, imgs, transforms=None): + self.imgs = imgs + self.transforms = transforms + + def __len__(self): + return len(self.imgs) + + def __getitem__(self, i): + img = self.imgs[i] + if self.transforms is not None: + img = self.transforms(img) + return img + + +class Coco(Dataset): + + def __init__( + self, + data_path, + name=None, + image_size=None, + pre_process=None, + pipe_tokenizer=None, + pipe_tokenizer_2=None, + latent_dtype=torch.float32, + latent_device="cuda", + latent_framework="torch", + **kwargs, + ): + super().__init__() + self.captions_df = pd.read_csv(f"{data_path}/captions/captions.tsv", sep="\t") + self.image_size = image_size + self.preprocessed_dir = os.path.abspath(f"{data_path}/preprocessed/") + self.img_dir = os.path.abspath(f"{data_path}/validation/data/") + self.name = name + + # Preprocess prompts + self.captions_df["input_tokens"] = self.captions_df["caption"].apply( + lambda x: self.preprocess(x, pipe_tokenizer)) + self.captions_df["input_tokens_2"] = self.captions_df["caption"].apply( + lambda x: self.preprocess(x, pipe_tokenizer_2)) + self.latent_dtype = latent_dtype + self.latent_device = latent_device if torch.cuda.is_available() else "cpu" + if latent_framework == "torch": + self.latents = ( + torch.load(f"{data_path}/latents/latents.pt").to(latent_dtype).to(latent_device)) + elif latent_framework == "numpy": + self.latents = ( + torch.Tensor( + np.load(f"{data_path}/latents/latents.npy")).to(latent_dtype).to(latent_device)) + + def preprocess(self, prompt, tokenizer): + converted_prompt = self.convert_prompt(prompt, tokenizer) + return tokenizer( + converted_prompt, + padding="max_length", + max_length=tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + + def image_to_tensor(self, img): + img = np.asarray(img) + if len(img.shape) == 2: + img = np.expand_dims(img, axis=-1) + tensor = torch.Tensor(img.transpose([2, 0, 1])).to(torch.uint8) + if tensor.shape[0] == 1: + tensor = tensor.repeat(3, 1, 1) + return tensor + + def preprocess_images(self, file_name): + img = Image.open(self.img_dir + "/" + file_name) + tensor = self.image_to_tensor(img) + target_name = file_name.split(".")[0] + target_path = self.preprocessed_dir + "/" + target_name + ".pt" + if not os.path.exists(target_path): + torch.save(tensor, target_path) + return target_path + + def convert_prompt(self, prompt, tokenizer): + tokens = tokenizer.tokenize(prompt) + unique_tokens = set(tokens) + for token in unique_tokens: + if token in tokenizer.added_tokens_encoder: + replacement = token + i = 1 + while f"{token}_{i}" in tokenizer.added_tokens_encoder: + replacement += f" {token}_{i}" + i += 1 + + prompt = prompt.replace(token, replacement) + + return prompt + + def get_item(self, id): + return dict(self.captions_df.loc[id], latents=self.latents) + + def get_item_count(self): + return len(self.captions_df) + + def get_img(self, id): + img = Image.open(self.img_dir + "/" + self.captions_df.loc[id]["file_name"]) + return self.image_to_tensor(img) + + def get_imgs(self, id_list): + image_list = [] + for id in id_list: + image_list.append(self.get_img(id)) + return image_list + + def get_caption(self, i): + return self.get_item(i)["caption"] + + def get_captions(self, id_list): + return [self.get_caption(id) for id in id_list] + + def get_item_loc(self, id): + return self.img_dir + "/" + self.captions_df.loc[id]["file_name"] diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt new file mode 100644 index 000000000..3b453267e --- /dev/null +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/requirements.txt @@ -0,0 +1,8 @@ +accelerate==0.23.0 +diffusers==0.21.2 +open-clip-torch==2.7.0 +opencv-python==4.8.1.78 +pycocotools==2.0.7 +scipy==1.9.1 +torchmetrics[image]==1.2.0 +transformers==4.33.2 diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index b466d6303..7ce70e783 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -7,12 +7,7 @@ import torch from torch import nn -from torch._decomp import get_decompositions -from brevitas.backport.fx.experimental.proxy_tensor import make_fx -from brevitas.export.manager import _force_requires_grad_false -from brevitas.export.manager import _JitTraceExportWrapper -from brevitas_examples.llm.llm_quant.export import BlockQuantProxyLevelManager from brevitas_examples.llm.llm_quant.export import brevitas_proxy_export_mode @@ -31,3 +26,12 @@ def export_onnx(pipe, trace_inputs, output_dir, export_manager): print(f"Saving unet to {output_path} ...") with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager): torch.onnx.export(pipe.unet, args=trace_inputs, f=output_path) + + +def export_torch_export(pipe, trace_inputs, output_dir, export_manager): + output_path = os.path.join(output_dir, 'unet.onnx') + print(trace_inputs[1]) + print(f"Saving unet to {output_path} ...") + with torch.no_grad(), brevitas_proxy_export_mode(pipe.unet, export_manager): + torch.export.export( + UnetExportWrapper(pipe.unet), args=(trace_inputs[0],), kwargs=trace_inputs[1]) diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py index b2c30176f..a5af383ef 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py @@ -3,8 +3,81 @@ SPDX-License-Identifier: MIT """ +from contextlib import contextmanager + import torch +from brevitas.export.common.handler.base import BaseHandler +from brevitas.export.manager import _set_proxy_export_handler +from brevitas.export.manager import _set_proxy_export_mode +from brevitas.export.manager import BaseManager +from brevitas.nn.quant_layer import QuantWeightBiasInputOutputLayer +from brevitas.proxy.parameter_quant import WeightQuantProxyFromInjector + + +class InferenceWeightProxyHandler(BaseHandler): + handled_layer = WeightQuantProxyFromInjector + + def __init__(self): + super(InferenceWeightProxyHandler, self).__init__() + self.scale = None + self.zero_point = None + self.bit_width = None + self.float_weight = None + + def prepare_for_export(self, module): + assert len(module.tracked_module_list) == 1, "Shared quantizers not supported." + quant_layer = module.tracked_module_list[0] + self.float_weight = quant_layer.quant_weight() + quant_layer.weight.data = quant_layer.weight.data.cpu() + self.scale = module.scale() + self.zero_point = module.zero_point() + self.bit_width = module.bit_width() + + def forward(self, x): + return self.float_weight, self.scale, self.zero_point, self.bit_width + + +class InferenceWeightProxyManager(BaseManager): + handlers = [InferenceWeightProxyHandler] + + @classmethod + def set_export_handler(cls, module): + if hasattr(module, + 'requires_export_handler') and module.requires_export_handler and not isinstance( + module, (WeightQuantProxyFromInjector)): + return + _set_proxy_export_handler(cls, module) + + +def store_mapping_tensor_state_dict(model): + mapping = dict() + for module in model.modules(): + if isinstance(module, QuantWeightBiasInputOutputLayer): + mapping[module.weight.data_ptr()] = module.weight.device + return mapping + + +def restore_mapping(model, mapping): + for module in model.modules(): + if isinstance(module, QuantWeightBiasInputOutputLayer): + module.weight.data = module.weight.data.to(mapping[module.weight.data_ptr()]) + + +@contextmanager +def brevitas_proxy_inference_mode(model): + mapping = store_mapping_tensor_state_dict(model) + is_training = model.training + model.eval() + model.apply(InferenceWeightProxyManager.set_export_handler) + _set_proxy_export_mode(model, enabled=True) + try: + yield model + finally: + restore_mapping(model, mapping) + _set_proxy_export_mode(model, enabled=False) + model.train(is_training) + def unet_input_shape(resolution): return (4, resolution // 8, resolution // 8)