Skip to content

Commit

Permalink
feat: add HQQ quantization support (#795)
Browse files Browse the repository at this point in the history
* feat: add HQQ quantization support

* modify gptq_marlin kernels to support hqq

* fix: windows compilation

* formatting
  • Loading branch information
AlpinDale authored Nov 2, 2024
1 parent 43965f7 commit f98e7b2
Show file tree
Hide file tree
Showing 11 changed files with 766 additions and 85 deletions.
6 changes: 4 additions & 2 deletions aphrodite/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,11 +321,13 @@ def gptq_marlin_gemm(a: torch.Tensor,
size_k: int,
is_k_full: bool,
has_zp: bool = False,
use_fp32_reduce: bool = False) -> torch.Tensor:
use_fp32_reduce: bool = False,
is_zp_float: bool = False) -> torch.Tensor:
return torch.ops._C.gptq_marlin_gemm(a, b_q_weight, b_scales, b_zeros,
g_idx, perm, workspace, b_q_type,
size_m, size_n, size_k, is_k_full,
has_zp, use_fp32_reduce)
has_zp, use_fp32_reduce,
is_zp_float)


# fp8
Expand Down
2 changes: 1 addition & 1 deletion aphrodite/modeling/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod",
"AWQMarlinLinearMethod", "AWQLinearMethod",
"AWQMarlinLinearMethod", "AWQLinearMethod", "HQQMarlinMethod",
]


Expand Down
56 changes: 55 additions & 1 deletion aphrodite/modeling/parameter.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Callable, Optional, Union
from typing import Callable, List, Optional, Tuple, Union

import torch
from torch.nn import Parameter
Expand Down Expand Up @@ -335,3 +335,57 @@ def _adjust_shard_indexes_for_packing(shard_size, shard_offset, packed_factor,
shard_offset=shard_offset,
marlin_tile_size=marlin_tile_size)
return shard_size, shard_offset

# Qweights in HQQ need to be reshaped such that the shape of the stored tensors
# is the actual shape used in weight multiplication. This is needed to correctly
# repack to Marlin. We also store shard size and offsets in order to be able to
# correctly unpack (shard by shard) from 4-bit to 8-bit.
class HQQQweightParameter(PackedAphroditeParameter):

def __init__(self, packed_factor: int, packed_dim: int, **kwargs):
super().__init__(packed_factor, packed_dim, None, **kwargs)
self.shard_offsets: List[Tuple[int, int]] = []
self.pack_factor = packed_factor

def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size)
self.shard_offsets.append((shard_offset, shard_size))

loaded_weight = loaded_weight.reshape(-1, self.shape[1])
super().load_merged_column_weight(loaded_weight, **kwargs)

def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
self.shard_offsets.append((0, self.shape[self.output_dim]))

loaded_weight = loaded_weight.reshape(-1, self.shape[1])
super().load_row_parallel_weight(loaded_weight)

def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
shard_offset = kwargs.get("shard_offset")
shard_size = kwargs.get("shard_size")
shard_size, shard_offset = self.adjust_shard_indexes_for_packing(
shard_offset=shard_offset, shard_size=shard_size)
self.shard_offsets.append((shard_offset, shard_size))

loaded_weight = loaded_weight.reshape(-1, self.shape[1])
super().load_qkv_weight(loaded_weight, **kwargs)


# Zero points and scales in HQQ must also be reshaped to their actual shapes.
class HQQZeroScaleParameter(GroupQuantScaleParameter):

def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
super().load_merged_column_weight(loaded_weight, **kwargs)

def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
super().load_row_parallel_weight(loaded_weight)

def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
loaded_weight = loaded_weight.reshape(-1, self.shape[1])
super().load_qkv_weight(loaded_weight, **kwargs)

2 changes: 2 additions & 0 deletions aphrodite/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from aphrodite.quantization.gptq import GPTQConfig
from aphrodite.quantization.gptq_marlin import GPTQMarlinConfig
from aphrodite.quantization.gptq_marlin_24 import GPTQMarlin24Config
from aphrodite.quantization.hqq_marlin import HQQMarlinConfig
from aphrodite.quantization.marlin import MarlinConfig
from aphrodite.quantization.qqq import QQQConfig
from aphrodite.quantization.quip import QuipConfig
Expand Down Expand Up @@ -45,6 +46,7 @@
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
"qqq": QQQConfig,
"hqq": HQQMarlinConfig,
"experts_int8": ExpertsInt8Config,
# the quant_llm methods
"fp2": QuantLLMFPConfig,
Expand Down
279 changes: 279 additions & 0 deletions aphrodite/quantization/hqq_marlin.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,279 @@
from typing import Any, Dict, List, Optional, Tuple

import torch

from aphrodite import _custom_ops as ops
from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
from aphrodite.modeling.parameter import (BaseAphroditeParameter,
HQQQweightParameter,
HQQZeroScaleParameter)
from aphrodite.modeling.utils import set_weight_attrs
from aphrodite.quantization.base_config import QuantizationConfig
from aphrodite.quantization.utils.marlin_utils import (
GPTQ_MARLIN_MAX_PARALLEL, GPTQ_MARLIN_MIN_THREAD_N,
marlin_make_empty_g_idx, marlin_permute_scales)
from aphrodite.quantization.utils.marlin_utils_test import MarlinWorkspace
from aphrodite.quantization.utils.quant_utils import gptq_pack
from aphrodite.scalar_type import scalar_types


class HQQMarlinConfig(QuantizationConfig):
"""Config class for HQQ Marlin"""

# (num_bits, is_sym) -> quant_type
TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}

def __init__(
self,
weight_bits: int,
group_size: int,
) -> None:
self.pack_factor = 8 // weight_bits # packed into uint8
self.group_size = group_size
self.quant_type = self.TYPE_MAP[(weight_bits)]

def __repr__(self) -> str:
return (f"HQQMarlinConfig(quant_type={self.quant_type}, "
f"group_size={self.group_size})")

@classmethod
def get_name(cls) -> str:
return "hqq"

@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]

@classmethod
def get_min_capability(cls) -> int:
return 80

@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]

@classmethod
def from_config(cls, config: Dict[str, Any]) -> "HQQMarlinConfig":
wq_params = (config["quant_config"]["weight_quant_params"])
weight_bits = cls.get_from_keys(wq_params, ["nbits"])
group_size = cls.get_from_keys(wq_params, ["group_size"])
return cls(weight_bits, group_size)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
#TODO
return None

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["HQQMarlinMethod"]:
if isinstance(layer, LinearBase):
return HQQMarlinMethod(self)
return None

def get_scaled_act_names(self) -> List[str]:
return []


# Empty HQQ parameter, will be ignored during loading
class HQQEmptyParameter(BaseAphroditeParameter):

def load_merged_column_weight(self, loaded_weight: torch.Tensor, **kwargs):
pass

def load_row_parallel_weight(self, loaded_weight: torch.Tensor):
pass

def load_qkv_weight(self, loaded_weight: torch.Tensor, **kwargs):
pass


def error_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None:
raise ValueError("No loader provided for HQQ parameter!")


class HQQMarlinMethod(LinearMethodBase):
"""Linear method for HQQ Marlin.
"""

def __init__(
self,
quant_config: HQQMarlinConfig,
):
self.quant_config = quant_config

def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
self.output_size_per_partition = sum(output_partition_sizes)

self.input_size_per_partition = input_size_per_partition

weight_loader = extra_weight_attrs.get("weight_loader", error_loader)

self.scales_and_zp_size = (input_size_per_partition //
self.quant_config.group_size)

# Quantized weights
qweight = HQQQweightParameter(
data=torch.empty(
self.output_size_per_partition //
self.quant_config.pack_factor,
input_size_per_partition,
dtype=torch.uint8,
),
input_dim=1,
output_dim=0,
packed_dim=0,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader)

set_weight_attrs(qweight, {
"is_hqq_weight": True,
"shard_offsets:": [],
})

zeros = HQQZeroScaleParameter(data=torch.empty(
self.output_size_per_partition,
self.scales_and_zp_size,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

scales = HQQZeroScaleParameter(data=torch.empty(
self.output_size_per_partition,
self.scales_and_zp_size,
dtype=params_dtype,
),
input_dim=1,
output_dim=0,
weight_loader=weight_loader)

layer.register_parameter("W_q", qweight)
layer.register_parameter("zero", zeros)
layer.register_parameter("scale", scales)

# Ignore extra parameters in the HQQ model.
# To be added as needed.
ignore_parameters = ("axis", "channel_wise", "compute_dtype",
"encoded_state_dict", "group_size", "nbits",
"offload_meta", "optimize", "packing",
"quant_scale", "quant_zero", "round_zero",
"shape", "stores_quant_config",
"unpack_view_dtype", "view_as_float")
for name in ignore_parameters:
layer.register_parameter(
name,
HQQEmptyParameter(data=torch.empty(0),
weight_loader=weight_loader))

# Unpack weights from the HQQ format and repack them to GPTQ -> Marlin
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
dev = layer.W_q.device

# unpack function from https://github.com/mobiusml/hqq
def unpack_4bit_u8(
W_q: torch.Tensor,
shard_offsets: List[Tuple[int, int]],
) -> torch.Tensor: # uint8/2 > uint8
dtype = torch.uint8
tmp = torch.empty([2 * W_q.shape[0], W_q.shape[1]],
dtype=dtype,
device=W_q.device)
for (offset, size) in shard_offsets:
tmp_offset = 2 * offset
tmp[tmp_offset:tmp_offset +
size] = (W_q[offset:offset + size] & 0b11110000) >> 4
tmp[tmp_offset + size:tmp_offset +
2 * size] = (W_q[offset:offset + size] & 0b00001111)
return tmp

# Unpack from 4-bit to 8-bit
shard_offsets = getattr(layer.W_q, "shard_offsets", [])
qweight_t = unpack_4bit_u8(layer.W_q, shard_offsets).transpose(1, 0)

# Repack to GPTQ
gptq_w_q = gptq_pack(qweight_t, 4, self.input_size_per_partition,
self.output_size_per_partition)

# Repack to Marlin
sort_indices = torch.empty(0, dtype=torch.int, device=gptq_w_q.device)
marlin_w_q = ops.gptq_marlin_repack(
gptq_w_q,
sort_indices,
self.input_size_per_partition,
self.output_size_per_partition,
4,
).to(dev)
marlin_s = marlin_permute_scales(layer.scale.transpose(1, 0),
self.input_size_per_partition,
self.output_size_per_partition,
self.quant_config.group_size).to(dev)
marlin_zp = marlin_permute_scales(layer.zero.transpose(1, 0),
self.input_size_per_partition,
self.output_size_per_partition,
self.quant_config.group_size).to(dev)

layer.g_idx = marlin_make_empty_g_idx(dev)
layer.g_idx_sort_indices = marlin_make_empty_g_idx(dev)

layer.marlin_qweight = marlin_w_q
layer.marlin_zeros = marlin_zp
layer.marlin_scales = marlin_s

def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
workspace = MarlinWorkspace(self.output_size_per_partition,
GPTQ_MARLIN_MIN_THREAD_N,
GPTQ_MARLIN_MAX_PARALLEL)

scales = layer.marlin_scales
zeros = layer.marlin_zeros
orig_type = x.dtype

if orig_type != torch.float16:
x = x.to(torch.float16)
scales = scales.to(torch.float16)
zeros = zeros.to(torch.float16)

marlin_out = ops.gptq_marlin_gemm(
x,
layer.marlin_qweight,
scales,
zeros,
layer.g_idx,
layer.g_idx_sort_indices,
workspace.scratch,
scalar_types.uint4,
x.shape[0],
self.output_size_per_partition,
self.input_size_per_partition,
True, # is_k_full
True, # has_zp
False, # use 32-bit reduce
True, # use float zp
)

if bias is not None:
marlin_out.add_(bias)

if orig_type != torch.float16:
return marlin_out.to(orig_type)
else:
return marlin_out
Loading

0 comments on commit f98e7b2

Please sign in to comment.