diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index e42b739e3..538f1a0d7 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -227,6 +227,7 @@ def input_bit_width(module): # Define tracing input if args.is_sd_xl: + assert args.export_target != 'torchscript', "Torchscript export of SD-XL not supported" generate_fn = generate_unet_xl_rand_inputs shape = SD_XL_EMBEDDINGS_SHAPE else: diff --git a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py index 8e5bd886c..b2c30176f 100644 --- a/src/brevitas_examples/stable_diffusion/sd_quant/utils.py +++ b/src/brevitas_examples/stable_diffusion/sd_quant/utils.py @@ -68,6 +68,10 @@ def generate_unet_xl_rand_inputs( device='cpu', dtype=torch.float32, with_return_dict_false=False): + # We need to pass a combination of args and kwargs to ONNX export + # If we pass all kwargs, something breaks + # If we pass only the last element as kwargs, since it is a dict, it has a weird interaction and something breaks + # The solution is to pass only one argument as args, and everything else as kwargs unet_rand_inputs = generate_unet_rand_inputs( embedding_shape, unet_input_shape, batch_size, device, dtype, with_return_dict_false) sample = unet_rand_inputs['sample']