Skip to content

Commit

Permalink
Feat (gpfq): separate float and quant forward pass for speedup
Browse files Browse the repository at this point in the history
  • Loading branch information
fabianandresgrob committed May 16, 2024
1 parent 3464ec7 commit 2e53cc8
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 30 deletions.
142 changes: 126 additions & 16 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# SPDX-License-Identifier: BSD-3-Clause

from copy import deepcopy
import tempfile
from typing import Callable, List, Optional

import numpy as np
Expand Down Expand Up @@ -64,7 +65,8 @@ def __init__(
act_order: bool = False,
use_gpfa2q: bool = False,
accumulator_bit_width: Optional[int] = None,
a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True) -> None:
a2q_layer_filter_fnc: Optional[Callable[[nn.Module], bool]] = lambda x: True,
collect_float_first: bool = False) -> None:
if not inplace:
model = deepcopy(model)
super().__init__(
Expand All @@ -83,22 +85,34 @@ def __init__(
self.accumulator_bit_width = accumulator_bit_width
self.a2q_layer_filter_fnc = a2q_layer_filter_fnc # returns true when to use GPFA2Q

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass
# speeding up by collecting float input first so we don't need to do it later
self.collect_float_first = collect_float_first

def __enter__(self):
# initialize gpxq layers
self.setup_gpxq_layers()
if self.collect_float_first:
self.float_collection_hooks = dict()
# set up hooks for collecting the float input
for name, layer in self.gpxq_layers.items():
# Attach float collecting hook
self.float_collection_hooks[name] = layer.layer.register_forward_hook(
layer.collect_float_input)

# Disable quantization
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)

return self
else:
# if we're not collecting, setup original hooks
return self.setup_gpxq_hooks()

# Disable quantization
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
# Collect float input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass
def finalize_float_collection(self):
# remove the hooks we attached during the float collection
for name, hook in self.float_collection_hooks.items():
hook.remove()

# Re-enable quantization. If activation quantization is disabled,
# we also disable bias quantization
Expand All @@ -109,6 +123,37 @@ def catch_stopfwd(self, *args, **kwargs):
self.disable_quant_inference.disable_bias_quantization(self.model, is_training=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

# setup the original hooks
self.setup_gpxq_hooks()

def catch_stopfwd(self, *args, **kwargs):
# Collect quant input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

if not self.collect_float_first:
# Disable quantization
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_param_quantization(self.model, is_training=False)
self.disable_quant_inference.disable_act_quantization(self.model, is_training=False)
# Collect float input
try:
self.orig_forward(*args, **kwargs)
except StopFwdException:
pass

# Re-enable quantization. If activation quantization is disabled,
# we also disable bias quantization
self.disable_quant_inference.enable_param_quantization(self.model, is_training=False)
if self.use_quant_activations:
self.disable_quant_inference.enable_act_quantization(self.model, is_training=False)
else:
self.disable_quant_inference.disable_bias_quantization(
self.model, is_training=False)
restore_return_quant_tensor(self.model, self.return_quant_tensor_state)

if self.return_forward_output:
# If we want to return the output of the network, we need to disable all hooks
for name, gpxq_class in self.gpxq_layers.items():
Expand Down Expand Up @@ -152,6 +197,71 @@ def __init__(self, layer, name, act_order, len_parallel_layers, create_weight_or
self.quantized_input = None
self.index_computed = False
self.p = p
self.save_dir = None

def collect_float_input(self, module, args, output):
# this is the hook function to offload the output of this layer to disc
inp = self.process_input(args)
batch_size = inp.shape[0]

# Preprocess the input to compute the Hessian
if isinstance(self.layer, qnn.QuantLinear):
if len(inp.shape) > 2:
inp = inp.reshape((-1, sum(inp.shape[2:])))
# For QuantLinear layer, groups will be 1
inp_processed = inp.unsqueeze(0)

if isinstance(self.layer, SUPPORTED_CONV_OP):
# Pick the correct unfoldNd class
if isinstance(
self.layer,
(qnn.QuantConvTranspose1d, qnn.QuantConvTranspose2d, qnn.QuantConvTranspose3d)):
unfold_impl = unfoldNd.UnfoldTransposeNd
else:
unfold_impl = unfoldNd.UnfoldNd

unfold = unfold_impl(
self.layer.kernel_size,
dilation=self.layer.dilation,
padding=self.layer.padding,
stride=self.layer.kernel_size)

# Split input based on how many groups in convolution
inp_by_group = torch.chunk(inp, self.groups, 1)
inp_processed = []
# Preprocess input by group
for i, inp in enumerate(inp_by_group):

inp = unfold(inp)

batch_size, num_blocks = inp.shape[0], inp.shape[-1]
inp = torch.transpose(inp, 1, 2) # shape (B, L, C*kernel_size[0]*kernel_size[1])
inp = inp.reshape(-1, inp.size(-1)) # shape (B*L, C*kernel_size[0]*kernel_size[1])

if not self.index_computed:
self.index_computed = True
self.rand_indices = np.concatenate([
np.random.choice(
np.arange(num_blocks * i, num_blocks * (i + 1)),
size=int(
self.p * num_blocks + 1 if self.p != 1 else self.p * num_blocks))
for i in range(batch_size)]) # need to define self.p (probability)

indexes = self.rand_indices
if np.max(self.rand_indices) > inp.shape[0]:
indexes = self.rand_indices < inp.shape[0]
indexes = self.rand_indices[indexes]

inp = inp[indexes]
inp_processed.append(inp)
inp_processed = torch.stack(inp_processed)

inp_processed = inp_processed.cpu()

if self.float_input is None:
self.float_input = inp_processed
else:
self.float_input = torch.cat([self.float_input, inp_processed], dim=1)

def update_batch(self, module, input, current_layer):
if self.disable_pre_forward_hook:
Expand Down
4 changes: 4 additions & 0 deletions src/brevitas/graph/gptq.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,10 @@ def __init__(
# How many subblock to use during GPTQ for each layer
self.num_blocks = num_blocks

def __enter__(self):
self.setup_gpxq_layers()
return self.setup_gpxq_hooks()

def catch_stopfwd(self, *args, **kwargs):
try:
self.orig_forward(*args, **kwargs)
Expand Down
27 changes: 17 additions & 10 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,54 +113,61 @@ def _is_module_supported(self, module):
else:
return False

@abstractmethod
def __enter__(self):
pass

def setup_gpxq_layers(self):
# The user can specify on which layers to apply gptq in parallel.
# All the others will be executed sequentially
dict_of_layers = {
self.dict_of_layers = {
name: [(name, module)] for name,
module in self.model.named_modules() if self._is_module_supported(module)}
if self.group_of_parallel_layers is not None:
for parallel_layers in self.group_of_parallel_layers:
for name in parallel_layers:
if name not in dict_of_layers:
if name not in self.dict_of_layers:
raise ValueError(
"The layer {} is not present in the model or it is not supported for GPTQ"
.format(name))
del dict_of_layers[name]
del self.dict_of_layers[name]
names = '_'.join(parallel_layers)
dict_of_layers[names] = [
self.dict_of_layers[names] = [
(name, attrgetter(name)(self.model)) for name in parallel_layers]

# Print warning if hooks are attached to any module, since the normal forward flow of the
# network is highly disrupted during GPxQ
for _, parallel_layers in dict_of_layers.items():
for _, parallel_layers in self.dict_of_layers.items():
for name, module in parallel_layers:
if len(module._forward_hooks) > 0 or len(module._forward_pre_hooks):
warnings.warn(
f'Hooks detected during setup for GPxQ. '
f'Behaviour might deviate from what expected.')

# Attach hooks for GPTQ
# initialize GPxQ
if self._is_module_supported(module):
gpxq_module_optimizer = self.initialize_module_optimizer(
module,
name,
act_order=self.act_order,
len_parallel_layers=len(parallel_layers),
create_weight_orig=self.create_weight_orig)
hook_fn = partial(
gpxq_module_optimizer.update_batch, current_layer=self.current_layer)
self.hook_dict[name] = module.register_forward_pre_hook(hook_fn)
self.gpxq_layers[name] = gpxq_module_optimizer

def setup_gpxq_hooks(self):
for name, module in self.gpxq_layers.items():
# Attach hooks for GPxQ
hook_fn = partial(module.update_batch, current_layer=self.current_layer)
self.hook_dict[name] = module.layer.register_forward_pre_hook(hook_fn)

if not self.use_quant_activations:
self.return_quant_tensor_state = disable_return_quant_tensor(self.model)
self.disable_quant_inference.disable_act_quantization(
self.model, is_training=self.model.training)
self.disable_quant_inference.disable_bias_quantization(
self.model, is_training=self.model.training)

self.num_layers = len(dict_of_layers)
self.num_layers = len(self.dict_of_layers)
return self

def __exit__(self, type, value, traceback):
Expand Down
20 changes: 18 additions & 2 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,7 +535,14 @@ def apply_gptq(calib_loader, model, act_order=False):
gptq.update()


def apply_gpfq(calib_loader, model, act_order, p=1.0, use_gpfa2q=False, accumulator_bit_width=None):
def apply_gpfq(
calib_loader,
model,
act_order,
p=1.0,
use_gpfa2q=False,
accumulator_bit_width=None,
collect_float_first=True):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
Expand All @@ -545,7 +552,16 @@ def apply_gpfq(calib_loader, model, act_order, p=1.0, use_gpfa2q=False, accumula
use_quant_activations=True,
act_order=act_order,
use_gpfa2q=use_gpfa2q,
accumulator_bit_width=accumulator_bit_width) as gpfq:
accumulator_bit_width=accumulator_bit_width,
collect_float_first=collect_float_first) as gpfq:
if collect_float_first:
print('Collecting float input first...')
for i, (images, target) in tqdm(enumerate(calib_loader)):
images = images.to(device)
images = images.to(dtype)
gpfq.orig_forward(images)
gpfq.finalize_float_collection()

gpfq_model = gpfq.model
for i in tqdm(range(gpfq.num_layers)):
for i, (images, target) in enumerate(calib_loader):
Expand Down
15 changes: 13 additions & 2 deletions src/brevitas_examples/imagenet_classification/ptq/ptq_evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,11 @@ def parse_type(v, default_type):
help=
'Split Ratio for Channel Splitting. When set to 0.0, Channel Splitting will not be applied. (default: 0.0)'
)
add_bool_arg(
parser,
'collect-float-first',
default=False,
help='In GPFQ, separate float and quant forward pass for speed up. (default: False)')
add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(parser, 'gpfa2q', default=False, help='GPFA2Q (default: disabled)')
Expand Down Expand Up @@ -426,7 +431,12 @@ def main():

if args.gpfq:
print("Performing GPFQ:")
apply_gpfq(calib_loader, quant_model, p=args.gpfq_p, act_order=args.gpxq_act_order)
apply_gpfq(
calib_loader,
quant_model,
p=args.gpfq_p,
act_order=args.gpxq_act_order,
collect_float_first=args.collect_float_first)

if args.gpfa2q:
print("Performing GPFA2Q:")
Expand All @@ -436,7 +446,8 @@ def main():
p=args.gpfq_p,
act_order=args.gpxq_act_order,
use_gpfa2q=args.gpfa2q,
accumulator_bit_width=args.accumulator_bit_width)
accumulator_bit_width=args.accumulator_bit_width,
collect_float_first=args.collect_float_first)

if args.gptq:
print("Performing GPTQ:")
Expand Down

0 comments on commit 2e53cc8

Please sign in to comment.