diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/export.py b/src/brevitas_examples/stable_diffusion/sd_quant/export.py index ac737f276..64bcac34f 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/export.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/export.py @@ -49,7 +49,7 @@ def handle_quant_param(layer, layer_dict): layer_dict['weight_scale_shape'] = weight_scale_shape layer_dict['weight_zp'] = weight_zp.numpy().tolist() layer_dict['weight_zp_shape'] = weight_scale_shape - layer_dict['weight_zp_dtype'] = str(torch.uint8) + layer_dict['weight_zp_dtype'] = str(torch.int8) return layer_dict