From d84502082e5be7079ea75acf93efe99f751e821b Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 16 Mar 2024 15:24:54 +0000 Subject: [PATCH] Cleanup --- .../stable_diffusion/main.py | 26 +++++++++++-------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index bb9e649b1..199f30570 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -161,19 +161,20 @@ def input_bit_width(module): print(f"Moving model to {args.device}...") pipe = pipe.to(args.device) - if is_input_quantized and args.input_scale_type == 'static': - print("Applying activation calibration") - with calibration_mode(pipe.unet): + if args.quantize: + if is_input_quantized and args.input_scale_type == 'static': + print("Applying activation calibration") + with calibration_mode(pipe.unet): + prompts = VALIDATION_PROMPTS + run_val_inference( + pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + + print("Applying bias correction") + with bias_correction_mode(pipe.unet): prompts = VALIDATION_PROMPTS run_val_inference( pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) - print("Applying bias correction") - with bias_correction_mode(pipe.unet): - prompts = VALIDATION_PROMPTS - run_val_inference( - pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) - # Perform inference if args.prompt: print(f"Running inference with prompt '{args.prompt}' ...") @@ -182,14 +183,16 @@ def input_bit_width(module): pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) if args.export_target: + assert args.weight_quant_format == 'int', "Only integer quantization supported for export." + # Move to cpu and to float32 to enable CPU export if not (args.float16 and args.export_cuda_float16): pipe.unet.to('cpu').to(torch.float32) pipe.unet.eval() device = next(iter(pipe.unet.parameters())).device dtype = next(iter(pipe.unet.parameters())).dtype - if args.export_target: - assert args.weight_quant_format == 'int', "Only integer quantization supported for export." + + # Define tracing input if args.is_sd_xl: generate_fn = generate_unet_xl_rand_inputs shape = SD_XL_EMBEDDINGS_SHAPE @@ -201,6 +204,7 @@ def input_bit_width(module): unet_input_shape=unet_input_shape(args.resolution), device=device, dtype=dtype) + if args.export_target == 'torchscript': if args.weight_quant_granularity == 'per_group': export_manager = BlockQuantProxyLevelManager