diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index ef720d092..85ecd6ef5 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -263,6 +263,9 @@ def single_layer_update(self): # No permutation, permutation tensor is a ordered index perm = torch.tensor(range(weight.shape[-1]), device=dev) permutation_list.append(perm) + + self.reactivate_quantization() + for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( @@ -358,6 +361,8 @@ def single_layer_update(self): perm = torch.tensor(range(weight.shape[-1]), device=dev) permutation_list.append(perm) + self.reactivate_quantization() + for t in range(weight.shape[-1]): for group_index in range(self.groups): U[group_index] += torch.matmul( diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 31d31433b..d0d45a630 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -9,6 +9,9 @@ from packaging import version import torch +from brevitas.graph.calibrate import disable_return_quant_tensor +from brevitas.graph.calibrate import restore_return_quant_tensor + try: from torch.linalg import LinAlgError except: @@ -86,9 +89,11 @@ def catch_stopfwd(self, *args, **kwargs): # If we want to return the output of the network, we need to disable all hooks for name, gpxq_class in self.gpxq_layers.items(): gpxq_class.disable_pre_forward_hook = True + out = self.orig_forward(*args, **kwargs) for name, gpxq_class in self.gpxq_layers.items(): gpxq_class.disable_pre_forward_hook = False + return out def initialize_module_optimizer( @@ -134,6 +139,7 @@ def __init__( device='cpu', dtype=torch.float32) self.nsamples = 0 + self.done = False assert torch_version >= version.parse('1.10'), "GPTQ requires torch 1.10 or higher" @@ -257,6 +263,8 @@ def single_layer_update(self, percdamp=.01): finally: del self.H + self.reactivate_quantization() + for i1 in range(0, self.columns, self.blocksize): i2 = min(i1 + self.blocksize, self.columns) count = i2 - i1 diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index b85ac1188..deb613b1a 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -164,6 +164,10 @@ def __enter__(self): return self def __exit__(self, type, value, traceback): + for name, layer in self.gpxq_layers.items(): + if not layer.done: + layer.reactivate_quantization() + if isinstance(self.model, (GraphModule, TorchGraphModule)): self.model.__class__.forward = self.orig_forward else: @@ -219,6 +223,10 @@ def __init__( self.disable_pre_forward_hook = False # Some layers require knowledge from quant inputs to compute quant weights self.quant_metadata = None + self.disable_quant_inference = DisableEnableQuantization() + self.return_quant_tensor_state = disable_return_quant_tensor(self.layer) + self.disable_quant_inference.disable_param_quantization(self.layer, False) + self.done = False def process_input(self, inp): # Input is a tuple, so we take first element @@ -255,6 +263,11 @@ def update_batch(self): def single_layer_update(self): pass + def reactivate_quantization(self): + self.done = True + self.disable_quant_inference.enable_param_quantization(self.layer, False) + restore_return_quant_tensor(self.layer, self.return_quant_tensor_state) + def get_quant_weights(self, i, i1, permutation_list): # We need to recompute quant weights at runtime since our float weights are being updated # Add offset in case of blockwise computation diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 31ab57361..a03eabb2b 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -4,6 +4,7 @@ """ import re +import torch from torch import nn from brevitas import nn as qnn @@ -190,6 +191,8 @@ def quantize_model( 'narrow_range': False, 'quantize_zero_point': quantize_weight_zero_point}, **weight_float_format) + if dtype == torch.float16: + weight_quant = weight_quant.let(**{'scaling_min_val': 1e-4}) # Set the group_size is we're doing groupwise quantization if weight_quant_granularity == 'per_group': diff --git a/src/brevitas_examples/stable_diffusion/main.py b/src/brevitas_examples/stable_diffusion/main.py index 3ac49080a..9faad9657 100644 --- a/src/brevitas_examples/stable_diffusion/main.py +++ b/src/brevitas_examples/stable_diffusion/main.py @@ -62,14 +62,6 @@ 'close up photo of a rabbit, forest in spring, haze, halation, bloom, dramatic atmosphere, centred, rule of thirds, 200mm 1.4f macro shot' } -clip_score_fn = partial(clip_score, model_name_or_path="openai/clip-vit-base-patch16") - - -def calculate_clip_score(images, prompts): - images_int = (np.array(images)).astype("uint8") - clip_score = clip_score_fn(torch.from_numpy(images_int).permute(2, 0, 1), prompts).detach() - return round(float(clip_score), 4) - def run_test_inference( pipe, resolution, prompts, seeds, output_path, device, dtype, name_prefix=''): @@ -174,7 +166,11 @@ def main(args): for m in pipe.unet.modules(): if isinstance(m, KwargsForwardHook) and hasattr(m.module, 'in_features'): m.in_features = m.module.in_features - prompts = VALIDATION_PROMPTS + prompts = dict() + for i, (k, v) in enumerate(CALIBRATION_PROMPTS.items()): + if i == args.calibration_prompt: + break + prompts[k] = v run_val_inference( pipe, args.resolution, prompts, test_seeds, output_dir, args.device, dtype) @@ -310,25 +306,6 @@ def input_bit_width(module): fid.update(quant_images_values, real=False) print(f"FID: {float(fid.compute())}") - clip_score_float = 0. - for prompt, images in float_images.items(): - clip_score_prompt = 0. - for image in images: - clip_score_prompt += calculate_clip_score(image, prompt) - clip_score_float += clip_score_prompt / len(images) - clip_score_float /= len(float_images) - - clip_score_quant = 0. - for prompt, images in quant_images.items(): - clip_score_prompt = 0. - for image in images: - clip_score_prompt += calculate_clip_score(image, prompt) - clip_score_quant += clip_score_prompt / len(images) - clip_score_quant /= len(quant_images) - - print(f"CLIP float: {clip_score_float}") - print(f"CLIP quant: {clip_score_quant}") - if args.export_target: # Move to cpu and to float32 to enable CPU export if not (dtype == torch.float16 and args.export_cuda_float16): @@ -373,8 +350,8 @@ def input_bit_width(module): '-b', '--batch-size', type=int, - default=4, - help='How many seeds to use for each image during validation. Default: 4') + default=2, + help='How many seeds to use for each image during validation. Default: 2') parser.add_argument( '--prompt', type=int,