diff --git a/.github/workflows/ci-sglang-benchmark.yml b/.github/workflows/ci-sglang-benchmark.yml index 3214100ba..e6af1e73c 100644 --- a/.github/workflows/ci-sglang-benchmark.yml +++ b/.github/workflows/ci-sglang-benchmark.yml @@ -4,6 +4,18 @@ # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# =================================== README =================================== +# The `benchmark_sglang` job in this CI is mostly dependent on code outside +# of the `shark-ai` repo itself. By including it here, we are able to maintain +# an apples-to-apples comparison between shortfin and SGLang performance in a +# centralized location, as we place more effort in shortfin LLM performance, and +# WHILE WE WORK TOWARDS A BETTER ALTERNATIVE. + +# We should not be generally repeating this pattern, and should never repeat +# this pattern outside of specifically benchmarking shortfin apps against +# external projects, as part of an organized and clearly defined effort. +# ============================================================================== + name: SGLang Llama Benchmarking Tests on: @@ -21,9 +33,9 @@ concurrency: cancel-in-progress: true jobs: - sglang_bench_serve: + benchmark_shortfin: if: ${{ github.repository_owner == 'nod-ai' || github.event_name != 'schedule' }} - name: "SGLang Serving Benchmark Tests" + name: "SGLang Serving Benchmark With Shortfin" strategy: matrix: version: [3.11] @@ -48,7 +60,7 @@ jobs: id: cache-pip with: path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + key: pip-${{ matrix.version }}-${{ hashFiles('*requirements*.txt','shortfin/requirements*.txt','sharktank/requirements*.txt') }} - name: Install pip deps run: | @@ -72,13 +84,145 @@ jobs: - name: Install SGLang run: pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" - - name: Launch Shortfin Server - run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py --log-cli-level=INFO --html=out/llm/sglang/index.html + - name: Run Shortfin Benchmark Tests + run: pytest -v app_tests/benchmark_tests/llm/sglang_benchmarks/shortfin_benchmark_test.py --log-cli-level=INFO --html=shortfin_index.html --self-contained-html + + - name: Upload pytest report + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 + with: + name: shortfin_benchmark + path: shortfin_index.html + + benchmark_sglang: + name: "SGLang Serving Benchmark With SGLang" + strategy: + matrix: + version: [3.11] + fail-fast: false + runs-on: mi300x-4 + defaults: + run: + shell: bash + env: + PIP_CACHE_DIR: "${{ github.workspace }}/.pip-cache" + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{matrix.version}} + + - name: Cache Pip Packages + uses: actions/cache@6849a6489940f00c2f30c0fb92c6274307ccb58a # v4.1.2 + id: cache-pip + with: + path: ${{ env.PIP_CACHE_DIR }} + key: pip-${{ matrix.version }} + + - name: Install SGLang + run: | + python -m pip install --no-compile --upgrade pip + pip install "git+https://github.com/nod-ai/sglang.git#subdirectory=python" + + - name: Set up Docker Buildx + uses: docker/setup-buildx-action@v3 + + # Instruction for SGLang image sourced from here: + # https://sgl-project.github.io/start/install.html#method-3-using-docker + # We have to run in a docker container due to their vLLM dependency. + # From their pyproject.toml: + # HIP (Heterogeneous-computing Interface for Portability) for AMD + # => base docker rocm/vllm-dev:20241022, not from public vllm whl + # srt_hip = ["sglang[runtime_common]", "torch", "vllm==0.6.3.dev13"] + - name: Pull SGLang Image (Had issues with sglang:v0.3.5.post2-rocm620) + run: | + docker pull lmsysorg/sglang:v0.3.5.post1-rocm620 + + - name: Run SGLang Server + run: | + docker run --rm -d \ + --name=sglang-server \ + --device=/dev/kfd \ + --device=/dev/dri \ + --ipc=host \ + --shm-size 16G \ + --group-add video \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + -v $HOME/dockerx:/dockerx \ + -v /data:/data \ + -p 30000:30000 \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env HF_TOKEN=${{ secrets.HF_TOKEN }} \ + lmsysorg/sglang:v0.3.5.post1-rocm620 \ + python3 -m sglang.launch_server \ + --model-path meta-llama/Llama-3.1-8B-Instruct \ + --host 0.0.0.0 \ + --port 30000 \ + --tp 1 \ + --dtype float16 \ + --disable-cuda-graph + + - name: Run SGLang Benchmark Tests + run: | + pytest -v app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py --port 30000 --log-cli-level=INFO --html=sglang_index.html --self-contained-html + + - name: Stop sglang-server + run: docker stop sglang-server || true # Stop container if it's running + + # Deleting image after run due to large disk space requirement (83 GB) + - name: Cleanup SGLang Image + run: docker image rm lmsysorg/sglang:v0.3.5.post1-rocm620 + + - name: Upload pytest report + uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 + with: + name: sglang_benchmark + path: sglang_index.html + + merge_and_upload_reports: + name: "Merge and upload benchmark reports" + needs: [benchmark_shortfin, benchmark_sglang] + if: needs.benchmark_shortfin.result == 'success' || needs.benchmark_sglang.result == 'success' + runs-on: ubuntu-24.04 + defaults: + run: + shell: bash + steps: + - name: "Setting up Python" + id: setup_python + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: 3.11 + + - name: Install pytest-html-merger + run: pip install pytest-html-merger + + - name: Download shortfin report + uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 + with: + name: shortfin_benchmark + path: reports + continue-on-error: true + + - name: Download sglang report + uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 + with: + name: sglang_benchmark + path: reports + continue-on-error: true + + - name: Merge html reports + run: | + mkdir merged_reports + pytest_html_merger -i reports -o merged_reports/index.html - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@4f9cc6602d3f66b9c108549d475ec49e8ef4d45e # v4.0.0 with: github_token: ${{ secrets.SHARK_PLATFORM_GH_TOKEN }} - publish_dir: ./out/llm/sglang + publish_dir: merged_reports destination_dir: ./llm/sglang keep_files: true diff --git a/.github/workflows/ci-sglang-integration-tests.yml b/.github/workflows/ci-sglang-integration-tests.yml index e5cd47fd7..154657504 100644 --- a/.github/workflows/ci-sglang-integration-tests.yml +++ b/.github/workflows/ci-sglang-integration-tests.yml @@ -49,7 +49,7 @@ jobs: id: cache-pip with: path: ${{ env.PIP_CACHE_DIR }} - key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements.txt') }} + key: pip-${{ steps.setup_python.outputs.python-version }}-${{ hashFiles('*requirements*.txt','shortfin/requirements*.txt','sharktank/requirements*.txt') }} - name: Install pip deps run: | @@ -64,7 +64,7 @@ jobs: # Use newest possible releases to be able to track commits that may # cause errors. - pip install -f https://iree.dev/pip-release-links.html --upgrade \ + pip install -f https://iree.dev/pip-release-links.html --upgrade --pre \ iree-base-compiler \ iree-base-runtime \ "numpy<2.0" diff --git a/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py index 1e1c64b24..95d628bf1 100644 --- a/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py +++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/conftest.py @@ -54,3 +54,17 @@ def pre_process_model(request, tmp_path_factory): compile_model(mlir_path, vmfb_path, settings) return tmp_dir + + +def pytest_addoption(parser): + parser.addoption( + "--port", + action="store", + default="30000", + help="Port that SGLang server is running on", + ) + + +@pytest.fixture(scope="module") +def sglang_args(request): + return request.config.getoption("--port") diff --git a/app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py index b66904570..675f9ef54 100644 --- a/app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py +++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/sglang_benchmark_test.py @@ -5,9 +5,6 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception import logging -import multiprocessing -import os -from pathlib import Path import pytest import time from unittest.mock import patch @@ -17,71 +14,32 @@ from .utils import SGLangBenchmarkArgs, log_jsonl_result -from integration_tests.llm.utils import ( - find_available_port, - start_llm_server, -) +from integration_tests.llm.utils import download_tokenizer, wait_for_server logger = logging.getLogger(__name__) -device_settings = { - "device_flags": [ - "--iree-hal-target-backends=rocm", - "--iree-hip-target=gfx942", - ], - "device": "hip", -} - @pytest.mark.parametrize( - "request_rate,model_param_file_name", - [ - (req_rate, "meta-llama-3.1-8b-instruct.f16.gguf") - for req_rate in [1, 2, 4, 8, 16, 32] - ], + "request_rate,tokenizer_id", + [(req_rate, "NousResearch/Meta-Llama-3-8B") for req_rate in [1, 2, 4, 8, 16, 32]], ) -@pytest.mark.parametrize( - "pre_process_model", - [ - ( - { - "model_name": "llama3_8B_fp16", - "model_param_file_name": "meta-llama-3.1-8b-instruct.f16.gguf", - "settings": device_settings, - "batch_sizes": [1, 4], - } - ) - ], - indirect=True, -) -def test_sglang_benchmark_server( - request_rate, model_param_file_name, pre_process_model -): - # TODO: Remove when multi-device is fixed - os.environ["ROCR_VISIBLE_DEVICES"] = "1" +def test_sglang_benchmark(request_rate, tokenizer_id, sglang_args, tmp_path_factory): + tmp_dir = tmp_path_factory.mktemp("sglang_benchmark_test") - tmp_dir = pre_process_model + # Download tokenizer for llama3_8B_fp16 + download_tokenizer(tmp_dir, tokenizer_id) - config_path = tmp_dir / "config.json" - vmfb_path = tmp_dir / "model.vmfb" - tokenizer_path = tmp_dir / "tokenizer.json" - model_path = tmp_dir / model_param_file_name + logger.info("Beginning SGLang benchmark test...") - # Start shortfin llm server - port = find_available_port() - server_process = start_llm_server( - port, - tokenizer_path, - config_path, - vmfb_path, - model_path, - device_settings, - timeout=30, - ) + port = sglang_args + base_url = f"http://localhost:{port}" + + # Setting a high timeout gives enough time for downloading model artifacts + # and starting up server... Takes a little longer than shortfin. + wait_for_server(base_url, timeout=600) - # Run and collect SGLang Serving Benchmark benchmark_args = SGLangBenchmarkArgs( - backend="shortfin", + backend="sglang", num_prompt=10, base_url=f"http://localhost:{port}", tokenizer=tmp_dir, @@ -95,21 +53,15 @@ def test_sglang_benchmark_server( logger.info("Running SGLang Benchmark with the following args:") logger.info(benchmark_args) + try: start = time.time() with patch.object(bench_serving, "print", side_effect=logger.info): - benchmark_process = multiprocessing.Process( - target=bench_serving.run_benchmark, - args=(benchmark_args.as_namespace(),), + bench_serving.run_benchmark( + benchmark_args.as_namespace(), ) - benchmark_process.start() - benchmark_process.join() - logger.info(f"Benchmark run completed in {str(time.time() - start)} seconds") logger.info("======== RESULTS ========") log_jsonl_result(benchmark_args.output_file) except Exception as e: logger.error(e) - - server_process.terminate() - server_process.wait() diff --git a/app_tests/benchmark_tests/llm/sglang_benchmarks/shortfin_benchmark_test.py b/app_tests/benchmark_tests/llm/sglang_benchmarks/shortfin_benchmark_test.py new file mode 100644 index 000000000..33c21b104 --- /dev/null +++ b/app_tests/benchmark_tests/llm/sglang_benchmarks/shortfin_benchmark_test.py @@ -0,0 +1,116 @@ +# Copyright 2024 Advanced Micro Devices, Inc. +# +# Licensed under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import logging +import multiprocessing +import os +from pathlib import Path +import pytest +import time +from unittest.mock import patch + +pytest.importorskip("sglang") +from sglang import bench_serving + +from .utils import ( + SGLangBenchmarkArgs, + log_jsonl_result, +) + +from integration_tests.llm.utils import ( + find_available_port, + start_llm_server, +) + +logger = logging.getLogger(__name__) + +device_settings = { + "device_flags": [ + "--iree-hal-target-backends=rocm", + "--iree-hip-target=gfx942", + ], + "device": "hip", +} + + +@pytest.mark.parametrize( + "request_rate,model_param_file_name", + [ + (req_rate, "meta-llama-3.1-8b-instruct.f16.gguf") + for req_rate in [1, 2, 4, 8, 16, 32] + ], +) +@pytest.mark.parametrize( + "pre_process_model", + [ + ( + { + "model_name": "llama3_8B_fp16", + "model_param_file_name": "meta-llama-3.1-8b-instruct.f16.gguf", + "settings": device_settings, + "batch_sizes": [1, 4], + } + ) + ], + indirect=True, +) +def test_shortfin_benchmark(request_rate, model_param_file_name, pre_process_model): + # TODO: Remove when multi-device is fixed + os.environ["ROCR_VISIBLE_DEVICES"] = "1" + + tmp_dir = pre_process_model + + config_path = tmp_dir / "config.json" + vmfb_path = tmp_dir / "model.vmfb" + tokenizer_path = tmp_dir / "tokenizer.json" + model_path = tmp_dir / model_param_file_name + + # Start shortfin llm server + port = find_available_port() + server_process = start_llm_server( + port, + tokenizer_path, + config_path, + vmfb_path, + model_path, + device_settings, + timeout=30, + ) + + # Run and collect SGLang Serving Benchmark + benchmark_args = SGLangBenchmarkArgs( + backend="shortfin", + num_prompt=10, + base_url=f"http://localhost:{port}", + tokenizer=tmp_dir, + request_rate=request_rate, + ) + output_file = ( + tmp_dir + / f"{benchmark_args.backend}_{benchmark_args.num_prompt}_{benchmark_args.request_rate}.jsonl" + ) + benchmark_args.output_file = output_file + + logger.info("Running SGLang Benchmark with the following args:") + logger.info(benchmark_args) + try: + start = time.time() + with patch.object(bench_serving, "print", side_effect=logger.info): + benchmark_process = multiprocessing.Process( + target=bench_serving.run_benchmark, + args=(benchmark_args.as_namespace(),), + ) + benchmark_process.start() + benchmark_process.join() + + logger.info(f"Benchmark run completed in {str(time.time() - start)} seconds") + logger.info("======== RESULTS ========") + log_jsonl_result(benchmark_args.output_file) + except Exception as e: + logger.error(e) + + server_process.terminate() + server_process.wait() diff --git a/app_tests/integration_tests/llm/sglang/sglang_frontend_test.py b/app_tests/integration_tests/llm/sglang/sglang_frontend_test.py index efab14ea7..72b3d4052 100644 --- a/app_tests/integration_tests/llm/sglang/sglang_frontend_test.py +++ b/app_tests/integration_tests/llm/sglang/sglang_frontend_test.py @@ -29,7 +29,7 @@ "device": "hip", } -ACCEPTED_THRESHOLD = 0.8 +ACCEPTED_THRESHOLD = 0.7 def compute_similarity(model: SentenceTransformer, sentence_1: str, sentence_2: str):