Skip to content

Commit

Permalink
Fix (examples/llm): change rotation interface
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Nov 18, 2024
1 parent c3208cf commit 619829e
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 19 deletions.
3 changes: 1 addition & 2 deletions src/brevitas/graph/equalize.py
Original file line number Diff line number Diff line change
Expand Up @@ -1367,7 +1367,6 @@ def _apply_rotate(model: nn.Module, regions: List[Region], full_rotation_method=
module.offload_params(module)

if insert_rotation_module and len(region.srcs) == 0:
# print(name, module.in_features, K)
rewriter = ModuleInstanceToModuleInstance(
module, RotatedModule(had_mat=rot_mat, k=K, layer=module))
rewriters.append(rewriter)
Expand Down Expand Up @@ -1467,7 +1466,7 @@ def rotate_matmuls(self, graph_module):

def apply(self,
graph_model: GraphModule) -> Union[Tuple[GraphModule, List[Transform]], GraphModule]:

rewriters = []
regions = _extract_regions(
graph_model,
state_impl_kwargs={
Expand Down
9 changes: 4 additions & 5 deletions src/brevitas_examples/llm/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,8 @@ usage: main.py [-h] [--model MODEL] [--seed SEED] [--nsamples NSAMPLES]
[--convert-layernorm-to-rmsnorm] [--replace-rmsnorm]
[--no-quantize] [--no-float16]
[--scaling-min-val SCALING_MIN_VAL] [--replace-mha]
[--weight-equalization]
[--graph-rotation {fx,layerwise,fused_no_fx}]
[--graph-rotation-mode {had,ort}] [--rotation-orphan-sink]
[--weight-equalization] [--rotation {fx,layerwise,fused_no_fx}]
[--rotation-mode {had,ort}] [--rotation-orphan-sink]
[--act-equalization {None,layerwise,fx}] [--load-awq LOAD_AWQ]
[--export-target {None,onnx_qcdq,torch_qcdq,sharded_torchmlir_group_weight,sharded_packed_torchmlir_group_weight}]
[--export-prefix EXPORT_PREFIX]
Expand Down Expand Up @@ -148,9 +147,9 @@ options:
--weight-equalization
Apply weight equalization. Relevant to ReLU based
models (e.g. OPT).
--graph-rotation {fx,layerwise,fused_no_fx}
--rotation {fx,layerwise,fused_no_fx}
Apply graph rotation equalization
--graph-rotation-mode {had,ort}
--rotation-mode {had,ort}
If GraphRotation is enabled, decide how to compute the
random rotation matrix that is fully fused. Online or
partial rotation will always be Hadamard
Expand Down
18 changes: 10 additions & 8 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def fused_rotation_no_fx(model, calibration_loader, args):
new_model = offload_model(new_model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink,
full_rotation_method=args.graph_rotation_mode,
full_rotation_method=args.rotation_mode,
return_rewriters=True)
new_model, rewriters = eq.apply(new_model)
rewriters = fix_rewriter(rewriters, model, 'weight')
Expand Down Expand Up @@ -104,10 +104,12 @@ def model_export(model, ref_input, args):


def validate(args):
if args.graph_rotation == 'fx':
if args.rotation == 'fx':
assert args.ln_affine_merge, 'Graph rotation requires to merge LN/RMS norm affine parameters'
assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)'
assert args.convert_layernorm_to_rmsnorm, 'Graph rotation requires to replace LayerNorm with RMSNorm'
elif args.rotation == 'fused_no_fx':
assert args.replace_rmsnorm, 'Graph rotation requires to replace HF RMSNorm with PyTorch ones (torch 2.4+ require)'
if not args.no_quantize:
if args.gptq and args.gpfq:
warn("Both GPTQ and GPFQ are enabled.")
Expand Down Expand Up @@ -259,16 +261,16 @@ def main(args):
apply_layernorm_to_rmsnorm(model)
print("Layernorm To RMSNorm applied.")

if args.graph_rotation == 'fx':
if args.rotation == 'fx':
model = offload_model(model)
eq = GraphRotationEqualization(
orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.graph_rotation_mode)
orphan_sink=args.rotation_orphan_sink, full_rotation_method=args.rotation_mode)
model = eq.apply(model)
remove_hooks(model)
elif args.graph_rotation == 'layerwise':
elif args.rotation == 'layerwise':
eq = LayerwiseActivationRotation()
model = eq.apply(model)
elif args.graph_rotation == 'fused_no_fx':
elif args.rotation == 'fused_no_fx':
fused_rotation_no_fx(model, calibration_loader, args)

# Insert standard MHA layers when performing fx based weight/act equalization to avoid dealing
Expand Down Expand Up @@ -600,13 +602,13 @@ def parse_args(args):
action='store_true',
help='Apply weight equalization. Relevant to ReLU based models (e.g. OPT).')
parser.add_argument(
'--graph-rotation',
'--rotation',
type=str,
default=None,
choices=['fx', 'layerwise', 'fused_no_fx'],
help='Apply graph rotation equalization')
parser.add_argument(
'--graph-rotation-mode',
'--rotation-mode',
default='had',
choices=['had', 'ort'],
help=
Expand Down
7 changes: 3 additions & 4 deletions tests/brevitas_examples/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"no_quantize": True,
"rotation_orphan_sink": True,
"convert_layernorm_to_rmsnorm": True,
"graph_rotation": "fx",
"rotation": "fx",
"exp_layer_types": {
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
Expand All @@ -394,7 +394,7 @@ def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4):
"no_quantize": True,
"rotation_orphan_sink": False,
"convert_layernorm_to_rmsnorm": True,
"graph_rotation": "fx",
"rotation": "fx",
"exp_layer_types": {
"L__self___model_layers_0_self_attn_k_proj":
"<class 'torch.nn.modules.linear.Linear'>",
Expand All @@ -417,8 +417,7 @@ def test_small_models_quant_layer(caplog, layer_args):
if args.replace_rmsnorm:
if torch_version < version.parse('2.4'):
pytest.skip("Replacing RMSNorm requires torch 2.4+ or greater")
if hasattr(args, 'graph_rotation') and args.graph_rotation == 'fx' and platform.system(
) == 'Windows':
if hasattr(args, 'rotation') and args.rotation == 'fx' and platform.system() == 'Windows':
pytest.skip("Skipping dynamo + windows")
float_ppl, quant_ppl, model = validate_args_and_run_main(args)
assert_layer_types(model, exp_layer_types)
Expand Down

0 comments on commit 619829e

Please sign in to comment.