Skip to content

Commit

Permalink
Export comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 27, 2024
1 parent c74d98d commit 6c74db3
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ def input_bit_width(module):

# Define tracing input
if args.is_sd_xl:
assert args.export_target != 'torchscript', "Torchscript export of SD-XL not supported"
generate_fn = generate_unet_xl_rand_inputs
shape = SD_XL_EMBEDDINGS_SHAPE
else:
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas_examples/stable_diffusion/sd_quant/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ def generate_unet_xl_rand_inputs(
device='cpu',
dtype=torch.float32,
with_return_dict_false=False):
# We need to pass a combination of args and kwargs to ONNX export
# If we pass all kwargs, something breaks
# If we pass only the last element as kwargs, since it is a dict, it has a weird interaction and something breaks
# The solution is to pass only one argument as args, and everything else as kwargs
unet_rand_inputs = generate_unet_rand_inputs(
embedding_shape, unet_input_shape, batch_size, device, dtype, with_return_dict_false)
sample = unet_rand_inputs['sample']
Expand Down

0 comments on commit 6c74db3

Please sign in to comment.