diff --git a/.github/workflows/examples_llm_pytest.yml b/.github/workflows/examples_llm_pytest.yml new file mode 100644 index 000000000..e939a93b2 --- /dev/null +++ b/.github/workflows/examples_llm_pytest.yml @@ -0,0 +1,65 @@ +name: Examples LLM Pytest + +on: + push: + branches: [ master, dev ] + pull_request: + types: + - review_requested + +jobs: + build: + runs-on: ${{ matrix.platform }} + strategy: + fail-fast: false + + + matrix: + python_version: ['3.8', '3.9'] + pytorch_version: ['2.2.2', '2.3.1', '2.4.0'] + platform: ['windows-latest', 'ubuntu-latest', 'macos-latest'] + jit_status: ['jit_disabled', 'jit_enabled'] + + + exclude: + - pytorch_version: '1.9.1' + platform: 'macos-latest' + + - pytorch_version: '1.9.1' + jit_status: 'jit_enabled' + + + + if: ${{ !github.event.pull_request.draft }} + steps: + + - name: Checkout repo + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python_version }} + + - name: Install Nox dependencies + shell: bash + run: pip install -r requirements/requirements-nox.txt + + - name: Install update + shell: bash + run: sudo apt-get update + if: startsWith(runner.os, 'Linux') == true + + - name: Install libsndfile and libgomp1 on Ubuntu + shell: bash + run: sudo apt-get install -y libsndfile-dev libgomp1 + if: startsWith(runner.os, 'Linux') == true + + - name: Install libomp on macOS + shell: bash + run: brew install libomp + if: startsWith(runner.os, 'macOS') == true + + - name: Run Nox session for brevitas_examples pytest + shell: bash + run: nox -v -s tests_brevitas_examples_llm-${{ matrix.python_version }}\(${{ matrix.jit_status }}\,\ pytorch_${{ matrix.pytorch_version }}\) diff --git a/.github/workflows/gen_github_actions.py b/.github/workflows/gen_github_actions.py index 4cd6c6827..2c4908a6c 100644 --- a/.github/workflows/gen_github_actions.py +++ b/.github/workflows/gen_github_actions.py @@ -8,6 +8,7 @@ BASE_YML_REDUCED_TEMPLATE = 'base_reduced.yml.template' PYTEST_YML = 'pytest.yml' EXAMPLES_PYTEST_YML = 'examples_pytest.yml' +EXAMPLES_LLM_PYTEST_YML = 'examples_llm_pytest.yml' DEVELOP_INSTALL_YML = 'develop_install.yml' FINN_INTEGRATION_YML = 'finn_integration.yml' ORT_INTEGRATION_YML = 'ort_integration.yml' @@ -25,6 +26,10 @@ ('pytorch_version', list(PYTORCH_LIST_REDUCED)), ('platform', PLATFORM_LIST_REDUCED)]) +EXAMPLES_LLM_PYTEST_MATRIX_REDUCED = od([('python_version', list(PYTHON_VERSIONS_REDUCED)), + ('pytorch_version', list( + ('2.4.0',))), ('platform', PLATFORM_LIST_REDUCED)]) + FINN_MATRIX_REDUCED = od([('python_version', list(PYTHON_VERSIONS_REDUCED)), ('pytorch_version', list(PYTORCH_LIST_REDUCED)), ('platform', PLATFORM_LIST_REDUCED)]) @@ -61,6 +66,11 @@ MATRIX = od([('python_version', list(PYTHON_VERSIONS)), ('pytorch_version', list(PYTORCH_VERSIONS)), ('platform', PLATFORM_LIST)]) +EXAMPLES_LLM_PYTEST_PYTORCH_VERSIONS = ('2.2.2', '2.3.1', '2.4.0') +EXAMPLES_LLM_PYTEST_MATRIX = od([('python_version', list(PYTHON_VERSIONS)), + ('pytorch_version', list(EXAMPLES_LLM_PYTEST_PYTORCH_VERSIONS)), + ('platform', PLATFORM_LIST)]) + FINN_MATRIX = od([('python_version', list(PYTHON_VERSIONS)), ('pytorch_version', list(PYTORCH_VERSIONS)), ('platform', FINN_PLATFORM_LIST)]) @@ -80,6 +90,13 @@ 'nox -v -s tests_brevitas_examples_cpu-${{ matrix.python_version }}\(${{ matrix.jit_status }}\,\ pytorch_${{ matrix.pytorch_version }}\)' )]),] +EXAMPLES_LLM_PYTEST_STEP_LIST = [ + od([('name', 'Run Nox session for brevitas_examples pytest'), ('shell', 'bash'), + ( + 'run', + 'nox -v -s tests_brevitas_examples_llm-${{ matrix.python_version }}\(${{ matrix.jit_status }}\,\ pytorch_${{ matrix.pytorch_version }}\)' + )]),] + FINN_INTEGRATION_STEP_LIST = [ od([('name', 'Install protobuf on Ubuntu'), ('shell', 'bash'), ('run', 'sudo apt-get install protobuf-compiler libprotoc-dev'), @@ -167,6 +184,23 @@ def gen_examples_pytest_yml(): pytest.gen_yaml(BASE_YML_REDUCED_TEMPLATE, 'reduced_' + EXAMPLES_PYTEST_YML) +def gen_examples_llm_pytest_yml(): + pytest = Action( + 'Examples LLM Pytest', + EXCLUDE_LIST + JIT_EXCLUDE_LIST, + combine_od_list([EXAMPLES_LLM_PYTEST_MATRIX, PYTEST_MATRIX_EXTRA]), + EXAMPLES_LLM_PYTEST_STEP_LIST, + STRATEGY) + pytest.gen_yaml(BASE_YML_TEMPLATE, EXAMPLES_LLM_PYTEST_YML) + pytest = Action( + 'Examples LLM Pytest', + EXCLUDE_LIST, + combine_od_list([EXAMPLES_LLM_PYTEST_MATRIX_REDUCED, PYTEST_MATRIX_EXTRA_REDUCED]), + EXAMPLES_LLM_PYTEST_STEP_LIST, + STRATEGY) + pytest.gen_yaml(BASE_YML_REDUCED_TEMPLATE, 'reduced_' + EXAMPLES_LLM_PYTEST_YML) + + def gen_test_develop_install_yml(): test_develop_install = Action( 'Test develop install', EXCLUDE_LIST, MATRIX, TEST_INSTALL_DEV_STEP_LIST, STRATEGY) @@ -243,6 +277,7 @@ def gen_test_brevitas_end_to_end(): if __name__ == '__main__': gen_pytest_yml() gen_examples_pytest_yml() + gen_examples_llm_pytest_yml() gen_test_develop_install_yml() gen_test_brevitas_finn_integration() gen_test_brevitas_ort_integration() diff --git a/.github/workflows/reduced_examples_llm_pytest.yml b/.github/workflows/reduced_examples_llm_pytest.yml new file mode 100644 index 000000000..b9c3deffe --- /dev/null +++ b/.github/workflows/reduced_examples_llm_pytest.yml @@ -0,0 +1,64 @@ +name: Examples LLM Pytest + +on: + pull_request: + types: + - opened + - reopened + - synchronize + - ready_for_review + + +jobs: + build: + runs-on: ${{ matrix.platform }} + strategy: + fail-fast: false + + + matrix: + python_version: ['3.8'] + pytorch_version: ['2.4.0'] + platform: ['ubuntu-latest'] + jit_status: ['jit_disabled'] + + + exclude: + - pytorch_version: '1.9.1' + platform: 'macos-latest' + + + + if: ${{ !github.event.pull_request.draft }} + steps: + + - name: Checkout repo + uses: actions/checkout@v2 + + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python_version }} + + - name: Install Nox dependencies + shell: bash + run: pip install -r requirements/requirements-nox.txt + + - name: Install update + shell: bash + run: sudo apt-get update + if: startsWith(runner.os, 'Linux') == true + + - name: Install libsndfile and libgomp1 on Ubuntu + shell: bash + run: sudo apt-get install -y libsndfile-dev libgomp1 + if: startsWith(runner.os, 'Linux') == true + + - name: Install libomp on macOS + shell: bash + run: brew install libomp + if: startsWith(runner.os, 'macOS') == true + + - name: Run Nox session for brevitas_examples pytest + shell: bash + run: nox -v -s tests_brevitas_examples_llm-${{ matrix.python_version }}\(${{ matrix.jit_status }}\,\ pytorch_${{ matrix.pytorch_version }}\) diff --git a/noxfile.py b/noxfile.py index ffb1c5fbd..17a38789d 100644 --- a/noxfile.py +++ b/noxfile.py @@ -7,15 +7,20 @@ import nox from packaging import version +from packaging.version import parse sys.path.append(os.path.join(os.path.dirname(__file__), os.path.join('.', '.github', 'workflows'))) +from gen_github_actions import EXAMPLES_LLM_PYTEST_PYTORCH_VERSIONS from gen_github_actions import JIT_STATUSES from gen_github_actions import PYTHON_VERSIONS from gen_github_actions import PYTORCH_VERSIONS IS_OSX = system() == 'Darwin' -PYTORCH_STABLE_WHEEL_SRC = 'https://download.pytorch.org/whl/torch_stable.html' +PYTORCH_STABLE_WHEEL_SRC = 'https://download.pytorch.org/whl/cpu' +PYTORCH_STABLE_WHEEL_SRC_LEGACY = 'https://download.pytorch.org/whl/torch_stable.html' PYTORCH_IDS = tuple([f'pytorch_{i}' for i in PYTORCH_VERSIONS]) +EXAMPLES_LLM_PYTEST_PYTORCH_IDS = tuple([ + f'pytorch_{i}' for i in EXAMPLES_LLM_PYTEST_PYTORCH_VERSIONS]) JIT_IDS = tuple([f'{i}'.lower() for i in JIT_STATUSES]) LSTM_EXPORT_MIN_PYTORCH = '1.10.1' @@ -26,14 +31,21 @@ '1.12.1': '0.13.1', '1.13.0': '0.14.0', '2.0.1': '0.15.2', - '2.1.0': '0.16.0'} + '2.1.0': '0.16.0', + '2.2.2': '0.17.2', + '2.3.1': '0.18.1', + '2.4.0': '0.19.0'} PARSED_TORCHVISION_VERSION_DICT = {version.parse(k): v for k, v in TORCHVISION_VERSION_DICT.items()} def install_pytorch(pytorch, session): if not IS_OSX: - cmd = [f'torch=={pytorch}+cpu', '-f', PYTORCH_STABLE_WHEEL_SRC] + if parse(pytorch) < parse('2.4.0'): + cmd = [f'torch=={pytorch}+cpu', '-f', PYTORCH_STABLE_WHEEL_SRC_LEGACY] + else: + cmd = [f'torch=={pytorch}', '--index-url', PYTORCH_STABLE_WHEEL_SRC] + else: cmd = [f'torch=={pytorch}'] session.install(*cmd) @@ -42,11 +54,18 @@ def install_pytorch(pytorch, session): def install_torchvision(pytorch, session): torchvision = PARSED_TORCHVISION_VERSION_DICT[version.parse(pytorch)] if not IS_OSX: - cmd = [ - f'torch=={pytorch}+cpu', # make sure correct pytorch version is kept - f'torchvision=={torchvision}+cpu', - '-f', - PYTORCH_STABLE_WHEEL_SRC] + if parse(pytorch) < parse('2.4.0'): + cmd = [ + f'torch=={pytorch}+cpu', # make sure correct pytorch version is kept + f'torchvision=={torchvision}+cpu', + '-f', + PYTORCH_STABLE_WHEEL_SRC_LEGACY] + else: + cmd = [ + f'torch=={pytorch}', + f'torchvision=={torchvision}', + '--index-url', + PYTORCH_STABLE_WHEEL_SRC] else: cmd = [f'torch=={pytorch}', f'torchvision=={torchvision}'] session.install(*cmd) @@ -105,7 +124,25 @@ def tests_brevitas_examples_cpu(session, pytorch, jit_status): install_pytorch(pytorch, session) install_torchvision(pytorch, session) # For CV eval scripts session.install('--upgrade', '.[test, tts, stt, vision]') - session.run('pytest', '-n', 'logical', 'tests/brevitas_examples') + session.run( + 'pytest', + '-n', + 'logical', + '--ignore-glob', + 'tests/brevitas_examples/*llm*', + 'tests/brevitas_examples') + + +@nox.session(python=PYTHON_VERSIONS) +@nox.parametrize( + "pytorch", EXAMPLES_LLM_PYTEST_PYTORCH_VERSIONS, ids=EXAMPLES_LLM_PYTEST_PYTORCH_IDS) +@nox.parametrize("jit_status", JIT_STATUSES, ids=JIT_IDS) +def tests_brevitas_examples_llm(session, pytorch, jit_status): + session.env['BREVITAS_JIT'] = '{}'.format(int(jit_status == 'jit_enabled')) + install_pytorch(pytorch, session) + install_torchvision(pytorch, session) # Optimum seems to require torchvision + session.install('-e', '.[test, llm, export]') + session.run('pytest', '-n', 'logical', '-k', 'llm', 'tests/brevitas_examples/test_llm.py') @nox.session(python=PYTHON_VERSIONS) diff --git a/pytest.ini b/pytest.ini index a560d3e16..a3cd14b59 100644 --- a/pytest.ini +++ b/pytest.ini @@ -7,3 +7,6 @@ log_cli_format = %(message)s # pytest-mock should use Pypi's mock rather than Python's built-in mock_use_standalone_module = true + +markers = + llm: mark a test which tests brevitas_examples/llm diff --git a/requirements/requirements-llm.txt b/requirements/requirements-llm.txt new file mode 100644 index 000000000..7070cc9c6 --- /dev/null +++ b/requirements/requirements-llm.txt @@ -0,0 +1,3 @@ +optimum-amd[brevitas] @ git+https://github.com/huggingface/optimum-amd.git@main +tqdm +transformers diff --git a/setup.py b/setup.py index 10e920981..4a756962d 100644 --- a/setup.py +++ b/setup.py @@ -39,6 +39,7 @@ def read_requirements(filename): "test": read_requirements('requirements-test.txt'), "tts": read_requirements('requirements-tts.txt'), "stt": read_requirements('requirements-stt.txt'), + "llm": read_requirements('requirements-llm.txt'), "vision": read_requirements('requirements-vision.txt'), "finn_integration": read_requirements('requirements-finn-integration.txt'), "ort_integration": read_requirements('requirements-ort-integration.txt')}, diff --git a/src/brevitas_examples/common/generative/quantize.py b/src/brevitas_examples/common/generative/quantize.py index 57670f6f6..73831d5fa 100644 --- a/src/brevitas_examples/common/generative/quantize.py +++ b/src/brevitas_examples/common/generative/quantize.py @@ -153,7 +153,8 @@ 'sym': Int8DynamicActPerGroupFloat}}}, 'po2_scale': { 'stats': { - 'per_group': MXInt8Act}}}}, + 'per_group': { + 'sym': MXInt8Act}}}}}, 'float': { 'static': { 'float_scale': { @@ -175,7 +176,8 @@ 'dynamic': { 'po2_scale': { 'stats': { - 'per_group': MXFloat8e4m3Act}}}}, + 'per_group': { + 'sym': MXFloat8e4m3Act}}}}}, 'float_fnuz': { 'static': { 'float_scale': { diff --git a/src/brevitas_examples/llm/main.py b/src/brevitas_examples/llm/main.py index 05d84f647..6060ef498 100644 --- a/src/brevitas_examples/llm/main.py +++ b/src/brevitas_examples/llm/main.py @@ -5,6 +5,7 @@ import argparse import re +import sys import numpy as np from optimum.amd.brevitas.accelerate_utils import offload_model @@ -40,157 +41,6 @@ from brevitas_examples.llm.llm_quant.run_utils import get_fx from brevitas_examples.llm.llm_quant.run_utils import modify_dataloader -parser = argparse.ArgumentParser() -parser.add_argument( - '--model', - type=str, - default="facebook/opt-125m", - help='HF model name. Default: facebook/opt-125m.') -parser.add_argument( - '--seed', type=int, default=0, help='Seed for sampling the calibration data. Default: 0.') -parser.add_argument( - '--nsamples', type=int, default=128, help='Number of calibration data samples. Default: 128.') -parser.add_argument('--seqlen', type=int, default=2048, help='Sequence length. Default: 2048.') -parser.add_argument('--eval', action='store_true', help='Eval model PPL on the chosen Dataset.') -parser.add_argument( - '--dataset', - type=str, - choices=['wikitext2', 'c4'], - default='wikitext2', - help='Dataset to use for quantization (default: %(default)s)') -parser.add_argument('--weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.') -parser.add_argument( - '--weight-param-method', - type=str, - default='stats', - choices=['stats', 'mse'], - help='How scales/zero-point are determined. Default: stats.') -parser.add_argument( - '--weight-scale-precision', - type=str, - default='float_scale', - choices=['float_scale', 'po2_scale'], - help='Whether scale is a float value or a po2. Default: po2.') -parser.add_argument( - '--weight-quant-type', - type=str, - default='sym', - choices=['sym', 'asym'], - help='Weight quantization type. Default: asym.') -parser.add_argument( - '--weight-quant-format', - type=quant_format_validator, - default='int', - help= - 'Weight quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. It\'s possible to add float_ocp_ or float_fnuz_ before the exponent/mantissa bitwidth. Default: int.' -) -parser.add_argument( - '--weight-quant-granularity', - type=str, - default='per_group', - choices=['per_channel', 'per_tensor', 'per_group'], - help='Granularity for scales/zero-point of weights. Default: per_group.') -parser.add_argument( - '--weight-group-size', - type=int, - default=128, - help='Group size for per_group weight quantization. Default: 128.') -parser.add_argument( - '--quantize-weight-zero-point', action='store_true', help='Quantize weight zero-point.') -parser.add_argument( - '--input-bit-width', - type=int, - default=None, - help='Input bit width. Default: None (disables input quantization).') -parser.add_argument( - '--input-quant-format', - type=quant_format_validator, - default='int', - help= - 'Input quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. It\'s possible to add float_ocp_ or float_fnuz_ before the exponent/mantissa bitwidth. Default: int.' -) -parser.add_argument( - '--input-param-method', - type=str, - default='stats', - choices=['stats', 'mse'], - help= - 'How scales/zero-point are determined. Default: stats (percentile for static, absmax or minmax for dynamic).' -) -parser.add_argument( - '--input-scale-precision', - type=str, - default='float_scale', - choices=['float_scale', 'po2_scale'], - help='Whether input scale is a float value or a po2. Default: float.') -parser.add_argument( - '--input-scale-type', - type=str, - default='static', - choices=['static', 'dynamic', 'no_scale'], - help='Whether input scale is a static value or a dynamic value.') -parser.add_argument( - '--input-quant-type', - type=str, - default='asym', - choices=['sym', 'asym'], - help='Input quantization type. Default: asym.') -parser.add_argument( - '--input-quant-granularity', - type=str, - default='per_tensor', - choices=['per_tensor', 'per_row', 'per_group'], - help='Granularity for scales/zero-point of inputs. Default: per_tensor.') -parser.add_argument( - '--input-group-size', - type=int, - default=64, - help='Group size for per_group input quantization. Default: 64.') -parser.add_argument( - '--quantize-input-zero-point', action='store_true', help='Quantize input zero-point.') -parser.add_argument( - '--quantize-last-layer', action='store_true', help='Quantize last nn.Linear layer.') -parser.add_argument('--gptq', action='store_true', help='Apply GPTQ.') -parser.add_argument('--act-calibration', action='store_true', help='Apply activation calibration.') -parser.add_argument('--bias-corr', action='store_true', help='Apply bias correction.') -parser.add_argument('--ln-affine-merge', action='store_true', help='Merge LN affine params.') -parser.add_argument('--no-quantize', action='store_true', help='Disable quantization.') -parser.add_argument( - '--no-float16', - action='store_true', - help='Disable float16 as base datatype and switch to float32.') -parser.add_argument( - '--replace-mha', - action='store_true', - help='Replace HuggingFace Attention with a quantizable version') -parser.add_argument( - '--weight-equalization', - action='store_true', - help='Apply weight equalization. Relevant to ReLU based models (e.g. OPT).') -parser.add_argument( - '--act-equalization', - default=None, - choices=[None, 'layerwise', 'fx'], - help='Apply activation equalization (SmoothQuant). Layerwise introduces standalone mul nodes,' - 'while fx merges them whenever possible into previous tensors, which is possible on ReLU based models (e.g. OPT).' -) -parser.add_argument('--load-awq', type=str, default=None, help="Load the awq search results.") -parser.add_argument( - '--export-target', - default=None, - choices=[ - None, - 'onnx_qcdq', - 'torch_qcdq', - 'sharded_torchmlir_group_weight', - 'sharded_packed_torchmlir_group_weight'], - help='Model export.') -parser.add_argument( - '--checkpoint-name', - type=str, - default=None, - help="Filename to save checkpoint. If `None`, no checkpoint is saved (default: %(default)s)") - def set_seed(seed): np.random.seed(seed) @@ -213,16 +63,15 @@ def model_export(model, ref_input, args): export_manager = StdQCDQONNXManager export_manager.change_weight_export(export_weight_q_node=True) - print(f"Exporting the model in ./quantized_onnx/{args.model.replace('/', '-')}") + print(f"Exporting the model in ./{args.export_prefix}") with torch.no_grad(), brevitas_proxy_export_mode(model, export_manager=export_manager): onnx_export_from_model( model, - f"./quantized_onnx/{args.model.replace('/', '-')}", + f"./{args.export_prefix}", task="text-generation-with-past", do_validation=False) elif args.export_target == 'torch_qcdq': - export_torch_qcdq( - model, ref_input['input_ids'], export_path=f"{args.model.replace('/', '-')}.pt") + export_torch_qcdq(model, ref_input['input_ids'], export_path=f"{args.export_prefix}.pt") def validate(args): @@ -261,11 +110,13 @@ def validate(args): assert args.export_target != 'torch_qcdq', "Cannot export Torch QCDQ with FX" -def main(): - args = parser.parse_args() +def main(args): validate(args) set_seed(args.seed) + if args.export_prefix is None: + args.export_prefix = f"{args.model.replace('/', '--')}" + if args.no_float16: dtype = torch.float32 else: @@ -281,6 +132,8 @@ def main(): print("Model loaded.") model.eval() tokenizer = AutoTokenizer.from_pretrained(args.model) + float_ppl = None + quant_ppl = None if args.load_awq: from brevitas_examples.llm.llm_quant.awq.pre_quant import apply_awq @@ -325,10 +178,10 @@ def main(): assert args.export_target != 'torch_qcdq', "TorchScript QCDQ export and Evaluation simultaneously" print("Float model eval...") model = offload_model(model) - ppl = compute_perplexity( + float_ppl = compute_perplexity( model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) remove_hooks(model) - print(f"Float perplexity ({args.dataset}): {ppl}") + print(f"Float perplexity ({args.dataset}): {float_ppl}") if require_fx: model = get_fx(model) @@ -432,9 +285,9 @@ def main(): if args.eval: print("Model eval...") - ppl = compute_perplexity( + quant_ppl = compute_perplexity( model, validation_loader, context_length=args.seqlen // 2, tokenizer=tokenizer) - print(f"Quantized perplexity ({args.dataset}): {ppl}") + print(f"Quantized perplexity ({args.dataset}): {quant_ppl}") remove_hooks(model) if args.checkpoint_name is not None: @@ -447,6 +300,176 @@ def main(): model = model.to(dtype=torch.float32) model_export(model, calibration_loader[0], args) + return float_ppl, quant_ppl, model + + +def parse_args(args): + parser = argparse.ArgumentParser() + parser.add_argument( + '--model', + type=str, + default="facebook/opt-125m", + help='HF model name. Default: facebook/opt-125m.') + parser.add_argument( + '--seed', type=int, default=0, help='Seed for sampling the calibration data. Default: 0.') + parser.add_argument( + '--nsamples', + type=int, + default=128, + help='Number of calibration data samples. Default: 128.') + parser.add_argument('--seqlen', type=int, default=2048, help='Sequence length. Default: 2048.') + parser.add_argument('--eval', action='store_true', help='Eval model PPL on the chosen Dataset.') + parser.add_argument( + '--dataset', + type=str, + choices=['wikitext2', 'c4'], + default='wikitext2', + help='Dataset to use for quantization (default: %(default)s)') + parser.add_argument( + '--weight-bit-width', type=int, default=8, help='Weight bit width. Default: 8.') + parser.add_argument( + '--weight-param-method', + type=str, + default='stats', + choices=['stats', 'mse'], + help='How scales/zero-point are determined. Default: stats.') + parser.add_argument( + '--weight-scale-precision', + type=str, + default='float_scale', + choices=['float_scale', 'po2_scale'], + help='Whether scale is a float value or a po2. Default: po2.') + parser.add_argument( + '--weight-quant-type', + type=str, + default='sym', + choices=['sym', 'asym'], + help='Weight quantization type. Default: asym.') + parser.add_argument( + '--weight-quant-format', + type=quant_format_validator, + default='int', + help= + 'Weight quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. It\'s possible to add float_ocp_ or float_fnuz_ before the exponent/mantissa bitwidth. Default: int.' + ) + parser.add_argument( + '--weight-quant-granularity', + type=str, + default='per_group', + choices=['per_channel', 'per_tensor', 'per_group'], + help='Granularity for scales/zero-point of weights. Default: per_group.') + parser.add_argument( + '--weight-group-size', + type=int, + default=128, + help='Group size for per_group weight quantization. Default: 128.') + parser.add_argument( + '--quantize-weight-zero-point', action='store_true', help='Quantize weight zero-point.') + parser.add_argument( + '--input-bit-width', + type=int, + default=None, + help='Input bit width. Default: None (disables input quantization).') + parser.add_argument( + '--input-quant-format', + type=quant_format_validator, + default='int', + help= + 'Input quantization type. Either int or eXmY, with X+Y==weight_bit_width-1. It\'s possible to add float_ocp_ or float_fnuz_ before the exponent/mantissa bitwidth. Default: int.' + ) + parser.add_argument( + '--input-param-method', + type=str, + default='stats', + choices=['stats', 'mse'], + help= + 'How scales/zero-point are determined. Default: stats (percentile for static, absmax or minmax for dynamic).' + ) + parser.add_argument( + '--input-scale-precision', + type=str, + default='float_scale', + choices=['float_scale', 'po2_scale'], + help='Whether input scale is a float value or a po2. Default: float.') + parser.add_argument( + '--input-scale-type', + type=str, + default='static', + choices=['static', 'dynamic', 'no_scale'], + help='Whether input scale is a static value or a dynamic value.') + parser.add_argument( + '--input-quant-type', + type=str, + default='asym', + choices=['sym', 'asym'], + help='Input quantization type. Default: asym.') + parser.add_argument( + '--input-quant-granularity', + type=str, + default='per_tensor', + choices=['per_tensor', 'per_row', 'per_group'], + help='Granularity for scales/zero-point of inputs. Default: per_tensor.') + parser.add_argument( + '--input-group-size', + type=int, + default=64, + help='Group size for per_group input quantization. Default: 64.') + parser.add_argument( + '--quantize-input-zero-point', action='store_true', help='Quantize input zero-point.') + parser.add_argument( + '--quantize-last-layer', action='store_true', help='Quantize last nn.Linear layer.') + parser.add_argument('--gptq', action='store_true', help='Apply GPTQ.') + parser.add_argument( + '--act-calibration', action='store_true', help='Apply activation calibration.') + parser.add_argument('--bias-corr', action='store_true', help='Apply bias correction.') + parser.add_argument('--ln-affine-merge', action='store_true', help='Merge LN affine params.') + parser.add_argument('--no-quantize', action='store_true', help='Disable quantization.') + parser.add_argument( + '--no-float16', + action='store_true', + help='Disable float16 as base datatype and switch to float32.') + parser.add_argument( + '--replace-mha', + action='store_true', + help='Replace HuggingFace Attention with a quantizable version') + parser.add_argument( + '--weight-equalization', + action='store_true', + help='Apply weight equalization. Relevant to ReLU based models (e.g. OPT).') + parser.add_argument( + '--act-equalization', + default=None, + choices=[None, 'layerwise', 'fx'], + help='Apply activation equalization (SmoothQuant). Layerwise introduces standalone mul nodes,' + 'while fx merges them whenever possible into previous tensors, which is possible on ReLU based models (e.g. OPT).' + ) + parser.add_argument('--load-awq', type=str, default=None, help="Load the awq search results.") + parser.add_argument( + '--export-target', + default=None, + choices=[ + None, + 'onnx_qcdq', + 'torch_qcdq', + 'sharded_torchmlir_group_weight', + 'sharded_packed_torchmlir_group_weight'], + help='Model export.') + parser.add_argument( + '--export-prefix', + type=str, + default=None, + help= + "Path prefix to use for the various export flows. If None, a path will be derived from the model name (default: %(default)s)" + ) + parser.add_argument( + '--checkpoint-name', + type=str, + default=None, + help="Filename to save checkpoint. If `None`, no checkpoint is saved (default: %(default)s)" + ) + return parser.parse_args(args) + if __name__ == '__main__': - main() + args = parse_args(sys.argv[1:]) + main(args) diff --git a/tests/brevitas_examples/test_llm.py b/tests/brevitas_examples/test_llm.py new file mode 100644 index 000000000..6a98911a9 --- /dev/null +++ b/tests/brevitas_examples/test_llm.py @@ -0,0 +1,478 @@ +# Copyright (C) 2024, Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause + +from argparse import Namespace +from dataclasses import dataclass +import logging +import os +import shutil + +import numpy as np +import onnx +import pytest +import pytest_cases +import torch + +from brevitas import config +# LLM example depends on optimum-amd, which requires PyTorch>=2.2 +from brevitas_examples.llm.main import main +from brevitas_examples.llm.main import parse_args +from tests.marker import jit_disabled_for_export +from tests.marker import requires_pt_ge + + +def ptid2pathname(string): + return string.replace("/", "-").replace(":", "-") + + +def allclose(x, y): + return np.allclose(x, y, rtol=1e-03, atol=1e+01, equal_nan=False) + + +def allveryclose(x, y): + return np.allclose(x, y, rtol=1e-04, atol=2e+02, equal_nan=False) + + +def allexact(x, y): + return np.allclose(x, y, rtol=0.0, atol=0.0, equal_nan=False) + + +# Check that all args in args are used +def validate_args(args): + a = vars(args) + da = vars(parse_args([])) + for k in a.keys(): + assert k in da.keys(), f"Key {k} does not seem to be a valid argument for `main`" + + +def validate_args_and_run_main(args): + validate_args(args) + float_ppl, quant_ppl, model = main(args) + return float_ppl, quant_ppl, model + + +def assert_layer_types(model, exp_layer_types): + for key, string in exp_layer_types.items(): + matched = False + layer_names = [] + for name, layer in model.named_modules(): + layer_names += [name] + if name == key: + matched = True + ltype = str(type(layer)) + assert ltype == string, f"Expected layer type: {string}, found {ltype} for key: {key}" + continue + assert matched, f"Layer key: {key} not found in {layer_names}" + + +class UpdatableNamespace(Namespace): + + def update(self, **kwargs): + self.__dict__.update(**kwargs) + + +def requires_fx(args): + return args.act_equalization == "fx" or args.weight_equalization or args.ln_affine_merge + + +@dataclass +class ModelAndPpl: + name: str + float_ppl: float + supports_fx: bool + + +@pytest_cases.fixture( + scope="session", + ids=[ + "llama", + "mistral", #"mixtral", + ], + params=[ + ModelAndPpl( + name="hf-internal-testing/tiny-random-LlamaForCausalLM", + float_ppl=None, + supports_fx=True, + ), + ModelAndPpl( + name="hf-internal-testing/tiny-random-MistralForCausalLM", + float_ppl=None, + supports_fx=False, + ), + #ModelAndPpl( # Ready for MoE support + # name="dacorvo/Mixtral-tiny", + # float_ppl=None, + # supports_fx=True, + #), + ]) +def small_models_with_ppl(request): + yield request.param + + +@pytest_cases.fixture() +def default_run_args(request): + args = UpdatableNamespace(**vars(parse_args([]))) + args.nsamples = 2 + args.seqlen = 2 + args.model = "hf-internal-testing/tiny-random-MistralForCausalLM" + args.dataset = "c4" + args.eval = True + #args.checkpoint = ptid2pathname(request.node.nodeid) + ".pth" # Example filename which won't clash + args.export_prefix = ptid2pathname(request.node.nodeid) + args.weight_bit_width = 8 + args.weight_quant_granularity = "per_channel" # "per_tensor", "per_channel", "per_group". + args.input_bit_width = 8 + args.act_calibration = True + return args + + +def run_test_models_run_args(args, model_with_ppl): + args.model = model_with_ppl.name + exp_float_ppl = model_with_ppl.float_ppl + use_fx = requires_fx(args) + if use_fx and not model_with_ppl.supports_fx: + pytest.xfail(f"{model_with_ppl.name} does not support FX") + float_ppl, quant_ppl, model = validate_args_and_run_main(args) + + +# yapf: disable +@pytest_cases.fixture( + ids=[ + "defaults", + "bias_corr=True", + "act_equalization=layerwise", + "act_equalization=fx", + "weight_equalization=True", + "gptq=True", + "ln_affine_merge=True",], + params=[ + {}, + {"bias_corr": True}, + {"act_equalization": "layerwise"}, + {"act_equalization": "fx"}, + {"weight_equalization": True}, + {"gptq": True}, + {"ln_affine_merge": True},]) +# yapf: enable +def toggle_run_args(default_run_args, request): + args = default_run_args + args.update(**request.param) + yield args + + +@pytest.mark.llm +@requires_pt_ge('2.2') +def test_small_models_toggle_run_args(caplog, toggle_run_args, small_models_with_ppl): + caplog.set_level(logging.INFO) + run_test_models_run_args(toggle_run_args, small_models_with_ppl) + + +@pytest_cases.fixture( + scope="session", + ids=[ + "opt",], + params=[ + ModelAndPpl( + name="hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run + float_ppl=None, + supports_fx=True, + ),]) +def small_models_with_ppl_pt_ge_2_4(request): + yield request.param + + +@pytest.mark.llm +@requires_pt_ge('2.4') +def test_small_models_toggle_run_args_pt_ge_2_4( + caplog, toggle_run_args, small_models_with_ppl_pt_ge_2_4): + caplog.set_level(logging.INFO) + run_test_models_run_args(toggle_run_args, small_models_with_ppl_pt_ge_2_4) + + +@pytest_cases.fixture( + ids=[ + "llama", + "mistral",], + params=[ + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "act_equalization": "layerwise", + "gptq": True, + "float_ppl": 31274.05078125, + "quant_ppl": 33139.23046875}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_equalization": "fx", + "bias_corr": True, + "float_ppl": 33239.5, + "quant_ppl": 33283.75390625},]) +def acc_args_and_acc(default_run_args, request): + args = default_run_args + run_dict = request.param + float_ppl = run_dict["float_ppl"] + quant_ppl = run_dict["quant_ppl"] + del run_dict["float_ppl"] + del run_dict["quant_ppl"] + args.update(**run_dict) + yield args, float_ppl, quant_ppl + + +@pytest.mark.llm +@requires_pt_ge('2.2') +def test_small_models_acc(caplog, acc_args_and_acc): + caplog.set_level(logging.INFO) + args, exp_float_ppl, exp_quant_ppl = acc_args_and_acc + float_ppl, quant_ppl, model = validate_args_and_run_main(args) + float_ppl = float_ppl.detach().cpu().numpy() + quant_ppl = quant_ppl.detach().cpu().numpy() + assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + + +@pytest_cases.fixture( + ids=[ + "opt-replace-mha",], + params=[ + { + "model": "hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run + "weight_equalization": True, + "ln_affine_merge": True, + "replace_mha": True, + "float_ppl": 50016.0, + "quant_ppl": 50016.0},]) +def acc_args_and_acc_pt_ge_2_4(default_run_args, request): + args = default_run_args + run_dict = request.param + float_ppl = run_dict["float_ppl"] + quant_ppl = run_dict["quant_ppl"] + del run_dict["float_ppl"] + del run_dict["quant_ppl"] + args.update(**run_dict) + yield args, float_ppl, quant_ppl + + +@pytest.mark.llm +@requires_pt_ge('2.4') +def test_small_models_acc_pt_ge_2_4(caplog, acc_args_and_acc_pt_ge_2_4): + caplog.set_level(logging.INFO) + args, exp_float_ppl, exp_quant_ppl = acc_args_and_acc_pt_ge_2_4 + float_ppl, quant_ppl, model = validate_args_and_run_main(args) + float_ppl = float_ppl.detach().cpu().numpy() + quant_ppl = quant_ppl.detach().cpu().numpy() + assert allveryclose(exp_float_ppl, float_ppl), f"Expected float PPL {exp_float_ppl}, measured PPL {float_ppl}" + assert allveryclose(exp_quant_ppl, quant_ppl), f"Expected quant PPL {exp_quant_ppl}, measured PPL {quant_ppl}" + + +@pytest_cases.fixture( + ids=[ + "mistral-int8", + "mistral-weight-only", + "mistral-fp8_ocp", + "mistral-fp8_fnuz", + "llama-mxfp8", + "llama-int8-act_equalization=layerwise", + "mistral-int8-quant-last-layer",], + params=[ + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "exp_layer_types": { + "lm_head": + "", + "model.layers.0.self_attn.q_proj": + "", + "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": + "", + "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": + "",}}, + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "input_bit_width": None, + "act_calibration": False, + "exp_layer_types": { + "model.layers.0.self_attn.q_proj": + "", + "model.layers.0.self_attn.q_proj.input_quant": + "", + "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": + "",}}, + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "weight_quant_format": "float_ocp_e4m3", + "weight_quant_type": "sym", + "input_quant_format": "float_ocp_e5m2", + "input_quant_type": "sym", + "exp_layer_types": { + "model.layers.0.self_attn.q_proj": + "", + "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": + "", + "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": + "",}}, + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "weight_quant_format": "float_fnuz_e4m3", + "weight_quant_type": "sym", + "input_quant_format": "float_fnuz_e5m2", + "input_quant_type": "sym", + "exp_layer_types": { + "model.layers.0.self_attn.q_proj": + "", + "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": + "", + "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": + "",}}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "weight_quant_format": "float_ocp_e4m3", + "weight_scale_precision": "po2_scale", + "weight_param_method": "stats", + "weight_quant_granularity": "per_group", + "weight_group_size": 16, + "weight_quant_type": "sym", + "input_quant_format": "float_ocp_e5m2", + "input_scale_type": "dynamic", + "input_scale_precision": "po2_scale", + "input_param_method": "stats", + "input_quant_granularity": "per_group", + "input_group_size": 16, + "input_quant_type": "sym", + "act_calibration": False, + "exp_layer_types": { + "model.layers.0.self_attn.q_proj": + "", + "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant": + "", + "model.layers.0.self_attn.q_proj.input_quant.fused_activation_quant_proxy.tensor_quant.input_view_impl": + "", + "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant": + "", + "model.layers.0.self_attn.q_proj.weight_quant.tensor_quant.input_view_impl": + "",}}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "act_equalization": "layerwise", + "exp_layer_types": { + "model.layers.0.self_attn.q_proj": + "", + "model.layers.0.self_attn.q_proj.layer": + "",}}, + { + "model": "hf-internal-testing/tiny-random-MistralForCausalLM", + "quantize_last_layer": True, + "exp_layer_types": { + "lm_head": ""}},]) +def layer_args(default_run_args, request): + args = default_run_args + layer_dict = request.param + exp_layer_types = layer_dict["exp_layer_types"] + del layer_dict["exp_layer_types"] + args.update(**layer_dict) + yield args, exp_layer_types + + +@pytest.mark.llm +@requires_pt_ge('2.2') +def test_small_models_quant_layer(caplog, layer_args): + caplog.set_level(logging.INFO) + args, exp_layer_types = layer_args + float_ppl, quant_ppl, model = validate_args_and_run_main(args) + assert_layer_types(model, exp_layer_types) + + +@pytest_cases.fixture( + ids=[ + "opt-replace-mha",], + params=[ + { + "model": "hf-internal-testing/tiny-random-OPTForCausalLM", # Requires PT>=2.4 to run + "replace_mha": True, + "exp_layer_types": { + "model.decoder.layers.0.self_attn": + "", + "model.decoder.layers.0.self_attn.mha": + "",}},]) +def layer_args_pt_ge_2_4(default_run_args, request): + args = default_run_args + layer_dict = request.param + exp_layer_types = layer_dict["exp_layer_types"] + del layer_dict["exp_layer_types"] + args.update(**layer_dict) + yield args, exp_layer_types + + +@pytest.mark.llm +@requires_pt_ge('2.4') +def test_small_models_quant_layer_pt_ge_2_4(caplog, layer_args_pt_ge_2_4): + caplog.set_level(logging.INFO) + args, exp_layer_types = layer_args_pt_ge_2_4 + float_ppl, quant_ppl, model = validate_args_and_run_main(args) + assert_layer_types(model, exp_layer_types) + + +@pytest_cases.fixture( + ids=[ + "qcdq-asym", + "qcdq-sym",], + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "quantize_weight_zero_point": True, + "quantize_input_zero_point": True, + "export_target": "onnx_qcdq",}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "weight_quant_type": "sym", + "input_quant_type": "sym", + "export_target": "onnx_qcdq",},]) +def onnx_export_args(default_run_args, request): + args = default_run_args + export_dict = request.param + args.update(**export_dict) + yield args + + +@pytest.mark.llm +@jit_disabled_for_export() +@requires_pt_ge('2.2') +def test_small_models_onnx_export(caplog, onnx_export_args): + caplog.set_level(logging.INFO) + args = onnx_export_args + float_ppl, quant_ppl, model = validate_args_and_run_main(args) + onnx_model = onnx.load(os.path.join(args.export_prefix, "model.onnx")) + shutil.rmtree(args.export_prefix) + + +@pytest_cases.fixture( + ids=[ + "qcdq-asym", + "qcdq-sym",], + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "eval": False, + "quantize_weight_zero_point": True, + "quantize_input_zero_point": True, + "export_target": "torch_qcdq",}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "eval": False, + "weight_quant_type": "sym", + "input_quant_type": "sym", + "export_target": "torch_qcdq",},]) +def torch_export_args(default_run_args, request): + args = default_run_args + export_dict = request.param + args.update(**export_dict) + yield args + + +@pytest.mark.llm +@jit_disabled_for_export() +@requires_pt_ge('2.2') +def test_small_models_torch_export(caplog, torch_export_args): + caplog.set_level(logging.INFO) + args = torch_export_args + float_ppl, quant_ppl, model = validate_args_and_run_main(args) + filepath = args.export_prefix + ".pt" + torchscript_model = torch.jit.load(filepath) + os.remove(filepath)