Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (examples/llm): Convert BertAttention to a quantizable version #961

Open
wants to merge 12 commits into
base: dev
Choose a base branch
from
184 changes: 168 additions & 16 deletions src/brevitas_examples/llm/llm_quant/mha_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,41 @@ def num_heads(self):
def batch_first(self):
return self.wrapped_mha.batch_first


class QuantizableOPTAttention(MultiheadAttentionWrapper):

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if key_value_states is None:
key_value_states = hidden_states
if layer_head_mask is not None:
raise RuntimeError("layer_head_mask is not supported.")
if self.mha.batch_first:
batch_size, query_seq_length = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[1]
else:
query_seq_length, batch_size = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[0]
num_heads = self.mha.num_heads
attention_mask = attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
attn_output, attn_output_weights = self.mha(
hidden_states,
key_value_states,
key_value_states,
attn_mask=attention_mask,
need_weights=output_attentions,
average_attn_weights=False)
past_key_value = None
return attn_output, attn_output_weights, past_key_value

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):
Expand Down Expand Up @@ -142,36 +177,153 @@ def set_weight(value):
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)


class QuantizableOPTAttention(MultiheadAttentionWrapper):
class QuantizableBertAttention(MultiheadAttentionWrapper):

def __init__(
self,
all_head_size,
num_attention_heads,
ln_normalized_shape,
dropout=0.,
bias=True,
add_bias_kv=False,
add_zero_attn=False,
kdim=None,
vdim=None,
batch_first=False,
ln_eps=1e-05,
ln_elementwise_affine=True,
ln_bias=True,
device=None,
dtype=None) -> None:
super().__init__(
embed_dim=all_head_size,
num_heads=num_attention_heads,
dropout=dropout,
bias=bias,
add_bias_kv=add_bias_kv,
add_zero_attn=add_zero_attn,
kdim=kdim,
vdim=vdim,
batch_first=batch_first,
device=device,
dtype=dtype)
self.ln = nn.LayerNorm(
normalized_shape=ln_normalized_shape,
eps=ln_eps,
elementwise_affine=ln_elementwise_affine,
bias=ln_bias,
device=device,
dtype=dtype)

def forward(
self,
hidden_states: torch.Tensor,
key_value_states: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
attention_mask: Optional[torch.Tensor] = None,
layer_head_mask: Optional[torch.Tensor] = None,
output_attentions: bool = False,
attention_mask: Optional[torch.FloatTensor] = None,
head_mask: Optional[torch.FloatTensor] = None,
encoder_hidden_states: Optional[torch.FloatTensor] = None,
encoder_attention_mask: Optional[torch.FloatTensor] = None,
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
output_attentions: Optional[bool] = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
if key_value_states is None:
key_value_states = hidden_states
if layer_head_mask is not None:
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
if encoder_attention_mask is not None:
attention_mask = encoder_attention_mask
if head_mask is not None:
raise RuntimeError("layer_head_mask is not supported.")
if self.batch_first:

if self.mha.batch_first:
batch_size, query_seq_length = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[1]
key_value_seq_length = encoder_hidden_states.shape[1]
else:
query_seq_length, batch_size = hidden_states.shape[:2]
key_value_seq_length = key_value_states.shape[0]
num_heads = self.num_heads
key_value_seq_length = encoder_hidden_states.shape[0]
num_heads = self.mha.num_heads
attention_mask = attention_mask_handler(
attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length)
attn_output, attn_output_weights = self.mha(
hidden_states,
key_value_states,
key_value_states,
encoder_hidden_states,
encoder_hidden_states,
attn_mask=attention_mask,
need_weights=output_attentions,
average_attn_weights=False)
ln_output = self.ln(attn_output + hidden_states)
past_key_value = None
return attn_output, attn_output_weights, past_key_value
return ln_output, attn_output_weights, past_key_value

def _load_from_state_dict(
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys,
error_msgs):

def set_bias(value):
bias_name = f'{prefix}mha.in_proj_bias'
if bias_name in state_dict:
state_dict[bias_name] += value
else:
state_dict[bias_name] = value

def set_weight(value):
weight_name = f'{prefix}mha.in_proj_weight'
if weight_name in state_dict:
state_dict[weight_name] += value
else:
state_dict[weight_name] = value

embed_dim = self.mha.embed_dim
for name, value in list(state_dict.items()):
if prefix + 'self.query.weight' in name:
weight = torch.zeros((3 * embed_dim, embed_dim),
device=value.device,
dtype=value.dtype)
weight[:embed_dim] = value
set_weight(weight)
del state_dict[name]
elif prefix + 'self.key.weight' in name:
weight = torch.zeros((3 * embed_dim, embed_dim),
device=value.device,
dtype=value.dtype)
weight[embed_dim:2 * embed_dim] = value
set_weight(weight)
del state_dict[name]
elif prefix + 'self.value.weight' in name:
weight = torch.zeros((3 * embed_dim, embed_dim),
device=value.device,
dtype=value.dtype)
weight[2 * embed_dim:3 * embed_dim] = value
set_weight(weight)
del state_dict[name]
if prefix + 'self.query.bias' in name:
bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype)
bias[:embed_dim] = value
set_bias(bias)
del state_dict[name]
elif prefix + 'self.key.bias' in name:
bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype)
bias[embed_dim:2 * embed_dim] = value
set_bias(bias)
del state_dict[name]
elif prefix + 'self.value.bias' in name:
bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype)
bias[2 * embed_dim:3 * embed_dim] = value
set_bias(bias)
del state_dict[name]
if prefix + 'output.dense.weight' in name:
weight_name = f'{prefix}mha.out_proj.weight'
state_dict[weight_name] = value
del state_dict[name]
if prefix + 'output.dense.bias' in name:
weight_name = f'{prefix}mha.out_proj.bias'
state_dict[weight_name] = value
del state_dict[name]
if prefix + 'output.LayerNorm.weight' in name:
weight_name = f'{prefix}ln.weight'
state_dict[weight_name] = value
del state_dict[name]
if prefix + 'output.LayerNorm.bias' in name:
weight_name = f'{prefix}ln.bias'
state_dict[weight_name] = value
del state_dict[name]
return super()._load_from_state_dict(
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs)
29 changes: 28 additions & 1 deletion src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,38 @@
import warnings

from transformers.models.bert.modeling_bert import BertAttention
from transformers.models.opt.modeling_opt import OPTAttention

from brevitas.graph import ModuleToModuleByClass
from brevitas_examples.llm.llm_quant.mha_layers import QuantizableBertAttention
from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention

QUANTIZABLE_MHA_MAP = {OPTAttention: (QuantizableOPTAttention, {'batch_first': True})}
QUANTIZABLE_MHA_MAP = {
OPTAttention: (QuantizableOPTAttention, {
'batch_first': True}),
BertAttention: (QuantizableBertAttention, {
'batch_first': True}),}


def _set_bert_mha_attributes(module):
module.all_head_size = module._modules['self'].all_head_size
module.num_attention_heads = module._modules['self'].num_attention_heads
module.ln_normalized_shape = module._modules['output'].LayerNorm.normalized_shape
module.ln_eps = module._modules['output'].LayerNorm.eps
module.ln_elementwise_affine = module._modules['output'].LayerNorm.elementwise_affine
module.ln_bias = False if module._modules['output'].LayerNorm.bias is None else True


_SET_ATTRIBUTES_MAP = {
BertAttention: _set_bert_mha_attributes,}


def set_mha_attributes(model):
for name, module in model.named_modules():
mod_type = type(module)
if mod_type in _SET_ATTRIBUTES_MAP.keys():
_SET_ATTRIBUTES_MAP[mod_type](module)
return model


def replace_mha_with_quantizable_layers(model, dtype):
Expand Down
Loading