Skip to content

Commit

Permalink
Feat (ex/llm): Allow supplying a prefix for the exported ONNX/TS model
Browse files Browse the repository at this point in the history
  • Loading branch information
nickfraser committed Aug 21, 2024
1 parent 277a8b0 commit a86d858
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/brevitas_examples/llm/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,15 @@ def model_export(model, ref_input, args):
export_manager = StdQCDQONNXManager
export_manager.change_weight_export(export_weight_q_node=True)

print(f"Exporting the model in ./quantized_onnx/{args.model.replace('/', '-')}")
print(f"Exporting the model in ./{args.export_prefix}")
with torch.no_grad(), brevitas_proxy_export_mode(model, export_manager=export_manager):
onnx_export_from_model(
model,
f"./quantized_onnx/{args.model.replace('/', '-')}",
f"./{args.export_prefix}",
task="text-generation-with-past",
do_validation=False)
elif args.export_target == 'torch_qcdq':
export_torch_qcdq(
model, ref_input['input_ids'], export_path=f"{args.model.replace('/', '-')}.pt")
export_torch_qcdq(model, ref_input['input_ids'], export_path=f"{args.export_prefix}.pt")


def validate(args):
Expand Down Expand Up @@ -115,6 +114,9 @@ def main(args):
validate(args)
set_seed(args.seed)

if args.export_prefix is None:
args.export_prefix = f"{args.model.replace('/', '--')}"

if args.no_float16:
dtype = torch.float32
else:
Expand Down Expand Up @@ -452,6 +454,13 @@ def parse_args(args):
'sharded_torchmlir_group_weight',
'sharded_packed_torchmlir_group_weight'],
help='Model export.')
parser.add_argument(
'--export-prefix',
type=str,
default=None,
help=
"Path prefix to use for the various export flows. If None, a path will be derived from the model name (default: %(default)s)"
)
parser.add_argument(
'--checkpoint-name',
type=str,
Expand Down

0 comments on commit a86d858

Please sign in to comment.