diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index cb1f42920..d370a67df 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -846,6 +846,11 @@ def input_zp_stats_type(): 'vae-fp16-fix', default=False, help='Rescale the VAE to not go NaN with FP16. Default: Disabled') + add_bool_arg( + parser, + 'share-qkv-quant', + default=False, + help='Share QKV/KV quantization. Default: Disabled') args = parser.parse_args() print("Args: " + str(vars(args))) main(args) diff --git a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py index 6fb987967..fe511be11 100644 --- a/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py +++ b/src/brevitas_examples/stable_diffusion/mlperf_evaluation/accuracy.py @@ -602,6 +602,7 @@ def compute_mlperf_fid( model_to_replace=None, samples_to_evaluate=500, output_dir=None, + device='cpu', vae_force_upcast=True): assert os.path.isfile(path_to_coco + '/tools/val2014.npz'), "Val2014.npz file required. Check the MLPerf directory for instructions" @@ -611,13 +612,13 @@ def compute_mlperf_fid( dtype = next(iter(model_to_replace.unet.parameters())).dtype res_dict = {} model = BackendPytorch( - path_to_sdxl, 'xl', steps=20, batch_size=1, device='cpu', precision=dtype) + path_to_sdxl, 'xl', steps=20, batch_size=1, device=device, precision=dtype) model.load() if model_to_replace is not None: model.pipe.unet = model_to_replace.unet if not vae_force_upcast: - model.pipe.vae = model.pipe.vae + model.pipe.vae = model_to_replace.vae model.pipe.vae.config.force_upcast = vae_force_upcast ds = Coco( diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index 09c331951..89d846a79 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -31,7 +31,7 @@ def handle_quant_param(layer, layer_dict): weight_scale = layer.weight_quant.export_handler.symbolic_kwargs['dequantize_symbolic_kwargs'][ 'scale'].data weight_zp = layer.weight_quant.export_handler.symbolic_kwargs['dequantize_symbolic_kwargs'][ - 'zero_point'].data - 128. # apply offset to have signed zp + 'zero_point'].data if layer.output_quant.export_handler.symbolic_kwargs is not None: output_scale = layer.output_quant.export_handler.symbolic_kwargs[ 'dequantize_symbolic_kwargs']['scale'].data @@ -43,13 +43,15 @@ def handle_quant_param(layer, layer_dict): layer_dict['input_zp'] = input_zp.numpy().tolist() layer_dict['input_zp_shape'] = input_zp.shape layer_dict['input_zp_dtype'] = str(torch.int8) - layer_dict['weight_scale'] = weight_scale.numpy().tolist() + layer_dict['weight_scale'] = weight_scale.cpu().numpy().tolist() nelems = layer.weight.shape[0] weight_scale_shape = [nelems] + [1] * (layer.weight.data.ndim - 1) layer_dict['weight_scale_shape'] = weight_scale_shape - layer_dict['weight_zp'] = weight_zp.numpy().tolist() - layer_dict['weight_zp_shape'] = weight_scale_shape - layer_dict['weight_zp_dtype'] = str(torch.int8) + if torch.sum(weight_zp) != 0.: + weight_zp = weight_zp - 128. # apply offset to have signed z + layer_dict['weight_zp'] = weight_zp.cpu().numpy().tolist() + layer_dict['weight_zp_shape'] = weight_scale_shape + layer_dict['weight_zp_dtype'] = str(torch.int8) return layer_dict