Skip to content

Commit

Permalink
Cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Mar 16, 2024
1 parent 845abb1 commit d845020
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions src/brevitas_examples/stable_diffusion/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,19 +161,20 @@ def input_bit_width(module):
print(f"Moving model to {args.device}...")
pipe = pipe.to(args.device)

if is_input_quantized and args.input_scale_type == 'static':
print("Applying activation calibration")
with calibration_mode(pipe.unet):
if args.quantize:
if is_input_quantized and args.input_scale_type == 'static':
print("Applying activation calibration")
with calibration_mode(pipe.unet):
prompts = VALIDATION_PROMPTS
run_val_inference(
pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype)

print("Applying bias correction")
with bias_correction_mode(pipe.unet):
prompts = VALIDATION_PROMPTS
run_val_inference(
pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype)

print("Applying bias correction")
with bias_correction_mode(pipe.unet):
prompts = VALIDATION_PROMPTS
run_val_inference(
pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype)

# Perform inference
if args.prompt:
print(f"Running inference with prompt '{args.prompt}' ...")
Expand All @@ -182,14 +183,16 @@ def input_bit_width(module):
pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype)

if args.export_target:
assert args.weight_quant_format == 'int', "Only integer quantization supported for export."

# Move to cpu and to float32 to enable CPU export
if not (args.float16 and args.export_cuda_float16):
pipe.unet.to('cpu').to(torch.float32)
pipe.unet.eval()
device = next(iter(pipe.unet.parameters())).device
dtype = next(iter(pipe.unet.parameters())).dtype
if args.export_target:
assert args.weight_quant_format == 'int', "Only integer quantization supported for export."

# Define tracing input
if args.is_sd_xl:
generate_fn = generate_unet_xl_rand_inputs
shape = SD_XL_EMBEDDINGS_SHAPE
Expand All @@ -201,6 +204,7 @@ def input_bit_width(module):
unet_input_shape=unet_input_shape(args.resolution),
device=device,
dtype=dtype)

if args.export_target == 'torchscript':
if args.weight_quant_granularity == 'per_group':
export_manager = BlockQuantProxyLevelManager
Expand Down

0 comments on commit d845020

Please sign in to comment.