diff --git a/tests/brevitas_examples/llm.py b/tests/brevitas_examples/llm.py index 675bf2df2..a05a2e2c5 100644 --- a/tests/brevitas_examples/llm.py +++ b/tests/brevitas_examples/llm.py @@ -4,10 +4,13 @@ from argparse import Namespace from dataclasses import dataclass import logging +import os import shutil import numpy as np +import onnx import pytest +import torch from brevitas_examples.llm.main import main from brevitas_examples.llm.main import parse_args @@ -87,6 +90,7 @@ def default_run_args(request): args.dataset = "c4" args.eval = True #args.checkpoint = ptid2pathname(request.node.nodeid) + ".pth" # Example filename which won't clash + args.export_prefix = ptid2pathname(request.node.nodeid) args.weight_bit_width = 8 args.weight_quant_granularity = "per_channel" # "per_tensor", "per_channel", "per_group". args.input_bit_width = 8 @@ -284,3 +288,60 @@ def test_small_models_quant_layer(caplog, layer_args): args, exp_layer_types = layer_args float_ppl, quant_ppl, model = main(args) assert_layer_types(model, exp_layer_types) + + +@pytest.fixture( + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "quantize_weight_zero_point": True, + "quantize_input_zero_point": True, + "export_target": "onnx_qcdq",}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "weight_quant_type": "sym", + "input_quant_type": "sym", + "export_target": "onnx_qcdq",},]) +def onnx_export_args(default_run_args, request): + args = default_run_args + export_dict = request.param + args.update(**export_dict) + yield args + + +def test_small_models_onnx_export(caplog, onnx_export_args): + caplog.set_level(logging.INFO) + args = onnx_export_args + float_ppl, quant_ppl, model = main(args) + onnx_model = onnx.load(os.path.join(args.export_prefix, "model.onnx")) + shutil.rmtree(args.export_prefix) + + +@pytest.fixture( + params=[ + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "eval": False, + "quantize_weight_zero_point": True, + "quantize_input_zero_point": True, + "export_target": "torch_qcdq",}, + { + "model": "hf-internal-testing/tiny-random-LlamaForCausalLM", + "eval": False, + "weight_quant_type": "sym", + "input_quant_type": "sym", + "export_target": "torch_qcdq",},]) +def torch_export_args(default_run_args, request): + args = default_run_args + export_dict = request.param + args.update(**export_dict) + yield args + + +def test_small_models_torch_export(caplog, torch_export_args): + caplog.set_level(logging.INFO) + args = torch_export_args + float_ppl, quant_ppl, model = main(args) + filepath = args.export_prefix + ".pt" + torchscript_model = torch.jit.load(filepath) + os.remove(filepath)