Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gpfq/act order #729

Merged
merged 2 commits into from
Nov 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 27 additions & 13 deletions src/brevitas/graph/gpfq.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
inplace: bool = True,
create_weight_orig: bool = True,
use_quant_activations: bool = True,
p: int = 0.25,
p: float = 1.0,
return_forward_output: bool = False,
act_order: bool = False) -> None:
if not inplace:
Expand Down Expand Up @@ -117,12 +117,10 @@ def __init__(
act_order,
len_parallel_layers=1,
create_weight_orig=True,
p=0.25) -> None:

if act_order:
raise ValueError("Act_order is not supported in GPFQ")
p=1.0) -> None:

super().__init__(layer, name, act_order, len_parallel_layers, create_weight_orig)

self.float_input = None
self.quantized_input = None
self.index_computed = False
Expand Down Expand Up @@ -220,25 +218,41 @@ def single_layer_update(self):
weight.shape[0], weight.shape[1], self.float_input.shape[1], device=dev, dtype=dtype)
self.float_input = self.float_input.to(dev)
self.quantized_input = self.quantized_input.to(dev)
permutation_list = [torch.tensor(range(weight.shape[-1]))]
# We don't need full Hessian, we just need the diagonal
self.H_diag = self.quantized_input.transpose(2, 1).square().sum(
2) # summing over Batch dimension
permutation_list = []
for group_index in range(self.groups):
if self.act_order:
# Re-order Hessian_diagonal so that weights associated to
# higher magnitude activations are quantized first
perm = torch.argsort(self.H_diag[group_index, :], descending=True)
else:
# No permutation, permutation tensor is a ordered index
perm = torch.tensor(range(weight.shape[-1]), device=dev)
permutation_list.append(perm)
for t in range(weight.shape[-1]):
for group_index in range(self.groups):
U[group_index] += torch.matmul(
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
weight[group_index, :, t].unsqueeze(1),
self.float_input[group_index, :,
t].unsqueeze(0)) #[OC/Groups, 1] * [1, INSHAPE[1]]
norm = torch.linalg.norm(self.quantized_input[group_index, :, t], 2) ** 2
weight[group_index, :, permutation_list[group_index][t]].unsqueeze(1),
self.float_input[group_index, :, permutation_list[group_index][t]].unsqueeze(
0)) #[OC/Groups, 1] * [1, INSHAPE[1]]
norm = torch.linalg.norm(
self.quantized_input[group_index, :, permutation_list[group_index][t]], 2) ** 2
if norm > 0:
q_arg = U[group_index].matmul(self.quantized_input[group_index, :, t]) / norm
q_arg = U[group_index].matmul(
self.quantized_input[group_index, :,
permutation_list[group_index][t]]) / norm
else:
q_arg = torch.zeros_like(U[group_index, :, 0])

weight[group_index, :, t] = q_arg
weight[group_index, :, permutation_list[group_index][t]] = q_arg
q = self.get_quant_weights(t, 0, permutation_list)
for group_index in range(self.groups):
U[group_index] -= torch.matmul(
q[group_index].unsqueeze(1),
self.quantized_input[group_index, :, t].unsqueeze(0))
self.quantized_input[group_index, :,
permutation_list[group_index][t]].unsqueeze(0))

del self.float_input
del self.quantized_input
Original file line number Diff line number Diff line change
Expand Up @@ -471,12 +471,12 @@ def apply_gptq(calib_loader, model, act_order=False):
gptq.update()


def apply_gpfq(calib_loader, model, p=0.25):
def apply_gpfq(calib_loader, model, act_order, p=0.25):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
with torch.no_grad():
with gpfq_mode(model, p=p, use_quant_activations=True) as gpfq:
with gpfq_mode(model, p=p, use_quant_activations=True, act_order=act_order) as gpfq:
gpfq_model = gpfq.model
for i in tqdm(range(gpfq.num_layers)):
for i, (images, target) in enumerate(calib_loader):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
default=True,
help='Narrow range for weight quantization (default: enabled)')
parser.add_argument(
'--gpfq-p', default=0.25, type=float, help='P parameter for GPFQ (default: 0.25)')
'--gpfq-p', default=1.0, type=float, help='P parameter for GPFQ (default: 0.25)')
parser.add_argument(
'--quant-format',
default='int',
Expand Down Expand Up @@ -207,10 +207,12 @@
default=3,
type=int,
help='Exponent bit width used with float quantization for activations (default: 3)')
add_bool_arg(parser, 'gptq', default=True, help='GPTQ (default: enabled)')
add_bool_arg(parser, 'gptq', default=False, help='GPTQ (default: disabled)')
add_bool_arg(parser, 'gpfq', default=False, help='GPFQ (default: disabled)')
add_bool_arg(
parser, 'gptq-act-order', default=False, help='GPTQ Act order heuristic (default: disabled)')
add_bool_arg(
parser, 'gpfq-act-order', default=False, help='GPFQ Act order heuristic (default: disabled)')
add_bool_arg(parser, 'learned-round', default=False, help='Learned round (default: disabled)')
add_bool_arg(parser, 'calibrate-bn', default=False, help='Calibrate BN (default: disabled)')

Expand Down Expand Up @@ -241,6 +243,7 @@ def main():
f"{'gptq_' if args.gptq else ''}"
f"{'gpfq_' if args.gpfq else ''}"
f"{'gptq_act_order_' if args.gptq_act_order else ''}"
f"{'gpfq_act_order_' if args.gpfq_act_order else ''}"
f"{'learned_round_' if args.learned_round else ''}"
f"{'weight_narrow_range_' if args.weight_narrow_range else ''}"
f"{args.bias_bit_width}bias_"
Expand All @@ -263,6 +266,7 @@ def main():
f"GPFQ: {args.gpfq} - "
f"GPFQ P: {args.gpfq_p} - "
f"GPTQ Act Order: {args.gptq_act_order} - "
f"GPFQ Act Order: {args.gpfq_act_order} - "
f"Learned Round: {args.learned_round} - "
f"Weight narrow range: {args.weight_narrow_range} - "
f"Bias bit width: {args.bias_bit_width} - "
Expand Down Expand Up @@ -359,7 +363,7 @@ def main():

if args.gpfq:
print("Performing GPFQ:")
apply_gpfq(calib_loader, quant_model, p=args.gpfq_p)
apply_gpfq(calib_loader, quant_model, p=args.gpfq_p, act_order=args.gpfq_act_order)

if args.gptq:
print("Performing GPTQ:")
Expand Down