Skip to content

Commit

Permalink
Feat (tests): testing various cases for gpxq_modes
Browse files Browse the repository at this point in the history
  • Loading branch information
i-colbert committed Mar 2, 2024
1 parent 935dd47 commit e3dd0ea
Showing 1 changed file with 70 additions and 60 deletions.
130 changes: 70 additions & 60 deletions tests/brevitas/graph/test_gpxq.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,15 @@
# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause

import pytest
import torch
import torch.nn
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.utils.data import TensorDataset

from brevitas.graph.gpfq import gpfq_mode
from brevitas.graph.quantize import preprocess_for_quantize
from brevitas.graph.gptq import gptq_mode
import brevitas.nn as qnn
from brevitas_examples.imagenet_classification.ptq.ptq_common import quantize_model

from .equalization_fixtures import *


class QuantConvModel(nn.Module):
Expand All @@ -30,9 +28,10 @@ def forward(self, x):


def apply_gpfq(
calib_loader,
model,
act_order,
calib_loader: DataLoader,
model: nn.Module,
act_order: bool,
use_quant_activations: bool = True,
accumulator_bit_width: int = 32,
a2q_layer_filter_fnc=lambda x: True):
model.eval()
Expand All @@ -42,7 +41,7 @@ def apply_gpfq(
# use A2GPFQ if accumulator is less than 32 is specified
with gpfq_mode(
model,
use_quant_activations=True,
use_quant_activations=use_quant_activations,
act_order=act_order,
use_gpfa2q=accumulator_bit_width < 32,
accumulator_bit_width=accumulator_bit_width,
Expand All @@ -57,69 +56,80 @@ def apply_gpfq(
gpfq.update()


def apply_gptq(
calib_loader: DataLoader, model: nn.Module, act_order: bool, use_quant_activations: bool):
model.eval()
dtype = next(model.parameters()).dtype
device = next(model.parameters()).device
with torch.no_grad():
with gptq_mode(
model,
use_quant_activations=use_quant_activations,
act_order=act_order,
) as gptq:
gptq_model = gptq.model
for _ in range(gptq.num_layers):
for _, (images, _) in enumerate(calib_loader):
images = images.to(device)
images = images.to(dtype)
gptq_model(images)
gptq.update()


def custom_layer_filter_fnc(layer: nn.Module) -> bool:
if isinstance(layer, nn.Conv2d) and layer.in_channels == 3:
return False
return True


@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("acc_bit_width", [32, 24, 16])
def test_toymodels(toy_model, request, act_order: bool, acc_bit_width: int):
model_name = request.node.callspec.id.split('-')[0]

torch.manual_seed(SEED)

model_class = toy_model
model = model_class()

# preprocess model for quantization, like merge BN etc.
model = preprocess_for_quantize(model)
# quantize model pretty basic
model = quantize_model(
model,
backend='layerwise',
weight_bit_width=8,
act_bit_width=8,
bias_bit_width=32,
scale_factor_type='float_scale',
weight_narrow_range=False,
weight_param_method='stats',
weight_quant_granularity='per_channel',
weight_quant_type='sym',
layerwise_first_last_bit_width=8,
act_param_method='stats',
act_quant_percentile=99.999,
act_quant_type='sym',
quant_format='int')

if 'mha' in model_name:
inp = torch.randn(256, *IN_SIZE_LINEAR[1:])
else:
inp = torch.randn(256, *IN_SIZE_CONV[1:])
def identity_layer_filter_func(layer: nn.Module) -> bool:
return True

dataset = TensorDataset(inp, inp)
calibloader = DataLoader(dataset, batch_size=32, num_workers=0, pin_memory=True, shuffle=True)

apply_gpfq(
calibloader,
model,
act_order=act_order,
accumulator_bit_width=acc_bit_width,
a2q_layer_filter_fnc=custom_layer_filter_fnc)
filter_func_dict = {
"identity": identity_layer_filter_func,
"ignore_input": custom_layer_filter_fnc,}


@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("acc_bit_width", [32, 24, 16])
def test_toymodels(act_order: bool, acc_bit_width: int):
@pytest.mark.parametrize("use_quant_activations", [True, False])
@pytest.mark.parametrize("acc_bit_width", [32, 24, 16, 12])
@pytest.mark.parametrize("filter_func_str", filter_func_dict.keys())
def test_gpfq(
act_order: bool, use_quant_activations: bool, acc_bit_width: int, filter_func_str: str):
model = QuantConvModel()
inp = torch.randn(100, 3, 32, 32)
dataset = TensorDataset(inp, inp)
calibloader = DataLoader(dataset, batch_size=32, num_workers=0, pin_memory=True, shuffle=True)
filter_func = filter_func_dict[filter_func_str]
if (acc_bit_width < 32) and (not use_quant_activations or filter_func_str == "identity"):
# GPFA2Q requires that the quant activations are used. GPFA2Q.single_layer_update will
# raise a ValueError if GPFA2Q.quant_input is None (also see GPxQ.process_input). This will
# happen when `use_quant_activations=False` or when the input to a model is not quantized
# and `a2q_layer_filter_fnc` does not properly handle it.
with pytest.raises(ValueError):
apply_gpfq(
calibloader,
model,
act_order=act_order,
use_quant_activations=use_quant_activations,
accumulator_bit_width=acc_bit_width,
a2q_layer_filter_fnc=filter_func)
else:
apply_gpfq(
calibloader,
model,
act_order=act_order,
use_quant_activations=use_quant_activations,
accumulator_bit_width=acc_bit_width,
a2q_layer_filter_fnc=filter_func)


@pytest.mark.parametrize("act_order", [True, False])
@pytest.mark.parametrize("use_quant_activations", [True, False])
def test_gptq(act_order: bool, use_quant_activations: bool):
model = QuantConvModel()
inp = torch.randn(256, *IN_SIZE_CONV[1:])
inp = torch.randn(100, 3, 32, 32)
dataset = TensorDataset(inp, inp)
calibloader = DataLoader(dataset, batch_size=32, num_workers=0, pin_memory=True, shuffle=True)
apply_gpfq(
calibloader,
model,
act_order=act_order,
accumulator_bit_width=acc_bit_width,
a2q_layer_filter_fnc=custom_layer_filter_fnc)
apply_gptq(calibloader, model, act_order=act_order, use_quant_activations=use_quant_activations)

0 comments on commit e3dd0ea

Please sign in to comment.