Skip to content

Commit

Permalink
Fix cli args for TP support
Browse files Browse the repository at this point in the history
Signed-off-by: aviator19941 <avinash.sharma@amd.com>
  • Loading branch information
aviator19941 committed Nov 2, 2024
1 parent 04af5af commit 349ae8a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 44 deletions.
9 changes: 8 additions & 1 deletion sharktank/sharktank/utils/export_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,21 +244,28 @@ def iree_benchmark_vmfb(
benchmark_args = [
f"ROCR_VISIBLE_DEVICES={hip_device_id}",
"iree-benchmark-module",
f"--device=hip://{hip_device_id}",
"--hip_use_streams=true",
"--hip_allow_inline_execution=true",
"--device_allocator=caching",
f"--module={vmfb_name}",
]
if self.tensor_parallelism_size > 1:
import pdb

pdb.set_trace()
base_irpa_path, _ = os.path.splitext(irpa_path)
params = [
f"--parameters=model={base_irpa_path}.rank{i}.irpa"
for i in range(self.tensor_parallelism_size)
]
devices = [
f"--device=hip[{i}]" for i in range(self.tensor_parallelism_size)
]
else:
params = [f"--parameters=model={irpa_path}"]
devices = [f"--device=hip://{hip_device_id}"]
benchmark_args += params
benchmark_args += devices
benchmark_args += args
cmd = subprocess.list2cmdline(benchmark_args)
logging.getLogger().info(f"Launching run command:\n" f"cd {cwd} && {cmd}")
Expand Down
49 changes: 6 additions & 43 deletions sharktank/tests/models/llama/benchmark_amdgpu_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,16 +146,6 @@ def testBenchmark8B_f16_Decomposed(self):
output_vmfb = self.llama8b_f16_decomposed_artifacts.create_file(
suffix=".vmfb", prefix=output_file_name
)
output_shard_file_name = str(
self.artifacts_dir
/ f"llama3.1_8b_fp16_decomposed_tp{self.tensor_parallelism_size}_parameters.irpa"
)
# shard_irpa file
shard_return_code = self.llama8b_f16_decomposed_artifacts.shard_irpa_file(
gguf_file=self.gguf_path, output_irpa=output_shard_file_name
)
if shard_return_code == 0:
self.irpa_path = output_shard_file_name
export_return_code = self.llama8b_f16_decomposed_artifacts.export_to_mlir(
mlir_path=output_mlir,
json_path=output_json,
Expand Down Expand Up @@ -196,16 +186,6 @@ def testBenchmark8B_f16_Decodeposed(self):
suffix=".vmfb", prefix=output_file_name
)
self.llama8b_f16_decodeposed_artifacts.attention_kernel = "torch"
output_shard_file_name = str(
self.artifacts_dir
/ f"llama3.1_8b_fp16_tp{self.tensor_parallelism_size}_parameters_torch_sdpa.irpa"
)
# shard_irpa file
shard_return_code = self.llama8b_f16_decodeposed_artifacts.shard_irpa_file(
gguf_file=self.gguf_path, output_irpa=output_shard_file_name
)
if shard_return_code == 0:
self.irpa_path = output_shard_file_name
export_return_code = self.llama8b_f16_decodeposed_artifacts.export_to_mlir(
mlir_path=output_mlir,
json_path=output_json,
Expand Down Expand Up @@ -247,16 +227,6 @@ def testBenchmark8B_fp8_Decomposed(self):
output_vmfb = self.llama8b_fp8_decomposed_artifacts.create_file(
suffix=".vmfb", prefix=output_file_name
)
output_shard_file_name = str(
self.artifacts_dir
/ f"llama3.1_8b_fp8_decomposed_tp{self.tensor_parallelism_size}_parameters.irpa"
)
# shard_irpa file
shard_return_code = self.llama8b_fp8_decomposed_artifacts.shard_irpa_file(
gguf_file=self.gguf_path, output_irpa=output_shard_file_name
)
if shard_return_code == 0:
self.irpa_path = output_shard_file_name
export_return_code = self.llama8b_fp8_decomposed_artifacts.export_to_mlir(
mlir_path=output_mlir,
json_path=output_json,
Expand Down Expand Up @@ -298,16 +268,6 @@ def testBenchmark8B_fp8_Decodeposed(self):
output_vmfb = self.llama8b_fp8_decodeposed_artifacts.create_file(
suffix=".vmfb", prefix=output_file_name
)
output_shard_file_name = str(
self.artifacts_dir
/ f"llama3.1_8b_fp8_decodeposed_tp{self.tensor_parallelism_size}_parameters.irpa"
)
# shard_irpa file
shard_return_code = self.llama8b_fp8_decodeposed_artifacts.shard_irpa_file(
gguf_file=self.gguf_path, output_irpa=output_shard_file_name
)
if shard_return_code == 0:
self.irpa_path = output_shard_file_name
export_return_code = self.llama8b_fp8_decodeposed_artifacts.export_to_mlir(
mlir_path=output_mlir,
json_path=output_json,
Expand Down Expand Up @@ -421,6 +381,9 @@ def setUp(self):
"--benchmark_repetitions=3",
]

@pytest.mark.xfail(
reason="Runtime Error", strict=True, raises=IreeBenchmarkException
)
def testBenchmark70B_f16_TP8_Decomposed(self):
output_file_name = self.dir_path_70b / "f16_decomposed"
output_mlir = self.llama70b_f16_decomposed_artifacts.create_file(
Expand Down Expand Up @@ -760,7 +723,7 @@ def testBenchmark405B_f16_TP8_Decomposed(self):
)

@pytest.mark.xfail(reason="Compile Error", strict=True, raises=IreeCompileException)
def testBenchmark405B_f16_Decodeposed(self):
def testBenchmark405B_f16_TP8_Decodeposed(self):
output_file_name = self.dir_path_405b / "f16_torch"
output_mlir = self.llama405b_f16_decodeposed_artifacts.create_file(
suffix=".mlir", prefix=output_file_name
Expand Down Expand Up @@ -811,7 +774,7 @@ def testBenchmark405B_f16_Decodeposed(self):
@pytest.mark.xfail(
reason="Test not yet implemented", strict=True, raises=ExportMlirException
)
def testBenchmark405B_fp8_Decomposed(self):
def testBenchmark405B_fp8_TP8_Decomposed(self):
output_file_name = self.dir_path_405b / "fp8_decomposed"
output_mlir = self.llama405b_fp8_decomposed_artifacts.create_file(
suffix=".mlir", prefix=output_file_name
Expand Down Expand Up @@ -862,7 +825,7 @@ def testBenchmark405B_fp8_Decomposed(self):
@pytest.mark.xfail(
reason="Test not yet implemented", strict=True, raises=ExportMlirException
)
def testBenchmark405B_fp8_Decodeposed(self):
def testBenchmark405B_fp8_TP8_Decodeposed(self):
output_file_name = self.dir_path_405b / "fp8_torch"
output_mlir = self.llama405b_fp8_decodeposed_artifacts.create_file(
suffix=".mlir", prefix=output_file_name
Expand Down

0 comments on commit 349ae8a

Please sign in to comment.