Skip to content

Commit

Permalink
Early exit in case of incompatibility
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Apr 8, 2024
1 parent 940cbd7 commit 0aab438
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,11 @@ def run_val_inference(pipe, resolution, prompts, seeds, output_path, device, dty

def main(args):

if args.export_target:
assert args.weight_quant_format == 'int', "Currently only integer quantization supported for export."
if args.is_sd_xl:
assert args.export_target != 'torchscript', "Torchscript export of SD-XL not supported"

# Select dtype
if args.float16:
dtype = torch.float16
Expand Down Expand Up @@ -216,8 +221,6 @@ def input_bit_width(module):
pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype)

if args.export_target:
assert args.weight_quant_format == 'int', "Only integer quantization supported for export."

# Move to cpu and to float32 to enable CPU export
if not (args.float16 and args.export_cuda_float16):
pipe.unet.to('cpu').to(torch.float32)
Expand All @@ -227,7 +230,6 @@ def input_bit_width(module):

# Define tracing input
if args.is_sd_xl:
assert args.export_target != 'torchscript', "Torchscript export of SD-XL not supported"
generate_fn = generate_unet_xl_rand_inputs
shape = SD_XL_EMBEDDINGS_SHAPE
else:
Expand Down

0 comments on commit 0aab438

Please sign in to comment.