diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index cb1f42920..d370a67df 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -846,6 +846,11 @@ def input_zp_stats_type(): 'vae-fp16-fix', default=False, help='Rescale the VAE to not go NaN with FP16. Default: Disabled') + add_bool_arg( + parser, + 'share-qkv-quant', + default=False, + help='Share QKV/KV quantization. Default: Disabled') args = parser.parse_args() print("Args: " + str(vars(args))) main(args) diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py index 6fb987967..fe511be11 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py @@ -602,6 +602,7 @@ def compute_mlperf_fid( model_to_replace=None, samples_to_evaluate=500, output_dir=None, + device='cpu', vae_force_upcast=True): assert os.path.isfile(path_to_coco + '/tools/val2014.npz'), "Val2014.npz file required. Check the MLPerf directory for instructions" @@ -611,13 +612,13 @@ def compute_mlperf_fid( dtype = next(iter(model_to_replace.unet.parameters())).dtype res_dict = {} model = BackendPytorch( - path_to_sdxl, 'xl', steps=20, batch_size=1, device='cpu', precision=dtype) + path_to_sdxl, 'xl', steps=20, batch_size=1, device=device, precision=dtype) model.load() if model_to_replace is not None: model.pipe.unet = model_to_replace.unet if not vae_force_upcast: - model.pipe.vae = model.pipe.vae + model.pipe.vae = model_to_replace.vae model.pipe.vae.config.force_upcast = vae_force_upcast ds = Coco(