Skip to content

Commit

Permalink
test (ex/llm): Added ONNX export and torchscript tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Aug 21, 2024
1 parent a86d858 commit d916143
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions tests/brevitas_examples/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from argparse import Namespace
from dataclasses import dataclass
import logging
import os
import shutil

import numpy as np
import onnx
import pytest
import torch

from brevitas_examples.llm.main import main
from brevitas_examples.llm.main import parse_args
Expand Down Expand Up @@ -87,6 +90,7 @@ def default_run_args(request):
args.dataset = "c4"
args.eval = True
#args.checkpoint = ptid2pathname(request.node.nodeid) + ".pth" # Example filename which won't clash
args.export_prefix = ptid2pathname(request.node.nodeid)
args.weight_bit_width = 8
args.weight_quant_granularity = "per_channel" # "per_tensor", "per_channel", "per_group".
args.input_bit_width = 8
Expand Down Expand Up @@ -284,3 +288,60 @@ def test_small_models_quant_layer(caplog, layer_args):
args, exp_layer_types = layer_args
float_ppl, quant_ppl, model = main(args)
assert_layer_types(model, exp_layer_types)


@pytest.fixture(
params=[
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"quantize_weight_zero_point": True,
"quantize_input_zero_point": True,
"export_target": "onnx_qcdq",},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"weight_quant_type": "sym",
"input_quant_type": "sym",
"export_target": "onnx_qcdq",},])
def onnx_export_args(default_run_args, request):
args = default_run_args
export_dict = request.param
args.update(**export_dict)
yield args


def test_small_models_onnx_export(caplog, onnx_export_args):
caplog.set_level(logging.INFO)
args = onnx_export_args
float_ppl, quant_ppl, model = main(args)
onnx_model = onnx.load(os.path.join(args.export_prefix, "model.onnx"))
shutil.rmtree(args.export_prefix)


@pytest.fixture(
params=[
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"eval": False,
"quantize_weight_zero_point": True,
"quantize_input_zero_point": True,
"export_target": "torch_qcdq",},
{
"model": "hf-internal-testing/tiny-random-LlamaForCausalLM",
"eval": False,
"weight_quant_type": "sym",
"input_quant_type": "sym",
"export_target": "torch_qcdq",},])
def torch_export_args(default_run_args, request):
args = default_run_args
export_dict = request.param
args.update(**export_dict)
yield args


def test_small_models_torch_export(caplog, torch_export_args):
caplog.set_level(logging.INFO)
args = torch_export_args
float_ppl, quant_ppl, model = main(args)
filepath = args.export_prefix + ".pt"
torchscript_model = torch.jit.load(filepath)
os.remove(filepath)

0 comments on commit d916143

Please sign in to comment.