From 349ae8a3edf3b65f418562fe49ba1a6a196cb779 Mon Sep 17 00:00:00 2001 From: aviator19941 Date: Sat, 2 Nov 2024 04:59:26 -0500 Subject: [PATCH] Fix cli args for TP support Signed-off-by: aviator19941 --- sharktank/sharktank/utils/export_artifacts.py | 9 +++- .../models/llama/benchmark_amdgpu_test.py | 49 +++---------------- 2 files changed, 14 insertions(+), 44 deletions(-) diff --git a/sharktank/sharktank/utils/export_artifacts.py b/sharktank/sharktank/utils/export_artifacts.py index 4bf2c9ce3..4cf7a0e38 100644 --- a/sharktank/sharktank/utils/export_artifacts.py +++ b/sharktank/sharktank/utils/export_artifacts.py @@ -244,21 +244,28 @@ def iree_benchmark_vmfb( benchmark_args = [ f"ROCR_VISIBLE_DEVICES={hip_device_id}", "iree-benchmark-module", - f"--device=hip://{hip_device_id}", "--hip_use_streams=true", "--hip_allow_inline_execution=true", "--device_allocator=caching", f"--module={vmfb_name}", ] if self.tensor_parallelism_size > 1: + import pdb + + pdb.set_trace() base_irpa_path, _ = os.path.splitext(irpa_path) params = [ f"--parameters=model={base_irpa_path}.rank{i}.irpa" for i in range(self.tensor_parallelism_size) ] + devices = [ + f"--device=hip[{i}]" for i in range(self.tensor_parallelism_size) + ] else: params = [f"--parameters=model={irpa_path}"] + devices = [f"--device=hip://{hip_device_id}"] benchmark_args += params + benchmark_args += devices benchmark_args += args cmd = subprocess.list2cmdline(benchmark_args) logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}") diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index 9f6523972..689a6277e 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -146,16 +146,6 @@ def testBenchmark8B_f16_Decomposed(self): output_vmfb = self.llama8b_f16_decomposed_artifacts.create_file( suffix=".vmfb", prefix=output_file_name ) - output_shard_file_name = str( - self.artifacts_dir - / f"llama3.1_8b_fp16_decomposed_tp{self.tensor_parallelism_size}_parameters.irpa" - ) - # shard_irpa file - shard_return_code = self.llama8b_f16_decomposed_artifacts.shard_irpa_file( - gguf_file=self.gguf_path, output_irpa=output_shard_file_name - ) - if shard_return_code == 0: - self.irpa_path = output_shard_file_name export_return_code = self.llama8b_f16_decomposed_artifacts.export_to_mlir( mlir_path=output_mlir, json_path=output_json, @@ -196,16 +186,6 @@ def testBenchmark8B_f16_Decodeposed(self): suffix=".vmfb", prefix=output_file_name ) self.llama8b_f16_decodeposed_artifacts.attention_kernel = "torch" - output_shard_file_name = str( - self.artifacts_dir - / f"llama3.1_8b_fp16_tp{self.tensor_parallelism_size}_parameters_torch_sdpa.irpa" - ) - # shard_irpa file - shard_return_code = self.llama8b_f16_decodeposed_artifacts.shard_irpa_file( - gguf_file=self.gguf_path, output_irpa=output_shard_file_name - ) - if shard_return_code == 0: - self.irpa_path = output_shard_file_name export_return_code = self.llama8b_f16_decodeposed_artifacts.export_to_mlir( mlir_path=output_mlir, json_path=output_json, @@ -247,16 +227,6 @@ def testBenchmark8B_fp8_Decomposed(self): output_vmfb = self.llama8b_fp8_decomposed_artifacts.create_file( suffix=".vmfb", prefix=output_file_name ) - output_shard_file_name = str( - self.artifacts_dir - / f"llama3.1_8b_fp8_decomposed_tp{self.tensor_parallelism_size}_parameters.irpa" - ) - # shard_irpa file - shard_return_code = self.llama8b_fp8_decomposed_artifacts.shard_irpa_file( - gguf_file=self.gguf_path, output_irpa=output_shard_file_name - ) - if shard_return_code == 0: - self.irpa_path = output_shard_file_name export_return_code = self.llama8b_fp8_decomposed_artifacts.export_to_mlir( mlir_path=output_mlir, json_path=output_json, @@ -298,16 +268,6 @@ def testBenchmark8B_fp8_Decodeposed(self): output_vmfb = self.llama8b_fp8_decodeposed_artifacts.create_file( suffix=".vmfb", prefix=output_file_name ) - output_shard_file_name = str( - self.artifacts_dir - / f"llama3.1_8b_fp8_decodeposed_tp{self.tensor_parallelism_size}_parameters.irpa" - ) - # shard_irpa file - shard_return_code = self.llama8b_fp8_decodeposed_artifacts.shard_irpa_file( - gguf_file=self.gguf_path, output_irpa=output_shard_file_name - ) - if shard_return_code == 0: - self.irpa_path = output_shard_file_name export_return_code = self.llama8b_fp8_decodeposed_artifacts.export_to_mlir( mlir_path=output_mlir, json_path=output_json, @@ -421,6 +381,9 @@ def setUp(self): "--benchmark_repetitions=3", ] + @pytest.mark.xfail( + reason="Runtime Error", strict=True, raises=IreeBenchmarkException + ) def testBenchmark70B_f16_TP8_Decomposed(self): output_file_name = self.dir_path_70b / "f16_decomposed" output_mlir = self.llama70b_f16_decomposed_artifacts.create_file( @@ -760,7 +723,7 @@ def testBenchmark405B_f16_TP8_Decomposed(self): ) @pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException) - def testBenchmark405B_f16_Decodeposed(self): + def testBenchmark405B_f16_TP8_Decodeposed(self): output_file_name = self.dir_path_405b / "f16_torch" output_mlir = self.llama405b_f16_decodeposed_artifacts.create_file( suffix=".mlir", prefix=output_file_name @@ -811,7 +774,7 @@ def testBenchmark405B_f16_Decodeposed(self): @pytest.mark.xfail( reason="Test not yet implemented", strict=True, raises=ExportMlirException ) - def testBenchmark405B_fp8_Decomposed(self): + def testBenchmark405B_fp8_TP8_Decomposed(self): output_file_name = self.dir_path_405b / "fp8_decomposed" output_mlir = self.llama405b_fp8_decomposed_artifacts.create_file( suffix=".mlir", prefix=output_file_name @@ -862,7 +825,7 @@ def testBenchmark405B_fp8_Decomposed(self): @pytest.mark.xfail( reason="Test not yet implemented", strict=True, raises=ExportMlirException ) - def testBenchmark405B_fp8_Decodeposed(self): + def testBenchmark405B_fp8_TP8_Decodeposed(self): output_file_name = self.dir_path_405b / "fp8_torch" output_mlir = self.llama405b_fp8_decodeposed_artifacts.create_file( suffix=".mlir", prefix=output_file_name