diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 1e7a96a91..bb9e649b1 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -161,11 +161,14 @@ def input_bit_width(module): print(f"Moving model to {args.device}...") pipe = pipe.to(args.device) - with calibration_mode(pipe.unet): - prompts = VALIDATION_PROMPTS - run_val_inference( - pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) - + if is_input_quantized and args.input_scale_type == 'static': + print("Applying activation calibration") + with calibration_mode(pipe.unet): + prompts = VALIDATION_PROMPTS + run_val_inference( + pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) + + print("Applying bias correction") with bias_correction_mode(pipe.unet): prompts = VALIDATION_PROMPTS run_val_inference(