From d9e8a386523df888ae33168a0fe696e2ed10c7dc Mon Sep 17 00:00:00 2001 From: Nick Fraser Date: Mon, 19 Aug 2024 12:20:40 +0100 Subject: [PATCH] Fix (example/llm): Fix API for newer `torch_mlir` releases. --- .../llm/llm_quant/sharded_mlir_group_export.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py b/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py index ef0a72880..528ca9100 100644 --- a/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py +++ b/src/brevitas_examples/llm/llm_quant/sharded_mlir_group_export.py @@ -44,7 +44,7 @@ import torch from torch._decomp import get_decompositions import torch_mlir -from torch_mlir import TensorPlaceholder +from torch_mlir.torchscript import TensorPlaceholder from tqdm import tqdm from brevitas.backport.fx._symbolic_trace import wrap @@ -313,7 +313,7 @@ def compile_to_vmfb(inputs, layers, export_context_manager, export_class, is_fir if is_first: ts_g = compile_vicuna_layer( export_context_manager, export_class, layer, inputs[0], inputs[1], inputs[2]) - module = torch_mlir.compile( + module = torch_mlir.torchscript.compile( ts_g, (hidden_states_placeholder, inputs[1], inputs[2]), output_type="torch", backend_legal_ops=["quant.matmul_rhs_group_quant"], @@ -330,7 +330,7 @@ def compile_to_vmfb(inputs, layers, export_context_manager, export_class, is_fir inputs[2], inputs[3], inputs[4]) - module = torch_mlir.compile( + module = torch_mlir.torchscript.compile( ts_g, ( inputs[0],