diff --git a/src/brevitas/graph/gpfq.py b/src/brevitas/graph/gpfq.py index fd7df9223..d04811871 100644 --- a/src/brevitas/graph/gpfq.py +++ b/src/brevitas/graph/gpfq.py @@ -3,17 +3,18 @@ from copy import deepcopy import math -from math import pi +import os +from tempfile import TemporaryDirectory from typing import Callable, List, Optional import numpy as np import torch -from torch.fft import fft -from torch.fft import fftn +from torch.fx import GraphModule as TorchGraphModule import torch.nn as nn import unfoldNd from brevitas.function import get_upper_bound_on_l1_norm +from brevitas.fx import GraphModule from brevitas.graph.calibrate import disable_return_quant_tensor from brevitas.graph.calibrate import restore_return_quant_tensor from brevitas.graph.gpxq import GPxQ @@ -65,8 +66,13 @@ class gpfq_mode(gpxq_mode): Example: >>> with torch.no_grad(): - >>> with gpfq_mode(model) as gpfq: + >>> with gpfq_mode(model, collect_float_first) as gpfq: >>> gpfq_model = gpfq.model + >>> if collect_float_first: + >>> for img, t in calib_loader: + >>> img = img.cuda() + >>> gpfq_model(img) + >>> gpfq.finalize_float_collection() >>> for i in tqdm(range(gpfq.num_layers)): >>> for img, t in calib_loader: >>> img = img.cuda() @@ -87,7 +93,8 @@ def __init__( use_gpfa2q: bool = False, accumulator_bit_width: Optional[int] = None, a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True, - compression_rate: Optional[float] = 0.0) -> None: + compression_rate: Optional[float] = 0.0, + collect_float_first: bool = False) -> None: if not inplace: model = deepcopy(model) super().__init__( @@ -111,22 +118,47 @@ def __init__( if self.compression_rate < 0.0 or self.compression_rate > 1.0: raise ValueError('Compression rate for random projection must be between 0 and 1.') - def catch_stopfwd(self, *args, **kwargs): - # Collect quant input - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass + # speeding up by collecting float input first so we don't need to do it later + self.collect_float_first = collect_float_first + + def __enter__(self): + # initialize gpxq layers + self.setup_gpxq_layers() + if self.collect_float_first: + self.float_collection_hooks = dict() + # set up hooks for collecting the float input and storing them on disc + for name, layer in self.gpxq_layers.items(): + # Attach float collecting hook + self.float_collection_hooks[name] = layer.layer.register_forward_hook( + layer.collect_float_input) + + # Disable quantization + self.return_quant_tensor_state = disable_return_quant_tensor(self.model) + self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) + self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) + + return self + else: + # if we're not collecting, setup original hooks + # setup catch_stopfwd + self.orig_forward = self.model.forward + if isinstance(self.model, (GraphModule, TorchGraphModule)): + self.model.__class__.forward = self.catch_stopfwd + else: + self.model.forward = self.catch_stopfwd + return self.setup_gpxq_hooks() - # Disable quantization - self.return_quant_tensor_state = disable_return_quant_tensor(self.model) - self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) - self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) - # Collect float input - try: - self.orig_forward(*args, **kwargs) - except StopFwdException: - pass + def finalize_float_collection(self): + # remove the hooks we attached during the float collection + for name, hook in self.float_collection_hooks.items(): + hook.remove() + + # create temp dir + self.tmp_dir = TemporaryDirectory() + + # save all float activations to disc and delete them in the layers + for name, layer in self.gpxq_layers.items(): + layer.offload_float_input(tmp_dir=self.tmp_dir.name) # Re-enable quantization. If activation quantization is disabled, # we also disable bias quantization @@ -136,6 +168,48 @@ def catch_stopfwd(self, *args, **kwargs): else: self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False) restore_return_quant_tensor(self.model, self.return_quant_tensor_state) + # setup catch_stopfwd + self.orig_forward = self.model.forward + if isinstance(self.model, (GraphModule, TorchGraphModule)): + self.model.__class__.forward = self.catch_stopfwd + else: + self.model.forward = self.catch_stopfwd + # setup the original hooks + self.setup_gpxq_hooks() + + def __exit__(self, type, value, traceback): + # delete tmp dir + if self.collect_float_first: + self.tmp_dir.cleanup() + return super().__exit__(type, value, traceback) + + def catch_stopfwd(self, *args, **kwargs): + # Collect quant input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + + if not self.collect_float_first: + # Disable quantization + self.return_quant_tensor_state = disable_return_quant_tensor(self.model) + self.disable_quant_inference.disable_param_quantization(self.model, is_training=False) + self.disable_quant_inference.disable_act_quantization(self.model, is_training=False) + # Collect float input + try: + self.orig_forward(*args, **kwargs) + except StopFwdException: + pass + + # Re-enable quantization. If activation quantization is disabled, + # we also disable bias quantization + self.disable_quant_inference.enable_param_quantization(self.model, is_training=False) + if self.use_quant_activations: + self.disable_quant_inference.enable_act_quantization(self.model, is_training=False) + else: + self.disable_quant_inference.disable_bias_quantization( + self.model, is_training=False) + restore_return_quant_tensor(self.model, self.return_quant_tensor_state) if self.return_forward_output: # If we want to return the output of the network, we need to disable all hooks @@ -156,7 +230,8 @@ def initialize_module_optimizer( len_parallel_layers=len_parallel_layers, create_weight_orig=create_weight_orig, p=self.p, - compression_rate=self.compression_rate) + compression_rate=self.compression_rate, + collect_float_first=self.collect_float_first) else: return GPFA2Q( layer=layer, @@ -166,7 +241,8 @@ def initialize_module_optimizer( create_weight_orig=create_weight_orig, p=self.p, accumulator_bit_width=self.accumulator_bit_width, - compression_rate=self.compression_rate) + compression_rate=self.compression_rate, + collect_float_first=self.collect_float_first) class GPFQ(GPxQ): @@ -175,8 +251,15 @@ class GPFQ(GPxQ): """ def __init__( - self, layer, name, act_order, len_parallel_layers, create_weight_orig, p, - compression_rate) -> None: + self, + layer, + name, + act_order, + len_parallel_layers, + create_weight_orig, + p, + compression_rate, + collect_float_first) -> None: super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig) @@ -185,6 +268,81 @@ def __init__( self.index_computed = False self.p = p self.compression_rate = compression_rate + self.collect_float_first = collect_float_first + + def collect_float_input(self, module, args, output): + # this is the hook function to offload the output of this layer to disc + inp = self.process_input(args) + batch_size = inp.shape[0] + + # Preprocess the input to compute the Hessian + if isinstance(self.layer, qnn.QuantLinear): + if len(inp.shape) > 2: + inp = inp.reshape((-1, sum(inp.shape[2:]))) + # For QuantLinear layer, groups will be 1 + inp_processed = inp.unsqueeze(0) + + if isinstance(self.layer, SUPPORTED_CONV_OP): + # Pick the correct unfoldNd class + if isinstance( + self.layer, + (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)): + unfold_impl = unfoldNd.UnfoldTransposeNd + else: + unfold_impl = unfoldNd.UnfoldNd + + unfold = unfold_impl( + self.layer.kernel_size, + dilation=self.layer.dilation, + padding=self.layer.padding, + stride=self.layer.kernel_size) + + # Split input based on how many groups in convolution + inp_by_group = torch.chunk(inp, self.groups, 1) + inp_processed = [] + # Preprocess input by group + for i, inp in enumerate(inp_by_group): + + inp = unfold(inp) + + batch_size, num_blocks = inp.shape[0], inp.shape[-1] + inp = torch.transpose(inp, 1, 2) # shape (B, L, C*kernel_size[0]*kernel_size[1]) + inp = inp.reshape(-1, inp.size(-1)) # shape (B*L, C*kernel_size[0]*kernel_size[1]) + + if not self.index_computed: + self.index_computed = True + self.rand_indices = np.concatenate([ + np.random.choice( + np.arange(num_blocks * i, num_blocks * (i + 1)), + size=int( + self.p * num_blocks + 1 if self.p != 1 else self.p * num_blocks)) + for i in range(batch_size)]) # need to define self.p (probability) + + indexes = self.rand_indices + if np.max(self.rand_indices) > inp.shape[0]: + indexes = self.rand_indices < inp.shape[0] + indexes = self.rand_indices[indexes] + + inp = inp[indexes] + inp_processed.append(inp) + inp_processed = torch.stack(inp_processed) + + inp_processed = inp_processed.cpu() + + if self.float_input is None: + self.float_input = inp_processed + else: + self.float_input = torch.cat([self.float_input, inp_processed], dim=1) + + def offload_float_input(self, tmp_dir): + # create tmp directory for this layer + self.save_dir = tmp_dir + '/' + self.name + os.makedirs(self.save_dir, exist_ok=True) + self.float_input_file = self.save_dir + '/float_input.pt' + # offload float input + torch.save(self.float_input, self.float_input_file) + # then delete float_input to save memory + del self.float_input def update_batch(self, module, input, current_layer): if self.disable_pre_forward_hook: @@ -272,6 +430,9 @@ def single_layer_update(self): weight = self.layer.weight.data dev = weight.device dtype = weight.dtype + # load float input from disc if needed + if self.collect_float_first: + self.float_input = torch.load(self.float_input_file) if isinstance(self.layer, SUPPORTED_CONV_OP): if isinstance( self.layer, @@ -336,7 +497,8 @@ def __init__( create_weight_orig, accumulator_bit_width, p, - compression_rate) -> None: + compression_rate, + collect_float_first) -> None: GPFQ.__init__( self, layer=layer, @@ -345,7 +507,8 @@ def __init__( len_parallel_layers=len_parallel_layers, create_weight_orig=create_weight_orig, p=p, - compression_rate=compression_rate) + compression_rate=compression_rate, + collect_float_first=collect_float_first) self.accumulator_bit_width = accumulator_bit_width assert self.accumulator_bit_width is not None @@ -359,6 +522,10 @@ def single_layer_update(self): weight = self.layer.weight.data dev = weight.device dtype = weight.dtype + # load float input from disc if needed + if self.collect_float_first: + # load float_input from disc + self.float_input = torch.load(self.float_input_file) if isinstance(self.layer, SUPPORTED_CONV_OP): if isinstance(self.layer, (qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d)): weight = weight.transpose(1, 0) # This performs a view diff --git a/src/brevitas/graph/gptq.py b/src/brevitas/graph/gptq.py index 31d31433b..fad7dc0ff 100644 --- a/src/brevitas/graph/gptq.py +++ b/src/brevitas/graph/gptq.py @@ -13,9 +13,11 @@ from torch.linalg import LinAlgError except: LinAlgError = RuntimeError +from torch.fx import GraphModule as TorchGraphModule import unfoldNd from brevitas import torch_version +from brevitas.fx import GraphModule from brevitas.graph.gpxq import GPxQ from brevitas.graph.gpxq import gpxq_mode from brevitas.graph.gpxq import StopFwdException @@ -76,6 +78,15 @@ def __init__( # How many subblock to use during GPTQ for each layer self.num_blocks = num_blocks + def __enter__(self): + self.orig_forward = self.model.forward + if isinstance(self.model, (GraphModule, TorchGraphModule)): + self.model.__class__.forward = self.catch_stopfwd + else: + self.model.forward = self.catch_stopfwd + self.setup_gpxq_layers() + return self.setup_gpxq_hooks() + def catch_stopfwd(self, *args, **kwargs): try: self.orig_forward(*args, **kwargs) diff --git a/src/brevitas/graph/gpxq.py b/src/brevitas/graph/gpxq.py index fdbaee52f..31407d7c3 100644 --- a/src/brevitas/graph/gpxq.py +++ b/src/brevitas/graph/gpxq.py @@ -99,12 +99,6 @@ def __init__( self.group_of_parallel_layers = group_of_parallel_layers self.return_forward_output = return_forward_output - self.orig_forward = self.model.forward - if isinstance(self.model, (GraphModule, TorchGraphModule)): - self.model.__class__.forward = self.catch_stopfwd - else: - self.model.forward = self.catch_stopfwd - def _is_module_supported(self, module): if isinstance(module, SUPPORTED_CONV_OP): return True @@ -113,34 +107,38 @@ def _is_module_supported(self, module): else: return False + @abstractmethod def __enter__(self): + pass + + def setup_gpxq_layers(self): # The user can specify on which layers to apply gptq in parallel. # All the others will be executed sequentially - dict_of_layers = { + self.dict_of_layers = { name: [(name, module)] for name, module in self.model.named_modules() if self._is_module_supported(module)} if self.group_of_parallel_layers is not None: for parallel_layers in self.group_of_parallel_layers: for name in parallel_layers: - if name not in dict_of_layers: + if name not in self.dict_of_layers: raise ValueError( "The layer {} is not present in the model or it is not supported for GPTQ" .format(name)) - del dict_of_layers[name] + del self.dict_of_layers[name] names = '_'.join(parallel_layers) - dict_of_layers[names] = [ + self.dict_of_layers[names] = [ (name, attrgetter(name)(self.model)) for name in parallel_layers] # Print warning if hooks are attached to any module, since the normal forward flow of the # network is highly disrupted during GPxQ - for _, parallel_layers in dict_of_layers.items(): + for _, parallel_layers in self.dict_of_layers.items(): for name, module in parallel_layers: if len(module._forward_hooks) > 0 or len(module._forward_pre_hooks): warnings.warn( f'Hooks detected during setup for GPxQ. ' f'Behaviour might deviate from what expected.') - # Attach hooks for GPTQ + # initialize GPxQ if self._is_module_supported(module): gpxq_module_optimizer = self.initialize_module_optimizer( module, @@ -148,11 +146,14 @@ def __enter__(self): act_order=self.act_order, len_parallel_layers=len(parallel_layers), create_weight_orig=self.create_weight_orig) - hook_fn = partial( - gpxq_module_optimizer.update_batch, current_layer=self.current_layer) - self.hook_dict[name] = module.register_forward_pre_hook(hook_fn) self.gpxq_layers[name] = gpxq_module_optimizer + def setup_gpxq_hooks(self): + for name, module in self.gpxq_layers.items(): + # Attach hooks for GPxQ + hook_fn = partial(module.update_batch, current_layer=self.current_layer) + self.hook_dict[name] = module.layer.register_forward_pre_hook(hook_fn) + if not self.use_quant_activations: self.return_quant_tensor_state = disable_return_quant_tensor(self.model) self.disable_quant_inference.disable_act_quantization( @@ -160,7 +161,7 @@ def __enter__(self): self.disable_quant_inference.disable_bias_quantization( self.model, is_training=self.model.training) - self.num_layers = len(dict_of_layers) + self.num_layers = len(self.dict_of_layers) return self def __exit__(self, type, value, traceback): diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py index 9d94df12f..afa44e0ef 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_common.py @@ -543,7 +543,8 @@ def apply_gpfq( p=1.0, use_gpfa2q=False, accumulator_bit_width=None, - compression_rate=0.0): + compression_rate=0.0, + collect_float_first=True): model.eval() dtype = next(model.parameters()).dtype device = next(model.parameters()).device @@ -554,8 +555,17 @@ def apply_gpfq( act_order=act_order, use_gpfa2q=use_gpfa2q, accumulator_bit_width=accumulator_bit_width, - compression_rate=compression_rate) as gpfq: + compression_rate=compression_rate, + collect_float_first=collect_float_first) as gpfq: gpfq_model = gpfq.model + if collect_float_first: + print('Collecting float input first...') + for i, (images, target) in tqdm(enumerate(calib_loader)): + images = images.to(device) + images = images.to(dtype) + gpfq_model(images) + gpfq.finalize_float_collection() + for i in tqdm(range(gpfq.num_layers)): for i, (images, target) in enumerate(calib_loader): images = images.to(device) diff --git a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py index 7e2bf6ee5..eed92b6ec 100644 --- a/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py +++ b/src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py @@ -245,6 +245,11 @@ def parse_type(v, default_type): type=float, help='Specify compression rate < 1.0 for random projection. Default is 0.0 and does not use RP.' ) +add_bool_arg( + parser, + 'collect-float-first', + default=False, + help='In GPFQ, separate float and quant forward pass for speed up. (default: False)') add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)') add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)') add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)') @@ -437,7 +442,8 @@ def main(): quant_model, p=args.gpfq_p, act_order=args.gpxq_act_order, - compression_rate=args.compression_rate) + compression_rate=args.compression_rate, + collect_float_first=args.collect_float_first) if args.gpfa2q: print("Performing GPFA2Q:") @@ -448,7 +454,8 @@ def main(): act_order=args.gpxq_act_order, use_gpfa2q=args.gpfa2q, accumulator_bit_width=args.accumulator_bit_width, - compression_rate=args.compression_rate) + compression_rate=args.compression_rate, + collect_float_first=args.collect_float_first) if args.gptq: print("Performing GPTQ:")