From 06dc443ccc1e661b2f14da0ffcf730707f1f7a07 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Fri, 30 Jun 2023 22:54:48 +0100 Subject: [PATCH 01/10] Feat (llm): QuantizableBert --- .../llm/llm_quant/mha_layers.py | 159 ++++++++++++++++-- 1 file changed, 144 insertions(+), 15 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index cf694d4eb..b73622c01 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -7,6 +7,26 @@ from brevitas.utils.torch_utils import KwargsForwardHook +def attention_mask_handler( + attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length): + """Re-arrange attention mask to go from 4D to 3D (explicit batch_size and n_heads) or 2D + (implicit batch_size and n_heads).""" + if len(attention_mask.shape) == 4: + if attention_mask.shape[0] == 1: + attention_mask = attention_mask.repeat(batch_size, 1, 1, 1) + if attention_mask.shape[1] == 1: + attention_mask = attention_mask.repeat(1, num_heads, 1, 1) + if attention_mask.shape[2] == 1: + attention_mask = attention_mask.repeat(1, 1, query_seq_length, 1) + attention_mask = attention_mask.view( + batch_size * num_heads, query_seq_length, key_value_seq_length) + elif len(attention_mask.shape) == 2 and attention_mask.shape[0] == 1: + # This could happen in Encoder-like architecture + assert query_seq_length == key_value_seq_length + attention_mask = attention_mask.repeat(query_seq_length, 1) + return attention_mask + + def attention_mask_handler( attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length): """Re-arrange attention mask to go from 4D to 3D (explicit batch_size and n_heads) or 2D @@ -76,6 +96,41 @@ def num_heads(self): def batch_first(self): return self.wrapped_mha.batch_first + +class QuantizableOPTAttention(MultiheadAttentionWrapper): + + def forward( + self, + hidden_states: torch.Tensor, + key_value_states: Optional[torch.Tensor] = None, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + attention_mask: Optional[torch.Tensor] = None, + layer_head_mask: Optional[torch.Tensor] = None, + output_attentions: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if key_value_states is None: + key_value_states = hidden_states + if layer_head_mask is not None: + raise RuntimeError("layer_head_mask is not supported.") + if self.mha.batch_first: + batch_size, query_seq_length = hidden_states.shape[:2] + key_value_seq_length = key_value_states.shape[1] + else: + query_seq_length, batch_size = hidden_states.shape[:2] + key_value_seq_length = key_value_states.shape[0] + num_heads = self.mha.num_heads + attention_mask = attention_mask_handler( + attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length) + attn_output, attn_output_weights = self.mha( + hidden_states, + key_value_states, + key_value_states, + attn_mask=attention_mask, + need_weights=output_attentions, + average_attn_weights=False) + past_key_value = None + return attn_output, attn_output_weights, past_key_value + def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): @@ -142,36 +197,110 @@ def set_weight(value): state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) -class QuantizableOPTAttention(MultiheadAttentionWrapper): +class QuantizableBertAttention(MultiheadAttentionWrapper): + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) def forward( self, hidden_states: torch.Tensor, - key_value_states: Optional[torch.Tensor] = None, - past_key_value: Optional[Tuple[torch.Tensor]] = None, - attention_mask: Optional[torch.Tensor] = None, - layer_head_mask: Optional[torch.Tensor] = None, - output_attentions: bool = False, + attention_mask: Optional[torch.FloatTensor] = None, + head_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + output_attentions: Optional[bool] = False, ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: - if key_value_states is None: - key_value_states = hidden_states - if layer_head_mask is not None: + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + if encoder_attention_mask is not None: + attention_mask = encoder_attention_mask + if head_mask is not None: raise RuntimeError("layer_head_mask is not supported.") - if self.batch_first: + + if self.mha.batch_first: batch_size, query_seq_length = hidden_states.shape[:2] - key_value_seq_length = key_value_states.shape[1] + key_value_seq_length = encoder_hidden_states.shape[1] else: query_seq_length, batch_size = hidden_states.shape[:2] - key_value_seq_length = key_value_states.shape[0] - num_heads = self.num_heads + key_value_seq_length = encoder_hidden_states.shape[0] + num_heads = self.mha.num_heads attention_mask = attention_mask_handler( attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length) attn_output, attn_output_weights = self.mha( hidden_states, - key_value_states, - key_value_states, + encoder_hidden_states, + encoder_hidden_states, attn_mask=attention_mask, need_weights=output_attentions, average_attn_weights=False) past_key_value = None return attn_output, attn_output_weights, past_key_value + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, + error_msgs): + + def set_bias(value): + bias_name = f'{prefix}mha.in_proj_bias' + if bias_name in state_dict: + state_dict[bias_name] += value + else: + state_dict[bias_name] = value + + def set_weight(value): + weight_name = f'{prefix}mha.in_proj_weight' + if weight_name in state_dict: + state_dict[weight_name] += value + else: + state_dict[weight_name] = value + + embed_dim = self.mha.embed_dim + for name, value in list(state_dict.items()): + if prefix + 'query.weight' in name: + weight = torch.zeros((3 * embed_dim, embed_dim), + device=value.device, + dtype=value.dtype) + weight[:embed_dim] = value + set_weight(weight) + del state_dict[name] + elif prefix + 'key.weight' in name: + weight = torch.zeros((3 * embed_dim, embed_dim), + device=value.device, + dtype=value.dtype) + weight[embed_dim:2 * embed_dim] = value + set_weight(weight) + del state_dict[name] + elif prefix + 'value.weight' in name: + weight = torch.zeros((3 * embed_dim, embed_dim), + device=value.device, + dtype=value.dtype) + weight[2 * embed_dim:3 * embed_dim] = value + set_weight(weight) + del state_dict[name] + if prefix + 'query.bias' in name: + bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) + bias[:embed_dim] = value + set_bias(bias) + del state_dict[name] + elif prefix + 'key.bias' in name: + bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) + bias[embed_dim:2 * embed_dim] = value + set_bias(bias) + del state_dict[name] + elif prefix + 'value.bias' in name: + bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) + bias[2 * embed_dim:3 * embed_dim] = value + set_bias(bias) + del state_dict[name] + state_dict[prefix + 'mha.out_proj.weight'] = torch.eye(self.mha.out_proj.weight.shape[0]) + state_dict[prefix + 'mha.out_proj.bias'] = torch.zeros(self.mha.out_proj.bias.shape) + # elif prefix + 'self.output.dense.weight' in name: + # state_dict[prefix + 'mha.out_proj.weight'] = value + # del state_dict[name] + # elif prefix + 'self.output.dense.bias' in name: + # state_dict[prefix + 'mha.out_proj.bias'] = value + # del state_dict[name] + return super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) From d32fca464bc7b00d89120e6f5f57d1ba35c0eea1 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 18 Jul 2023 12:56:22 +0100 Subject: [PATCH 02/10] Remove comments --- src/brevitas_examples/llm/llm_quant/mha_layers.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index b73622c01..9b56c0e63 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -296,11 +296,5 @@ def set_weight(value): del state_dict[name] state_dict[prefix + 'mha.out_proj.weight'] = torch.eye(self.mha.out_proj.weight.shape[0]) state_dict[prefix + 'mha.out_proj.bias'] = torch.zeros(self.mha.out_proj.bias.shape) - # elif prefix + 'self.output.dense.weight' in name: - # state_dict[prefix + 'mha.out_proj.weight'] = value - # del state_dict[name] - # elif prefix + 'self.output.dense.bias' in name: - # state_dict[prefix + 'mha.out_proj.bias'] = value - # del state_dict[name] return super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) From 31568492a963cbc180b99b31852be4b31dccfa72 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 23 May 2024 10:59:30 +0100 Subject: [PATCH 03/10] Fix (llm): Add all rewriters to the list --- src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py index e96a6d946..2a9505227 100644 --- a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -13,7 +13,7 @@ def replace_mha_with_quantizable_layers(model, dtype): for src_module, (quantizable_module, quantizable_module_kwargs) in QUANTIZABLE_MHA_MAP.items(): rewriter = ModuleToModuleByClass( src_module, quantizable_module, **quantizable_module_kwargs, dtype=dtype) - rewriters.append(rewriter) + rewriters.append(rewriter) if not rewriters: warnings.warn( f"No module to replace was found. Supported modules are {list(QUANTIZABLE_MHA_MAP.keys())}" From 8dd94745275c23cb1bf38a455f4cb60a9befd3d1 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 23 May 2024 16:47:24 +0100 Subject: [PATCH 04/10] Fix (example/llm): Remove redundant mask handler --- .../llm/llm_quant/mha_layers.py | 20 ------------------- 1 file changed, 20 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index 9b56c0e63..4b9eafb71 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -7,26 +7,6 @@ from brevitas.utils.torch_utils import KwargsForwardHook -def attention_mask_handler( - attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length): - """Re-arrange attention mask to go from 4D to 3D (explicit batch_size and n_heads) or 2D - (implicit batch_size and n_heads).""" - if len(attention_mask.shape) == 4: - if attention_mask.shape[0] == 1: - attention_mask = attention_mask.repeat(batch_size, 1, 1, 1) - if attention_mask.shape[1] == 1: - attention_mask = attention_mask.repeat(1, num_heads, 1, 1) - if attention_mask.shape[2] == 1: - attention_mask = attention_mask.repeat(1, 1, query_seq_length, 1) - attention_mask = attention_mask.view( - batch_size * num_heads, query_seq_length, key_value_seq_length) - elif len(attention_mask.shape) == 2 and attention_mask.shape[0] == 1: - # This could happen in Encoder-like architecture - assert query_seq_length == key_value_seq_length - attention_mask = attention_mask.repeat(query_seq_length, 1) - return attention_mask - - def attention_mask_handler( attention_mask, batch_size, num_heads, query_seq_length, key_value_seq_length): """Re-arrange attention mask to go from 4D to 3D (explicit batch_size and n_heads) or 2D From 6ebccd68ec8da5130d3cd5241485ce9c3eafe6e4 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 23 May 2024 16:48:23 +0100 Subject: [PATCH 05/10] Feat (example/llm): Removed init function from quantizable BERT --- src/brevitas_examples/llm/llm_quant/mha_layers.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index 4b9eafb71..a92ad8c3d 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -179,9 +179,6 @@ def set_weight(value): class QuantizableBertAttention(MultiheadAttentionWrapper): - def __init__(self, *args, **kwargs) -> None: - super().__init__(*args, **kwargs) - def forward( self, hidden_states: torch.Tensor, From 556162fd1c6e65048373f6c88bb59e98bdb704c1 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Thu, 23 May 2024 16:50:49 +0100 Subject: [PATCH 06/10] Feat (example/llm): Add BertSelfAttention to the replace MHA dictionary --- .../llm/llm_quant/prepare_for_quantize.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py index 2a9505227..b983cdccf 100644 --- a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -1,11 +1,15 @@ import warnings from transformers.models.opt.modeling_opt import OPTAttention +from transformers.models.bert.modeling_bert import BertSelfAttention from brevitas.graph import ModuleToModuleByClass -from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention +from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention, QuantizableBertAttention -QUANTIZABLE_MHA_MAP = {OPTAttention: (QuantizableOPTAttention, {'batch_first': True})} +QUANTIZABLE_MHA_MAP = { + OPTAttention: (QuantizableOPTAttention, {'batch_first': True}), + BertSelfAttention: (QuantizableBertAttention, {'batch_first': True}), +} def replace_mha_with_quantizable_layers(model, dtype): From 87c0756fdf8b3a5bf012d189ae97d5cb022913a2 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 27 May 2024 16:31:49 +0100 Subject: [PATCH 07/10] Feat (examples/llm): Added correct forward signature to `QuantizableBertAttention` --- .../llm/llm_quant/mha_layers.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index a92ad8c3d..27f89c2d3 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -179,6 +179,32 @@ def set_weight(value): class QuantizableBertAttention(MultiheadAttentionWrapper): + def __init__( + self, + all_head_size, + num_attention_heads, + dropout=0., + bias=True, + add_bias_kv=False, + add_zero_attn=False, + kdim=None, + vdim=None, + batch_first=False, + device=None, + dtype=None) -> None: + super().__init__( + embed_dim=all_head_size, + num_heads=num_attention_heads, + dropout=dropout, + bias=bias, + add_bias_kv=add_bias_kv, + add_zero_attn=add_zero_attn, + kdim=kdim, + vdim=vdim, + batch_first=batch_first, + device=device, + dtype=dtype) + def forward( self, hidden_states: torch.Tensor, From 9ef4018fb0d3ffd5f6778964ae9b423ea65999bc Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 27 May 2024 16:43:37 +0100 Subject: [PATCH 08/10] Fixed formatting. --- .../llm/llm_quant/prepare_for_quantize.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py index b983cdccf..a02236147 100644 --- a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -1,15 +1,17 @@ import warnings -from transformers.models.opt.modeling_opt import OPTAttention from transformers.models.bert.modeling_bert import BertSelfAttention +from transformers.models.opt.modeling_opt import OPTAttention from brevitas.graph import ModuleToModuleByClass -from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention, QuantizableBertAttention +from brevitas_examples.llm.llm_quant.mha_layers import QuantizableBertAttention +from brevitas_examples.llm.llm_quant.mha_layers import QuantizableOPTAttention QUANTIZABLE_MHA_MAP = { - OPTAttention: (QuantizableOPTAttention, {'batch_first': True}), - BertSelfAttention: (QuantizableBertAttention, {'batch_first': True}), -} + OPTAttention: (QuantizableOPTAttention, { + 'batch_first': True}), + BertSelfAttention: (QuantizableBertAttention, { + 'batch_first': True}),} def replace_mha_with_quantizable_layers(model, dtype): From f624ad57cf8e9c2d1d9c3b957cd4ad73bd51d2f5 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 31 May 2024 19:38:26 +0100 Subject: [PATCH 09/10] Feat (llm/bert): Replace BertAttention, not BertSelfAttention --- .../llm/llm_quant/mha_layers.py | 44 +++++++++++++++---- .../llm/llm_quant/prepare_for_quantize.py | 26 ++++++++++- 2 files changed, 59 insertions(+), 11 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/mha_layers.py b/src/brevitas_examples/llm/llm_quant/mha_layers.py index 27f89c2d3..151d86a84 100644 --- a/src/brevitas_examples/llm/llm_quant/mha_layers.py +++ b/src/brevitas_examples/llm/llm_quant/mha_layers.py @@ -183,6 +183,7 @@ def __init__( self, all_head_size, num_attention_heads, + ln_normalized_shape, dropout=0., bias=True, add_bias_kv=False, @@ -190,6 +191,9 @@ def __init__( kdim=None, vdim=None, batch_first=False, + ln_eps=1e-05, + ln_elementwise_affine=True, + ln_bias=True, device=None, dtype=None) -> None: super().__init__( @@ -204,6 +208,13 @@ def __init__( batch_first=batch_first, device=device, dtype=dtype) + self.ln = nn.LayerNorm( + normalized_shape=ln_normalized_shape, + eps=ln_eps, + elementwise_affine=ln_elementwise_affine, + bias=ln_bias, + device=device, + dtype=dtype) def forward( self, @@ -238,8 +249,9 @@ def forward( attn_mask=attention_mask, need_weights=output_attentions, average_attn_weights=False) + ln_output = self.ln(attn_output + hidden_states) past_key_value = None - return attn_output, attn_output_weights, past_key_value + return ln_output, attn_output_weights, past_key_value def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, @@ -261,43 +273,57 @@ def set_weight(value): embed_dim = self.mha.embed_dim for name, value in list(state_dict.items()): - if prefix + 'query.weight' in name: + if prefix + 'self.query.weight' in name: weight = torch.zeros((3 * embed_dim, embed_dim), device=value.device, dtype=value.dtype) weight[:embed_dim] = value set_weight(weight) del state_dict[name] - elif prefix + 'key.weight' in name: + elif prefix + 'self.key.weight' in name: weight = torch.zeros((3 * embed_dim, embed_dim), device=value.device, dtype=value.dtype) weight[embed_dim:2 * embed_dim] = value set_weight(weight) del state_dict[name] - elif prefix + 'value.weight' in name: + elif prefix + 'self.value.weight' in name: weight = torch.zeros((3 * embed_dim, embed_dim), device=value.device, dtype=value.dtype) weight[2 * embed_dim:3 * embed_dim] = value set_weight(weight) del state_dict[name] - if prefix + 'query.bias' in name: + if prefix + 'self.query.bias' in name: bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) bias[:embed_dim] = value set_bias(bias) del state_dict[name] - elif prefix + 'key.bias' in name: + elif prefix + 'self.key.bias' in name: bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) bias[embed_dim:2 * embed_dim] = value set_bias(bias) del state_dict[name] - elif prefix + 'value.bias' in name: + elif prefix + 'self.value.bias' in name: bias = torch.zeros(3 * embed_dim, device=value.device, dtype=value.dtype) bias[2 * embed_dim:3 * embed_dim] = value set_bias(bias) del state_dict[name] - state_dict[prefix + 'mha.out_proj.weight'] = torch.eye(self.mha.out_proj.weight.shape[0]) - state_dict[prefix + 'mha.out_proj.bias'] = torch.zeros(self.mha.out_proj.bias.shape) + if prefix + 'output.dense.weight' in name: + weight_name = f'{prefix}mha.out_proj.weight' + state_dict[weight_name] = value + del state_dict[name] + if prefix + 'output.dense.bias' in name: + weight_name = f'{prefix}mha.out_proj.bias' + state_dict[weight_name] = value + del state_dict[name] + if prefix + 'output.LayerNorm.weight' in name: + weight_name = f'{prefix}ln.weight' + state_dict[weight_name] = value + del state_dict[name] + if prefix + 'output.LayerNorm.bias' in name: + weight_name = f'{prefix}ln.bias' + state_dict[weight_name] = value + del state_dict[name] return super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py index a02236147..77b4782e5 100644 --- a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -1,6 +1,6 @@ import warnings -from transformers.models.bert.modeling_bert import BertSelfAttention +from transformers.models.bert.modeling_bert import BertAttention from transformers.models.opt.modeling_opt import OPTAttention from brevitas.graph import ModuleToModuleByClass @@ -10,10 +10,32 @@ QUANTIZABLE_MHA_MAP = { OPTAttention: (QuantizableOPTAttention, { 'batch_first': True}), - BertSelfAttention: (QuantizableBertAttention, { + BertAttention: (QuantizableBertAttention, { 'batch_first': True}),} +def _set_bert_mha_attributes(module): + module.all_head_size = module._modules['self'].all_head_size + module.num_attention_heads = module._modules['self'].num_attention_heads + module.ln_normalized_shape = module._modules['output'].LayerNorm.normalized_shape + module.ln_eps = module._modules['output'].LayerNorm.eps + module.ln_elementwise_affine = module._modules['output'].LayerNorm.elementwise_affine + module.ln_bias = False if module._modules['output'].LayerNorm.bias is None else True + + +_SET_ATTRIBUTES_MAP = { + BertAttention: _set_bert_mha_attributes, +} + + +def set_mha_attributes(model): + for name, module in model.named_modules(): + mod_type = type(module) + if mod_type in _SET_ATTRIBUTES_MAP.keys(): + _SET_ATTRIBUTES_MAP[mod_type](module) + return model + + def replace_mha_with_quantizable_layers(model, dtype): rewriters = [] for src_module, (quantizable_module, quantizable_module_kwargs) in QUANTIZABLE_MHA_MAP.items(): From 4e926f924dd1a16238fbd28c5c9440e225cf66b2 Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Fri, 31 May 2024 19:57:51 +0100 Subject: [PATCH 10/10] Fix formatting. --- src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py index 77b4782e5..b29c83809 100644 --- a/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py +++ b/src/brevitas_examples/llm/llm_quant/prepare_for_quantize.py @@ -24,8 +24,7 @@ def _set_bert_mha_attributes(module): _SET_ATTRIBUTES_MAP = { - BertAttention: _set_bert_mha_attributes, -} + BertAttention: _set_bert_mha_attributes,} def set_mha_attributes(model):