Skip to content

Commit

Permalink
Fix (brevitas_examples/sdxl): Various fixes (#991)
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 authored Jul 18, 2024
1 parent 072a02b commit 7c7d825
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 7 deletions.
5 changes: 5 additions & 0 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -846,6 +846,11 @@ def input_zp_stats_type():
'vae-fp16-fix',
default=False,
help='Rescale the VAE to not go NaN with FP16. Default: Disabled')
add_bool_arg(
parser,
'share-qkv-quant',
default=False,
help='Share QKV/KV quantization. Default: Disabled')
args = parser.parse_args()
print("Args: " + str(vars(args)))
main(args)
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def compute_mlperf_fid(
model_to_replace=None,
samples_to_evaluate=500,
output_dir=None,
device='cpu',
vae_force_upcast=True):

assert os.path.isfile(path_to_coco + '/tools/val2014.npz'), "Val2014.npz file required. Check the MLPerf directory for instructions"
Expand All @@ -611,13 +612,13 @@ def compute_mlperf_fid(
dtype = next(iter(model_to_replace.unet.parameters())).dtype
res_dict = {}
model = BackendPytorch(
path_to_sdxl, 'xl', steps=20, batch_size=1, device='cpu', precision=dtype)
path_to_sdxl, 'xl', steps=20, batch_size=1, device=device, precision=dtype)
model.load()

if model_to_replace is not None:
model.pipe.unet = model_to_replace.unet
if not vae_force_upcast:
model.pipe.vae = model.pipe.vae
model.pipe.vae = model_to_replace.vae

model.pipe.vae.config.force_upcast = vae_force_upcast
ds = Coco(
Expand Down
12 changes: 7 additions & 5 deletions src/brevitas_examples/stable_diffusion/sd_quant/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def handle_quant_param(layer, layer_dict):
weight_scale = layer.weight_quant.export_handler.symbolic_kwargs['dequantize_symbolic_kwargs'][
'scale'].data
weight_zp = layer.weight_quant.export_handler.symbolic_kwargs['dequantize_symbolic_kwargs'][
'zero_point'].data - 128. # apply offset to have signed zp
'zero_point'].data
if layer.output_quant.export_handler.symbolic_kwargs is not None:
output_scale = layer.output_quant.export_handler.symbolic_kwargs[
'dequantize_symbolic_kwargs']['scale'].data
Expand All @@ -43,13 +43,15 @@ def handle_quant_param(layer, layer_dict):
layer_dict['input_zp'] = input_zp.numpy().tolist()
layer_dict['input_zp_shape'] = input_zp.shape
layer_dict['input_zp_dtype'] = str(torch.int8)
layer_dict['weight_scale'] = weight_scale.numpy().tolist()
layer_dict['weight_scale'] = weight_scale.cpu().numpy().tolist()
nelems = layer.weight.shape[0]
weight_scale_shape = [nelems] + [1] * (layer.weight.data.ndim - 1)
layer_dict['weight_scale_shape'] = weight_scale_shape
layer_dict['weight_zp'] = weight_zp.numpy().tolist()
layer_dict['weight_zp_shape'] = weight_scale_shape
layer_dict['weight_zp_dtype'] = str(torch.int8)
if torch.sum(weight_zp) != 0.:
weight_zp = weight_zp - 128. # apply offset to have signed z
layer_dict['weight_zp'] = weight_zp.cpu().numpy().tolist()
layer_dict['weight_zp_shape'] = weight_scale_shape
layer_dict['weight_zp_dtype'] = str(torch.int8)
return layer_dict


Expand Down

0 comments on commit 7c7d825

Please sign in to comment.