diff --git a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml index eec2a51e2f8fd..3db77d5f16022 100644 --- a/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml +++ b/.buildkite/nightly-benchmarks/benchmark-pipeline.yaml @@ -9,8 +9,11 @@ steps: - image: badouralix/curl-jq command: - sh .buildkite/nightly-benchmarks/scripts/wait-for-image.sh + - wait + - label: "A100" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" agents: queue: A100 plugins: @@ -41,20 +44,43 @@ steps: - name: devshm emptyDir: medium: Memory - # - label: "H100" - # agents: - # queue: H100 - # plugins: - # - docker#v5.11.0: - # image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT - # command: - # - bash - # - .buildkite/nightly-benchmarks/run-benchmarks-suite.sh - # mount-buildkite-agent: true - # propagate-environment: true - # ipc: host - # gpus: all - # environment: - # - VLLM_USAGE_SOURCE - # - HF_TOKEN + - label: "H200" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" + agents: + queue: H200 + plugins: + - docker#v5.12.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + command: + - bash + - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh + mount-buildkite-agent: true + propagate-environment: true + ipc: host + gpus: 4,5,6,7 + volumes: + - /data/benchmark-hf-cache:/root/.cache/huggingface + environment: + - VLLM_USAGE_SOURCE + - HF_TOKEN + + - label: "H100" + # skip: "use this flag to conditionally skip the benchmark step, useful for PR testing" + agents: + queue: H100 + plugins: + - docker#v5.12.0: + image: public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:$BUILDKITE_COMMIT + command: + - bash + - .buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh + mount-buildkite-agent: true + propagate-environment: true + ipc: host + gpus: all # see CUDA_VISIBLE_DEVICES for actual GPUs used + volumes: + - /data/benchmark-hf-cache:/root/.cache/huggingface + environment: + - VLLM_USAGE_SOURCE + - HF_TOKEN diff --git a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py index 7cf05610b9953..9d3646e2f6a15 100644 --- a/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py +++ b/.buildkite/nightly-benchmarks/scripts/convert-results-json-to-markdown.py @@ -157,6 +157,18 @@ def results_to_json(latency, throughput, serving): throughput_results, serving_results) + for df in [latency_results, serving_results, throughput_results]: + if df.empty: + continue + + # Sort all dataframes by their respective "Test name" columns + df.sort_values(by="Test name", inplace=True) + + # The GPUs sometimes come in format of "GPUTYPE\nGPUTYPE\n...", + # we want to turn it into "8xGPUTYPE" + df["GPU"] = df["GPU"].apply( + lambda x: f"{len(x.split('\n'))}x{x.split('\n')[0]}") + # get markdown tables latency_md_table = tabulate(latency_results, headers='keys', diff --git a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh index d397b05cdff23..0d16a83781ab2 100644 --- a/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh +++ b/.buildkite/nightly-benchmarks/scripts/run-performance-benchmarks.sh @@ -6,6 +6,7 @@ # Do not set -e, as the mixtral 8x22B model tends to crash occasionally # and we still want to see other benchmarking results even when mixtral crashes. +set -x set -o pipefail check_gpus() { @@ -85,11 +86,7 @@ kill_gpu_processes() { ps -aux lsof -t -i:8000 | xargs -r kill -9 - pkill -f pt_main_thread - # this line doesn't work now - # ps aux | grep python | grep openai | awk '{print $2}' | xargs -r kill -9 - pkill -f python3 - pkill -f /usr/bin/python3 + pgrep python3 | xargs -r kill -9 # wait until GPU memory usage smaller than 1GB @@ -289,7 +286,7 @@ run_serving_tests() { # run the server echo "Running test case $test_name" echo "Server command: $server_command" - eval "$server_command" & + bash -c "$server_command" & server_pid=$! # wait until the server is alive @@ -322,7 +319,7 @@ run_serving_tests() { echo "Running test case $test_name with qps $qps" echo "Client command: $client_command" - eval "$client_command" + bash -c "$client_command" # record the benchmarking commands jq_output=$(jq -n \ diff --git a/.buildkite/run-amd-test.sh b/.buildkite/run-amd-test.sh index 902e162720b89..3515ccd65667e 100755 --- a/.buildkite/run-amd-test.sh +++ b/.buildkite/run-amd-test.sh @@ -85,7 +85,6 @@ if [[ $commands == *" kernels "* ]]; then --ignore=kernels/test_encoder_decoder_attn.py \ --ignore=kernels/test_flash_attn.py \ --ignore=kernels/test_flashinfer.py \ - --ignore=kernels/test_gguf.py \ --ignore=kernels/test_int8_quant.py \ --ignore=kernels/test_machete_gemm.py \ --ignore=kernels/test_mamba_ssm.py \ diff --git a/.buildkite/run-cpu-test-ppc64le.sh b/.buildkite/run-cpu-test-ppc64le.sh index 5d7a0bff90963..bc06838d804ff 100755 --- a/.buildkite/run-cpu-test-ppc64le.sh +++ b/.buildkite/run-cpu-test-ppc64le.sh @@ -4,49 +4,11 @@ # It serves a sanity check for compilation and basic model usage. set -ex -# Try building the docker image -docker build -t cpu-test -f Dockerfile.ppc64le . - # Setup cleanup -remove_docker_container() { docker rm -f cpu-test || true; } +remove_docker_container() { docker rm -f cpu-test || true; docker system prune -f; } trap remove_docker_container EXIT remove_docker_container -# Run the image, setting --shm-size=4g for tensor parallel. -source /etc/environment -#docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN --env VLLM_CPU_KVCACHE_SPACE=4 --shm-size=4g --name cpu-test cpu-test -docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/huggingface --privileged=true --network host -e HF_TOKEN="$HF_TOKEN" --name cpu-test cpu-test - -function cpu_tests() { - set -e - - # Run basic model test - docker exec cpu-test bash -c " - set -e - pip install pytest pytest-asyncio \ - decord einops librosa peft Pillow sentence-transformers soundfile \ - transformers_stream_generator matplotlib datamodel_code_generator - pip install torchvision --index-url https://download.pytorch.org/whl/cpu - pytest -v -s tests/models/decoder_only/language -m cpu_model - pytest -v -s tests/models/embedding/language -m cpu_model - pytest -v -s tests/models/encoder_decoder/language -m cpu_model - pytest -v -s tests/models/decoder_only/audio_language -m cpu_model - pytest -v -s tests/models/decoder_only/vision_language -m cpu_model" - - # online inference - docker exec cpu-test bash -c " - set -e - python3 -m vllm.entrypoints.openai.api_server --model facebook/opt-125m & - timeout 600 bash -c 'until curl localhost:8000/v1/models; do sleep 1; done' || exit 1 - python3 benchmarks/benchmark_serving.py \ - --backend vllm \ - --dataset-name random \ - --model facebook/opt-125m \ - --num-prompts 20 \ - --endpoint /v1/completions \ - --tokenizer facebook/opt-125m" -} +# Try building the docker image +docker build -t cpu-test -f Dockerfile.ppc64le . -# All of CPU tests are expected to be finished less than 25 mins. -export -f cpu_tests -timeout 25m bash -c "cpu_tests" diff --git a/.buildkite/run-cpu-test.sh b/.buildkite/run-cpu-test.sh index f0128f091b742..4f1729d46dae2 100644 --- a/.buildkite/run-cpu-test.sh +++ b/.buildkite/run-cpu-test.sh @@ -25,6 +25,7 @@ docker run -itd --entrypoint /bin/bash -v ~/.cache/huggingface:/root/.cache/hugg function cpu_tests() { set -e + export NUMA_NODE=$2 # offline inference docker exec cpu-test-avx2-"$NUMA_NODE" bash -c " @@ -57,6 +58,12 @@ function cpu_tests() { pytest -s -v \ tests/quantization/test_ipex_quant.py" + # Run chunked-prefill and prefix-cache test + docker exec cpu-test-"$NUMA_NODE" bash -c " + set -e + pytest -s -v -k cpu_model \ + tests/basic_correctness/test_chunked_prefill.py" + # online inference docker exec cpu-test-"$NUMA_NODE" bash -c " set -e @@ -75,4 +82,4 @@ function cpu_tests() { # All of CPU tests are expected to be finished less than 25 mins. export -f cpu_tests -timeout 25m bash -c "cpu_tests $CORE_RANGE" +timeout 30m bash -c "cpu_tests $CORE_RANGE $NUMA_NODE" diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 545b253c07db0..fc23c9cff0d87 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -9,8 +9,7 @@ # label(str): the name of the test. emoji allowed. # fast_check(bool): whether to run this on each commit on fastcheck pipeline. # fast_check_only(bool): run this test on fastcheck pipeline only -# nightly(bool): run this test in nightly pipeline only -# optional(bool): never run this test by default (i.e. need to unblock manually) +# optional(bool): never run this test by default (i.e. need to unblock manually) unless it's scheduled nightly run. # command(str): the single command to run for tests. incompatible with commands. # commands(list): the list of commands to run for test. incompatbile with command. # mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd] @@ -51,7 +50,9 @@ steps: - tests/multimodal - tests/test_utils - tests/worker + - tests/test_lazy_torch_compile.py commands: + - python3 test_lazy_torch_compile.py - pytest -v -s mq_llm_engine # MQLLMEngine - pytest -v -s async_engine # AsyncLLMEngine - NUM_SCHEDULER_STEPS=4 pytest -v -s async_engine/test_async_llm_engine.py @@ -229,7 +230,7 @@ steps: source_file_dependencies: - vllm/lora - tests/lora - command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py + command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore lora/test_long_context.py lora/test_chatglm3_tp.py lora/test_llama_tp.py parallelism: 4 - label: "PyTorch Fullgraph Smoke Test" # 9min @@ -336,7 +337,7 @@ steps: - pytest -v -s models/embedding/vision_language -m core_model - label: Language Models Test (Extended) # 50min - nightly: true + optional: true source_file_dependencies: - vllm/ - tests/models/decoder_only/language @@ -362,7 +363,7 @@ steps: - pytest -v -s models/encoder_decoder/vision_language -m core_model - label: Multi-Modal Models Test (Extended) # 1h15m - nightly: true + optional: true source_file_dependencies: - vllm/ - tests/models/decoder_only/audio_language @@ -474,18 +475,23 @@ steps: - pytest -v -s distributed/test_pp_cudagraph.py - pytest -v -s distributed/test_pipeline_parallel.py -- label: LoRA Long Context (Distributed) # 11min - # This test runs llama 13B, so it is required to run on 4 GPUs. +- label: LoRA TP Test (Distributed) num_gpus: 4 soft_fail: true source_file_dependencies: - vllm/lora - - tests/lora/test_long_context + - tests/lora commands: # FIXIT: find out which code initialize cuda before running the test # before the fix, we need to use spawn to test it - export VLLM_WORKER_MULTIPROC_METHOD=spawn + # This test runs llama 13B, so it is required to run on 4 GPUs. - pytest -v -s -x lora/test_long_context.py + # There is some Tensor Parallelism related processing logic in LoRA that + # requires multi-GPU testing for validation. + - pytest -v -s -x lora/test_chatglm3_tp.py + - pytest -v -s -x lora/test_llama_tp.py + - label: Weight Loading Multiple GPU Test # 33min working_dir: "/vllm-workspace/tests" @@ -513,6 +519,7 @@ steps: - label: Distributed Tests (A100) # optional gpu: a100 + optional: true num_gpus: 4 source_file_dependencies: - vllm/ @@ -526,6 +533,7 @@ steps: - label: LM Eval Large Models # optional gpu: a100 + optional: true num_gpus: 4 working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" source_file_dependencies: diff --git a/.buildkite/upload-wheels.sh b/.buildkite/upload-wheels.sh index 541b395eddbe7..7345dd4e66b29 100644 --- a/.buildkite/upload-wheels.sh +++ b/.buildkite/upload-wheels.sh @@ -25,7 +25,12 @@ echo "Version: $version" # If the version contains "dev", rename it to v1.0.0.dev for consistency if [[ $version == *dev* ]]; then - new_version="1.0.0.dev" + suffix="${version##*.}" + if [[ $suffix == cu* ]]; then + new_version="1.0.0.dev+${suffix}" + else + new_version="1.0.0.dev" + fi new_wheel="${wheel/$version/$new_version}" mv -- "$wheel" "$new_wheel" wheel="$new_wheel" diff --git a/.github/workflows/sphinx-lint.yml b/.github/workflows/sphinx-lint.yml new file mode 100644 index 0000000000000..e0bb24276a653 --- /dev/null +++ b/.github/workflows/sphinx-lint.yml @@ -0,0 +1,32 @@ +name: Lint documentation + +on: + push: + branches: + - main + paths: + - "docs/**" + pull_request: + branches: + - main + paths: + - "docs/**" + +jobs: + sphinx-lint: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.12"] + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -r requirements-lint.txt + - name: Linting docs + run: tools/sphinx-lint.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index 5acbd762ee957..ff34225537cdd 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -196,6 +196,7 @@ set(VLLM_EXT_SRC "csrc/quantization/gptq/q_gemm.cu" "csrc/quantization/compressed_tensors/int8_quant_kernels.cu" "csrc/quantization/fp8/common.cu" + "csrc/quantization/gguf/gguf_kernel.cu" "csrc/cuda_utils_kernels.cu" "csrc/prepare_inputs/advance_step.cu" "csrc/torch_bindings.cpp") @@ -206,7 +207,19 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Set CUTLASS_REVISION manually -- its revision detection doesn't work in this case. set(CUTLASS_REVISION "v3.5.1" CACHE STRING "CUTLASS revision to use") - FetchContent_Declare( + # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided + if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) + set(VLLM_CUTLASS_SRC_DIR $ENV{VLLM_CUTLASS_SRC_DIR}) + endif() + + if(VLLM_CUTLASS_SRC_DIR) + if(NOT IS_ABSOLUTE VLLM_CUTLASS_SRC_DIR) + get_filename_component(VLLM_CUTLASS_SRC_DIR "${VLLM_CUTLASS_SRC_DIR}" ABSOLUTE) + endif() + message(STATUS "The VLLM_CUTLASS_SRC_DIR is set, using ${VLLM_CUTLASS_SRC_DIR} for compilation") + FetchContent_Declare(cutlass SOURCE_DIR ${VLLM_CUTLASS_SRC_DIR}) + else() + FetchContent_Declare( cutlass GIT_REPOSITORY https://github.com/nvidia/cutlass.git GIT_TAG v3.5.1 @@ -216,7 +229,8 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Important: If GIT_SHALLOW is enabled then GIT_TAG works only with branch names and tags. # So if the GIT_TAG above is updated to a commit hash, GIT_SHALLOW must be set to FALSE GIT_SHALLOW TRUE - ) + ) + endif() FetchContent_MakeAvailable(cutlass) list(APPEND VLLM_EXT_SRC @@ -224,7 +238,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" - "csrc/quantization/gguf/gguf_kernel.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") diff --git a/Dockerfile.rocm b/Dockerfile.rocm index 8fb79afaebe97..62d4a9b4909c3 100644 --- a/Dockerfile.rocm +++ b/Dockerfile.rocm @@ -51,9 +51,9 @@ RUN --mount=type=cache,target=/root/.cache/pip \ *"rocm-6.2"*) \ python3 -m pip uninstall -y torch torchvision \ && python3 -m pip install --pre \ - torch==2.6.0.dev20240918 \ + torch==2.6.0.dev20241113+rocm6.2 \ 'setuptools-scm>=8' \ - torchvision==0.20.0.dev20240918 \ + torchvision==0.20.0.dev20241113+rocm6.2 \ --extra-index-url https://download.pytorch.org/whl/nightly/rocm6.2;; \ *) ;; esac diff --git a/README.md b/README.md index 0ef073210d070..cfeb24cbb5823 100644 --- a/README.md +++ b/README.md @@ -16,9 +16,9 @@ Easy, fast, and cheap LLM serving for everyone --- *Latest News* 🔥 -- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing). +- [2024/11] We hosted [the seventh vLLM meetup](https://lu.ma/h0qvrajz) with Snowflake! Please find the meetup slides from vLLM team [here](https://docs.google.com/presentation/d/1e3CxQBV3JsfGp30SwyvS3eM_tW-ghOhJ9PAJGK6KR54/edit?usp=sharing), and Snowflake team [here](https://docs.google.com/presentation/d/1qF3RkDAbOULwz9WK5TOltt2fE9t6uIc_hVNLFAaQX6A/edit?usp=sharing). - [2024/10] We have just created a developer slack ([slack.vllm.ai](https://slack.vllm.ai)) focusing on coordinating contributions and discussing features. Please feel free to join us there! -- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://raysummit.anyscale.com/flow/anyscale/raysummit2024/landing/page/sessioncatalog?tab.day=20241001&search.sessiontracks=1719251906298001uzJ2) from other vLLM contributors and users! +- [2024/10] Ray Summit 2024 held a special track for vLLM! Please find the opening talk slides from the vLLM team [here](https://docs.google.com/presentation/d/1B_KQxpHBTRa_mDF-tR6i8rWdOU5QoTZNcEg2MKZxEHM/edit?usp=sharing). Learn more from the [talks](https://www.youtube.com/playlist?list=PLzTswPQNepXl6AQwifuwUImLPFRVpksjR) from other vLLM contributors and users! - [2024/09] We hosted [the sixth vLLM meetup](https://lu.ma/87q3nvnh) with NVIDIA! Please find the meetup slides [here](https://docs.google.com/presentation/d/1wrLGwytQfaOTd5wCGSPNhoaW3nq0E-9wqyP7ny93xRs/edit?usp=sharing). - [2024/07] We hosted [the fifth vLLM meetup](https://lu.ma/lp0gyjqr) with AWS! Please find the meetup slides [here](https://docs.google.com/presentation/d/1RgUD8aCfcHocghoP3zmXzck9vX3RCI9yfUAB2Bbcl4Y/edit?usp=sharing). - [2024/07] In partnership with Meta, vLLM officially supports Llama 3.1 with FP8 quantization and pipeline parallelism! Please check out our blog post [here](https://blog.vllm.ai/2024/07/23/llama31.html). diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 25c8b1bbf3e22..c3fed56e8a956 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -54,6 +54,7 @@ async def async_request_tgi( "do_sample": True, "temperature": 0.01, # TGI does not accept 0.0 temperature. "top_p": 0.99, # TGI does not accept 1.0 top_p. + "truncate": request_func_input.prompt_len, # TGI does not accept ignore_eos flag. } payload = { diff --git a/csrc/ops.h b/csrc/ops.h index 672e608e9c47e..ea001190bc202 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -128,6 +128,7 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, int64_t thx, int64_t thy); torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); +#endif torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n); @@ -138,6 +139,7 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row); +#ifndef USE_ROCM bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, diff --git a/csrc/quantization/gguf/ggml-common.h b/csrc/quantization/gguf/ggml-common.h index fba94fd1d157b..d42205a6571db 100644 --- a/csrc/quantization/gguf/ggml-common.h +++ b/csrc/quantization/gguf/ggml-common.h @@ -1,7 +1,7 @@ // copied from https://github.com/ggerganov/llama.cpp/blob/b2899/ggml-common.h #define QK_K 256 #define K_QUANTS_PER_ITERATION 2 -#define WARP_SIZE 32 +#define WARP_SIZE_GGUF 32 #define K_SCALE_SIZE 12 #define CUDA_DEQUANTIZE_BLOCK_SIZE 256 #define CUDA_QUANTIZE_BLOCK_SIZE 256 @@ -1112,4 +1112,19 @@ static __device__ __forceinline__ int __dp4a(const int a, const int b, int c) { #endif return c; } + +static __device__ __forceinline__ uint32_t __vcmpeq4(const uint32_t a, const uint32_t b) { + uint32_t neq = a^b; + return !(neq & 0xff000000) * 0xff000000 | + !(neq & 0x00ff0000) * 0x00ff0000 | + !(neq & 0x0000ff00) * 0x0000ff00 | + !(neq & 0x000000ff) * 0x000000ff; +} + +static __device__ __forceinline__ uint32_t __vsub4(const uint32_t a, const uint32_t b) { + return (static_cast(((a & 0xff000000) >> 24) - ((b & 0xff000000) >> 24)) << 24) + + (static_cast(((a & 0x00ff0000) >> 16) - ((b & 0x00ff0000) >> 16)) << 16) + + (static_cast(((a & 0x0000ff00) >> 8) - ((b & 0x0000ff00) >> 8)) << 8) + + (static_cast(((a & 0x000000ff) >> 0) - ((b & 0x000000ff) >> 0)) << 0); +} #endif // defined(USE_ROCM) diff --git a/csrc/quantization/gguf/gguf_kernel.cu b/csrc/quantization/gguf/gguf_kernel.cu index 37e4de4e14dd3..5f0eaf5a973fb 100644 --- a/csrc/quantization/gguf/gguf_kernel.cu +++ b/csrc/quantization/gguf/gguf_kernel.cu @@ -4,6 +4,8 @@ #include #include +#include "cuda_compat.h" + #include "ggml-common.h" #include "vecdotq.cuh" #include "dequantize.cuh" @@ -32,8 +34,8 @@ static __global__ void quantize_q8_1(const half* __restrict__ x, #pragma unroll for (int mask = 16; mask > 0; mask >>= 1) { - amax = fmaxf(amax, __shfl_xor_sync(0xffffffff, amax, mask, 32)); - sum += __shfl_xor_sync(0xffffffff, sum, mask, 32); + amax = fmaxf(amax, VLLM_SHFL_XOR_SYNC_WIDTH(amax, mask, 32)); + sum += VLLM_SHFL_XOR_SYNC_WIDTH(sum, mask, 32); } const float d = amax / 127; diff --git a/csrc/quantization/gguf/mmq.cuh b/csrc/quantization/gguf/mmq.cuh index d13efd5965313..c935faa07df0c 100644 --- a/csrc/quantization/gguf/mmq.cuh +++ b/csrc/quantization/gguf/mmq.cuh @@ -10,7 +10,7 @@ static __device__ __forceinline__ void mul_mat_q( const int blocks_per_row_x = ncols_x / qk; const int blocks_per_col_y = nrows_y / QK8_1; - const int blocks_per_warp = WARP_SIZE / qi; + const int blocks_per_warp = WARP_SIZE_GGUF / qi; const int & ncols_dst = ncols_y; @@ -27,10 +27,10 @@ static __device__ __forceinline__ void mul_mat_q( allocate_tiles(&tile_x_ql, &tile_x_dm, &tile_x_qh, &tile_x_sc); - __shared__ int tile_y_qs[mmq_x * WARP_SIZE]; - __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE/QI8_1]; + __shared__ int tile_y_qs[mmq_x * WARP_SIZE_GGUF]; + __shared__ half2 tile_y_ds[mmq_x * WARP_SIZE_GGUF/QI8_1]; - float sum[mmq_y/WARP_SIZE][mmq_x/nwarps] = {{0.0f}}; + float sum[mmq_y/WARP_SIZE_GGUF][mmq_x/nwarps] = {{0.0f}}; for (int ib0 = 0; ib0 < blocks_per_row_x; ib0 += blocks_per_warp) { @@ -39,26 +39,26 @@ static __device__ __forceinline__ void mul_mat_q( #pragma unroll for (int ir = 0; ir < qr; ++ir) { - const int kqs = ir*WARP_SIZE + threadIdx.x; + const int kqs = ir*WARP_SIZE_GGUF + threadIdx.x; const int kbxd = kqs / QI8_1; #pragma unroll for (int i = 0; i < mmq_x; i += nwarps) { const int col_y_eff = min(col_y_0 + threadIdx.y + i, ncols_y-1); // to prevent out-of-bounds memory accesses const block_q8_1 * by0 = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + kbxd]; - const int index_y = (threadIdx.y + i) * WARP_SIZE + kqs % WARP_SIZE; + const int index_y = (threadIdx.y + i) * WARP_SIZE_GGUF + kqs % WARP_SIZE_GGUF; tile_y_qs[index_y] = get_int_from_int8_aligned(by0->qs, threadIdx.x % QI8_1); } #pragma unroll for (int ids0 = 0; ids0 < mmq_x; ids0 += nwarps * QI8_1) { - const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE/QI8_1)) % mmq_x; - const int kby = threadIdx.x % (WARP_SIZE/QI8_1); + const int ids = (ids0 + threadIdx.y * QI8_1 + threadIdx.x / (WARP_SIZE_GGUF/QI8_1)) % mmq_x; + const int kby = threadIdx.x % (WARP_SIZE_GGUF/QI8_1); const int col_y_eff = min(col_y_0 + ids, ncols_y-1); // if the sum is not needed it's faster to transform the scale to f32 ahead of time - const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE/QI8_1) + kby].ds; - half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE/QI8_1) + kby]; + const half2 * dsi_src = &y[col_y_eff*blocks_per_col_y + ib0 * (qk/QK8_1) + ir*(WARP_SIZE_GGUF/QI8_1) + kby].ds; + half2 * dsi_dst = &tile_y_ds[ids * (WARP_SIZE_GGUF/QI8_1) + kby]; if (need_sum) { *dsi_dst = *dsi_src; } else { @@ -70,12 +70,12 @@ static __device__ __forceinline__ void mul_mat_q( __syncthreads(); // #pragma unroll // unrolling this loop causes too much register pressure - for (int k = ir*WARP_SIZE/qr; k < (ir+1)*WARP_SIZE/qr; k += vdr) { + for (int k = ir*WARP_SIZE_GGUF/qr; k < (ir+1)*WARP_SIZE_GGUF/qr; k += vdr) { #pragma unroll for (int j = 0; j < mmq_x; j += nwarps) { #pragma unroll - for (int i = 0; i < mmq_y; i += WARP_SIZE) { - sum[i/WARP_SIZE][j/nwarps] += vec_dot( + for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) { + sum[i/WARP_SIZE_GGUF][j/nwarps] += vec_dot( tile_x_ql, tile_x_dm, tile_x_qh, tile_x_sc, tile_y_qs, tile_y_ds, threadIdx.x + i, threadIdx.y + j, k); } @@ -93,12 +93,12 @@ static __device__ __forceinline__ void mul_mat_q( } #pragma unroll - for (int i = 0; i < mmq_y; i += WARP_SIZE) { + for (int i = 0; i < mmq_y; i += WARP_SIZE_GGUF) { const int row_dst = row_dst_0 + threadIdx.x + i; if (row_dst >= nrows_dst) { continue; } - dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE][j/nwarps]); + dst[col_dst*nrows_dst + row_dst] = __float2half(sum[i/WARP_SIZE_GGUF][j/nwarps]); } } } @@ -115,7 +115,7 @@ static __device__ __forceinline__ void mul_mat_q( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q4_0, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_0, 2) #endif mul_mat_q4_0( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -140,7 +140,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -165,7 +165,7 @@ static void ggml_mul_mat_q4_0_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q4_1, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_1, 2) #endif mul_mat_q4_1( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -190,7 +190,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -215,7 +215,7 @@ static void ggml_mul_mat_q4_1_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q5_0, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_0, 2) #endif mul_mat_q5_0( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -240,7 +240,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -265,7 +265,7 @@ static void ggml_mul_mat_q5_0_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q5_1, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_1, 2) #endif mul_mat_q5_1( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -289,7 +289,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -314,7 +314,7 @@ static void ggml_mul_mat_q5_1_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q8_0, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q8_0, 2) #endif mul_mat_q8_0( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -338,7 +338,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -363,7 +363,7 @@ static void ggml_mul_mat_q8_0_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q2_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q2_K, 2) #endif mul_mat_q2_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -387,7 +387,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -412,7 +412,7 @@ static void ggml_mul_mat_q2_K_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q3_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q3_K, 2) #endif mul_mat_q3_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -438,7 +438,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -463,7 +463,7 @@ static void ggml_mul_mat_q3_K_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q4_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q4_K, 2) #endif mul_mat_q4_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -487,7 +487,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -512,7 +512,7 @@ static void ggml_mul_mat_q4_K_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q5_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q5_K, 2) #endif mul_mat_q5_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -537,7 +537,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; @@ -562,7 +562,7 @@ static void ggml_mul_mat_q5_K_q8_1_cuda( template static __global__ void #if defined(USE_ROCM) -__launch_bounds__(WARP_SIZE*NWARPS_Q6_K, 2) +__launch_bounds__(WARP_SIZE_GGUF*NWARPS_Q6_K, 2) #endif mul_mat_q6_K( const void * __restrict__ vx, const void * __restrict__ vy, half * __restrict__ dst, @@ -586,7 +586,7 @@ static void ggml_mul_mat_q6_K_q8_1_cuda( const int block_num_x = (nrows_x + mmq_y - 1) / mmq_y; const int block_num_y = (ncols_y + mmq_x - 1) / mmq_x; const dim3 block_nums(block_num_x, block_num_y, 1); - const dim3 block_dims(WARP_SIZE, nwarps, 1); + const dim3 block_dims(WARP_SIZE_GGUF, nwarps, 1); if (nrows_x % mmq_y == 0) { const bool need_check = false; diff --git a/csrc/quantization/gguf/mmvq.cuh b/csrc/quantization/gguf/mmvq.cuh index b221ae7896138..b01e939808a3f 100644 --- a/csrc/quantization/gguf/mmvq.cuh +++ b/csrc/quantization/gguf/mmvq.cuh @@ -28,8 +28,8 @@ static __global__ void mul_mat_vec_q(const void * __restrict__ vx, const void * // sum up partial sums and write back result #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) { - tmp += __shfl_xor_sync(0xffffffff, tmp, mask, 32); + for (int mask = WARP_SIZE/2; mask > 0; mask >>= 1) { + tmp += VLLM_SHFL_XOR_SYNC(tmp, mask); } if (threadIdx.x == 0) { diff --git a/csrc/quantization/gguf/vecdotq.cuh b/csrc/quantization/gguf/vecdotq.cuh index d5af345a6b26f..e00422637c65b 100644 --- a/csrc/quantization/gguf/vecdotq.cuh +++ b/csrc/quantization/gguf/vecdotq.cuh @@ -43,7 +43,7 @@ static __device__ __forceinline__ int get_int_from_uint8_aligned(const uint8_t * template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_impl( const int * v, const int * u, const float & d4, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -68,7 +68,7 @@ template static __device__ __forceinline__ float vec_dot_q4_0_q8_1_imp template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_impl( const int * v, const int * u, const half2 & dm4, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -95,7 +95,7 @@ template static __device__ __forceinline__ float vec_dot_q4_1_q8_1_imp template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_impl( const int * vl, const int * vh, const int * u, const float & d5, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -128,7 +128,7 @@ template static __device__ __forceinline__ float vec_dot_q5_0_q8_1_imp template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_impl( const int * vl, const int * vh, const int * u, const half2 & dm5, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -162,7 +162,7 @@ template static __device__ __forceinline__ float vec_dot_q5_1_q8_1_imp template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_impl( const int * v, const int * u, const float & d8_0, const float & d8_1) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -176,7 +176,7 @@ template static __device__ __forceinline__ float vec_dot_q8_0_q8_1_imp template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_impl( const int * v, const int * u, const half2 & dm8, const half2 & ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; @@ -202,7 +202,7 @@ template static __device__ __forceinline__ float vec_dot_q8_1_q8_1_imp static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( const int & v, const int * __restrict__ u, const uint8_t * __restrict__ scales, const half2 & dm2, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -230,7 +230,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmvq( static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ scales, const half2 & dm2, const float & d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi_d = 0; int sumi_m = 0; @@ -267,7 +267,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( const int & vl, const int & vh, const int * __restrict__ u, const uint8_t * __restrict__ scales, const int & scale_offset, const float & d3, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf = 0.0f; @@ -301,7 +301,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmvq( static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ scales, const float & d3, const float & d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM int sumi = 0; #pragma unroll @@ -326,7 +326,7 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -351,7 +351,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_vmmq( static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -382,7 +382,7 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( const int * __restrict__ vl, const int * __restrict__ vh, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm5, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -413,7 +413,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_vmmq( static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const uint8_t * __restrict__ sc, const uint8_t * __restrict__ m, const half2 & dm4, const half2 * __restrict__ ds8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; float sumf_m = 0.0f; @@ -445,7 +445,7 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1_impl_mmq( static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( const int & vl, const int & vh, const int * __restrict__ u, const int8_t * __restrict__ scales, const float & d, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf = 0.0f; #pragma unroll @@ -465,7 +465,7 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmvq( static __device__ __forceinline__ float vec_dot_q6_K_q8_1_impl_mmq( const int * __restrict__ v, const int * __restrict__ u, const int8_t * __restrict__ sc, const float & d6, const float * __restrict__ d8) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM float sumf_d = 0.0f; #pragma unroll @@ -507,8 +507,8 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q4_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI4_0) + mmq_y/QI4_0]; + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF/QI4_0) + mmq_y/QI4_0]; *x_ql = tile_x_qs; *x_dm = (half2 *) tile_x_d; } @@ -529,11 +529,11 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); - // x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbx] = bxi->d; + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + // x_dmf[i * (WARP_SIZE_GGUF/QI4_0) + i / QI4_0 + kbx] = bxi->d; } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_0; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_0; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -543,7 +543,7 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_0 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI4_0) + i / QI4_0 + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI4_0) + i / QI4_0 + kbxd] = __half2float(bxi->d); } } @@ -559,13 +559,13 @@ static __device__ __forceinline__ float vec_dot_q4_0_q8_1_mul_mat( #pragma unroll for (int l = 0; l < VDR_Q4_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_0) % WARP_SIZE]; + u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF]; + u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI4_0) % WARP_SIZE_GGUF]; } return vec_dot_q4_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dmf[i * (WARP_SIZE/QI4_0) + i/QI4_0 + k/QI4_0], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i * (WARP_SIZE_GGUF + 1) + k], u, x_dmf[i * (WARP_SIZE_GGUF/QI4_0) + i/QI4_0 + k/QI4_0], + y_ds[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]); } static __device__ __forceinline__ float vec_dot_q4_1_q8_1( @@ -587,8 +587,8 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q4_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_1) + mmq_y/QI4_1]; + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI4_1) + mmq_y/QI4_1]; *x_ql = tile_x_qs; *x_dm = tile_x_dm; } @@ -608,10 +608,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_1; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_1; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -621,7 +621,7 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_1 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI4_1) + i / QI4_1 + kbxd] = bxi->dm; } } @@ -634,13 +634,13 @@ static __device__ __forceinline__ float vec_dot_q4_1_q8_1_mul_mat( #pragma unroll for (int l = 0; l < VDR_Q4_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI4_1) % WARP_SIZE]; + u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF]; + u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI4_1) % WARP_SIZE_GGUF]; } return vec_dot_q4_1_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], u, x_dm[i * (WARP_SIZE/QI4_1) + i/QI4_1 + k/QI4_1], - y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i * (WARP_SIZE_GGUF + 1) + k], u, x_dm[i * (WARP_SIZE_GGUF/QI4_1) + i/QI4_1 + k/QI4_1], + y_ds[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]); } static __device__ __forceinline__ float vec_dot_q5_0_q8_1( @@ -664,8 +664,8 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q5_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI5_0) + mmq_y/QI5_0]; + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF/QI5_0) + mmq_y/QI5_0]; *x_ql = tile_x_ql; *x_dm = (half2 *) tile_x_d; @@ -697,7 +697,7 @@ template static __device__ __forceinlin qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 qs0 = __vsubss4(qs0, 0x10101010); // subtract 16 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+0] = qs0; int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 @@ -706,10 +706,10 @@ template static __device__ __forceinlin qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 qs1 = __vsubss4(qs1, 0x10101010); // subtract 16 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+1] = qs1; } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_0; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_0; const int kbxd = k % blocks_per_tile_x_row; float * x_dmf = (float *) x_dm; @@ -722,7 +722,7 @@ template static __device__ __forceinlin } const block_q5_0 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI5_0) + i / QI5_0 + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI5_0) + i / QI5_0 + kbxd] = __half2float(bxi->d); } } @@ -730,7 +730,7 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_0) + i/QI5_0 + k/QI5_0; + const int index_bx = i * (WARP_SIZE_GGUF/QI5_0) + i/QI5_0 + k/QI5_0; const float * x_dmf = (const float *) x_dm; const float * y_df = (const float *) y_ds; @@ -738,12 +738,12 @@ static __device__ __forceinline__ float vec_dot_q5_0_q8_1_mul_mat( #pragma unroll for (int l = 0; l < VDR_Q5_0_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_0) % WARP_SIZE]; + u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF]; + u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI5_0) % WARP_SIZE_GGUF]; } return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2 * k], u, x_dmf[index_bx], y_df[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]); } static __device__ __forceinline__ float vec_dot_q5_1_q8_1( @@ -767,8 +767,8 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q5_1(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_1) + mmq_y/QI5_1]; + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI5_1) + mmq_y/QI5_1]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -801,7 +801,7 @@ template static __device__ __forceinlin qs0 |= (qh << 18) & 0x00100000; // 2 -> 20 qs0 |= (qh << 25) & 0x10000000; // 3 -> 28 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+0] = qs0; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+0] = qs0; int qs1 = (ql >> 4) & 0x0F0F0F0F; qs1 |= (qh >> 12) & 0x00000010; // 16 -> 4 @@ -809,10 +809,10 @@ template static __device__ __forceinlin qs1 |= (qh << 2) & 0x00100000; // 18 -> 20 qs1 |= (qh << 9) & 0x10000000; // 19 -> 28 - x_ql[i * (2*WARP_SIZE + 1) + 2*k+1] = qs1; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2*k+1] = qs1; } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_1; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_1; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -825,7 +825,7 @@ template static __device__ __forceinlin const block_q5_1 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI5_1) + i / QI5_1 + kbxd] = bxi->dm; } } @@ -833,18 +833,18 @@ static __device__ __forceinline__ float vec_dot_q5_1_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { const int kyqs = k % (QI8_1/2) + QI8_1 * (k / (QI8_1/2)); - const int index_bx = i * (WARP_SIZE/QI5_1) + + i/QI5_1 + k/QI5_1; + const int index_bx = i * (WARP_SIZE_GGUF/QI5_1) + + i/QI5_1 + k/QI5_1; int u[2*VDR_Q5_1_Q8_1_MMQ]; #pragma unroll for (int l = 0; l < VDR_Q5_1_Q8_1_MMQ; ++l) { - u[2*l+0] = y_qs[j * WARP_SIZE + (kyqs + l) % WARP_SIZE]; - u[2*l+1] = y_qs[j * WARP_SIZE + (kyqs + l + QI5_1) % WARP_SIZE]; + u[2*l+0] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l) % WARP_SIZE_GGUF]; + u[2*l+1] = y_qs[j * WARP_SIZE_GGUF + (kyqs + l + QI5_1) % WARP_SIZE_GGUF]; } return vec_dot_q8_1_q8_1_impl - (&x_ql[i * (2*WARP_SIZE + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE/QI8_1) + (2*k/QI8_1) % (WARP_SIZE/QI8_1)]); + (&x_ql[i * (2*WARP_SIZE_GGUF + 1) + 2 * k], u, x_dm[index_bx], y_ds[j * (WARP_SIZE_GGUF/QI8_1) + (2*k/QI8_1) % (WARP_SIZE_GGUF/QI8_1)]); } static __device__ __forceinline__ float vec_dot_q8_0_q8_1( @@ -865,8 +865,8 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q8_0(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_qs[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ float tile_x_d[mmq_y * (WARP_SIZE/QI8_0) + mmq_y/QI8_0]; + __shared__ int tile_x_qs[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ float tile_x_d[mmq_y * (WARP_SIZE_GGUF/QI8_0) + mmq_y/QI8_0]; *x_ql = tile_x_qs; *x_dm = (half2 *) tile_x_d; @@ -889,10 +889,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_int8(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_int8(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI8_0; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI8_0; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -903,7 +903,7 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q8_0 * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI8_0) + i / QI8_0 + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI8_0) + i / QI8_0 + kbxd] = __half2float(bxi->d); } } @@ -914,8 +914,8 @@ static __device__ __forceinline__ float vec_dot_q8_0_q8_1_mul_mat( const float * y_df = (const float *) y_ds; return vec_dot_q8_0_q8_1_impl - (&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[j * WARP_SIZE + k], x_dmf[i * (WARP_SIZE/QI8_0) + i/QI8_0 + k/QI8_0], - y_df[j * (WARP_SIZE/QI8_1) + k/QI8_1]); + (&x_ql[i * (WARP_SIZE_GGUF + 1) + k], &y_qs[j * WARP_SIZE_GGUF + k], x_dmf[i * (WARP_SIZE_GGUF/QI8_0) + i/QI8_0 + k/QI8_0], + y_df[j * (WARP_SIZE_GGUF/QI8_1) + k/QI8_1]); } static __device__ __forceinline__ float vec_dot_q2_K_q8_1( @@ -942,9 +942,9 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q2_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI2_K) + mmq_y/QI2_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI2_K) + mmq_y/QI2_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/4) + mmq_y/4]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -967,10 +967,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q2_K * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI2_K; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI2_K; const int kbxd = k % blocks_per_tile_x_row; #pragma unroll @@ -981,18 +981,18 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q2_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI2_K) + i / QI2_K + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI2_K) + i / QI2_K + kbxd] = bxi->dm; } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + int i = i0 + i_offset * 4 + k / (WARP_SIZE_GGUF/4); if (need_check) { i = min(i, i_max); } - const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI2_K/4); - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); + const block_q2_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/4)) / (QI2_K/4); + x_sc[i * (WARP_SIZE_GGUF/4) + i / 4 + k % (WARP_SIZE_GGUF/4)] = get_int_from_uint8_aligned(bxi->scales, k % (QI2_K/4)); } } @@ -1005,7 +1005,7 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( int v[QR2_K*VDR_Q2_K_Q8_1_MMQ]; - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); + const int kqsx = i * (WARP_SIZE_GGUF + 1) + kbx*QI2_K + (QI2_K/2) * (ky/(2*QI2_K)) + ky % (QI2_K/2); const int shift = 2 * ((ky % (2*QI2_K)) / (QI2_K/2)); #pragma unroll @@ -1013,10 +1013,10 @@ static __device__ __forceinline__ float vec_dot_q2_K_q8_1_mul_mat( v[l] = (x_ql[kqsx + l] >> shift) & 0x03030303; } - const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE/4) + i/4 + kbx*4]) + ky/4; + const uint8_t * scales = ((const uint8_t *) &x_sc[i * (WARP_SIZE_GGUF/4) + i/4 + kbx*4]) + ky/4; - const int index_y = j * WARP_SIZE + (QR2_K*k) % WARP_SIZE; - return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); + const int index_y = j * WARP_SIZE_GGUF + (QR2_K*k) % WARP_SIZE_GGUF; + return vec_dot_q2_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dm[i * (WARP_SIZE_GGUF/QI2_K) + i/QI2_K + kbx], y_df[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q3_K_q8_1( @@ -1047,10 +1047,10 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q3_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI3_K) + mmq_y/QI3_K]; - __shared__ int tile_x_qh[mmq_y * (WARP_SIZE/2) + mmq_y/2]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/4) + mmq_y/4]; + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI3_K) + mmq_y/QI3_K]; + __shared__ int tile_x_qh[mmq_y * (WARP_SIZE_GGUF/2) + mmq_y/2]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/4) + mmq_y/4]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -1073,10 +1073,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q3_K * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI3_K; + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI3_K; const int kbxd = k % blocks_per_tile_x_row; float * x_dmf = (float *) x_dm; @@ -1087,27 +1087,27 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q3_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI3_K) + i / QI3_K + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI3_K) + i / QI3_K + kbxd] = __half2float(bxi->d); } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 2) { - int i = i0 + i_offset * 2 + k / (WARP_SIZE/2); + int i = i0 + i_offset * 2 + k / (WARP_SIZE_GGUF/2); if (need_check) { i = min(i, i_max); } - const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/2)) / (QI3_K/2); + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/2)) / (QI3_K/2); // invert the mask with ~ so that a 0/1 results in 4/0 being subtracted - x_qh[i * (WARP_SIZE/2) + i / 2 + k % (WARP_SIZE/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); + x_qh[i * (WARP_SIZE_GGUF/2) + i / 2 + k % (WARP_SIZE_GGUF/2)] = ~get_int_from_uint8(bxi->hmask, k % (QI3_K/2)); } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) { - int i = i0 + i_offset * 4 + k / (WARP_SIZE/4); + int i = i0 + i_offset * 4 + k / (WARP_SIZE_GGUF/4); if (need_check) { i = min(i, i_max); } - const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/4)) / (QI3_K/4); + const block_q3_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/4)) / (QI3_K/4); const int ksc = k % (QI3_K/4); @@ -1121,7 +1121,7 @@ template static __device__ __forceinlin const int sc = __vsubss4(sc_low | sc_high, 0x20202020); - x_sc[i * (WARP_SIZE/4) + i / 4 + k % (WARP_SIZE/4)] = sc; + x_sc[i * (WARP_SIZE_GGUF/4) + i / 4 + k % (WARP_SIZE_GGUF/4)] = sc; } } @@ -1134,24 +1134,24 @@ static __device__ __forceinline__ float vec_dot_q3_K_q8_1_mul_mat( const float * x_dmf = (const float *) x_dm; const float * y_df = (const float *) y_ds; - const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE/4) + i/4 + kbx*4)) + ky/4; + const int8_t * scales = ((const int8_t *) (x_sc + i * (WARP_SIZE_GGUF/4) + i/4 + kbx*4)) + ky/4; int v[QR3_K*VDR_Q3_K_Q8_1_MMQ]; #pragma unroll for (int l = 0; l < QR3_K*VDR_Q3_K_Q8_1_MMQ; ++l) { - const int kqsx = i * (WARP_SIZE + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); + const int kqsx = i * (WARP_SIZE_GGUF + 1) + kbx*QI3_K + (QI3_K/2) * (ky/(2*QI3_K)) + ky % (QI3_K/2); const int shift = 2 * ((ky % 32) / 8); const int vll = (x_ql[kqsx + l] >> shift) & 0x03030303; - const int vh = x_qh[i * (WARP_SIZE/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); + const int vh = x_qh[i * (WARP_SIZE_GGUF/2) + i/2 + kbx * (QI3_K/2) + (ky+l)%8] >> ((ky+l) / 8); const int vlh = (vh << 2) & 0x04040404; v[l] = __vsubss4(vll, vlh); } - const int index_y = j * WARP_SIZE + (k*QR3_K) % WARP_SIZE; - return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); + const int index_y = j * WARP_SIZE_GGUF + (k*QR3_K) % WARP_SIZE_GGUF; + return vec_dot_q3_K_q8_1_impl_mmq(v, &y_qs[index_y], scales, x_dmf[i * (WARP_SIZE_GGUF/QI3_K) + i/QI3_K + kbx], y_df[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q4_K_q8_1( @@ -1200,9 +1200,9 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q4_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI4_K) + mmq_y/QI4_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + __shared__ int tile_x_ql[mmq_y * (WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI4_K) + mmq_y/QI4_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/8) + mmq_y/8]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -1225,10 +1225,10 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_K * bxi = bx0 + i*blocks_per_row + kbx; - x_ql[i * (WARP_SIZE + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); + x_ql[i * (WARP_SIZE_GGUF + 1) + k] = get_int_from_uint8_aligned(bxi->qs, kqsx); } - const int blocks_per_tile_x_row = WARP_SIZE / QI4_K; // == 1 if QK_K == 256 + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI4_K; // == 1 if QK_K == 256 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 #pragma unroll @@ -1238,27 +1238,27 @@ template static __device__ __forceinlin i = min(i, i_max); } const block_q4_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI4_K) + i / QI4_K + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI4_K) + i / QI4_K + kbxd] = bxi->dm; } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI4_K/8); + const block_q4_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/8)) / (QI4_K/8); const int * scales = (const int *) bxi->scales; - const int ksc = k % (WARP_SIZE/8); + const int ksc = k % (WARP_SIZE_GGUF/8); // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + x_sc[i * (WARP_SIZE_GGUF/8) + i / 8 + ksc] = scales8; } } @@ -1267,11 +1267,11 @@ static __device__ __forceinline__ float vec_dot_q4_K_q8_1_mul_mat( const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { (void)x_qh; - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2*((k % 16) / 8); + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE_GGUF/8) + i/8 + k/16]) + 2*((k % 16) / 8); - const int index_y = j * WARP_SIZE + (QR4_K*k) % WARP_SIZE; - return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE + 1) + k], &y_qs[index_y], sc, sc+8, - x_dm[i * (WARP_SIZE/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); + const int index_y = j * WARP_SIZE_GGUF + (QR4_K*k) % WARP_SIZE_GGUF; + return vec_dot_q4_K_q8_1_impl_mmq(&x_ql[i * (WARP_SIZE_GGUF + 1) + k], &y_qs[index_y], sc, sc+8, + x_dm[i * (WARP_SIZE_GGUF/QI4_K) + i/QI4_K], &y_ds[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q5_K_q8_1( @@ -1321,9 +1321,9 @@ static __device__ __forceinline__ float vec_dot_q5_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q5_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI5_K) + mmq_y/QI5_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI5_K) + mmq_y/QI5_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/8) + mmq_y/8]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -1360,11 +1360,11 @@ template static __device__ __forceinlin const int kq0 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + 0; const int kq1 = ky - ky % (QI5_K/2) + k % (QI5_K/4) + (QI5_K/4); - x_ql[i * (2*WARP_SIZE + 1) + kq0] = ql0 | qh0; - x_ql[i * (2*WARP_SIZE + 1) + kq1] = ql1 | qh1; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq0] = ql0 | qh0; + x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq1] = ql1 | qh1; } - const int blocks_per_tile_x_row = WARP_SIZE / QI5_K; // == 1 if QK_K == 256 + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI5_K; // == 1 if QK_K == 256 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 #pragma unroll @@ -1376,40 +1376,40 @@ template static __device__ __forceinlin } const block_q5_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dm[i * (WARP_SIZE/QI5_K) + i / QI5_K + kbxd] = bxi->dm; + x_dm[i * (WARP_SIZE_GGUF/QI5_K) + i / QI5_K + kbxd] = bxi->dm; } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / (QI5_K/8); + const block_q5_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/8)) / (QI5_K/8); const int * scales = (const int *) bxi->scales; - const int ksc = k % (WARP_SIZE/8); + const int ksc = k % (WARP_SIZE_GGUF/8); // scale arrangement after the following two lines: sc0,...,sc3, sc4,...,sc7, m0,...,m3, m4,...,m8 int scales8 = (scales[(ksc%2) + (ksc!=0)] >> (4 * (ksc & (ksc/2)))) & 0x0F0F0F0F; // lower 4 bits scales8 |= (scales[ksc/2] >> (2 * (ksc % 2))) & 0x30303030; // upper 2 bits - x_sc[i * (WARP_SIZE/8) + i / 8 + ksc] = scales8; + x_sc[i * (WARP_SIZE_GGUF/8) + i / 8 + ksc] = scales8; } } static __device__ __forceinline__ float vec_dot_q5_K_q8_1_mul_mat( const int * __restrict__ x_ql, const half2 * __restrict__ x_dm, const int * __restrict__ x_qh, const int * __restrict__ x_sc, const int * __restrict__ y_qs, const half2 * __restrict__ y_ds, const int & i, const int & j, const int & k) { - const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); + const uint8_t * sc = ((const uint8_t *) &x_sc[i * (WARP_SIZE_GGUF/8) + i/8 + k/16]) + 2 * ((k % 16) / 8); - const int index_x = i * (QR5_K*WARP_SIZE + 1) + QR5_K*k; - const int index_y = j * WARP_SIZE + (QR5_K*k) % WARP_SIZE; + const int index_x = i * (QR5_K*WARP_SIZE_GGUF + 1) + QR5_K*k; + const int index_y = j * WARP_SIZE_GGUF + (QR5_K*k) % WARP_SIZE_GGUF; return vec_dot_q5_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, sc+8, - x_dm[i * (WARP_SIZE/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); + x_dm[i * (WARP_SIZE_GGUF/QI5_K) + i/QI5_K], &y_ds[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_q6_K_q8_1( @@ -1439,9 +1439,9 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1( } template static __device__ __forceinline__ void allocate_tiles_q6_K(int ** x_ql, half2 ** x_dm, int ** x_qh, int ** x_sc) { - __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE) + mmq_y]; - __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE/QI6_K) + mmq_y/QI6_K]; - __shared__ int tile_x_sc[mmq_y * (WARP_SIZE/8) + mmq_y/8]; + __shared__ int tile_x_ql[mmq_y * (2*WARP_SIZE_GGUF) + mmq_y]; + __shared__ half2 tile_x_dm[mmq_y * (WARP_SIZE_GGUF/QI6_K) + mmq_y/QI6_K]; + __shared__ int tile_x_sc[mmq_y * (WARP_SIZE_GGUF/8) + mmq_y/8]; *x_ql = tile_x_ql; *x_dm = tile_x_dm; @@ -1478,11 +1478,11 @@ template static __device__ __forceinlin const int kq0 = ky - ky % QI6_K + k % (QI6_K/2) + 0; const int kq1 = ky - ky % QI6_K + k % (QI6_K/2) + (QI6_K/2); - x_ql[i * (2*WARP_SIZE + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); - x_ql[i * (2*WARP_SIZE + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); + x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq0] = __vsubss4(ql0 | qh0, 0x20202020); + x_ql[i * (2*WARP_SIZE_GGUF + 1) + kq1] = __vsubss4(ql1 | qh1, 0x20202020); } - const int blocks_per_tile_x_row = WARP_SIZE / QI6_K; // == 1 if QK_K == 256 + const int blocks_per_tile_x_row = WARP_SIZE_GGUF / QI6_K; // == 1 if QK_K == 256 const int kbxd = k % blocks_per_tile_x_row; // == 0 if QK_K == 256 float * x_dmf = (float *) x_dm; @@ -1496,20 +1496,20 @@ template static __device__ __forceinlin const block_q6_K * bxi = bx0 + i*blocks_per_row + kbxd; - x_dmf[i * (WARP_SIZE/QI6_K) + i / QI6_K + kbxd] = __half2float(bxi->d); + x_dmf[i * (WARP_SIZE_GGUF/QI6_K) + i / QI6_K + kbxd] = __half2float(bxi->d); } #pragma unroll for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 8) { - int i = (i0 + i_offset * 8 + k / (WARP_SIZE/8)) % mmq_y; + int i = (i0 + i_offset * 8 + k / (WARP_SIZE_GGUF/8)) % mmq_y; if (need_check) { i = min(i, i_max); } - const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE/8)) / 4; + const block_q6_K * bxi = bx0 + i*blocks_per_row + (k % (WARP_SIZE_GGUF/8)) / 4; - x_sc[i * (WARP_SIZE/8) + i / 8 + k % (WARP_SIZE/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); + x_sc[i * (WARP_SIZE_GGUF/8) + i / 8 + k % (WARP_SIZE_GGUF/8)] = get_int_from_int8(bxi->scales, k % (QI6_K/8)); } } @@ -1519,11 +1519,11 @@ static __device__ __forceinline__ float vec_dot_q6_K_q8_1_mul_mat( const float * x_dmf = (const float *) x_dm; const float * y_df = (const float *) y_ds; - const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE/8) + i/8 + k/8]); + const int8_t * sc = ((const int8_t *) &x_sc[i * (WARP_SIZE_GGUF/8) + i/8 + k/8]); - const int index_x = i * (QR6_K*WARP_SIZE + 1) + QR6_K*k; - const int index_y = j * WARP_SIZE + (QR6_K*k) % WARP_SIZE; - return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); + const int index_x = i * (QR6_K*WARP_SIZE_GGUF + 1) + QR6_K*k; + const int index_y = j * WARP_SIZE_GGUF + (QR6_K*k) % WARP_SIZE_GGUF; + return vec_dot_q6_K_q8_1_impl_mmq(&x_ql[index_x], &y_qs[index_y], sc, x_dmf[i * (WARP_SIZE_GGUF/QI6_K) + i/QI6_K], &y_df[index_y/QI8_1]); } static __device__ __forceinline__ float vec_dot_iq2_xxs_q8_1( @@ -1582,7 +1582,7 @@ static __device__ __forceinline__ float vec_dot_iq2_xs_q8_1( static __device__ __forceinline__ float vec_dot_iq2_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq2_s * bq2 = (const block_iq2_s *) vbq; const int ib32 = iqs; @@ -1619,7 +1619,7 @@ static __device__ __forceinline__ float vec_dot_iq2_s_q8_1( static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq3_xxs * bq2 = (const block_iq3_xxs *) vbq; const int ib32 = iqs; @@ -1646,7 +1646,7 @@ static __device__ __forceinline__ float vec_dot_iq3_xxs_q8_1( static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq3_s * bq2 = (const block_iq3_s *) vbq; const int ib32 = iqs; @@ -1671,7 +1671,7 @@ static __device__ __forceinline__ float vec_dot_iq3_s_q8_1( static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq1_s * bq1 = (const block_iq1_s *) vbq; const int qs_packed = get_int_b2(bq1->qs, iqs); @@ -1703,7 +1703,7 @@ static __device__ __forceinline__ float vec_dot_iq1_s_q8_1( static __device__ __forceinline__ float vec_dot_iq1_m_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq1_m * bq1 = (const block_iq1_m *) vbq; @@ -1763,7 +1763,7 @@ static __device__ __forceinline__ void get_int_from_table_16(const uint32_t & q4 static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq4_nl * bq = (const block_iq4_nl *) vbq; @@ -1788,7 +1788,7 @@ static __device__ __forceinline__ float vec_dot_iq4_nl_q8_1( static __device__ __forceinline__ float vec_dot_iq4_xs_q8_1( const void * __restrict__ vbq, const block_q8_1 * __restrict__ bq8_1, const int & iqs) { -#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 +#if defined __CUDA_ARCH__ && __CUDA_ARCH__ >= 610 || defined USE_ROCM const block_iq4_xs * bq4 = (const block_iq4_xs *) vbq; const uint8_t * values = (const uint8_t *)kvalues_iq4nl; diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index a33e2660d760e..8fce76eb52f9b 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -910,13 +910,16 @@ void marlin_cuda_2_4(const void* A, const void* B, const void* meta, void* C, // than better compute utilization thread_k = 128; thread_m = 128; - } else if (prob_n <= 256) { + } else { thread_k = 64; thread_m = 256; - } else { - thread_k = 32; - thread_m = 512; } + // Also had + // if prob_n > 256 + // thread_k = 32; + // thread_m = 512; + // but this is broken, + // TODO(Lucas, Alex M): figure out why } int thread_k_blocks = thread_k / 32; // 2:4 version with m16n8k32 instruction @@ -1079,6 +1082,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, // Verify A device and strides TORCH_CHECK(a.device().is_cuda(), "A is not on GPU"); TORCH_CHECK(a.is_contiguous(), "A is not contiguous"); + TORCH_CHECK(a.dtype() == torch::kFloat16, + "A is not float16, currently only float16 is supported"); // Verify B device and strides TORCH_CHECK(b_q_weight.device().is_cuda(), "b_q_weight is not on GPU"); @@ -1091,6 +1096,8 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, // Verify scales device and strides TORCH_CHECK(b_scales.device().is_cuda(), "b_scales is not on GPU"); TORCH_CHECK(b_scales.is_contiguous(), "b_scales is not contiguous"); + TORCH_CHECK(b_scales.dtype() == torch::kFloat16, + "A is not float16, currently only float16 is supported"); // Alloc C matrix const at::cuda::OptionalCUDAGuard device_guard(device_of(a)); diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 3dccdf61abf3b..4e64b9c92773a 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -258,6 +258,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " "SymInt size_n, int num_bits) -> Tensor"); // conditionally compiled so impl registrations are in source file +#endif // Dequantization for GGML. ops.def("ggml_dequantize(Tensor W, int type, SymInt m, SymInt n) -> Tensor"); @@ -274,6 +275,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor"); ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8); +#ifndef USE_ROCM // fp8_marlin Optimized Quantized GEMM for FP8 weight-only. ops.def( "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " diff --git a/docs/source/getting_started/cpu-installation.rst b/docs/source/getting_started/cpu-installation.rst index 69530fd778c55..649de1cd9b53c 100644 --- a/docs/source/getting_started/cpu-installation.rst +++ b/docs/source/getting_started/cpu-installation.rst @@ -5,11 +5,11 @@ Installation with CPU vLLM initially supports basic model inferencing and serving on x86 CPU platform, with data types FP32, FP16 and BF16. vLLM CPU backend supports the following vLLM features: -- Tensor Parallel (``-tp = N``) -- Quantization (``INT8 W8A8, AWQ``) - -.. note:: - More advanced features on `chunked-prefill`, `prefix-caching` and `FP8 KV cache` are under development and will be available soon. +- Tensor Parallel +- Model Quantization (``INT8 W8A8, AWQ``) +- Chunked-prefill +- Prefix-caching +- FP8-E5M2 KV-Caching (TODO) Table of contents: diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index f02626bda4c64..e3dbbc9affe66 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -170,6 +170,18 @@ To build vLLM using an existing PyTorch installation: $ pip install -e . --no-build-isolation +Use the local cutlass for compilation +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +Currently, before starting the build process, vLLM fetches cutlass code from GitHub. However, there may be scenarios where you want to use a local version of cutlass instead. +To achieve this, you can set the environment variable VLLM_CUTLASS_SRC_DIR to point to your local cutlass directory. + +.. code-block:: console + + $ git clone https://github.com/vllm-project/vllm.git + $ cd vllm + $ VLLM_CUTLASS_SRC_DIR=/path/to/cutlass pip install -e . + + Troubleshooting ~~~~~~~~~~~~~~~ diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index a70ebf99c746f..df06d736ca86b 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -38,41 +38,70 @@ For instance, vLLM's `OPT model Union[Tuple, CausalLMOutputWithPast]: - + positions: torch.Tensor, - + kv_caches: List[torch.Tensor], - + attn_metadata: AttentionMetadata, - + ) -> Optional[SamplerOutput]: - -1. Update the code by considering that :code:`input_ids` and :code:`positions` are now flattened tensors. -2. Replace the attention operation with either :code:`PagedAttention`, :code:`PagedAttentionWithRoPE`, or :code:`PagedAttentionWithALiBi` depending on the model's architecture. +To ensure compatibility with vLLM, your model must meet the following requirements: + +Initialization Code +^^^^^^^^^^^^^^^^^^^ + +All vLLM modules within the model must include a ``prefix`` argument in their constructor. This ``prefix`` is typically the full name of the module in the model's state dictionary and is crucial for: + +* Runtime support: vLLM's attention operators are registered in a model's state by their full names. Each attention operator must have a unique prefix as its layer name to avoid conflicts. +* Non-uniform quantization support: A quantized checkpoint can selectively quantize certain layers while keeping others in full precision. By providing the ``prefix`` during initialization, vLLM can match the current layer's ``prefix`` with the quantization configuration to determine if the layer should be initialized in quantized mode. + +The initialization code should look like this: + +.. code-block:: python + + from torch import nn + from vllm.config import VllmConfig + from vllm.attention import Attention + + class MyAttention(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str): + super().__init__() + self.attn = Attention(prefix=f"{prefix}.attn") + + class MyDecoderLayer(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str): + super().__init__() + self.self_attn = MyAttention(prefix=f"{prefix}.self_attn") + + class MyModel(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str): + super().__init__() + self.layers = nn.ModuleList( + [MyDecoderLayer(vllm_config, prefix=f"{prefix}.layers.{i}") for i in range(vllm_config.model_config.hf_config.num_hidden_layers)] + ) + + class MyModelForCausalLM(nn.Module): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + self.model = MyModel(vllm_config, prefix=f"{prefix}.model") + +Computation Code +^^^^^^^^^^^^^^^^ + +Rewrite the :meth:`~torch.nn.Module.forward` method of your model to remove any unnecessary code, such as training-specific code. Modify the input parameters to treat ``input_ids`` and ``positions`` as flattened tensors with a single batch size dimension, without a max-sequence length dimension. + +.. code-block:: python + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + ... .. note:: Currently, vLLM supports the basic multi-head attention mechanism and its variant with rotary positional embeddings. If your model employs a different attention mechanism, you will need to implement a new attention layer in vLLM. +For reference, check out the `LLAMA model `__. vLLM already supports a large number of models. It is recommended to find a model similar to yours and adapt it to your model's architecture. Check out the `vLLM models `__ directory for more examples. 3. (Optional) Implement tensor parallelism and quantization support ------------------------------------------------------------------- diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index e902d393f2f70..3f012284bfbff 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -325,6 +325,11 @@ Text Embedding - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + * - :code:`BertModel` + - BERT-based + - :code:`BAAI/bge-base-en-v1.5`, etc. + - + - * - :code:`Gemma2Model` - Gemma2-based - :code:`BAAI/bge-multilingual-gemma2`, etc. @@ -337,9 +342,19 @@ Text Embedding - ✅︎ * - :code:`Qwen2Model`, :code:`Qwen2ForCausalLM` - Qwen2-based - - :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`, etc. + - :code:`ssmits/Qwen2-7B-Instruct-embed-base`, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` (see note), etc. - ✅︎ - ✅︎ + * - :code:`RobertaModel`, :code:`RobertaForMaskedLM` + - RoBERTa-based + - :code:`sentence-transformers/all-roberta-large-v1`, :code:`sentence-transformers/all-roberta-large-v1`, etc. + - + - + * - :code:`XLMRobertaModel` + - XLM-RoBERTa-based + - :code:`intfloat/multilingual-e5-large`, etc. + - + - .. important:: Some model architectures support both generation and embedding tasks. @@ -348,6 +363,13 @@ Text Embedding .. tip:: You can override the model's pooling method by passing :code:`--override-pooler-config`. +.. note:: + Unlike base Qwen2, :code:`Alibaba-NLP/gte-Qwen2-7B-instruct` uses bi-directional attention. + You can set :code:`--hf-overrides '{"is_causal": false}'` to change the attention mask accordingly. + + On the other hand, its 1.5B variant (:code:`Alibaba-NLP/gte-Qwen2-1.5B-instruct`) uses causal attention + despite being described otherwise on its model card. + Reward Modeling --------------- @@ -390,6 +412,36 @@ Classification .. note:: As an interim measure, these models are supported in both offline and online inference via Embeddings API. +Sentence Pair Scoring +--------------------- + +.. list-table:: + :widths: 25 25 50 5 5 + :header-rows: 1 + + * - Architecture + - Models + - Example HF Models + - :ref:`LoRA ` + - :ref:`PP ` + * - :code:`BertForSequenceClassification` + - BERT-based + - :code:`cross-encoder/ms-marco-MiniLM-L-6-v2`, etc. + - + - + * - :code:`RobertaForSequenceClassification` + - RoBERTa-based + - :code:`cross-encoder/quora-roberta-base`, etc. + - + - + * - :code:`XLMRobertaForSequenceClassification` + - XLM-RoBERTa-based + - :code:`BAAI/bge-reranker-v2-m3`, etc. + - + - + +.. note:: + These models are supported in both offline and online inference via Score API. Multimodal Language Models ^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -424,6 +476,12 @@ Text Generation - Example HF Models - :ref:`LoRA ` - :ref:`PP ` + * - :code:`AriaForConditionalGeneration` + - Aria + - T + I + - :code:`rhymes-ai/Aria` + - + - ✅︎ * - :code:`Blip2ForConditionalGeneration` - BLIP-2 - T + I\ :sup:`E` @@ -561,10 +619,10 @@ Text Generation | :sup:`+` Multiple items can be inputted per text prompt for this modality. .. note:: - vLLM currently only supports adding LoRA to the language backbone of multimodal models. + vLLM currently only supports adding LoRA to the language backbone of multimodal models. .. note:: - For :code:`openbmb/MiniCPM-V-2`, the official repo doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. + The official :code:`openbmb/MiniCPM-V-2` doesn't work yet, so we need to use a fork (:code:`HwwwH/MiniCPM-V-2`) for now. For more details, please see: https://github.com/vllm-project/vllm/pull/4087#issuecomment-2250397630 Multimodal Embedding diff --git a/docs/source/quantization/fp8_e5m2_kvcache.rst b/docs/source/quantization/fp8_e5m2_kvcache.rst index 9ae07bcd3b991..b2d824427f786 100644 --- a/docs/source/quantization/fp8_e5m2_kvcache.rst +++ b/docs/source/quantization/fp8_e5m2_kvcache.rst @@ -4,7 +4,7 @@ FP8 E5M2 KV Cache ================== The int8/int4 quantization scheme requires additional scale GPU memory storage, which reduces the expected GPU memory benefits. -The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bflaot16 and fp8 to each other. +The FP8 data format retains 2~3 mantissa bits and can convert float/fp16/bfloat16 and fp8 to each other. Here is an example of how to enable this feature: diff --git a/docs/source/serving/compatibility_matrix.rst b/docs/source/serving/compatibility_matrix.rst index f629b3ca78318..fa03d2cde1486 100644 --- a/docs/source/serving/compatibility_matrix.rst +++ b/docs/source/serving/compatibility_matrix.rst @@ -39,12 +39,13 @@ Feature x Feature - :abbr:`prmpt adptr (Prompt Adapter)` - :ref:`SD ` - CUDA graph + - :abbr:`emd (Embedding Models)` - :abbr:`enc-dec (Encoder-Decoder Models)` - :abbr:`logP (Logprobs)` - :abbr:`prmpt logP (Prompt Logprobs)` - :abbr:`async output (Async Output Processing)` - multi-step - - :abbr:`MM (Multimodal)` + - :abbr:`mm (Multimodal)` - best-of - beam-search - :abbr:`guided dec (Guided Decoding)` @@ -64,6 +65,7 @@ Feature x Feature - - - + - * - :ref:`APC ` - ✅ - @@ -80,6 +82,7 @@ Feature x Feature - - - + - * - :ref:`LoRA ` - `✗ `__ - ✅ @@ -96,6 +99,7 @@ Feature x Feature - - - + - * - :abbr:`prmpt adptr (Prompt Adapter)` - ✅ - ✅ @@ -112,6 +116,7 @@ Feature x Feature - - - + - * - :ref:`SD ` - ✗ - ✅ @@ -128,6 +133,7 @@ Feature x Feature - - - + - * - CUDA graph - ✅ - ✅ @@ -144,6 +150,24 @@ Feature x Feature - - - + - + * - :abbr:`emd (Embedding Models)` + - ✗ + - ✗ + - ✗ + - ✗ + - ✗ + - ✗ + - + - + - + - + - + - + - + - + - + - * - :abbr:`enc-dec (Encoder-Decoder Models)` - ✗ - `✗ `__ @@ -151,6 +175,7 @@ Feature x Feature - ✗ - `✗ `__ - ✅ + - ✅ - - - @@ -166,7 +191,8 @@ Feature x Feature - ✅ - ✅ - ✅ - - ✅ + - ✅ + - ✗ - ✅ - - @@ -183,7 +209,8 @@ Feature x Feature - ✅ - `✗ `__ - ✅ - - ✅ + - ✗ + - ✅ - ✅ - - @@ -199,6 +226,7 @@ Feature x Feature - ✅ - ✗ - ✅ + - ✗ - ✗ - ✅ - ✅ @@ -215,6 +243,7 @@ Feature x Feature - ✅ - ✗ - ✅ + - ✗ - ✗ - ✅ - `✗ `__ @@ -224,14 +253,15 @@ Feature x Feature - - - - * - :abbr:`MM (Multimodal)` - - `✗ `__ + * - :abbr:`mm (Multimodal)` + - ✅ - `✗ `__ - `✗ `__ - ? - ? - ✅ - - ✗ + - ✅ + - ✅ - ✅ - ✅ - ✅ @@ -247,6 +277,7 @@ Feature x Feature - ✅ - `✗ `__ - ✅ + - ✗ - ✅ - ✅ - ✅ @@ -263,6 +294,7 @@ Feature x Feature - ✅ - `✗ `__ - ✅ + - ✗ - ✅ - ✅ - ✅ @@ -279,6 +311,7 @@ Feature x Feature - ? - ✅ - ✅ + - ✗ - ? - ✅ - ✅ @@ -311,7 +344,7 @@ Feature x Hardware - ✅ - ✅ - ✅ - - ✗ + - ✅ - ✅ * - :ref:`APC ` - `✗ `__ @@ -319,7 +352,7 @@ Feature x Hardware - ✅ - ✅ - ✅ - - ✗ + - ✅ - ✅ * - :ref:`LoRA ` - ✅ @@ -353,6 +386,14 @@ Feature x Hardware - ✅ - ✗ - ✅ + * - :abbr:`emd (Embedding Models)` + - ✅ + - ✅ + - ✅ + - ✅ + - ✅ + - ✅ + - ? * - :abbr:`enc-dec (Encoder-Decoder Models)` - ✅ - ✅ @@ -361,7 +402,7 @@ Feature x Hardware - ✅ - ✅ - ✗ - * - :abbr:`logP (Logprobs)` + * - :abbr:`mm (Multimodal)` - ✅ - ✅ - ✅ @@ -369,7 +410,7 @@ Feature x Hardware - ✅ - ✅ - ✅ - * - :abbr:`prmpt logP (Prompt Logprobs)` + * - :abbr:`logP (Logprobs)` - ✅ - ✅ - ✅ @@ -377,29 +418,29 @@ Feature x Hardware - ✅ - ✅ - ✅ - * - :abbr:`async output (Async Output Processing)` + * - :abbr:`prmpt logP (Prompt Logprobs)` - ✅ - ✅ - ✅ - ✅ - ✅ - - ✗ - - ✗ - * - multi-step - ✅ - ✅ + * - :abbr:`async output (Async Output Processing)` - ✅ - ✅ - ✅ - - `✗ `__ - ✅ - * - :abbr:`MM (Multimodal)` - ✅ + - ✗ + - ✗ + * - multi-step - ✅ - ✅ - ✅ - ✅ - ✅ + - `✗ `__ - ✅ * - best-of - ✅ diff --git a/docs/source/serving/metrics.rst b/docs/source/serving/metrics.rst index 15e57bd3fec65..231111cd7b738 100644 --- a/docs/source/serving/metrics.rst +++ b/docs/source/serving/metrics.rst @@ -2,9 +2,34 @@ Production Metrics ================== vLLM exposes a number of metrics that can be used to monitor the health of the -system. These metrics are exposed via the `/metrics` endpoint on the vLLM +system. These metrics are exposed via the ``/metrics`` endpoint on the vLLM OpenAI compatible API server. +You can start the server using Python, or using [Docker](deploying_with_docker.rst): + +.. code-block:: console + + $ vllm serve unsloth/Llama-3.2-1B-Instruct + +Then query the endpoint to get the latest metrics from the server: + +.. code-block:: console + + $ curl http://0.0.0.0:8000/metrics + + # HELP vllm:iteration_tokens_total Histogram of number of tokens per engine_step. + # TYPE vllm:iteration_tokens_total histogram + vllm:iteration_tokens_total_sum{model_name="unsloth/Llama-3.2-1B-Instruct"} 0.0 + vllm:iteration_tokens_total_bucket{le="1.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + vllm:iteration_tokens_total_bucket{le="8.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + vllm:iteration_tokens_total_bucket{le="16.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + vllm:iteration_tokens_total_bucket{le="32.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + vllm:iteration_tokens_total_bucket{le="64.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + vllm:iteration_tokens_total_bucket{le="128.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + vllm:iteration_tokens_total_bucket{le="256.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + vllm:iteration_tokens_total_bucket{le="512.0",model_name="unsloth/Llama-3.2-1B-Instruct"} 3.0 + ... + The following metrics are exposed: .. literalinclude:: ../../../vllm/engine/metrics.py diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index 79d032bf8b211..c39cef85897ed 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -44,6 +44,148 @@ We currently support the following OpenAI APIs: - This enables multi-modal inputs to be passed to embedding models, see [Using VLMs](../models/vlm.rst). - *Note: You should run `vllm serve` with `--task embedding` to ensure that the model is being run in embedding mode.* +## Score API for Cross Encoder Models + +vLLM supports *cross encoders models* at the **/v1/score** endpoint, which is not an OpenAI API standard endpoint. You can find the documentation for these kind of models at [sbert.net](https://www.sbert.net/docs/package_reference/cross_encoder/cross_encoder.html). + +A ***Cross Encoder*** takes exactly two sentences / texts as input and either predicts a score or label for this sentence pair. It can for example predict the similarity of the sentence pair on a scale of 0 … 1. + +### Example of usage for a pair of a string and a list of texts + +In this case, the model will compare the first given text to each of the texts containing the list. + +```bash +curl -X 'POST' \ + 'http://127.0.0.1:8000/v1/score' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "BAAI/bge-reranker-v2-m3", + "text_1": "What is the capital of France?", + "text_2": [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris." + ] +}' +``` + +Response: + +```bash +{ + "id": "score-request-id", + "object": "list", + "created": 693570, + "model": "BAAI/bge-reranker-v2-m3", + "data": [ + { + "index": 0, + "object": "score", + "score": [ + 0.001094818115234375 + ] + }, + { + "index": 1, + "object": "score", + "score": [ + 1 + ] + } + ], + "usage": {} +} +``` + +### Example of usage for a pair of two lists of texts + +In this case, the model will compare the one by one, making pairs by same index correspondent in each list. + +```bash +curl -X 'POST' \ + 'http://127.0.0.1:8000/v1/score' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "BAAI/bge-reranker-v2-m3", + "encoding_format": "float", + "text_1": [ + "What is the capital of Brazil?", + "What is the capital of France?" + ], + "text_2": [ + "The capital of Brazil is Brasilia.", + "The capital of France is Paris." + ] +}' +``` + +Response: + +```bash +{ + "id": "score-request-id", + "object": "list", + "created": 693447, + "model": "BAAI/bge-reranker-v2-m3", + "data": [ + { + "index": 0, + "object": "score", + "score": [ + 1 + ] + }, + { + "index": 1, + "object": "score", + "score": [ + 1 + ] + } + ], + "usage": {} +} +``` + +### Example of usage for a pair of two strings + +In this case, the model will compare the strings of texts. + +```bash +curl -X 'POST' \ + 'http://127.0.0.1:8000/v1/score' \ + -H 'accept: application/json' \ + -H 'Content-Type: application/json' \ + -d '{ + "model": "BAAI/bge-reranker-v2-m3", + "encoding_format": "float", + "text_1": "What is the capital of France?", + "text_2": "The capital of France is Paris." +}' +``` + +Response: + +```bash +{ + "id": "score-request-id", + "object": "list", + "created": 693447, + "model": "BAAI/bge-reranker-v2-m3", + "data": [ + { + "index": 0, + "object": "score", + "score": [ + 1 + ] + } + ], + "usage": {} +} +``` + ## Extra Parameters vLLM supports a set of parameters that are not part of the OpenAI API. diff --git a/examples/logging_configuration.md b/examples/logging_configuration.md index 0d278b0392403..9ac8b13cd5eaf 100644 --- a/examples/logging_configuration.md +++ b/examples/logging_configuration.md @@ -118,7 +118,7 @@ configuration for the root vLLM logger and for the logger you wish to silence: { "formatters": { "vllm": { - "class": "vllm.logging.NewLineFormatter", + "class": "vllm.logging_utils.NewLineFormatter", "datefmt": "%m-%d %H:%M:%S", "format": "%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s" } diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 391ac6b9b6b03..9b758fa2479f6 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,80 +1,22 @@ -from dataclasses import asdict - from vllm import LLM, SamplingParams -from vllm.engine.arg_utils import EngineArgs -from vllm.utils import FlexibleArgumentParser - - -def get_prompts(num_prompts: int): - # The default sample prompts. - prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", - ] - - if num_prompts != len(prompts): - prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts] - - return prompts - - -def main(args): - # Create prompts - prompts = get_prompts(args.num_prompts) - - # Create a sampling params object. - sampling_params = SamplingParams(n=args.n, - temperature=args.temperature, - top_p=args.top_p, - top_k=args.top_k, - max_tokens=args.max_tokens) - - # Create an LLM. - # The default model is 'facebook/opt-125m' - engine_args = EngineArgs.from_cli_args(args) - llm = LLM(**asdict(engine_args)) - - # Generate texts from the prompts. - # The output is a list of RequestOutput objects - # that contain the prompt, generated text, and other information. - outputs = llm.generate(prompts, sampling_params) - # Print the outputs. - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - - -if __name__ == '__main__': - parser = FlexibleArgumentParser() - parser = EngineArgs.add_cli_args(parser) - group = parser.add_argument_group("SamplingParams options") - group.add_argument("--num-prompts", - type=int, - default=4, - help="Number of prompts used for inference") - group.add_argument("--max-tokens", - type=int, - default=16, - help="Generated output length for sampling") - group.add_argument('--n', - type=int, - default=1, - help='Number of generated sequences per prompt') - group.add_argument('--temperature', - type=float, - default=0.8, - help='Temperature for text generation') - group.add_argument('--top-p', - type=float, - default=0.95, - help='top_p for text generation') - group.add_argument('--top-k', - type=int, - default=-1, - help='top_k for text generation') - args = parser.parse_args() - main(args) +# Sample prompts. +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] +# Create a sampling params object. +sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + +# Create an LLM. +llm = LLM(model="facebook/opt-125m") +# Generate texts from the prompts. The output is a list of RequestOutput objects +# that contain the prompt, generated text, and other information. +outputs = llm.generate(prompts, sampling_params) +# Print the outputs. +for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") diff --git a/examples/offline_inference_cli.py b/examples/offline_inference_cli.py new file mode 100644 index 0000000000000..391ac6b9b6b03 --- /dev/null +++ b/examples/offline_inference_cli.py @@ -0,0 +1,80 @@ +from dataclasses import asdict + +from vllm import LLM, SamplingParams +from vllm.engine.arg_utils import EngineArgs +from vllm.utils import FlexibleArgumentParser + + +def get_prompts(num_prompts: int): + # The default sample prompts. + prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", + ] + + if num_prompts != len(prompts): + prompts = (prompts * ((num_prompts // len(prompts)) + 1))[:num_prompts] + + return prompts + + +def main(args): + # Create prompts + prompts = get_prompts(args.num_prompts) + + # Create a sampling params object. + sampling_params = SamplingParams(n=args.n, + temperature=args.temperature, + top_p=args.top_p, + top_k=args.top_k, + max_tokens=args.max_tokens) + + # Create an LLM. + # The default model is 'facebook/opt-125m' + engine_args = EngineArgs.from_cli_args(args) + llm = LLM(**asdict(engine_args)) + + # Generate texts from the prompts. + # The output is a list of RequestOutput objects + # that contain the prompt, generated text, and other information. + outputs = llm.generate(prompts, sampling_params) + # Print the outputs. + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +if __name__ == '__main__': + parser = FlexibleArgumentParser() + parser = EngineArgs.add_cli_args(parser) + group = parser.add_argument_group("SamplingParams options") + group.add_argument("--num-prompts", + type=int, + default=4, + help="Number of prompts used for inference") + group.add_argument("--max-tokens", + type=int, + default=16, + help="Generated output length for sampling") + group.add_argument('--n', + type=int, + default=1, + help='Number of generated sequences per prompt') + group.add_argument('--temperature', + type=float, + default=0.8, + help='Temperature for text generation') + group.add_argument('--top-p', + type=float, + default=0.95, + help='top_p for text generation') + group.add_argument('--top-k', + type=int, + default=-1, + help='top_k for text generation') + + args = parser.parse_args() + main(args) diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 11af6880e1b5a..f08f22eec164a 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -402,6 +402,23 @@ def run_idefics3(question: str, modality: str): return llm, prompt, stop_token_ids +# Aria +def run_aria(question: str, modality: str): + assert modality == "image" + model_name = "rhymes-ai/Aria" + + llm = LLM(model=model_name, + tokenizer_mode="slow", + trust_remote_code=True, + dtype="bfloat16") + + prompt = (f"<|im_start|>user\n<|img|>\n{question}" + "<|im_end|>\n<|im_start|>assistant\n") + + stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] + return llm, prompt, stop_token_ids + + model_example_map = { "llava": run_llava, "llava-next": run_llava_next, @@ -423,6 +440,7 @@ def run_idefics3(question: str, modality: str): "molmo": run_molmo, "glm4v": run_glm4v, "idefics3": run_idefics3, + "aria": run_aria, } diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index dc12df8d78211..788b604cfd4a0 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -321,6 +321,25 @@ def load_idefics3(question, image_urls: List[str]) -> ModelRequestData: ) +def load_aria(question, image_urls: List[str]) -> ModelRequestData: + model_name = "rhymes-ai/Aria" + llm = LLM(model=model_name, + tokenizer_mode="slow", + trust_remote_code=True, + dtype="bfloat16", + limit_mm_per_prompt={"image": len(image_urls)}) + placeholders = "<|img|>\n" * len(image_urls) + prompt = (f"<|im_start|>user\n{placeholders}{question}<|im_end|>\n" + "<|im_start|>assistant\n") + stop_token_ids = [93532, 93653, 944, 93421, 1019, 93653, 93519] + return ModelRequestData( + llm=llm, + prompt=prompt, + stop_token_ids=stop_token_ids, + image_data=[fetch_image(url) for url in image_urls], + chat_template=None) + + model_example_map = { "phi3_v": load_phi3v, "h2ovl_chat": load_h2onvl, @@ -330,6 +349,7 @@ def load_idefics3(question, image_urls: List[str]) -> ModelRequestData: "qwen_vl_chat": load_qwenvl_chat, "mllama": load_mllama, "idefics3": load_idefics3, + "aria": load_aria, } diff --git a/examples/openai_cross_encoder_score.py b/examples/openai_cross_encoder_score.py new file mode 100644 index 0000000000000..8c32eea5dd252 --- /dev/null +++ b/examples/openai_cross_encoder_score.py @@ -0,0 +1,58 @@ +"""Examples Python client Score for Cross Encoder Models +""" + +import argparse +import json +import pprint + +import requests + + +def post_http_request(prompt: json, api_url: str) -> requests.Response: + headers = {"User-Agent": "Test Client"} + response = requests.post(api_url, headers=headers, json=prompt) + return response + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--host", type=str, default="localhost") + parser.add_argument("--port", type=int, default=8000) + parser.add_argument("--model", type=str, default="BAAI/bge-reranker-v2-m3") + args = parser.parse_args() + api_url = f"http://{args.host}:{args.port}/v1/score" + + model_name = args.model + + text_1 = "What is the capital of France?" + text_2 = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} + score_response = post_http_request(prompt=prompt, api_url=api_url) + print("Prompt for text_1 is string and text_2 is a list:") + pprint.pprint(prompt) + print("Score Response:") + pprint.pprint(score_response.data) + + text_1 = [ + "What is the capital of Brazil?", "What is the capital of France?" + ] + text_2 = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} + score_response = post_http_request(prompt=prompt, api_url=api_url) + print("Prompt for text_1 and text_2 are lists:") + pprint.pprint(prompt) + print("Score Response:") + pprint.pprint(score_response.data) + + text_1 = "What is the capital of Brazil?" + text_2 = "The capital of Brazil is Brasilia." + prompt = {"model": model_name, "text_1": text_1, "text_2": text_2} + score_response = post_http_request(prompt=prompt, api_url=api_url) + print("Prompt for text_1 and text_2 are strings:") + pprint.pprint(prompt) + print("Score Response:") + pprint.pprint(score_response.data) \ No newline at end of file diff --git a/examples/tool_chat_template_granite.jinja b/examples/tool_chat_template_granite.jinja index 2cc19e77188dc..467dcb2d10237 100644 --- a/examples/tool_chat_template_granite.jinja +++ b/examples/tool_chat_template_granite.jinja @@ -21,11 +21,7 @@ {{- '<|start_of_role|>user<|end_of_role|>' + message['content'] + '<|end_of_text|> ' }} {%- elif message['role'] == 'assistant_tool_call' or (message['role'] == 'assistant' and message.tool_calls is defined) %} - {{- '<|start_of_role|>assistant<|end_of_role|>' }} - {% for tc in message.tool_calls %} - {{- '<|tool_call|> ' + {'name': tc.function.name, 'arguments': tc.function.arguments}|tojson }} - {% endfor %} - {{- '<|end_of_text|> + {{- '<|start_of_role|>assistant<|end_of_role|><|tool_call|>' + message.tool_calls|map(attribute='function')|list|tojson(indent=4) + '<|end_of_text|> ' }} {%- elif message['role'] == 'assistant' %} {{- '<|start_of_role|>assistant<|end_of_role|>' + message['content'] + '<|end_of_text|> diff --git a/examples/tool_chat_template_llama3.1_json.jinja b/examples/tool_chat_template_llama3.1_json.jinja index c24a7e51335ef..033830936a56b 100644 --- a/examples/tool_chat_template_llama3.1_json.jinja +++ b/examples/tool_chat_template_llama3.1_json.jinja @@ -19,10 +19,18 @@ {#- This block extracts the system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %} - {%- set system_message = messages[0]['content']|trim %} + {%- if messages[0]['content'] is string %} + {%- set system_message = messages[0]['content']|trim %} + {%- else %} + {%- set system_message = messages[0]['content'][0]['text']|trim %} + {%- endif %} {%- set messages = messages[1:] %} {%- else %} - {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} + {%- if tools is not none %} + {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} + {%- else %} + {%- set system_message = "" %} + {%- endif %} {%- endif %} {#- System message #} @@ -33,8 +41,8 @@ {{- "Cutting Knowledge Date: December 2023\n" }} {{- "Today Date: " + date_string + "\n\n" }} {%- if tools is not none and not tools_in_user_message %} - {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call. " }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} {{- "Do not use variables.\n\n" }} {%- for t in tools %} {{- t | tojson(indent=4) }} @@ -48,7 +56,11 @@ {%- if tools_in_user_message and not tools is none %} {#- Extract the first user message so we can plug it in here #} {%- if messages | length != 0 %} - {%- set first_user_message = messages[0]['content']|trim %} + {%- if messages[0]['content'] is string %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- else %} + {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %} + {%- endif %} {%- set messages = messages[1:] %} {%- else %} {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} @@ -56,7 +68,7 @@ {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} {{- "Given the following functions, please respond with a JSON for a function call " }} {{- "with its proper arguments that best answers the given prompt.\n\n" }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} {{- "Do not use variables.\n\n" }} {%- for t in tools %} {{- t | tojson(indent=4) }} @@ -67,7 +79,17 @@ {%- for message in messages %} {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} - {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} + {%- if message['content'] is string %} + {{- message['content'] | trim}} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'text' %} + {{- content['text'] | trim }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|eot_id|>' }} {%- elif 'tool_calls' in message %} {%- if not message.tool_calls|length == 1 %} {{- raise_exception("This model only supports single tool-calls at once!") }} @@ -81,10 +103,14 @@ {{- "<|eot_id|>" }} {%- elif message.role == "tool" or message.role == "ipython" %} {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} - {%- if message.content is mapping %} - {{- message.content | tojson }} - {%- else %} + {%- if message.content is string %} {{- { "output": message.content } | tojson }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'text' %} + {{- { "output": content['text'] } | tojson }} + {%- endif %} + {%- endfor %} {%- endif %} {{- "<|eot_id|>" }} {%- endif %} diff --git a/examples/tool_chat_template_llama3.2_json.jinja b/examples/tool_chat_template_llama3.2_json.jinja index 7e24777726a35..39f902c1c3c40 100644 --- a/examples/tool_chat_template_llama3.2_json.jinja +++ b/examples/tool_chat_template_llama3.2_json.jinja @@ -16,38 +16,70 @@ {%- set tools = none %} {%- endif %} +{#- Find out if there are any images #} +{% set image_ns = namespace(has_images=false) %} +{%- for message in messages %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {%- set image_ns.has_images = true %} + {%- endif %} + {%- endfor %} +{%- endfor %} + + {#- This block extracts the system message, so we can slot it into the right place. #} {%- if messages[0]['role'] == 'system' %} - {%- set system_message = messages[0]['content']|trim %} + {%- if messages[0]['content'] is string %} + {%- set system_message = messages[0]['content']|trim %} + {%- else %} + {#- Support vLLM's transforming of a content string to JSON. #} + {%- set system_message = messages[0]['content'][0]['text']|trim %} + {%- endif %} {%- set messages = messages[1:] %} {%- else %} - {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} + {%- if tools is not none %} + {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} + {%- else %} + {%- set system_message = "" %} + {%- endif %} {%- endif %} -{#- System message #} -{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} -{%- if tools is not none %} - {{- "Environment: ipython\n" }} +{#- Including an image is not compatible with a system message #} +{%- if image_ns.has_images and not system_message == "" %} + {{- raise_exception("Prompting with images is incompatible with system messages and tool use.") }} {%- endif %} -{{- "Cutting Knowledge Date: December 2023\n" }} -{{- "Today Date: " + date_string + "\n\n" }} -{%- if tools is not none and not tools_in_user_message %} - {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} - {{- "Do not use variables.\n\n" }} - {%- for t in tools %} - {{- t | tojson(indent=4) }} - {{- "\n\n" }} - {%- endfor %} + + +{#- System message, if there are no images #} +{%- if not image_ns.has_images %} + {{- "<|start_header_id|>system<|end_header_id|>\n\n" }} + {%- if tools is not none %} + {{- "Environment: ipython\n" }} + {%- endif %} + {{- "Cutting Knowledge Date: December 2023\n" }} + {{- "Today Date: " + date_string + "\n\n" }} + {%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call. " }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {%- endif %} + {{- system_message }} + {{- "<|eot_id|>" }} {%- endif %} -{{- system_message }} -{{- "<|eot_id|>" }} {#- Custom tools are passed in a user message with some extra guidance #} {%- if tools_in_user_message and not tools is none %} {#- Extract the first user message so we can plug it in here #} {%- if messages | length != 0 %} - {%- set first_user_message = messages[0]['content']|trim %} + {%- if messages[0]['content'] is string %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- else %} + {%- set first_user_message = messages[0]['content'] | selectattr('type', 'equalto', 'text') | map(attribute='text') | map('trim') | join('\n') %} + {%- endif %} {%- set messages = messages[1:] %} {%- else %} {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} @@ -55,7 +87,7 @@ {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} {{- "Given the following functions, please respond with a JSON for a function call " }} {{- "with its proper arguments that best answers the given prompt.\n\n" }} - {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}. ' }} {{- "Do not use variables.\n\n" }} {%- for t in tools %} {{- t | tojson(indent=4) }} @@ -66,7 +98,19 @@ {%- for message in messages %} {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} - {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' }} + {%- if message['content'] is string %} + {{- message['content'] | trim}} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'image' %} + {{- '<|image|>' }} + {%- elif content['type'] == 'text' %} + {{- content['text'] | trim }} + {%- endif %} + {%- endfor %} + {%- endif %} + {{- '<|eot_id|>' }} {%- elif 'tool_calls' in message %} {%- if not message.tool_calls|length == 1 %} {{- raise_exception("This model only supports single tool-calls at once!") }} @@ -80,10 +124,14 @@ {{- "<|eot_id|>" }} {%- elif message.role == "tool" or message.role == "ipython" %} {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} - {%- if message.content is mapping %} - {{- message.content | tojson }} - {%- else %} + {%- if message.content is string %} {{- { "output": message.content } | tojson }} + {%- else %} + {%- for content in message['content'] %} + {%- if content['type'] == 'text' %} + {{- { "output": content['text'] } | tojson }} + {%- endif %} + {%- endfor %} {%- endif %} {{- "<|eot_id|>" }} {%- endif %} diff --git a/format.sh b/format.sh index b3dcdc15bf948..0b196de9d0773 100755 --- a/format.sh +++ b/format.sh @@ -41,6 +41,7 @@ MYPY_VERSION=$(mypy --version | awk '{print $2}') CODESPELL_VERSION=$(codespell --version) ISORT_VERSION=$(isort --vn) CLANGFORMAT_VERSION=$(clang-format --version | awk '{print $3}') +SPHINX_LINT_VERSION=$(sphinx-lint --version | awk '{print $2}') # # params: tool name, tool version, required version tool_version_check() { @@ -57,6 +58,7 @@ tool_version_check "mypy" "$MYPY_VERSION" tool_version_check "isort" "$ISORT_VERSION" tool_version_check "codespell" "$CODESPELL_VERSION" tool_version_check "clang-format" "$CLANGFORMAT_VERSION" +tool_version_check "sphinx-lint" "$SPHINX_LINT_VERSION" YAPF_FLAGS=( '--recursive' @@ -313,3 +315,7 @@ if ! git diff --quiet &>/dev/null; then else echo "✨🎉 Format check passed! Congratulations! 🎉✨" fi + +echo 'vLLM sphinx-lint:' +tools/sphinx-lint.sh +echo 'vLLM sphinx-lint: Done' diff --git a/pyproject.toml b/pyproject.toml index 3c8c46cc8621e..253b706a774a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,4 +98,5 @@ markers = [ "quant_model: run this model test under Quantized category", "distributed_2_gpus: run this test only in distributed tests for 2 GPUs", "skip_v1: do not run this test with v1", + "optional: optional tests that are automatically skipped, include --optional to run them", ] diff --git a/requirements-lint.txt b/requirements-lint.txt index f9132bbf96437..711bb50a0e936 100644 --- a/requirements-lint.txt +++ b/requirements-lint.txt @@ -6,6 +6,7 @@ ruff==0.6.5 codespell==2.3.0 isort==5.13.2 clang-format==18.1.5 +sphinx-lint==1.0.0 # type checking mypy==1.11.1 diff --git a/requirements-tpu.txt b/requirements-tpu.txt index f9a0770804e55..3d1e80f6be620 100644 --- a/requirements-tpu.txt +++ b/requirements-tpu.txt @@ -16,8 +16,8 @@ ray[default] --find-links https://storage.googleapis.com/libtpu-releases/index.html --find-links https://storage.googleapis.com/jax-releases/jax_nightly_releases.html --find-links https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html -torch==2.6.0.dev20241028+cpu -torchvision==0.20.0.dev20241028+cpu -torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241028-cp310-cp310-linux_x86_64.whl +torch==2.6.0.dev20241114+cpu +torchvision==0.20.0.dev20241114+cpu +torch_xla[tpu] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20241114-cp310-cp310-linux_x86_64.whl jaxlib==0.4.32.dev20240829 jax==0.4.32.dev20240829 diff --git a/tests/basic_correctness/test_basic_correctness.py b/tests/basic_correctness/test_basic_correctness.py index 7f16baa65a644..fcba253d159f3 100644 --- a/tests/basic_correctness/test_basic_correctness.py +++ b/tests/basic_correctness/test_basic_correctness.py @@ -14,11 +14,12 @@ from vllm.platforms import current_platform from vllm.worker.model_runner import ModelInputForGPUWithSamplingMetadata +from ..conftest import VllmRunner from ..models.utils import check_outputs_equal from ..utils import multi_gpu_test MODELS = [ - "facebook/opt-125m", + "google/gemma-2-2b-it", "meta-llama/Llama-3.2-1B", ] @@ -42,8 +43,6 @@ def test_vllm_gc_ed(): @pytest.mark.parametrize("enforce_eager", [False, True]) def test_models( hf_runner, - vllm_runner, - example_prompts, model: str, backend: str, dtype: str, @@ -54,15 +53,27 @@ def test_models( if backend == "FLASHINFER" and current_platform.is_rocm(): pytest.skip("Flashinfer does not support ROCm/HIP.") + if backend == "XFORMERS" and model == "google/gemma-2-2b-it": + pytest.skip( + "XFORMERS does not support gemma2 with full context length.") + os.environ["VLLM_ATTENTION_BACKEND"] = backend + # 5042 tokens for gemma2 + # gemma2 has alternating sliding window size of 4096 + # we need a prompt with more than 4096 tokens to test the sliding window + prompt = "The following numbers of the sequence " + ", ".join( + str(i) for i in range(1024)) + " are:" + example_prompts = [prompt] + with hf_runner(model, dtype=dtype) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) - with vllm_runner(model, - dtype=dtype, - enforce_eager=enforce_eager, - gpu_memory_utilization=0.7) as vllm_model: + with VllmRunner(model, + max_model_len=8192, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7) as vllm_model: vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens) check_outputs_equal( diff --git a/tests/basic_correctness/test_chunked_prefill.py b/tests/basic_correctness/test_chunked_prefill.py index cc5bc2aca27c9..469d18a4dd7af 100644 --- a/tests/basic_correctness/test_chunked_prefill.py +++ b/tests/basic_correctness/test_chunked_prefill.py @@ -12,6 +12,7 @@ import pytest from tests.kernels.utils import override_backend_env_variable +from vllm.platforms import current_platform from ..models.utils import check_logprobs_close, check_outputs_equal from ..utils import multi_gpu_test @@ -206,12 +207,14 @@ def test_models_with_fp8_kv_cache( # NOTE: Increasing this in this suite will fail CI because we currently cannot # reset distributed env properly. Use a value > 1 just when you test. @pytest.mark.parametrize("tensor_parallel_size", [1]) +@pytest.mark.parametrize("dtype", ["half"]) def test_with_prefix_caching( vllm_runner, max_tokens: int, enforce_eager: bool, chunk_size: int, tensor_parallel_size: int, + dtype: str, ) -> None: """ Checks exact match decode with and without prefix caching @@ -233,7 +236,7 @@ def test_with_prefix_caching( for enable in (True, False): with vllm_runner( model, - dtype="half", + dtype=dtype, max_num_batched_tokens=max_num_batched_tokens, enable_chunked_prefill=True, enable_prefix_caching=enable, @@ -260,3 +263,61 @@ def test_with_prefix_caching( name_0="w/o prefix caching", name_1="with prefix caching", ) + + +@pytest.mark.parametrize("model", ["facebook/opt-125m"]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [32]) +@pytest.mark.parametrize("chunked_prefill_token_size", [1, 4, 16]) +@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("attention_backend", ["TORCH_SDPA"]) +@pytest.mark.cpu_model +@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") +def test_models_cpu( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + chunked_prefill_token_size: int, + enforce_eager: bool, + attention_backend: str, + monkeypatch, +) -> None: + test_models( + hf_runner, + vllm_runner, + example_prompts, + model, + dtype, + max_tokens, + chunked_prefill_token_size, + enforce_eager, + 1, + attention_backend, + monkeypatch, + ) + + +@pytest.mark.parametrize("max_tokens", [16]) +@pytest.mark.parametrize("enforce_eager", [False]) +@pytest.mark.parametrize("chunk_size", [30, 32]) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.cpu_model +@pytest.mark.skipif(not current_platform.is_cpu(), reason="CPU only") +def test_with_prefix_caching_cpu( + vllm_runner, + max_tokens: int, + enforce_eager: bool, + chunk_size: int, + dtype: str, +) -> None: + test_with_prefix_caching( + vllm_runner, + max_tokens, + enforce_eager, + chunk_size, + 1, + dtype, + ) diff --git a/tests/compile/backend.py b/tests/compile/backend.py index 9d5c68274374e..8fa10e5bd1b37 100644 --- a/tests/compile/backend.py +++ b/tests/compile/backend.py @@ -1,7 +1,9 @@ from copy import deepcopy -from typing import Callable +from typing import Callable, Union -import torch +from torch import fx + +from vllm.compilation.inductor_pass import InductorPass class TestBackend: @@ -11,19 +13,21 @@ class TestBackend: It also saves the graph before and after the custom passes for inspection. """ - def __init__(self, *args: Callable[[torch.fx.Graph], None]): - self.custom_passes = args + def __init__(self, *passes: Union[InductorPass, Callable[[fx.Graph], + None]]): + self.custom_passes = list(passes) from torch._inductor import config self.current_config = config.shallow_copy_dict() + self.current_config['force_disable_caches'] = True self.current_config['post_grad_custom_post_pass'] = self.post_pass - def __call__(self, graph: torch.fx.GraphModule, example_inputs): + def __call__(self, graph: fx.GraphModule, example_inputs): from torch._inductor.compile_fx import compile_fx return compile_fx(graph, example_inputs, config_patches=self.current_config) - def post_pass(self, graph: torch.fx.Graph): + def post_pass(self, graph: fx.Graph): self.graph_pre_pass = deepcopy(graph) for pass_ in self.custom_passes: pass_(graph) diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 0e40e3b4ebc96..7ef502abee345 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -10,8 +10,8 @@ from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig -from vllm.plugins import set_current_vllm_config +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, + set_current_vllm_config) from vllm.utils import direct_register_custom_op global_counter = 0 @@ -79,7 +79,7 @@ def test_simple_piecewise_compile(): vllm_config = VllmConfig(compilation_config=CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - non_cudagraph_ops=["silly.attention"], + splitting_ops=["silly.attention"], cudagraph_copy_inputs=True, )) with set_current_vllm_config(vllm_config): diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index 356d119a40334..dbd5a3bbffeab 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -16,8 +16,8 @@ from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile -from vllm.config import CompilationConfig, CompilationLevel, VllmConfig -from vllm.plugins import set_current_vllm_config +from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, + set_current_vllm_config) from vllm.utils import direct_register_custom_op # create a library to hold the custom op @@ -258,7 +258,7 @@ def run_model(llama_config, use_cudagraph=True, ) if split_attn: - compilation_config.non_cudagraph_ops = ["silly.attention"] + compilation_config.splitting_ops = ["silly.attention"] else: compilation_config = CompilationConfig( level=CompilationLevel.NO_COMPILATION, ) @@ -378,7 +378,7 @@ def benchmark(): compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, - non_cudagraph_ops=["silly.attention"], + splitting_ops=["silly.attention"], ) else: compilation_config = CompilationConfig( diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index c0db2e78824be..99781c55b672e 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -62,6 +62,16 @@ class TestSetting: method="encode", fullgraph=True, ), + # encoder-based embedding model (BERT) + TestSetting( + model="BAAI/bge-base-en-v1.5", + model_args=["--task", "embedding"], + pp_size=1, + tp_size=1, + attn_backend="XFORMERS", + method="encode", + fullgraph=True, + ), # vision language model TestSetting( model="microsoft/Phi-3.5-vision-instruct", @@ -103,7 +113,7 @@ def test_compile_correctness(test_setting: TestSetting): CompilationLevel.NO_COMPILATION, CompilationLevel.PIECEWISE, ]: - all_args.append(final_args + ["-O", str(level)]) + all_args.append(final_args + [f"-O{level}"]) all_envs.append({}) # inductor will change the output, so we only compare if the output @@ -121,7 +131,7 @@ def test_compile_correctness(test_setting: TestSetting): CompilationLevel.DYNAMO_AS_IS, CompilationLevel.DYNAMO_ONCE, ]: - all_args.append(final_args + ["-O", str(level)]) + all_args.append(final_args + [f"-O{level}"]) all_envs.append({}) if level != CompilationLevel.DYNAMO_ONCE and not fullgraph: # "DYNAMO_ONCE" will always use fullgraph diff --git a/tests/compile/test_functionalization.py b/tests/compile/test_functionalization.py new file mode 100644 index 0000000000000..5036189077be2 --- /dev/null +++ b/tests/compile/test_functionalization.py @@ -0,0 +1,95 @@ +import pytest +import torch + +import vllm.envs as envs +from vllm import LLM, SamplingParams +from vllm.compilation.fix_functionalization import FixFunctionalizationPass +from vllm.compilation.fusion import (FusionPass, find_auto_fn, + find_auto_fn_maybe) +from vllm.compilation.reshapes import RedundantReshapesPass +from vllm.compilation.vllm_inductor_pass import is_func +from vllm.config import CompilationConfig + +from .backend import TestBackend + +OPS_IN_MODEL = [ + torch.ops._C.rotary_embedding.default, + torch.ops._C.fused_add_rms_norm.default, + torch.ops._C.silu_and_mul.default, +] + +RMS_OP = torch.ops._C.rms_norm.default + +RMS_QUANT_OPS = { + "static_fp8": [ + torch.ops._C.rms_norm_static_fp8_quant.default, + torch.ops._C.fused_add_rms_norm_static_fp8_quant.default + ], +} + +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] + + +@pytest.mark.parametrize("model", + ["nm-testing/TinyLlama-1.1B-Chat-v1.0-FP8-e2e"]) +@pytest.mark.parametrize("do_fusion", [True, False]) +@pytest.mark.skipif(envs.VLLM_TARGET_DEVICE != "cuda", + reason="Only test on CUDA") +def test_fix_functionalization(model: str, do_fusion: bool): + torch.set_default_device("cuda") + + config = CompilationConfig.PassConfig(enable_fusion=do_fusion, + enable_reshape=True) + reshape_pass = RedundantReshapesPass(config) + fusion_pass = FusionPass.instance(config) + + passes = [reshape_pass, fusion_pass] if do_fusion else [reshape_pass] + func_pass = FixFunctionalizationPass(config) + backend_func = TestBackend(*passes, func_pass) + backend_no_func = TestBackend(*passes) + + # instantiate a full engine and manually compile the model 2x + # (with and without FixFunctionalizationPass) + llm = LLM(model=model, enforce_eager=True) + model_runner = llm.llm_engine.model_executor.driver_worker.model_runner + orig_model = model_runner.model + # TODO mark inputs dynamic? (currently torch.compile is triggered 4x) + # Can only do that by using the decorator but then we'd have to instantiate + # 2 LLM instances. + + sampling_params = SamplingParams(temperature=0.0, top_p=1.0) + model_runner.model = torch.compile(orig_model, + fullgraph=True, + backend=backend_func) + gen_func = llm.generate(prompts, sampling_params) + + model_runner.model = torch.compile(orig_model, + fullgraph=True, + backend=backend_no_func) + gen_no_func = llm.generate(prompts, sampling_params) + + for output_func, output_no_func in zip(gen_func, gen_no_func): + assert output_func.outputs[0].text == output_no_func.outputs[0].text + + # OPS_IN_MODEL always appear. RMS_OP is fused away if we run fusion, + # and replaced by fused quantized ops in RMS_QUANT_OPS. + ops = OPS_IN_MODEL + (RMS_QUANT_OPS["static_fp8"] + if do_fusion else [RMS_OP]) + + for op in ops: + find_auto_fn(backend_no_func.graph_post_pass.nodes, op) + assert find_auto_fn_maybe(backend_func.graph_post_pass.nodes, + op) is None # noqa: E501 + + # make sure the ops were all de-functionalized + found = dict() + for node in backend_func.graph_post_pass.nodes: + for op in ops: + if is_func(node, op): + found[op] = True + assert all(found[op] for op in ops) diff --git a/tests/compile/test_fusion.py b/tests/compile/test_fusion.py index 4db79b070fd8d..f92ec8d0de5f1 100644 --- a/tests/compile/test_fusion.py +++ b/tests/compile/test_fusion.py @@ -38,12 +38,6 @@ def forward(self, x): return y3 -# Init does pattern registration, which can only happen once -config = CompilationConfig(enable_fusion=True) -reshape_pass = RedundantReshapesPass(config) -fusion_pass = FusionPass.instance(config) - - @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize("hidden_size", [64, 3392, 4096]) @pytest.mark.parametrize("num_tokens", [7, 256, 533, 2048, 2049]) @@ -58,6 +52,11 @@ def test_fusion_rmsnorm_quant(dtype, hidden_size, num_tokens, eps): pytest.skip("Only test eps=1e-5 for now") # Reshape pass is needed for the fusion pass to work + config = CompilationConfig.PassConfig(enable_fusion=True, + enable_reshape=True) + reshape_pass = RedundantReshapesPass(config) + fusion_pass = FusionPass.instance(config) + backend = TestBackend(reshape_pass, fusion_pass) model = TestModel(hidden_size, eps) diff --git a/tests/compile/test_pass_manager.py b/tests/compile/test_pass_manager.py new file mode 100644 index 0000000000000..03e7535093c5d --- /dev/null +++ b/tests/compile/test_pass_manager.py @@ -0,0 +1,35 @@ +import pickle + +import pytest +import torch +from torch._inductor.codecache import BypassFxGraphCache + +from vllm.compilation.config import CompilationConfig +from vllm.compilation.inductor_pass import (CallableInductorPass, + as_inductor_pass) +from vllm.compilation.pass_manager import PostGradPassManager + + +def simple_callable(graph: torch.fx.Graph): + pass + + +@as_inductor_pass(files=(__file__, )) +def callable_decorated(graph: torch.fx.Graph): + pass + + +@pytest.mark.parametrize( + "works, callable", + [(False, simple_callable), (True, callable_decorated), + (True, CallableInductorPass(simple_callable, "simple_callable"))]) +def test_pass_manager(works: bool, callable): + config = CompilationConfig().pass_config + pass_manager = PostGradPassManager([callable]) + pass_manager.configure(config) # Adds default passes + + if works: + pickle.dumps(pass_manager) + else: + with pytest.raises(BypassFxGraphCache): + pickle.dumps(pass_manager) diff --git a/tests/compile/utils.py b/tests/compile/utils.py index 078c6bf9ea1df..7c92d165d05f7 100644 --- a/tests/compile/utils.py +++ b/tests/compile/utils.py @@ -4,7 +4,7 @@ from tests.quantization.utils import is_quant_method_supported from vllm import LLM, SamplingParams -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationLevel from vllm.platforms import current_platform TEST_MODELS = [ @@ -85,7 +85,7 @@ def check_full_graph_support(model, enforce_eager=True, tensor_parallel_size=tp_size, disable_custom_all_reduce=True, - compilation_config=CompilationConfig(level=optimization_level), + compilation_config=optimization_level, **model_kwargs) outputs = llm.generate(prompts, sampling_params) diff --git a/tests/conftest.py b/tests/conftest.py index 0dc1cc6e83c18..d56942d8912af 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -265,6 +265,7 @@ def __init__( model_kwargs: Optional[Dict[str, Any]] = None, is_embedding_model: bool = False, is_sentence_transformer: bool = False, + is_cross_encoder: bool = False, skip_tokenizer_init: bool = False, auto_cls: Type[_BaseAutoModelClass] = AutoModelForCausalLM, postprocess_inputs: Callable[..., BatchEncoding] = identity, @@ -282,6 +283,14 @@ def __init__( device="cpu", trust_remote_code=True, ).to(dtype=torch_dtype)) + elif is_cross_encoder: + # Lazy init required for AMD CI + from sentence_transformers import CrossEncoder + self.model = CrossEncoder(model_name, + device="cpu", + trust_remote_code=True) + self.model.model = self.wrap_device(self.model.model)\ + .to(dtype=torch_dtype) else: model_kwargs = model_kwargs if model_kwargs is not None else {} self.model = self.wrap_device( @@ -625,6 +634,9 @@ def generate_encoder_decoder_greedy_logprobs_limit( def encode(self, prompts: List[str]) -> List[List[torch.Tensor]]: return self.model.encode(prompts) + def predict(self, prompts: List[List[str]]) -> torch.Tensor: + return self.model.predict(prompts, convert_to_tensor=True) + def __enter__(self): return self @@ -898,6 +910,14 @@ def encode( req_outputs = self.model.encode(inputs) return [req_output.outputs.embedding for req_output in req_outputs] + def score( + self, + text_1: Union[str, List[str]], + text_2: Union[str, List[str]], + ) -> List[List[float]]: + req_outputs = self.model.score(text_1, text_2) + return [req_output.outputs.embedding for req_output in req_outputs] + def __enter__(self): return self @@ -1010,3 +1030,22 @@ def dummy_gemma2_embedding_path(): with open(json_path, "w") as f: json.dump(config, f) return _dummy_gemma2_embedding_path + + +# Add the flag `--optional` to allow run tests +# that are marked with @pytest.mark.optional +def pytest_addoption(parser): + parser.addoption("--optional", + action="store_true", + default=False, + help="run optional test") + + +def pytest_collection_modifyitems(config, items): + if config.getoption("--optional"): + # --optional given in cli: do not skip optional tests + return + skip_optional = pytest.mark.skip(reason="need --optional option to run") + for item in items: + if "optional" in item.keywords: + item.add_marker(skip_optional) diff --git a/tests/core/block/e2e/test_correctness_sliding_window.py b/tests/core/block/e2e/test_correctness_sliding_window.py index 9320a9ef62314..415d0bd8237df 100644 --- a/tests/core/block/e2e/test_correctness_sliding_window.py +++ b/tests/core/block/e2e/test_correctness_sliding_window.py @@ -3,6 +3,7 @@ import pytest +from tests.kernels.utils import override_backend_env_variable from vllm import LLM, SamplingParams from .conftest import get_text_from_llm_generator @@ -28,8 +29,9 @@ @pytest.mark.parametrize("test_llm_kwargs", [{}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, - batch_size, seed): + batch_size, seed, backend, monkeypatch): """ The test does a bunch of assignments "x1 = 10\nx2 = 33\n..." and then asks for value of one of them (which is outside the sliding window). @@ -38,6 +40,8 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, Additionally, we compare the results of the v1 and v2 managers. """ + override_backend_env_variable(monkeypatch, backend) + sampling_params = SamplingParams( max_tokens=1024, ignore_eos=True, @@ -84,7 +88,9 @@ def test_sliding_window_retrival(baseline_llm_generator, test_llm_generator, @pytest.mark.parametrize("test_llm_kwargs", [{"enable_chunked_prefill": True}]) @pytest.mark.parametrize("batch_size", [5]) @pytest.mark.parametrize("seed", [1]) -def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed): +@pytest.mark.parametrize("backend", ["FLASH_ATTN", "FLASHINFER", "XFORMERS"]) +def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed, + backend, monkeypatch): """ This is similar to test_sliding_window_retrival, however, it doesn't compare against the v1 block manager since v1 doesn't support @@ -93,6 +99,8 @@ def test_sliding_window_chunked_prefill(test_llm_generator, batch_size, seed): The results with and without chunked prefill are not the same due to numerical instabilities. """ + override_backend_env_variable(monkeypatch, backend) + sampling_params = SamplingParams( max_tokens=10, ignore_eos=True, diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index d325b9606843e..bbeb4b3a58f2a 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -5,9 +5,14 @@ import pytest +from tests.core.utils import create_dummy_sequence +from vllm.core.block.cpu_gpu_block_allocator import CpuGpuBlockAllocator from vllm.core.block.interfaces import Block, BlockAllocator -from vllm.core.block.prefix_caching_block import (PrefixCachingBlock, +from vllm.core.block.prefix_caching_block import (ComputedBlocksTracker, + PrefixCachingBlock, PrefixCachingBlockAllocator) +from vllm.sequence import Logprob +from vllm.utils import Device class TestPrefixCachingBlock: @@ -726,18 +731,71 @@ def test_touch_block(): token_ids=common_token_ids, allocator=allocator, ) - block_ids = [block.block_id for block in blocks] + block_hashes = [block.content_hash for block in blocks] # The allocated blocks should be marked as touched # but not computed. - computed_block_ids = allocator.get_computed_block_ids( - [], block_ids, skip_last_block_id=False) + computed_block_ids = allocator.find_cached_blocks_prefix( + block_hashes) assert len(computed_block_ids) == 0 allocator.mark_blocks_as_computed([]) - computed_block_ids = allocator.get_computed_block_ids( - [], block_ids, skip_last_block_id=False) + computed_block_ids = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes) assert len(computed_block_ids) == common_blocks + @staticmethod + def test_find_cached_blocks_prefix(): + """ + This test verifies the behavior of find_cached_blocks_prefix. + """ + block_size = 4 + num_blocks = 8 + total_test_blocks = 12 + allocator = PrefixCachingBlockAllocator(num_blocks=num_blocks, + block_size=block_size) + + token_ids = list(range(total_test_blocks * block_size)) + block_tokens_seq1 = token_ids[:num_blocks * block_size] + blocks_seq1 = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=block_tokens_seq1, + allocator=allocator, + ) + block_hashes_seq1 = [block.content_hash for block in blocks_seq1] + allocator.mark_blocks_as_computed([]) + + # All blocks should be cached. + cached_blocks_seq1 = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes_seq1) + assert len(cached_blocks_seq1) == num_blocks + + # Free the first sequence. + for block in blocks_seq1: + allocator.free(block) + + # All blocks should be still be cached if not required to be allocated. + cached_blocks = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes_seq1) + assert len(cached_blocks) == num_blocks + + block_tokens_seq2 = token_ids[num_blocks * block_size:] + blocks_seq2 = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=block_tokens_seq2, + allocator=allocator, + ) + block_hashes_seq2 = [block.content_hash for block in blocks_seq2] + allocator.mark_blocks_as_computed([]) + cached_blocks = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes_seq2) + assert len(cached_blocks) == len(blocks_seq2) + + # Half of the blocks from seq1 should still be cached. + num_evicted_blocks = len(blocks_seq2) + cached_blocks = allocator.find_cached_blocks_prefix( + block_hashes=block_hashes_seq1) + assert len(cached_blocks) == len(blocks_seq1) - num_evicted_blocks + @staticmethod def create_immutable_chain( block_size: int, @@ -762,3 +820,114 @@ def create_immutable_chain( blocks.append(prev_block) return blocks + + +class TestComputedBlocksTracker: + + @staticmethod + def _get_mock_allocator(): + return MagicMock(spec=PrefixCachingBlockAllocator) + + @staticmethod + def test_get_num_cached_tokens(): + """ + Test it correctly computes the number of cached tokens for a given + sequence: + + - The cache token count is derived from the number of cached blocks. + - The cache token count is updated when the allocator is updated. + - When a sequence is removed, the cache token count should be updated + accordingly. + + # TODO(rickyx): This behaviour for prefill sequence is a hack until + we fix the computed blocks tracking. + - The cache token count for prefill sequence doesn't change while + the sequence is in continuous prefill (chunked prefill). + """ + block_size = 4 + mock_allocator = TestComputedBlocksTracker._get_mock_allocator() + tracker = ComputedBlocksTracker( + allocator=mock_allocator, + block_size=block_size, + enable_caching=True, + ) + + # Not yet allocated. + tokens = [0, 1, 2, 3, 4, 5] + seq1 = create_dummy_sequence(request_id=0, + token_ids=tokens, + block_size=block_size) + mock_allocator.find_cached_blocks_prefix.return_value = [] + assert tracker.get_num_cached_tokens(seq1) == 0 + + mock_allocator.find_cached_blocks_prefix.return_value = [ + None + ] # 1 block cached. + # Result is cached for prefill sequence. + assert tracker.get_num_cached_tokens(seq1) == 0 + + # Mark the sequence as non-prefill. + seq1.data.update_num_computed_tokens(len(tokens)) # 6 tokens computed. + assert not seq1.is_prefill() + + # Recomputes for decoding sequence. + assert tracker.get_num_cached_tokens(seq1) == 4 + + # Append new tokens to the sequence. + num_new_tokens = 3 + for i in range(num_new_tokens): + seq1.append_token_id(i, {i: Logprob(logprob=0.0)}) + + assert tracker.get_num_cached_tokens(seq1) == 4 + + # Update the allocator. + mock_allocator.find_cached_blocks_prefix.return_value = [ + None + ] * 2 # 2 blocks cached. + assert tracker.get_num_cached_tokens(seq1) == 8 + + # Remove the sequence. + tracker.remove_seq(seq1.seq_id) + + # Re-create the sequence with the same request id to simulate recompute. + seq1 = create_dummy_sequence(request_id=0, + token_ids=tokens, + block_size=block_size) + mock_allocator.find_cached_blocks_prefix.return_value = [ + ] # no cached block + assert tracker.get_num_cached_tokens(seq1) == 0 + + @staticmethod + def test_correct_block_hash(): + """ + Test that the block hash is correctly computed for a sequence (should + match the underlying block allocator's block hash). So the number of + cached tokens is correctly retrieved. + """ + block_size = 4 + allocator = CpuGpuBlockAllocator.create( + allocator_type="prefix_caching", + num_gpu_blocks=16, + num_cpu_blocks=16, + block_size=block_size, + ) + gpu_allocator = allocator._allocators[Device.GPU] + + tracker = ComputedBlocksTracker( + allocator=allocator, + block_size=block_size, + enable_caching=True, + ) + + tokens = list(range(block_size * 4)) # 4 blocks. + seq = create_dummy_sequence(request_id=0, + token_ids=tokens, + block_size=block_size) + _ = TestPrefixCachingBlockAllocator.create_immutable_chain( + block_size=block_size, + token_ids=tokens, + allocator=gpu_allocator, + ) + allocator.mark_blocks_as_computed([]) + + assert tracker.get_num_cached_tokens(seq) == len(tokens) diff --git a/tests/core/test_scheduler.py b/tests/core/test_scheduler.py index 5ff32be611592..8f6de84e566e7 100644 --- a/tests/core/test_scheduler.py +++ b/tests/core/test_scheduler.py @@ -12,9 +12,9 @@ from vllm.lora.request import LoRARequest from vllm.sequence import SequenceGroup -from .utils import (append_new_token, append_new_token_seq_group, - create_dummy_prompt, get_sequence_groups, - schedule_and_update_computed_tokens) +from .utils import (append_new_token, append_new_token_seq, + append_new_token_seq_group, create_dummy_prompt, + get_sequence_groups, schedule_and_update_computed_tokens) def test_scheduler_add_seq_group(): @@ -305,6 +305,8 @@ def initialize_scheduler( block_size=4, num_cpu_blocks=8, num_gpu_blocks=8, + enable_prefix_caching=False, + enable_chunked_prefill=False, ): block_size = block_size scheduler_config = SchedulerConfig( @@ -312,8 +314,15 @@ def initialize_scheduler( max_num_batched_tokens=max_token_budget, max_num_seqs=max_num_seqs, max_model_len=max_model_len, + enable_chunked_prefill=enable_chunked_prefill, + ) + cache_config = CacheConfig( + block_size, + 1.0, + 1, + "auto", + enable_prefix_caching=enable_prefix_caching, ) - cache_config = CacheConfig(block_size, 1.0, 1, "auto") cache_config.num_cpu_blocks = num_cpu_blocks cache_config.num_gpu_blocks = num_gpu_blocks scheduler = Scheduler(scheduler_config, cache_config, lora_config) @@ -800,3 +809,165 @@ def test_scheduling_budget(): assert budget.num_curr_seqs == 0 budget.subtract_num_seqs(seq_group.request_id, 2) assert budget.num_curr_seqs == 0 + + +@pytest.mark.parametrize("enable_prefix_caching", [True, False]) +def test_prefix_caching_aware_prefills(enable_prefix_caching): + """ + Test the below scenario: + + For 3 sequences, seqA, seqB, seqC, share the first block as prefix. + + The test verifies the below scenarios: + 1. SeqA is first scheduled. + 2. SeqB and SeqC can be prefilled together in a single schedule round + even though there are not enough token budgets to prefill both without + considering prefix caching. + """ + + block_size = 4 + max_num_batched_tokens = 12 + max_seq_group = 3 + scheduler = initialize_scheduler( + block_size=block_size, + num_cpu_blocks=16, + num_gpu_blocks=16, + max_token_budget=max_num_batched_tokens, + max_num_seqs=max_seq_group, + max_model_len=max_num_batched_tokens, + enable_prefix_caching=enable_prefix_caching, + ) + + seqA_tokens = list(range(8)) + num_shared_tokens = 4 + seqB_tokens = seqA_tokens[:num_shared_tokens] + list(range( + 12, 16)) # Shared prefix first 4. + seqC_tokens = seqA_tokens[:num_shared_tokens] + list(range( + 16, 20)) # Shared prefix first 4. + + seqA, seqA_group = create_dummy_prompt("0", + prompt_tokens=seqA_tokens, + block_size=block_size) + seqB, seqB_group = create_dummy_prompt("1", + prompt_tokens=seqB_tokens, + block_size=block_size) + seqC, seqC_group = create_dummy_prompt("2", + prompt_tokens=seqC_tokens, + block_size=block_size) + + # Schedule seqA prefill. + scheduler.add_seq_group(seqA_group) + metas, out, _ = scheduler.schedule() + assert (len(out.scheduled_seq_groups) == 1 + and out.scheduled_seq_groups[0].seq_group == seqA_group) + assert out.scheduled_seq_groups[0].token_chunk_size == len(seqA_tokens) + + # Schedule seqA decode. + append_new_token_seq_group(len(seqA_tokens), seqA_group, 999) + metas, out, _ = scheduler.schedule() + + assert len(out.scheduled_seq_groups) == 1 + assert out.scheduled_seq_groups[0].seq_group == seqA_group + assert out.scheduled_seq_groups[0].token_chunk_size == 1 + + # Schedule seqB and seqC prefills should work with prefix caching. + scheduler.add_seq_group(seqB_group) + scheduler.add_seq_group(seqC_group) + metas, out, _ = scheduler.schedule() + + if enable_prefix_caching: + assert len(out.scheduled_seq_groups) == 2 + assert set([ + out.scheduled_seq_groups[0].seq_group, + out.scheduled_seq_groups[1].seq_group, + ]) == set([seqB_group, seqC_group]) + assert len(metas) == 2 + for meta in metas: + assert meta.token_chunk_size == 8 + assert (len(meta.computed_block_nums) == num_shared_tokens // + block_size) # 1 Block for the 8 tokens. + else: + assert len(out.scheduled_seq_groups) == 1 + assert len(metas) == 1 + assert metas[0].token_chunk_size == 8 + assert len(metas[0].computed_block_nums) == 0 # No blocks computed. + + +def test_no_multiple_partial_prefills_with_chunked_prefill_and_prefix_caching( +): + """ + This test verifies that we don't schedule new prefills if there's already + a continuous prefill in progress even though the new prefills with shared + prefix can fit in the token budget: + + - SeqA is being chunked prefill. + - SeqB with the same prompt shouldn't be scheduled for prefill even though + there's enough token budget to prefill the cached tokens. + - Neither should seqC be scheduled. + + - When seqA is in decoding phase, seqB and seqC can be scheduled. + - Entire seqB should be prefilled since it's a full prefix cache hit. + - SeqC would be partially prefilled with the prefix shared, and the + remaining unique tokens would be prefilled (rounded down to be + block-size aligned). + """ + + block_size = 2 + max_num_batched_tokens = 4 + max_seq_group = 3 + scheduler = initialize_scheduler( + block_size=block_size, + num_cpu_blocks=16, + num_gpu_blocks=16, + max_token_budget=max_num_batched_tokens, + max_num_seqs=max_seq_group, + max_model_len=100, + enable_prefix_caching=True, + enable_chunked_prefill=True, + ) + + seqA_tokens = list(range(8)) + seqB_tokens = seqA_tokens + seqC_shared_prefix_len = 4 + seqC_tokens = seqA_tokens[:seqC_shared_prefix_len] + list(range(12, 20)) + + seqA, seqA_group = create_dummy_prompt("0", + prompt_tokens=seqA_tokens, + block_size=block_size) + seqB, seqB_group = create_dummy_prompt("1", + prompt_tokens=seqB_tokens, + block_size=block_size) + + # Chunked prefill seqA. + scheduler.add_seq_group(seqA_group) + metas, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + assert out.scheduled_seq_groups[0].seq_group == seqA_group + assert out.scheduled_seq_groups[0].token_chunk_size == 4 + + # seqB should not be scheduled with ongoing prefills. + scheduler.add_seq_group(seqB_group) + metas, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 1 + assert out.scheduled_seq_groups[0].seq_group == seqA_group + assert out.scheduled_seq_groups[0].token_chunk_size == 4 + + # both seqB and seqC can now be scheduled with seqA is over. + # seqA is in decoding phase. + append_new_token_seq(seqA, 999) + seqC, seqC_group = create_dummy_prompt("2", + prompt_tokens=seqC_tokens, + block_size=block_size) + scheduler.add_seq_group(seqC_group) + metas, out = schedule_and_update_computed_tokens(scheduler) + assert len(out.scheduled_seq_groups) == 3 + + metas = {meta.request_id: meta for meta in metas} + assert metas[seqA_group.request_id].token_chunk_size == 1 # Decode + assert (metas[seqB_group.request_id].token_chunk_size == 8 + ) # Fully cached prefill + assert ( + metas[seqC_group.request_id].token_chunk_size == 6 + ), "A partial prefix of C (4 tokens) should be prefilled, with the " + "remaining tokens fit into 3 token budget (4-1 from the seqA). It will " + "then be rounded down to 2 tokens on block size, thus 6 tokens in total." diff --git a/tests/core/utils.py b/tests/core/utils.py index cd0caa4704e11..277368b57b938 100644 --- a/tests/core/utils.py +++ b/tests/core/utils.py @@ -1,17 +1,20 @@ import time -from typing import List, Optional +from collections import defaultdict +from typing import Any, Dict, List, Optional from typing import Sequence as GenericSequence from typing import Tuple from vllm import SamplingParams +from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.inputs import EncoderDecoderInputs, token_inputs from vllm.lora.request import LoRARequest -from vllm.sequence import Logprob, Sequence, SequenceGroup +from vllm.sequence import (Logprob, Sequence, SequenceGroup, + SequenceGroupMetadata) def create_dummy_prompt( request_id: str, - prompt_length: int, + prompt_length: int = -1, block_size: Optional[int] = None, lora_request: Optional[LoRARequest] = None, best_of: int = 1, @@ -26,6 +29,7 @@ def create_dummy_prompt( # Create dummy prompt sequence with tokens 0...block_size-1 # and prompt "0 ... block_size". prompt_tokens = list(range(prompt_length)) + prompt_str = " ".join([str(t) for t in prompt_tokens]) prompt = Sequence(int(request_id), inputs=token_inputs(prompt_tokens, prompt=prompt_str), @@ -42,6 +46,15 @@ def create_dummy_prompt( return prompt, seq_group +def create_dummy_sequence(request_id: int, token_ids: List[int], + block_size: int) -> Sequence: + return Sequence( + seq_id=request_id, + inputs=token_inputs(token_ids), + block_size=block_size, + ) + + def create_dummy_prompt_encoder_decoder( request_id: str, decoder_prompt_length: int, @@ -194,12 +207,40 @@ def append_new_token(out, token_id: int): def schedule_and_update_computed_tokens(scheduler): metas, out, _ = scheduler.schedule() - for s, meta in zip(out.scheduled_seq_groups, metas): - s.seq_group.update_num_computed_tokens(meta.token_chunk_size) + for s in out.scheduled_seq_groups: + s.seq_group.update_num_computed_tokens(s.token_chunk_size) return metas, out +def append_new_token_seq(seq: Sequence, token_id: int): + seq.append_token_id(token_id, {token_id: Logprob(token_id)}) + + def append_new_token_seq_group(token_chunk_size, seq_group, token_id: int): seq_group.update_num_computed_tokens(token_chunk_size) for seq in seq_group.get_seqs(): seq.append_token_id(token_id, {token_id: Logprob(token_id)}) + + +class SchedulerProxy: + """ + A proxy class to forward calls to the scheduler. + """ + + def __init__(self, scheduler: Scheduler): + self.scheduler_ = scheduler + self.call_history: Dict[str, List[Any]] = defaultdict(list) + + def __getattr__(self, name: str) -> Any: + + def wrapper(*args, **kwargs): + result = getattr(self.scheduler_, name)(*args, **kwargs) + self.call_history[name].append((args, kwargs, result)) + return result + + return wrapper + + def last_schedule_ret( + self, ) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs, Any]: + _, _, ret = self.call_history["schedule"][-1] + return ret diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index e0e424439e3a5..f702d7c46ea73 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -150,6 +150,75 @@ def worker_fn_with_cudagraph(): assert a.mean().cpu().item() == pynccl_comm.world_size**1 +@worker_fn_wrapper +def all_gather_worker_fn(): + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + + rank = pynccl_comm.rank + world_size = pynccl_comm.world_size + device = f'cuda:{pynccl_comm.rank}' + + num_elems = 1000 + tensor = torch.arange(num_elems, dtype=torch.float32, + device=device) + rank * num_elems + result = torch.zeros(num_elems * world_size, + dtype=torch.float32, + device=device) + + expected = torch.cat([ + torch.arange(num_elems, dtype=torch.float32) + r * num_elems + for r in range(world_size) + ]).to(device) + + with pynccl_comm.change_state(enable=True): + pynccl_comm.all_gather(result, tensor) + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_all_gather(): + distributed_run(all_gather_worker_fn, 2) + + +@worker_fn_wrapper +def reduce_scatter_worker_fn(): + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + + rank = pynccl_comm.rank + world_size = pynccl_comm.world_size + device = f'cuda:{pynccl_comm.rank}' + + num_elems = 1000 + tensor = torch.arange(num_elems, dtype=torch.float32, + device=device) + rank * num_elems + assert (num_elems % world_size == 0) + result = torch.zeros(num_elems // world_size, + dtype=torch.float32, + device=device) + + # Calculate expected result for this rank's chunk + scattered_size = num_elems // world_size + all_tensors = [ + torch.arange(num_elems, dtype=torch.float32) + r * num_elems + for r in range(world_size) + ] + expected = sum(tensor[rank * scattered_size:(rank + 1) * scattered_size] + for tensor in all_tensors).to(device) + + with pynccl_comm.change_state(enable=True): + pynccl_comm.reduce_scatter(result, tensor) + torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) + + +@pytest.mark.skipif(torch.cuda.device_count() < 2, + reason="Need at least 2 GPUs to run the test.") +def test_pynccl_reduce_scatter(): + distributed_run(reduce_scatter_worker_fn, 2) + + @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Need at least 2 GPUs to run the test.") def test_pynccl_with_cudagraph(): diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 7b1be5a9802fd..5b0e76fe53685 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -31,6 +31,34 @@ def test_limit_mm_per_prompt_parser(arg, expected): assert args.limit_mm_per_prompt == expected +def test_compilation_config(): + parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) + + # default value + args = parser.parse_args([]) + assert args.compilation_config is None + + # set to O3 + args = parser.parse_args(["-O3"]) + assert args.compilation_config.level == 3 + + # set to O 3 (space) + args = parser.parse_args(["-O", "3"]) + assert args.compilation_config.level == 3 + + # set to O 3 (equals) + args = parser.parse_args(["-O=3"]) + assert args.compilation_config.level == 3 + + # set to json + args = parser.parse_args(["--compilation-config", '{"level": 3}']) + assert args.compilation_config.level == 3 + + # set to json + args = parser.parse_args(['--compilation-config={"level": 3}']) + assert args.compilation_config.level == 3 + + def test_valid_pooling_config(): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) args = parser.parse_args([ diff --git a/tests/entrypoints/openai/test_chat.py b/tests/entrypoints/openai/test_chat.py index 8d13f64dce01c..8d23a2be6f9bb 100644 --- a/tests/entrypoints/openai/test_chat.py +++ b/tests/entrypoints/openai/test_chat.py @@ -829,6 +829,20 @@ async def test_inconsistent_tool_choice_and_tools(client: openai.AsyncOpenAI, "name": "nondefined_function_name" } }) + with pytest.raises(openai.BadRequestError): + await client.chat.completions.create( + model=MODEL_NAME, + messages=messages, + max_completion_tokens=1000, + tools=[{ + "type": "function", + "function": { + "name": "dummy_function_name", + "description": "This is a dummy function", + "parameters": sample_json_schema + } + }], + tool_choice={}) @pytest.mark.asyncio @@ -899,19 +913,19 @@ async def test_response_format_json_schema(client: openai.AsyncOpenAI): @pytest.mark.asyncio -async def test_extra_fields(client: openai.AsyncOpenAI): - with pytest.raises(BadRequestError) as exc_info: - await client.chat.completions.create( - model=MODEL_NAME, - messages=[{ - "role": "system", - "content": "You are a helpful assistant.", - "extra_field": "0", - }], # type: ignore - temperature=0, - seed=0) - - assert "extra_forbidden" in exc_info.value.message +async def test_extra_fields_allowed(client: openai.AsyncOpenAI): + resp = await client.chat.completions.create( + model=MODEL_NAME, + messages=[{ + "role": "user", + "content": "what is 1+1?", + "extra_field": "0", + }], # type: ignore + temperature=0, + seed=0) + + content = resp.choices[0].message.content + assert content is not None @pytest.mark.asyncio diff --git a/tests/entrypoints/openai/test_chat_echo.py b/tests/entrypoints/openai/test_chat_echo.py new file mode 100644 index 0000000000000..223ac5b41aa83 --- /dev/null +++ b/tests/entrypoints/openai/test_chat_echo.py @@ -0,0 +1,79 @@ +from typing import NamedTuple + +import openai # use the official client for correctness check +import pytest +import pytest_asyncio + +from ...utils import RemoteOpenAIServer + +# # any model with a chat template should work here +MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" +DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 + + +@pytest.fixture(scope="module") +def server(): + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--enforce-eager", + "--max-model-len", + "4080", + "--chat-template", + DUMMY_CHAT_TEMPLATE, + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest_asyncio.fixture +async def client(server): + async with server.get_async_client() as async_client: + yield async_client + + +class TestCase(NamedTuple): + model_name: str + echo: bool + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "test_case", + [ + TestCase(model_name=MODEL_NAME, echo=True), + TestCase(model_name=MODEL_NAME, echo=False) + ], +) +async def test_chat_session_with_echo_and_continue_final_message( + client: openai.AsyncOpenAI, test_case: TestCase): + saying: str = "Here is a common saying about apple. An apple a day, keeps" + # test echo with continue_final_message parameter + chat_completion = await client.chat.completions.create( + model=test_case.model_name, + messages=[{ + "role": "user", + "content": "tell me a common saying" + }, { + "role": "assistant", + "content": saying + }], + extra_body={ + "echo": test_case.echo, + "continue_final_message": True, + "add_generation_prompt": False + }) + assert chat_completion.id is not None + assert len(chat_completion.choices) == 1 + + choice = chat_completion.choices[0] + assert choice.finish_reason == "stop" + + message = choice.message + if test_case.echo: + assert message.content is not None and saying in message.content + else: + assert message.content is not None and saying not in message.content + assert message.role == "assistant" diff --git a/tests/entrypoints/openai/test_root_path.py b/tests/entrypoints/openai/test_root_path.py new file mode 100644 index 0000000000000..20f7960619efb --- /dev/null +++ b/tests/entrypoints/openai/test_root_path.py @@ -0,0 +1,103 @@ +import contextlib +import os +from typing import Any, List, NamedTuple + +import openai # use the official client for correctness check +import pytest + +from ...utils import RemoteOpenAIServer + +# # any model with a chat template should work here +MODEL_NAME = "Qwen/Qwen2-1.5B-Instruct" +DUMMY_CHAT_TEMPLATE = """{% for message in messages %}{{message['role'] + ': ' + message['content'] + '\\n'}}{% endfor %}""" # noqa: E501 +API_KEY = "abc-123" +ERROR_API_KEY = "abc" +ROOT_PATH = "llm" + + +@pytest.fixture(scope="module") +def server(): + args = [ + # use half precision for speed and memory savings in CI environment + "--dtype", + "float16", + "--enforce-eager", + "--max-model-len", + "4080", + "--root-path", # use --root-path=/llm for testing + "/" + ROOT_PATH, + "--chat-template", + DUMMY_CHAT_TEMPLATE, + ] + envs = os.environ.copy() + + envs["VLLM_API_KEY"] = API_KEY + with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server: + yield remote_server + + +class TestCase(NamedTuple): + model_name: str + base_url: List[str] + api_key: str + expected_error: Any + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "test_case", + [ + TestCase( + model_name=MODEL_NAME, + base_url=["v1"], # http://localhost:8000/v1 + api_key=ERROR_API_KEY, + expected_error=openai.AuthenticationError), + TestCase( + model_name=MODEL_NAME, + base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1 + api_key=ERROR_API_KEY, + expected_error=openai.AuthenticationError), + TestCase( + model_name=MODEL_NAME, + base_url=["v1"], # http://localhost:8000/v1 + api_key=API_KEY, + expected_error=None), + TestCase( + model_name=MODEL_NAME, + base_url=[ROOT_PATH, "v1"], # http://localhost:8000/llm/v1 + api_key=API_KEY, + expected_error=None), + ], +) +async def test_chat_session_root_path_with_api_key(server: RemoteOpenAIServer, + test_case: TestCase): + saying: str = "Here is a common saying about apple. An apple a day, keeps" + ctx = contextlib.nullcontext() + if test_case.expected_error is not None: + ctx = pytest.raises(test_case.expected_error) + with ctx: + client = openai.AsyncOpenAI( + api_key=test_case.api_key, + base_url=server.url_for(*test_case.base_url), + max_retries=0) + chat_completion = await client.chat.completions.create( + model=test_case.model_name, + messages=[{ + "role": "user", + "content": "tell me a common saying" + }, { + "role": "assistant", + "content": saying + }], + extra_body={ + "continue_final_message": True, + "add_generation_prompt": False + }) + + assert chat_completion.id is not None + assert len(chat_completion.choices) == 1 + choice = chat_completion.choices[0] + assert choice.finish_reason == "stop" + message = choice.message + assert len(message.content) > 0 + assert message.role == "assistant" diff --git a/tests/entrypoints/openai/test_score.py b/tests/entrypoints/openai/test_score.py new file mode 100644 index 0000000000000..7565ff7192f67 --- /dev/null +++ b/tests/entrypoints/openai/test_score.py @@ -0,0 +1,93 @@ +import pytest +import requests + +from vllm.entrypoints.openai.protocol import ScoreResponse + +from ...utils import RemoteOpenAIServer + +MODEL_NAME = "BAAI/bge-reranker-v2-m3" + + +@pytest.fixture(scope="module") +def server(): + args = [ + "--enforce-eager", + ] + + with RemoteOpenAIServer(MODEL_NAME, args) as remote_server: + yield remote_server + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_text_1_str_text_2_list(server: RemoteOpenAIServer, + model_name: str): + text_1 = "What is the capital of France?" + text_2 = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + + score_response = requests.post(server.url_for("v1/score"), + json={ + "model": model_name, + "text_1": text_1, + "text_2": text_2, + }) + score_response.raise_for_status() + score = ScoreResponse.model_validate(score_response.json()) + + assert score.id is not None + assert score.data is not None + assert len(score.data) == 2 + assert score.data[0].score[0] <= 0.01 + assert score.data[1].score[0] >= 0.9 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_text_1_list_text_2_list(server: RemoteOpenAIServer, + model_name: str): + text_1 = [ + "What is the capital of the United States?", + "What is the capital of France?" + ] + text_2 = [ + "The capital of Brazil is Brasilia.", "The capital of France is Paris." + ] + + score_response = requests.post(server.url_for("v1/score"), + json={ + "model": model_name, + "text_1": text_1, + "text_2": text_2, + }) + score_response.raise_for_status() + score = ScoreResponse.model_validate(score_response.json()) + + assert score.id is not None + assert score.data is not None + assert len(score.data) == 2 + assert score.data[0].score[0] <= 0.01 + assert score.data[1].score[0] >= 0.9 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("model_name", [MODEL_NAME]) +async def test_text_1_str_text_2_str(server: RemoteOpenAIServer, + model_name: str): + text_1 = "What is the capital of France?" + text_2 = "The capital of France is Paris." + + score_response = requests.post(server.url_for("v1/score"), + json={ + "model": model_name, + "text_1": text_1, + "text_2": text_2, + }) + score_response.raise_for_status() + score = ScoreResponse.model_validate(score_response.json()) + + assert score.id is not None + assert score.data is not None + assert len(score.data) == 1 + assert score.data[0].score[0] >= 0.9 diff --git a/tests/entrypoints/test_chat_utils.py b/tests/entrypoints/test_chat_utils.py index 72477e048eafa..996e60bfee592 100644 --- a/tests/entrypoints/test_chat_utils.py +++ b/tests/entrypoints/test_chat_utils.py @@ -766,8 +766,8 @@ def test_resolve_content_format_hf_defined(model, expected_format): ("tool_chat_template_granite_20b_fc.jinja", "string"), ("tool_chat_template_hermes.jinja", "string"), ("tool_chat_template_internlm2_tool.jinja", "string"), - ("tool_chat_template_llama3.1_json.jinja", "string"), - ("tool_chat_template_llama3.2_json.jinja", "string"), + ("tool_chat_template_llama3.1_json.jinja", "openai"), + ("tool_chat_template_llama3.2_json.jinja", "openai"), ("tool_chat_template_mistral_parallel.jinja", "string"), ("tool_chat_template_mistral.jinja", "string")], ) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index 3d3724c50421d..d943b048b7934 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -18,6 +18,7 @@ from vllm.attention.backends.utils import STR_NOT_IMPL_ENC_DEC_ROCM_HIP from vllm.attention.selector import (_Backend, _cached_get_attn_backend, global_force_attn_backend_context_manager) +from vllm.config import VllmConfig, set_current_vllm_config from vllm.forward_context import set_forward_context from vllm.platforms import current_platform @@ -594,6 +595,7 @@ def _run_encoder_attention_test( encoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, test_pt: TestPoint, + vllm_config: VllmConfig, ) -> torch.Tensor: ''' Run encoder attention. @@ -623,7 +625,7 @@ def _run_encoder_attention_test( attn_type = AttentionType.ENCODER packed_qkv = encoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - with set_forward_context(attn_metadata): + with set_forward_context(attn_metadata, vllm_config): # In the test setup the shape of the query is # [batch_size, seq_len, num_heads, head_size]. However # the attention backend expect the shape to be @@ -648,6 +650,7 @@ def _run_decoder_self_attention_test( decoder_test_params: PhaseTestParameters, attn_metadata: AttentionMetadata, test_pt: TestPoint, + vllm_config: VllmConfig, ) -> torch.Tensor: ''' Run decoder self-attention test. @@ -677,7 +680,7 @@ def _run_decoder_self_attention_test( kv_cache = test_rsrcs.kv_cache packed_qkv = decoder_test_params.packed_qkvo.packed_qkv assert packed_qkv is not None - with set_forward_context(attn_metadata): + with set_forward_context(attn_metadata, vllm_config): # In the test setup the shape of the query is # [batch_size, seq_len, num_heads, head_size]. However # the attention backend expect the shape to be @@ -701,6 +704,7 @@ def _run_encoder_decoder_cross_attention_test( cross_test_params: Optional[PhaseTestParameters], attn_metadata: AttentionMetadata, test_pt: TestPoint, + vllm_config: VllmConfig, ) -> torch.Tensor: ''' Run encoder/decoder cross-attention test. @@ -748,7 +752,7 @@ def _run_encoder_decoder_cross_attention_test( cross_pckd_qkv = cross_test_params.packed_qkvo.packed_qkv key = (None if cross_pckd_qkv is None else cross_pckd_qkv.key) value = (None if cross_pckd_qkv is None else cross_pckd_qkv.value) - with set_forward_context(attn_metadata): + with set_forward_context(attn_metadata, vllm_config): # In the test setup the shape of the query is # [batch_size, seq_len, num_heads, head_size]. However # the attention backend expect the shape to be @@ -839,7 +843,9 @@ def test_encoder_only( # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init - test_rsrcs = _make_test_resources(test_pt) + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + test_rsrcs = _make_test_resources(test_pt) # Construct encoder attention test params (only used # during prefill) @@ -863,7 +869,8 @@ def test_encoder_only( test_rsrcs.attn, enc_test_params, prephase_attn_metadata, - test_pt=test_pt)) + test_pt=test_pt, + vllm_config=vllm_config)) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, @@ -960,7 +967,9 @@ def test_e2e_enc_dec_attn( # Attention scale factor, attention backend instance, attention wrapper # instance, KV cache init - test_rsrcs = _make_test_resources(test_pt) + vllm_config = VllmConfig() + with set_current_vllm_config(vllm_config): + test_rsrcs = _make_test_resources(test_pt) # Construct encoder attention test params (only used # during prefill) @@ -1011,7 +1020,8 @@ def test_e2e_enc_dec_attn( enc_pckd_act_out = _run_encoder_attention_test(test_rsrcs.attn, enc_test_params, prephase_attn_metadata, - test_pt=test_pt) + test_pt=test_pt, + vllm_config=vllm_config) # - Is encoder attention result correct? assert_actual_matches_ideal(enc_test_params, enc_pckd_act_out, @@ -1023,7 +1033,8 @@ def test_e2e_enc_dec_attn( test_rsrcs, prephase_dec_test_params, prephase_attn_metadata, - test_pt=test_pt) + test_pt=test_pt, + vllm_config=vllm_config) # - Is prefill decoder self-attention correct? assert_actual_matches_ideal(prephase_dec_test_params, @@ -1037,7 +1048,8 @@ def test_e2e_enc_dec_attn( prephase_dec_test_params, prephase_cross_test_params, prephase_attn_metadata, - test_pt=test_pt) + test_pt=test_pt, + vllm_config=vllm_config) # - Is prefill encoder/decoder cross-attention correct? assert_actual_matches_ideal(prephase_cross_test_params, @@ -1061,7 +1073,8 @@ def test_e2e_enc_dec_attn( test_rsrcs, decphase_dec_test_params, decphase_attn_metadata, - test_pt=test_pt) + test_pt=test_pt, + vllm_config=vllm_config) # - Is decode-phase decoder self-attention correct? assert_actual_matches_ideal(decphase_dec_test_params, @@ -1075,7 +1088,8 @@ def test_e2e_enc_dec_attn( decphase_dec_test_params, None, decphase_attn_metadata, - test_pt=test_pt) + test_pt=test_pt, + vllm_config=vllm_config) # - Is decode-phase encoder/decoder cross-attention correct? assert_actual_matches_ideal(decphase_cross_test_params, diff --git a/tests/kernels/test_marlin_gemm.py b/tests/kernels/test_marlin_gemm.py index 3899ad1a325cf..5e047f4b099f1 100644 --- a/tests/kernels/test_marlin_gemm.py +++ b/tests/kernels/test_marlin_gemm.py @@ -50,6 +50,8 @@ (13, 17, 67), (26, 37, 13), (67, 13, 11), + (257, 13, 11), + (658, 13, 11), ] DTYPES = [torch.float16, torch.bfloat16] diff --git a/tests/kernels/test_prefix_prefill.py b/tests/kernels/test_prefix_prefill.py index a8a187ebaede4..3fdb7996ba4e0 100644 --- a/tests/kernels/test_prefix_prefill.py +++ b/tests/kernels/test_prefix_prefill.py @@ -40,6 +40,13 @@ def test_contexted_kv_attention( kv_cache_dtype: str, device: str, ) -> None: + + if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( + 89): + pytest.skip( + 'Triton limitation: fp8e4nv data type is not supported on CUDA' + ' arch < 89') + current_platform.seed_everything(0) torch.set_default_device(device) @@ -235,6 +242,13 @@ def test_contexted_kv_attention_alibi( kv_cache_dtype: str, device: str, ) -> None: + + if 'fp8' in kv_cache_dtype and not current_platform.has_device_capability( + 89): + pytest.skip( + 'Triton limitation: fp8e4nv data type is not supported on CUDA' + ' arch < 89') + current_platform.seed_everything(0) torch.set_default_device(device) @@ -462,3 +476,52 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: print(f"xformers Time: {(end_time - start_time)*1000:.2f} ms") atol = 1e-3 if "fp8" in kv_cache_dtype else 1e-6 torch.testing.assert_close(output, output_ref, atol=atol, rtol=0) + + +# These tests are optional to only run when explicitly invoked +# +# pytest -v -s --optional \ +# tests/kernels/test_prefix_prefill.py::test_contexted_kv_attention_f32 +# +# These tests are useful to test model dtype float32 on Turing devices. +# We skip them to not increase the time when running tests on CI +@pytest.mark.optional +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@pytest.mark.parametrize("sliding_window", SLIDING_WINDOW) +@torch.inference_mode() +def test_contexted_kv_attention_f32( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + sliding_window: int, + dtype: torch.dtype, + kv_cache_dtype: str, + device: str, +) -> None: + test_contexted_kv_attention(num_heads, num_queries_per_kv, head_size, + sliding_window, dtype, kv_cache_dtype, device) + + +@pytest.mark.optional +@pytest.mark.parametrize("num_heads", NUM_HEADS) +@pytest.mark.parametrize("num_queries_per_kv", NUM_QUERIES_PER_KV) +@pytest.mark.parametrize("head_size", HEAD_SIZES) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("kv_cache_dtype", KV_CACHE_DTYPES) +@pytest.mark.parametrize("device", CUDA_DEVICES) +@torch.inference_mode() +def test_contexted_kv_attention_alibi_f32( + num_heads: int, + num_queries_per_kv: int, + head_size: int, + dtype: torch.dtype, + kv_cache_dtype: str, + device: str, +) -> None: + test_contexted_kv_attention_alibi(num_heads, num_queries_per_kv, head_size, + dtype, kv_cache_dtype, device) diff --git a/tests/lora/test_chatglm3.py b/tests/lora/test_chatglm3_tp.py similarity index 56% rename from tests/lora/test_chatglm3.py rename to tests/lora/test_chatglm3_tp.py index de4cbea80924e..f17464573459f 100644 --- a/tests/lora/test_chatglm3.py +++ b/tests/lora/test_chatglm3_tp.py @@ -1,12 +1,21 @@ from typing import List import vllm +from tests.utils import fork_new_process_for_each_test from vllm.lora.request import LoRARequest +from ..utils import multi_gpu_test + MODEL_PATH = "THUDM/chatglm3-6b" PROMPT_TEMPLATE = """I want you to act as a SQL terminal in front of an example database, you need only to return the sql command to me.Below is an instruction that describes a task, Write a response that appropriately completes the request.\n"\n##Instruction:\nconcert_singer contains tables such as stadium, singer, concert, singer_in_concert. Table stadium has columns such as Stadium_ID, Location, Name, Capacity, Highest, Lowest, Average. Stadium_ID is the primary key.\nTable singer has columns such as Singer_ID, Name, Country, Song_Name, Song_release_year, Age, Is_male. Singer_ID is the primary key.\nTable concert has columns such as concert_ID, concert_Name, Theme, Stadium_ID, Year. concert_ID is the primary key.\nTable singer_in_concert has columns such as concert_ID, Singer_ID. concert_ID is the primary key.\nThe Stadium_ID of concert is the foreign key of Stadium_ID of stadium.\nThe Singer_ID of singer_in_concert is the foreign key of Singer_ID of singer.\nThe concert_ID of singer_in_concert is the foreign key of concert_ID of concert.\n\n###Input:\n{query}\n\n###Response:""" # noqa: E501 +EXPECTED_LORA_OUTPUT = [ + "SELECT count(*) FROM singer", + "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501 + "SELECT name , country , age FROM singer ORDER BY age", +] + def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: prompts = [ @@ -20,7 +29,6 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: "Show name, country, age for all singers ordered by age from the oldest to the youngest." # noqa: E501 ), ] - print(prompts) sampling_params = vllm.SamplingParams(temperature=0, max_tokens=32) outputs = llm.generate( prompts, @@ -37,23 +45,58 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts +@fork_new_process_for_each_test def test_chatglm3_lora(chatglm3_lora_files): llm = vllm.LLM(MODEL_PATH, max_model_len=1024, enable_lora=True, max_loras=4, max_lora_rank=64, + tensor_parallel_size=1, trust_remote_code=True) - expected_lora_output = [ - "SELECT count(*) FROM singer", - "SELECT avg(age) , min(age) , max(age) FROM singer WHERE country = 'France'", # noqa: E501 - "SELECT name , country , age FROM singer ORDER BY age", - ] + output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output1[i] == EXPECTED_LORA_OUTPUT[i] + output2 = do_sample(llm, chatglm3_lora_files, lora_id=2) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output2[i] == EXPECTED_LORA_OUTPUT[i] + +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_chatglm3_lora_tp4(chatglm3_lora_files): + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=False) + + output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output1[i] == EXPECTED_LORA_OUTPUT[i] + output2 = do_sample(llm, chatglm3_lora_files, lora_id=2) + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output2[i] == EXPECTED_LORA_OUTPUT[i] + + +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_chatglm3_lora_tp4_fully_sharded_loras(chatglm3_lora_files): + llm = vllm.LLM(MODEL_PATH, + max_model_len=1024, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=True) output1 = do_sample(llm, chatglm3_lora_files, lora_id=1) - for i in range(len(expected_lora_output)): - assert output1[i] == expected_lora_output[i] + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output1[i] == EXPECTED_LORA_OUTPUT[i] output2 = do_sample(llm, chatglm3_lora_files, lora_id=2) - for i in range(len(expected_lora_output)): - assert output2[i] == expected_lora_output[i] + for i in range(len(EXPECTED_LORA_OUTPUT)): + assert output2[i] == EXPECTED_LORA_OUTPUT[i] diff --git a/tests/lora/test_llama.py b/tests/lora/test_llama.py deleted file mode 100644 index e2a4f1ed0496a..0000000000000 --- a/tests/lora/test_llama.py +++ /dev/null @@ -1,146 +0,0 @@ -from typing import List - -import pytest -import ray - -import vllm -from vllm.distributed import cleanup_dist_env_and_memory -from vllm.lora.request import LoRARequest - -MODEL_PATH = "meta-llama/Llama-2-7b-hf" - - -def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: - prompts = [ - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 - "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 - ] - sampling_params = vllm.SamplingParams(temperature=0, - max_tokens=256, - stop=["[/assistant]"]) - outputs = llm.generate( - prompts, - sampling_params, - lora_request=LoRARequest(str(lora_id), lora_id, lora_path) - if lora_id else None) - # Print the outputs. - generated_texts: List[str] = [] - for output in outputs: - prompt = output.prompt - generated_text = output.outputs[0].text - generated_texts.append(generated_text) - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") - return generated_texts - - -@pytest.mark.parametrize("tp_size", [1, 2, 4]) -def test_llama_lora(sql_lora_files, tp_size, num_gpus_available): - if num_gpus_available < tp_size: - pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") - - llm = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=tp_size) - - expected_no_lora_output = [ - "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501 - " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501 - "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501 - " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501 - " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501 - "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501 - ] - expected_lora_output = [ - " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 - " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501 - " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501 - " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501 - " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501 - " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501 - ] - - print("lora adapter created") - assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output - - print("lora 1") - assert do_sample(llm, sql_lora_files, lora_id=1) == expected_lora_output - - print("no lora") - assert do_sample(llm, sql_lora_files, lora_id=0) == expected_no_lora_output - - print("lora 2") - assert do_sample(llm, sql_lora_files, lora_id=2) == expected_lora_output - - print("removing lora") - - -def test_llama_tensor_parallel_equality(sql_lora_files, num_gpus_available): - if num_gpus_available < 4: - pytest.skip("Not enough GPUs for tensor parallelism 4") - - llm_tp1 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=1) - output_tp1 = do_sample(llm_tp1, sql_lora_files, lora_id=1) - - del llm_tp1 - cleanup_dist_env_and_memory() - - llm_tp2 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=2) - output_tp2 = do_sample(llm_tp2, sql_lora_files, lora_id=1) - - del llm_tp2 - cleanup_dist_env_and_memory() - - assert output_tp1 == output_tp2 - - llm_tp4 = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=4) - output_tp4 = do_sample(llm_tp4, sql_lora_files, lora_id=1) - - del llm_tp4 - cleanup_dist_env_and_memory() - - assert output_tp1 == output_tp4 - - -def test_llama_lora_warmup(sql_lora_files): - """Test that the LLM initialization works with a warmup LORA path and - is more conservative""" - - @ray.remote(num_gpus=1) - def get_num_gpu_blocks_lora(): - llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16) - num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks - return num_gpu_blocks_lora_warmup - - @ray.remote(num_gpus=1) - def get_num_gpu_blocks_no_lora(): - llm = vllm.LLM(MODEL_PATH, max_num_seqs=16) - num_gpu_blocks_no_lora_warmup = ( - llm.llm_engine.cache_config.num_gpu_blocks) - return num_gpu_blocks_no_lora_warmup - - num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote()) - num_gpu_blocks_no_lora_warmup = ray.get( - get_num_gpu_blocks_no_lora.remote()) - assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, ( - "The warmup with lora should be more " - "conservative than without lora, therefore the number of " - "memory blocks for the KV cache should be " - "less when using lora than when not using lora") diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py new file mode 100644 index 0000000000000..aae6310a2a213 --- /dev/null +++ b/tests/lora/test_llama_tp.py @@ -0,0 +1,161 @@ +from typing import List + +import ray + +import vllm +from tests.utils import fork_new_process_for_each_test +from vllm.lora.request import LoRARequest + +from ..utils import multi_gpu_test + +MODEL_PATH = "meta-llama/Llama-2-7b-hf" + +EXPECTED_NO_LORA_OUTPUT = [ + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_75 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_76 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_77 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_78 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user]", # noqa: E501 + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? ", # noqa: E501 + "\n\n answer: 1\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_96 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_97 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a high tone mora with a gloss of /˧kot/ [kòt]? [/user] [assistant]\n\n answer: 2\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_98 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one m", # noqa: E501 + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. ", # noqa: E501 + " Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? ", # noqa: E501 + "\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]\n\n [user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE", # noqa: E501 +] +EXPECTED_LORA_OUTPUT = [ + " SELECT icao FROM table_name_74 WHERE airport = 'lilongwe international airport' ", # noqa: E501 + " SELECT nationality FROM table_name_11 WHERE elector = 'anchero pantaleone' ", # noqa: E501 + " SELECT one_mora FROM table_name_95 WHERE gloss = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] AND accented_mora = 'low tone mora with a gloss of /˩okiru/' [òkìɽɯ́] ", # noqa: E501 + " SELECT sex FROM people WHERE people_id IN (SELECT people_id FROM candidate GROUP BY sex ORDER BY COUNT(people_id) DESC LIMIT 1) ", # noqa: E501 + " SELECT pick FROM table_name_60 WHERE former_wnba_team = 'Minnesota Lynx' ", # noqa: E501 + " SELECT womens_doubles FROM table_28138035_4 WHERE mens_singles = 'Werner Schlager' " # noqa: E501 +] + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + prompts = [ + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_74 (icao VARCHAR, airport VARCHAR)\n\n question: Name the ICAO for lilongwe international airport [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_11 (nationality VARCHAR, elector VARCHAR)\n\n question: When Anchero Pantaleone was the elector what is under nationality? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_95 (one_mora VARCHAR, gloss VARCHAR, accented_mora VARCHAR)\n\n question: What is the one mora for a low tone mora with a gloss of /˩okiru/ [òkìɽɯ́]? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE candidate (people_id VARCHAR, unsure_rate INTEGER); CREATE TABLE people (sex VARCHAR, people_id VARCHAR)\n\n question: which gender got the highest average uncertain ratio. [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_name_60 (pick INTEGER, former_wnba_team VARCHAR)\n\n question: What pick was a player that previously played for the Minnesota Lynx? [/user] [assistant]", # noqa: E501 + "[user] Write a SQL query to answer the question based on the table schema.\n\n context: CREATE TABLE table_28138035_4 (womens_doubles VARCHAR, mens_singles VARCHAR)\n\n question: Name the women's doubles for werner schlager [/user] [assistant]" # noqa: E501 + ] + sampling_params = vllm.SamplingParams(temperature=0, + max_tokens=256, + stop=["[/assistant]"]) + outputs = llm.generate( + prompts, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@fork_new_process_for_each_test +def test_llama_lora(sql_lora_files): + + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=1) + + print("lora adapter created") + assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 1") + assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT + + print("no lora") + assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 2") + assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT + + print("removing lora") + + +@fork_new_process_for_each_test +def test_llama_lora_warmup(sql_lora_files): + """Test that the LLM initialization works with a warmup LORA path and + is more conservative""" + + @ray.remote(num_gpus=1) + def get_num_gpu_blocks_lora(): + llm = vllm.LLM(MODEL_PATH, enable_lora=True, max_num_seqs=16) + num_gpu_blocks_lora_warmup = llm.llm_engine.cache_config.num_gpu_blocks + return num_gpu_blocks_lora_warmup + + @ray.remote(num_gpus=1) + def get_num_gpu_blocks_no_lora(): + llm = vllm.LLM(MODEL_PATH, max_num_seqs=16) + num_gpu_blocks_no_lora_warmup = ( + llm.llm_engine.cache_config.num_gpu_blocks) + return num_gpu_blocks_no_lora_warmup + + num_gpu_blocks_lora_warmup = ray.get(get_num_gpu_blocks_lora.remote()) + num_gpu_blocks_no_lora_warmup = ray.get( + get_num_gpu_blocks_no_lora.remote()) + assert num_gpu_blocks_lora_warmup < num_gpu_blocks_no_lora_warmup, ( + "The warmup with lora should be more " + "conservative than without lora, therefore the number of " + "memory blocks for the KV cache should be " + "less when using lora than when not using lora") + + +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_llama_lora_tp4(sql_lora_files): + + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=4, + ) + + print("lora adapter created") + assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 1") + assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT + + print("no lora") + assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 2") + assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT + + print("removing lora") + + +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): + + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=4, + fully_sharded_loras=True, + ) + print("lora adapter created") + assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 1") + assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT + + print("no lora") + assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT + + print("lora 2") + assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT + + print("removing lora") diff --git a/tests/lora/test_punica_variation.py b/tests/lora/test_punica_variation.py index 52b82f25d23e1..3b20033271d26 100644 --- a/tests/lora/test_punica_variation.py +++ b/tests/lora/test_punica_variation.py @@ -6,12 +6,13 @@ import pytest import torch -from vllm.lora.ops.bgmv_expand import bgmv_expand -from vllm.lora.ops.bgmv_expand_slice import bgmv_expand_slice -from vllm.lora.ops.bgmv_shrink import bgmv_shrink -from vllm.lora.ops.sgmv_expand import sgmv_expand -from vllm.lora.ops.sgmv_expand_slice import sgmv_expand_slice -from vllm.lora.ops.sgmv_shrink import sgmv_shrink +# Enable custom op register +import vllm.lora.ops.bgmv_expand +import vllm.lora.ops.bgmv_expand_slice +import vllm.lora.ops.bgmv_shrink +import vllm.lora.ops.sgmv_expand +import vllm.lora.ops.sgmv_expand_slice +import vllm.lora.ops.sgmv_shrink # noqa: F401 from vllm.platforms import current_platform from .utils import (generate_data, generate_data_for_expand_nslices, @@ -37,6 +38,16 @@ def assert_close(a, b): torch.testing.assert_close(a, b, rtol=rtol, atol=atol) +# Unlike test_punica_sizes.py, we directly utilize custom op for +# testing, which verifies the correct registration of these ops. +bgmv_expand = torch.ops.vllm.bgmv_expand +bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice +bgmv_shrink = torch.ops.vllm.bgmv_shrink +sgmv_expand = torch.ops.vllm.sgmv_expand +sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice +sgmv_shrink = torch.ops.vllm.sgmv_shrink + + @pytest.mark.parametrize("batches", BATCHES) @pytest.mark.parametrize("num_loras", NUM_LORA) @pytest.mark.parametrize("rank", MAX_RANKS) diff --git a/tests/model_executor/test_enabled_custom_ops.py b/tests/model_executor/test_enabled_custom_ops.py index c54e30995da49..0a3aba255fd76 100644 --- a/tests/model_executor/test_enabled_custom_ops.py +++ b/tests/model_executor/test_enabled_custom_ops.py @@ -2,13 +2,12 @@ import pytest -from vllm.config import CompilationConfig, VllmConfig +from vllm.config import CompilationConfig, VllmConfig, set_current_vllm_config from vllm.model_executor.custom_op import CustomOp from vllm.model_executor.layers.activation import (GeluAndMul, ReLUSquaredActivation, SiluAndMul) from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.plugins import set_current_vllm_config # Registered subclass for test diff --git a/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_internvl.py b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_internvl.py new file mode 100644 index 0000000000000..af0c2aa211998 --- /dev/null +++ b/tests/models/decoder_only/vision_language/mm_processor_kwargs/test_internvl.py @@ -0,0 +1,206 @@ +"""Tests for InternVL's multimodal preprocessing kwargs.""" +from typing import Callable, Optional + +import pytest +from transformers import AutoTokenizer + +from vllm.inputs import InputContext, token_inputs +from vllm.multimodal import MultiModalRegistry + +from .....conftest import _ImageAssets +from ....utils import build_model_context + +models = ["OpenGVLab/InternVL2-2B"] + + +# Wrap lazy imports to avoid initializing CUDA during test collection +@pytest.fixture() +def input_processor_for_internvl(): + from vllm.model_executor.models.internvl import InternVLInputPipeline + + pipeline = InternVLInputPipeline('', '', '') + return pipeline.input_processor + + +@pytest.fixture() +def dummy_data_for_internvl(): + from vllm.model_executor.models.internvl import InternVLInputPipeline + + pipeline = InternVLInputPipeline('', '', '') + return pipeline.dummy_data + + +@pytest.fixture() +def get_max_internvl_image_tokens(): + from vllm.model_executor.models.internvl import ( + get_max_internvl_image_tokens) + return get_max_internvl_image_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("max_dynamic_patch", [1, 4]) +@pytest.mark.parametrize("dynamic_image_size", [True, False, None]) +def test_input_mapper_override( + model: str, + image_assets: _ImageAssets, + max_dynamic_patch: int, + dynamic_image_size: Optional[bool], +): + mm_processor_kwargs = { + "max_dynamic_patch": max_dynamic_patch, + } + if dynamic_image_size is not None: + mm_processor_kwargs["dynamic_image_size"] = dynamic_image_size + + expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1 + if dynamic_image_size is False: + expected_num_patches = 1 + + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=mm_processor_kwargs, + ) + + mm_registry = MultiModalRegistry() + mm_registry.init_mm_limits_per_prompt(ctx.model_config) + + image = image_assets[0].pil_image.resize((448 * 2, 448 * 2)) + vllm_result = mm_registry.map_input( + ctx.model_config, + {"image": image}, + ) + assert vllm_result["pixel_values"].size(1) == expected_num_patches + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None]) +@pytest.mark.parametrize("dynamic_image_size", [True, False, None]) +def test_max_tokens_override( + get_max_internvl_image_tokens: Callable, + model: str, + max_dynamic_patch: Optional[int], + dynamic_image_size: Optional[bool], +): + """Ensure get_max_internvl_image_tokens handles mm_processor_kwargs.""" + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + if max_dynamic_patch is None: + max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch + expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1 + if dynamic_image_size is False: + expected_num_patches = 1 + expected_max_tokens = 256 * expected_num_patches + + actual_max_tokens = get_max_internvl_image_tokens( + ctx=InputContext(ctx.model_config), + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + assert expected_max_tokens == actual_max_tokens + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("num_imgs", [1, 2]) +@pytest.mark.parametrize("max_dynamic_patch", [1, 4, None]) +@pytest.mark.parametrize("dynamic_image_size", [True, False, None]) +def test_dummy_data_override( + dummy_data_for_internvl: Callable, + model: str, + num_imgs: int, + max_dynamic_patch: Optional[int], + dynamic_image_size: Optional[bool], +): + """Ensure dummy_data_for_internvl handles kwargs properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the dummy data func. + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + + if max_dynamic_patch is None: + max_dynamic_patch = ctx.get_hf_config().max_dynamic_patch + expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1 + if dynamic_image_size is False: + expected_num_patches = 1 + expected_max_tokens = 256 * expected_num_patches + + dummy_data = dummy_data_for_internvl( + ctx=ctx, + seq_len=8192, # Should be bigger than num_imgs * toks_per_img + mm_counts={"image": num_imgs}, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + sequence_data = dummy_data.seq_data + + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + image_token_id = tokenizer.encode('', + add_special_tokens=False)[0] + + # Ensure we have the right number of placeholders per size + img_tok_count = sequence_data.get_token_ids().count(image_token_id) + assert img_tok_count == expected_max_tokens * num_imgs + + +@pytest.mark.parametrize("model", models) +@pytest.mark.parametrize("max_dynamic_patch", [1, 4]) +@pytest.mark.parametrize("dynamic_image_size", [True, False, None]) +@pytest.mark.parametrize("num_imgs", [1, 2]) +def test_input_processor_override( + input_processor_for_internvl: Callable, + image_assets: _ImageAssets, + model: str, + num_imgs: int, + max_dynamic_patch: int, + dynamic_image_size: Optional[bool], +): + """Ensure input_processor_for_internvl handles kwargs properly.""" + # Same as the previous test - don't initialize mm_processor_kwargs + # in this test and assume that the kwargs will be correctly expanded by + # the partial when calling the custom input processor. + expected_num_patches = max_dynamic_patch + 1 if max_dynamic_patch > 1 else 1 + if dynamic_image_size is False: + expected_num_patches = 1 + + ctx = build_model_context( + model_name=model, + tokenizer_name=model, + trust_remote_code=True, + mm_processor_kwargs=None, + ) + expected_toks_per_img = 256 * expected_num_patches + + # Build the image str / prompt based on the number of images we pass + tokenizer = AutoTokenizer.from_pretrained(model, trust_remote_code=True) + placeholders = "" if num_imgs == 1 else "\n".join( + f"Image-{i}: \n" for i in range(1, num_imgs + 1)) + prompt = placeholders + images = [image_assets[0].pil_image.resize((448 * 2, 448 * 2))] * num_imgs + + inputs = token_inputs(prompt_token_ids=tokenizer.encode(prompt), + prompt=prompt, + multi_modal_data={"image": images}) + + processed_inputs = input_processor_for_internvl( + ctx, + inputs, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) + + # Ensure we have the right number of placeholders per num_crops size + image_token_id = tokenizer.encode('', + add_special_tokens=False)[0] + img_tok_count = processed_inputs["prompt_token_ids"].count(image_token_id) + assert img_tok_count == expected_toks_per_img * num_imgs diff --git a/tests/models/embedding/language/test_embedding.py b/tests/models/embedding/language/test_embedding.py index c3f351ef707be..36b1e5887981c 100644 --- a/tests/models/embedding/language/test_embedding.py +++ b/tests/models/embedding/language/test_embedding.py @@ -21,6 +21,7 @@ marks=[pytest.mark.core_model]), pytest.param("ssmits/Qwen2-7B-Instruct-embed-base"), pytest.param("Alibaba-NLP/gte-Qwen2-1.5B-instruct"), + pytest.param("Alibaba-NLP/gte-Qwen2-7B-instruct"), ], ) @pytest.mark.parametrize("dtype", ["half"]) @@ -31,6 +32,10 @@ def test_models( model, dtype: str, ) -> None: + vllm_extra_kwargs = {} + if model == "Alibaba-NLP/gte-Qwen2-7B-instruct": + vllm_extra_kwargs["hf_overrides"] = {"is_causal": False} + # The example_prompts has ending "\n", for example: # "Write a short story about a robot that dreams for the first time.\n" # sentence_transformers will strip the input texts, see: @@ -43,8 +48,11 @@ def test_models( is_sentence_transformer=True) as hf_model: hf_outputs = hf_model.encode(example_prompts) - with vllm_runner(model, task="embedding", dtype=dtype, - max_model_len=None) as vllm_model: + with vllm_runner(model, + task="embedding", + dtype=dtype, + max_model_len=None, + **vllm_extra_kwargs) as vllm_model: vllm_outputs = vllm_model.encode(example_prompts) # This test is for verifying whether the model's extra_repr # can be printed correctly. diff --git a/tests/models/embedding/language/test_scoring.py b/tests/models/embedding/language/test_scoring.py new file mode 100644 index 0000000000000..30fa5ea7b36c0 --- /dev/null +++ b/tests/models/embedding/language/test_scoring.py @@ -0,0 +1,95 @@ +"""Compare the embedding outputs of HF and vLLM models. + +Run `pytest tests/models/embedding/language/test_embedding.py`. +""" +import math + +import pytest + +MODELS = [ + "cross-encoder/ms-marco-MiniLM-L-6-v2", # Bert + "BAAI/bge-reranker-v2-m3", # Roberta +] + +TEXTS_1 = [ + "What is the capital of France?", + "What is the capital of Germany?", +] + +TEXTS_2 = [ + "The capital of France is Paris.", + "The capital of Germany is Berlin.", +] + + +@pytest.fixture(scope="module", params=MODELS) +def model_name(request): + yield request.param + + +@pytest.mark.parametrize("dtype", ["half"]) +def test_llm_1_to_1(vllm_runner, hf_runner, model_name, dtype: str): + + text_pair = [TEXTS_1[0], TEXTS_2[0]] + + with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: + hf_outputs = hf_model.predict([text_pair]).tolist() + + with vllm_runner(model_name, + task="embedding", + dtype=dtype, + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(text_pair[0], text_pair[1]) + + assert len(vllm_outputs) == 1 + assert len(hf_outputs) == 1 + + assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01) + + +@pytest.mark.parametrize("dtype", ["half"]) +def test_llm_1_to_N(vllm_runner, hf_runner, model_name, dtype: str): + + text_pairs = [ + [TEXTS_1[0], TEXTS_2[0]], + [TEXTS_1[0], TEXTS_2[1]], + ] + + with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: + hf_outputs = hf_model.predict(text_pairs).tolist() + + with vllm_runner(model_name, + task="embedding", + dtype=dtype, + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(TEXTS_1[0], TEXTS_2) + + assert len(vllm_outputs) == 2 + assert len(hf_outputs) == 2 + + assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01) + assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01) + + +@pytest.mark.parametrize("dtype", ["half"]) +def test_llm_N_to_N(vllm_runner, hf_runner, model_name, dtype: str): + + text_pairs = [ + [TEXTS_1[0], TEXTS_2[0]], + [TEXTS_1[1], TEXTS_2[1]], + ] + + with hf_runner(model_name, dtype=dtype, is_cross_encoder=True) as hf_model: + hf_outputs = hf_model.predict(text_pairs).tolist() + + with vllm_runner(model_name, + task="embedding", + dtype=dtype, + max_model_len=None) as vllm_model: + vllm_outputs = vllm_model.score(TEXTS_1, TEXTS_2) + + assert len(vllm_outputs) == 2 + assert len(hf_outputs) == 2 + + assert math.isclose(hf_outputs[0], vllm_outputs[0][0], rel_tol=0.01) + assert math.isclose(hf_outputs[1], vllm_outputs[1][0], rel_tol=0.01) diff --git a/tests/models/embedding/utils.py b/tests/models/embedding/utils.py index fd1c44d9c117e..f96c7d2b176db 100644 --- a/tests/models/embedding/utils.py +++ b/tests/models/embedding/utils.py @@ -24,7 +24,7 @@ def check_embeddings_close( dim=0) fail_msg = (f"Test{prompt_idx}:" - f"\n{name_0}:\t{embeddings_0!r}" - f"\n{name_1}:\t{embeddings_1!r}") + f"\n{name_0}:\t{embeddings_0[:16]!r}" + f"\n{name_1}:\t{embeddings_1[:16]!r}") assert sim >= 1 - tol, fail_msg diff --git a/tests/models/registry.py b/tests/models/registry.py index 3848367b6126c..669c832b1df3a 100644 --- a/tests/models/registry.py +++ b/tests/models/registry.py @@ -43,6 +43,8 @@ class _HfExamplesInfo: trust_remote_code=True), "ArcticForCausalLM": _HfExamplesInfo("Snowflake/snowflake-arctic-instruct", trust_remote_code=True), + "AriaForConditionalGeneration": _HfExamplesInfo("rhymes-ai/Aria", + trust_remote_code=True), "BaiChuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan-7B", trust_remote_code=True), "BaichuanForCausalLM": _HfExamplesInfo("baichuan-inc/Baichuan2-7B-chat", @@ -135,6 +137,7 @@ class _HfExamplesInfo: "Qwen2ForRewardModel": _HfExamplesInfo("Qwen/Qwen2.5-Math-RM-72B"), "Qwen2ForSequenceClassification": _HfExamplesInfo("jason9693/Qwen2.5-1.5B-apeach"), # noqa: E501 "RobertaModel": _HfExamplesInfo("sentence-transformers/stsb-roberta-base-v2"), # noqa: E501 + "RobertaForMaskedLM": _HfExamplesInfo("sentence-transformers/all-roberta-large-v1"), # noqa: E501 "XLMRobertaModel": _HfExamplesInfo("intfloat/multilingual-e5-large"), # [Multimodal] "LlavaNextForConditionalGeneration": _HfExamplesInfo("royokong/e5-v"), @@ -143,6 +146,13 @@ class _HfExamplesInfo: "Qwen2VLForConditionalGeneration": _HfExamplesInfo("MrLight/dse-qwen2-2b-mrl-v1"), # noqa: E501 } +_CROSS_ENCODER_EXAMPLE_MODELS = { + # [Text-only] + "BertForSequenceClassification": _HfExamplesInfo("cross-encoder/ms-marco-MiniLM-L-6-v2"), # noqa: E501 + "RobertaForSequenceClassification": _HfExamplesInfo("cross-encoder/quora-roberta-base"), # noqa: E501 + "XLMRobertaForSequenceClassification": _HfExamplesInfo("BAAI/bge-reranker-v2-m3"), # noqa: E501 +} + _MULTIMODAL_EXAMPLE_MODELS = { # [Decoder-only] "Blip2ForConditionalGeneration": _HfExamplesInfo("Salesforce/blip2-opt-2.7b"), # noqa: E501 @@ -195,6 +205,7 @@ class _HfExamplesInfo: _EXAMPLE_MODELS = { **_TEXT_GENERATION_EXAMPLE_MODELS, **_EMBEDDING_EXAMPLE_MODELS, + **_CROSS_ENCODER_EXAMPLE_MODELS, **_MULTIMODAL_EXAMPLE_MODELS, **_SPECULATIVE_DECODING_EXAMPLE_MODELS, } diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index e462dae3dc688..289ea66b5ebc5 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -6,7 +6,10 @@ from vllm.model_executor.models import (is_embedding_model, is_text_generation_model, supports_multimodal) -from vllm.model_executor.models.registry import (_EMBEDDING_MODELS, +# yapf conflicts with isort for this block +# yapf: disable +from vllm.model_executor.models.registry import (_CROSS_ENCODER_MODELS, + _EMBEDDING_MODELS, _MULTIMODAL_MODELS, _SPECULATIVE_DECODING_MODELS, _TEXT_GENERATION_MODELS, @@ -29,22 +32,28 @@ def test_registry_imports(model_arch): model_arch in _TEXT_GENERATION_MODELS or model_arch in _MULTIMODAL_MODELS) + embedding_models = {**_EMBEDDING_MODELS, **_CROSS_ENCODER_MODELS} assert is_embedding_model(model_cls) is (model_arch - in _EMBEDDING_MODELS) + in embedding_models) assert supports_multimodal(model_cls) is (model_arch in _MULTIMODAL_MODELS) @fork_new_process_for_each_test -@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [ - ("LlamaForCausalLM", False, False), - ("MllamaForConditionalGeneration", True, False), - ("LlavaForConditionalGeneration", True, True), +@pytest.mark.parametrize("model_arch,is_mm,init_cuda,is_ce", [ + ("LlamaForCausalLM", False, False, False), + ("MllamaForConditionalGeneration", True, False, False), + ("LlavaForConditionalGeneration", True, True, False), + ("BertForSequenceClassification", False, False, True), + ("RobertaForSequenceClassification", False, False, True), + ("XLMRobertaForSequenceClassification", False, False, True), ]) -def test_registry_is_multimodal(model_arch, is_mm, init_cuda): +def test_registry_model_property(model_arch, is_mm, init_cuda, is_ce): assert ModelRegistry.is_multimodal_model(model_arch) is is_mm + assert ModelRegistry.is_cross_encoder_model(model_arch) is is_ce + if init_cuda and current_platform.is_cuda_alike(): assert not torch.cuda.is_initialized() diff --git a/tests/multimodal/test_processing.py b/tests/multimodal/test_processing.py new file mode 100644 index 0000000000000..b2367060c6c1b --- /dev/null +++ b/tests/multimodal/test_processing.py @@ -0,0 +1,370 @@ +from typing import cast + +import pytest +from transformers import BatchFeature + +from vllm.multimodal.processing import (PromptReplacement, find_text_matches, + find_token_matches, iter_token_matches, + iter_token_runs, replace_text_matches) +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import full_groupby + + +# yapf: disable +@pytest.mark.parametrize( + ("token_ids", "expected"), + [ + ([], []), + ( + [32000, 32000, 32000], + [{ "token_id": 32000, "start_idx": 0, "length": 3 }], + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [ + { "token_id": 9833, "start_idx": 0, "length": 1 }, + { "token_id": 28747, "start_idx": 1, "length": 1 }, + { "token_id": 32000, "start_idx": 2, "length": 3 }, + { "token_id": 9833, "start_idx": 5, "length": 1 }, + { "token_id": 28747, "start_idx": 6, "length": 1 }, + { "token_id": 32000, "start_idx": 7, "length": 2 }, + { "token_id": 918, "start_idx": 9, "length": 1 }, + ], + ), + ], +) +# yapf: enable +def test_iter_token_runs(token_ids, expected): + result = list(iter_token_runs(token_ids)) + + # Only displayed on error + print("result:", result) + + # Manually constructed results + assert [item._asdict() for item in result] == expected + + # Invariants + assert sum(run_info.length for run_info in result) == len(token_ids) + + +# yapf: disable +@pytest.mark.parametrize( + ("token_ids", "match_ids", "expected"), + [ + ([], [], [{ "start_idx": 0, "end_idx": 0 }]), + ([], [32000], []), + ( + [32000, 32000, 32000], + [32000], + [ + { "start_idx": 0, "end_idx": 1 }, + { "start_idx": 1, "end_idx": 2 }, + { "start_idx": 2, "end_idx": 3 }, + ], + ), + ( + [32000, 32000, 32000], + [32000, 32000], + [{ "start_idx": 0, "end_idx": 2 }], + ), + ( + [32000, 32000, 32000], + [32000, 32000, 32000], + [{ "start_idx": 0, "end_idx": 3 }], + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [28747, 32000], + [ + { "start_idx": 1, "end_idx": 3 }, + { "start_idx": 6, "end_idx": 8 }, + ], + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [28747, 32000, 32000, 32000], + [ + { "start_idx": 1, "end_idx": 5 }, + ], + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + [28747, 0, 32000], + [], + ), + ], +) +# yapf: enable +def test_iter_token_matches(token_ids, match_ids, expected): + result = list(iter_token_matches(token_ids, match_ids)) + + # Manually constructed results + assert [item._asdict() for item in result] == expected + + # Invariants + match_lens = [end - start for start, end in result] + print("match_lens:", match_lens) # Only displayed on error + assert all(match_len == len(match_ids) for match_len in match_lens) + + +# yapf: disable +@pytest.mark.parametrize( + ("prompt", "target_by_key", "expected_by_key"), + [ + ( + [], + { + "pattern_1": [], + "pattern_2": [32000], + }, + { + "pattern_1": [{ "start_idx": 0, "end_idx": 0 }], + "pattern_2": [], + } + ), + ( + [32000, 32000, 32000, 32000], + { + "pattern_1": [32000], + "pattern_2": [32000, 32000], + "pattern_3": [32000, 32000, 32000], + }, + { + "pattern_1": [ + { "start_idx": 0, "end_idx": 1 }, + { "start_idx": 1, "end_idx": 2 }, + { "start_idx": 2, "end_idx": 3 }, + { "start_idx": 3, "end_idx": 4 }, + ], + "pattern_2": [ + { "start_idx": 0, "end_idx": 2 }, + { "start_idx": 2, "end_idx": 4 }, + ], + "pattern_3": [ + { "start_idx": 0, "end_idx": 3 }, + ], + }, + ), + ( + [9833, 28747, 32000, 32000, 32000, 9833, 28747, 32000, 32000, 918], + { + "pattern_1": [28747, 32000], + "pattern_2": [28747, 32000, 32000, 32000], + "pattern_3": [28747, 0, 32000], + }, + { + "pattern_1": [ + { "start_idx": 1, "end_idx": 3 }, + { "start_idx": 6, "end_idx": 8 }, + ], + "pattern_2": [ + { "start_idx": 1, "end_idx": 5 }, + ], + "pattern_3": [], + }, + ), + ], +) +# yapf: enable +def test_find_token_matches(prompt, target_by_key, expected_by_key): + # Should not be used since there is nothing to convert to token IDs + mock_tokenizer = cast(AnyTokenizer, object()) + + result = find_token_matches( + prompt, + [ + PromptReplacement(target, [], 0).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ], + ) + + # Only displayed on error + print("result:", result) + + # Manually constructed results + result_groups = dict(full_groupby(result, key=lambda x: x.modality)) + assert { + key: [ + dict(start_idx=item.start_idx, end_idx=item.end_idx) + for item in result_groups.get(key, []) + ] + for key in expected_by_key + } == expected_by_key + + +# yapf: disable +@pytest.mark.parametrize( + ("prompt", "target_by_key", "expected_by_key"), + [ + # Detokenized test cases of `test_find_token_matches` + # using the vocab of llava-hf/llava-v1.6-mistral-7b-hf + ( + "", + { + "pattern_1": "", + "pattern_2": "", + }, + { + "pattern_1": [{ "start_idx": 0, "end_idx": 0 }], + "pattern_2": [], + } + ), + ( + "", + { + "pattern_1": "", + "pattern_2": "", + "pattern_3": "", + }, + { + "pattern_1": [ + { "start_idx": 0, "end_idx": 7 }, + { "start_idx": 7, "end_idx": 14 }, + { "start_idx": 14, "end_idx": 21 }, + { "start_idx": 21, "end_idx": 28 }, + ], + "pattern_2": [ + { "start_idx": 0, "end_idx": 14 }, + { "start_idx": 14, "end_idx": 28 }, + ], + "pattern_3": [ + { "start_idx": 0, "end_idx": 21 }, + ], + }, + ), + ( + "Image:Image:!", + { + "pattern_1": "Image:", + "pattern_2": "Image:", + "pattern_3": "Image:", + }, + { + "pattern_1": [ + { "start_idx": 0, "end_idx": 13 }, + { "start_idx": 27, "end_idx": 40 }, + ], + "pattern_2": [ + { "start_idx": 0, "end_idx": 27 }, + ], + "pattern_3": [], + }, + ), + # Test regex escape + ( + "<|image|><|image|>", + { + "pattern_1": "<|image|>", + "pattern_2": "<|image|>", + "pattern_3": "<|image|><|image|>", + }, + { + "pattern_1": [ + { "start_idx": 0, "end_idx": 9 }, + { "start_idx": 16, "end_idx": 25 }, + ], + "pattern_2": [ + { "start_idx": 0, "end_idx": 16 }, + { "start_idx": 16, "end_idx": 32 }, + ], + "pattern_3": [ + { "start_idx": 0, "end_idx": 25 }, + ], + }, + ), + ], +) +# yapf: enable +def test_find_text_matches(prompt, target_by_key, expected_by_key): + # Should not be used since there is nothing to convert to text + mock_tokenizer = cast(AnyTokenizer, object()) + + result = find_text_matches( + prompt, + [ + PromptReplacement(target, [], 0).bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ], + ) + + # Only displayed on error + print("result:", result) + + # Manually constructed results + result_groups = dict(full_groupby(result, key=lambda x: x.modality)) + assert { + key: [ + dict(start_idx=item.start_idx, end_idx=item.end_idx) + for item in result_groups.get(key, []) + ] + for key in expected_by_key + } == expected_by_key + + +# yapf: disable +@pytest.mark.parametrize( + ("prompt", "target_by_key", "repl_by_key", "expected_by_mm_count"), + [ + ( + "Image:Image:!", + { + # We use `` before `Image:` to test matches that + # occur out of order + "pattern_1": "", + "pattern_2": "Image:", + "pattern_3": "!", + }, + { + # Test whether target is confused with repl_unit + "pattern_1": ("", 1), + # Test empty repl_unit + "pattern_2": ("", 1), + # Test multiple repl_count + "pattern_3": ("?", 2), + }, + { + # Test no replacement + 0: "Image:Image:!", + # Test single replacement + 1: "Image:??", + # Test repeated replacement + 2: "??", + }, + ), + ] +) +# yapf: enable +def test_find_replace_text( + prompt, + target_by_key, + repl_by_key, + expected_by_mm_count, +): + # Should not be used since there is nothing to convert to text + mock_tokenizer = cast(AnyTokenizer, object()) + + matches = find_text_matches( + prompt, + [ + PromptReplacement(target, *repl_by_key[key]) \ + .bind(key, mock_tokenizer) + for key, target in target_by_key.items() + ], + ) + result_by_mm_count = { + mm_count: replace_text_matches( + prompt, + matches, + {key: list(range(mm_count)) + for key in repl_by_key}, + BatchFeature(), + ) + for mm_count in expected_by_mm_count + } + + # Only displayed on error + print("matches:", matches) + print("result_by_mm_count:", result_by_mm_count) + + # Manually constructed results + assert result_by_mm_count == expected_by_mm_count diff --git a/tests/multimodal/test_utils.py b/tests/multimodal/test_utils.py index 9869c8123f001..fd82fb0c55fd7 100644 --- a/tests/multimodal/test_utils.py +++ b/tests/multimodal/test_utils.py @@ -139,7 +139,8 @@ def test_repeat_and_pad_placeholder_tokens(model): 2, "", [32000, 32000, 32000], - [{ "offset": 0, "length": 2 }]), + [{ "offset": 0, "length": 2 }], + ), ( "", [3, 2], diff --git a/tests/prefix_caching/test_prefix_caching.py b/tests/prefix_caching/test_prefix_caching.py index 50723dbb610ac..8d16710f14585 100644 --- a/tests/prefix_caching/test_prefix_caching.py +++ b/tests/prefix_caching/test_prefix_caching.py @@ -2,10 +2,15 @@ Run `pytest tests/prefix_caching/test_prefix_caching.py`. """ + import pytest +from tests.conftest import VllmRunner +from tests.core.utils import SchedulerProxy, create_dummy_prompt from tests.kernels.utils import override_backend_env_variable from vllm import SamplingParams, TokensPrompt +from vllm.core.scheduler import Scheduler +from vllm.engine.llm_engine import LLMEngine from ..models.utils import check_outputs_equal @@ -27,6 +32,7 @@ @pytest.mark.parametrize("dtype", ["half"]) @pytest.mark.parametrize("max_tokens", [5]) @pytest.mark.parametrize("cached_position", [0, 1]) +@pytest.mark.parametrize("enable_chunked_prefill", [True, False]) @pytest.mark.parametrize("block_size", [16]) def test_mixed_requests( hf_runner, @@ -37,6 +43,7 @@ def test_mixed_requests( dtype: str, max_tokens: int, cached_position: int, + enable_chunked_prefill: bool, block_size: int, monkeypatch, ) -> None: @@ -55,6 +62,7 @@ def test_mixed_requests( model, dtype=dtype, enable_prefix_caching=True, + enable_chunked_prefill=enable_chunked_prefill, block_size=block_size, ) as vllm_model: # Run the first prompt so the cache is populated @@ -72,13 +80,13 @@ def test_mixed_requests( block_size) * block_size else: expected_num_cached_tokens = 0 - assert req_outputs[ - i].num_cached_tokens == expected_num_cached_tokens + assert ( + req_outputs[i].num_cached_tokens == expected_num_cached_tokens) - vllm_outputs = [ - (output.prompt_token_ids + list(output.outputs[0].token_ids), - output.prompt + output.outputs[0].text) for output in req_outputs - ] + vllm_outputs = [( + output.prompt_token_ids + list(output.outputs[0].token_ids), + output.prompt + output.outputs[0].text, + ) for output in req_outputs] check_outputs_equal( outputs_0_lst=hf_outputs, @@ -105,3 +113,89 @@ def test_unstable_prompt_sequence( for prompt in UNSTABLE_PROMPT_SEQUENCE: vllm_model.generate(TokensPrompt(prompt_token_ids=prompt), SamplingParams(max_tokens=1)) + + +@pytest.mark.parametrize("model", MODELS) +def test_fully_cached_prefill_needs_uncached_token(model): + block_size = 16 + max_num_batched_tokens = 16 + num_output_tokens = 5 + # Make a vllm engine + runner = VllmRunner( + model_name=model, + gpu_memory_utilization=0.7, + enable_chunked_prefill=True, + enforce_eager=True, + enable_prefix_caching=True, + block_size=block_size, + max_num_batched_tokens=max_num_batched_tokens, + max_num_seqs=max_num_batched_tokens, + ) + engine: LLMEngine = runner.model.llm_engine + + scheduler: Scheduler = SchedulerProxy(engine.scheduler[0]) # type: ignore + engine.scheduler[0] = scheduler + + # SeqA + seqA_tokens = list(range(2 * block_size)) + seqA, seq_groupA = create_dummy_prompt( + request_id="0", + prompt_tokens=seqA_tokens, + max_tokens=num_output_tokens, + block_size=block_size, + ) + + scheduler.add_seq_group(seq_groupA) + + assert seqA.data.get_num_computed_tokens() == 0 + + # Prefill seqA + while not seqA.is_finished(): + engine.step() + + # seqB + seqB_tokens = [t + 1 for t in seqA_tokens] # shift by 1 + seqB, seq_groupB = create_dummy_prompt( + request_id="1", + prompt_tokens=seqB_tokens, + max_tokens=num_output_tokens, + block_size=block_size, + ) + + # seqC is the same as seqA + seqC, seq_groupC = create_dummy_prompt( + request_id="2", + prompt_tokens=seqA_tokens, + max_tokens=num_output_tokens, + block_size=block_size, + ) + + scheduler.add_seq_group(seq_groupB) + scheduler.add_seq_group(seq_groupC) + + # Even seqC is fully cached, it should not be prefilled since we + # require at least 1 uncached token. + engine.step() + + sched_metas, sched_out, _ = scheduler.last_schedule_ret() + assert len(sched_out.scheduled_seq_groups) == 1 + assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == + seq_groupB.request_id) + assert (sched_out.scheduled_seq_groups[0].token_chunk_size == + max_num_batched_tokens) + + # When seqB is finished, seqC could be prefilled. + while not seqB.is_finished(): + engine.step() + sched_metas, sched_out, _ = scheduler.last_schedule_ret() + assert len(sched_out.scheduled_seq_groups) == 1 + assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == + seq_groupB.request_id) + + engine.step() + sched_metas, sched_out, _ = scheduler.last_schedule_ret() + assert len(sched_out.scheduled_seq_groups) == 1 + assert (sched_out.scheduled_seq_groups[0].seq_group.request_id == + seq_groupC.request_id) + assert sched_out.scheduled_seq_groups[0].token_chunk_size == len( + seqA_tokens) diff --git a/tests/quantization/utils.py b/tests/quantization/utils.py index 061a077592e80..8ebd8dd2be0d5 100644 --- a/tests/quantization/utils.py +++ b/tests/quantization/utils.py @@ -1,4 +1,4 @@ -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.quantization import get_quantization_config from vllm.platforms import current_platform @@ -10,6 +10,6 @@ def is_quant_method_supported(quant_method: str) -> bool: capability = current_platform.get_device_capability() assert capability is not None - min_capability = QUANTIZATION_METHODS[quant_method].get_min_capability() + min_capability = get_quantization_config(quant_method).get_min_capability() return capability.to_int() >= min_capability diff --git a/tests/test_lazy_torch_compile.py b/tests/test_lazy_torch_compile.py new file mode 100644 index 0000000000000..b8ac4dd93732b --- /dev/null +++ b/tests/test_lazy_torch_compile.py @@ -0,0 +1,68 @@ +# Description: Test the lazy import module +# The utility function cannot be placed in `vllm.utils` +# this needs to be a standalone script + +import contextlib +import dataclasses +import sys +import traceback +from typing import Callable, Generator + + +@dataclasses.dataclass +class BlameResult: + found: bool = False + trace_stack: str = "" + + +@contextlib.contextmanager +def blame(func: Callable) -> Generator[BlameResult, None, None]: + """ + Trace the function calls to find the first function that satisfies the + condition. The trace stack will be stored in the result. + + Usage: + + ```python + with blame(lambda: some_condition()) as result: + # do something + + if result.found: + print(result.trace_stack) + """ + result = BlameResult() + + def _trace_calls(frame, event, arg=None): + nonlocal result + if event in ['call', 'return']: + # for every function call or return + try: + # Temporarily disable the trace function + sys.settrace(None) + # check condition here + if not result.found and func(): + result.found = True + result.trace_stack = "".join(traceback.format_stack()) + # Re-enable the trace function + sys.settrace(_trace_calls) + except NameError: + # modules are deleted during shutdown + pass + return _trace_calls + + sys.settrace(_trace_calls) + + yield result + + sys.settrace(None) + + +module_name = "torch._inductor.async_compile" + +with blame(lambda: module_name in sys.modules) as result: + import vllm # noqa + +assert not result.found, (f"Module {module_name} is already imported, the" + f" first import location is:\n{result.trace_stack}") + +print(f"Module {module_name} is not imported yet") diff --git a/tests/tpu/test_compilation.py b/tests/tpu/test_compilation.py index 65bee85e7a1ea..b7124ebc1b0f3 100644 --- a/tests/tpu/test_compilation.py +++ b/tests/tpu/test_compilation.py @@ -4,7 +4,7 @@ import depyf -from vllm.config import CompilationConfig, CompilationLevel +from vllm.config import CompilationLevel temp_dir = tempfile.mkdtemp() with depyf.prepare_debug(temp_dir): @@ -34,8 +34,7 @@ # all the control llm = LLM(model="google/gemma-2b", enforce_eager=True, - compilation_config=CompilationConfig( - level=CompilationLevel.DYNAMO_AS_IS)) + compilation_config={"level": CompilationLevel.DYNAMO_AS_IS}) outputs = llm.generate(prompts, sampling_params) for output, answer in zip(outputs, answers): prompt = output.prompt diff --git a/tests/tpu/test_custom_dispatcher.py b/tests/tpu/test_custom_dispatcher.py index df348258efcba..bb1379deba3fc 100644 --- a/tests/tpu/test_custom_dispatcher.py +++ b/tests/tpu/test_custom_dispatcher.py @@ -13,9 +13,10 @@ def test_custom_dispatcher(): compare_two_settings( "google/gemma-2b", - arg1=["--enforce-eager", "-O", - str(CompilationLevel.DYNAMO_ONCE)], - arg2=["--enforce-eager", "-O", - str(CompilationLevel.DYNAMO_AS_IS)], + arg1=[ + "--enforce-eager", + f"-O{CompilationLevel.DYNAMO_ONCE}", + ], + arg2=["--enforce-eager", f"-O{CompilationLevel.DYNAMO_AS_IS}"], env1={}, env2={}) diff --git a/tests/v1/core/test_prefix_caching.py b/tests/v1/core/test_prefix_caching.py index d614d3e67460f..83bfbb6ade8d7 100644 --- a/tests/v1/core/test_prefix_caching.py +++ b/tests/v1/core/test_prefix_caching.py @@ -1,8 +1,11 @@ """Compare the with and without prefix caching.""" +import pytest + from vllm.inputs import token_inputs from vllm.sampling_params import SamplingParams +from vllm.utils import cdiv from vllm.v1.core.kv_cache_manager import KVCacheManager, Request -from vllm.v1.core.kv_cache_utils import hash_block_tokens +from vllm.v1.core.kv_cache_utils import KVCacheBlock, hash_block_tokens def make_request(request_id, prompt_token_ids): @@ -31,7 +34,8 @@ def test_prefill(): # Fully cache miss # Incomplete 1 block (7 tokens) unique_token_ids = [3] * 7 - req0 = make_request("0", common_token_ids + unique_token_ids) + all_token_ids = common_token_ids + unique_token_ids + req0 = make_request("0", all_token_ids) computed_blocks = manager.get_computed_blocks(req0) assert not computed_blocks blocks = manager.allocate_slots(req0, 55, computed_blocks) @@ -40,24 +44,16 @@ def test_prefill(): # Check full block metadata parent_block_hash = None for block_id in (0, 1, 2): - block_hash = hash_block_tokens(parent_block_hash, - manager.block_pool[block_id].token_ids) + block_tokens = tuple(all_token_ids[block_id * 16:(block_id + 1) * 16]) + block_hash = hash_block_tokens(parent_block_hash, block_tokens) assert manager.block_pool[block_id].block_hash == block_hash assert manager.block_pool[block_id].ref_cnt == 1 - assert manager.block_pool[block_id].num_hashed_tokens == 16 * ( - block_id + 1) - assert manager.block_pool[block_id].token_ids == tuple([block_id] * 16) parent_block_hash = block_hash # Check partial/preallocated block metadata for block_id in (3, 4): assert manager.block_pool[block_id].block_hash is None assert manager.block_pool[block_id].ref_cnt == 1 - assert manager.block_pool[block_id].num_hashed_tokens == 0 - if block_id == 3: - assert manager.block_pool[block_id].token_ids == [3] * 7 - else: - assert not manager.block_pool[block_id].token_ids # Cache hit in the common prefix when the original block is still in use. # Incomplete 1 block (5 tokens) @@ -113,7 +109,7 @@ def test_prefill(): req3 = make_request("3", [99] * (16 * 9)) computed_blocks = manager.get_computed_blocks(req3) assert not computed_blocks - blocks = manager.allocate_slots(req2, 16 * 9, computed_blocks) + blocks = manager.allocate_slots(req3, 16 * 9, computed_blocks) # This block ID order also checks the eviction order. assert [b.block_id for b in blocks] == [9, 4, 3, 6, 5, 8, 7, 2, 1, 0] assert manager.free_block_queue.num_free_blocks == 0 @@ -148,7 +144,7 @@ def test_decode(): req0.append_output_token_ids(8) new_blocks = manager.append_slots(req0, 4) assert new_blocks is not None and len(new_blocks) == 0 - assert len(manager.block_pool[3].token_ids) == 11 + assert manager.req_to_blocks[req0.request_id][-2].block_hash is None # Append slots without allocating a new block, but start using the # preallocated block. @@ -159,8 +155,7 @@ def test_decode(): req0.append_output_token_ids(7) new_blocks = manager.append_slots(req0, 15) assert new_blocks is not None and len(new_blocks) == 0 - assert len(manager.block_pool[3].token_ids) == 16 - assert len(manager.block_pool[4].token_ids) == 10 + assert manager.req_to_blocks[req0.request_id][-2].block_hash is not None # Append slots with allocating a new block. req0.num_computed_tokens = 74 @@ -171,9 +166,6 @@ def test_decode(): new_blocks = manager.append_slots(req0, 17) # Plus one preallocated block. assert new_blocks is not None and len(new_blocks) == 2 - assert len(manager.block_pool[4].token_ids) == 16 - assert len(manager.block_pool[5].token_ids) == 11 - assert len(manager.block_pool[6].token_ids) == 0 def test_evict(): @@ -217,3 +209,198 @@ def test_evict(): blocks = manager.allocate_slots(req2, 3, computed_blocks) assert [b.block_id for b in blocks] == [6, 5] assert manager.free_block_queue.num_free_blocks == 6 + + +def test_hash_block_correct_reuse(): + """ + This tests when a previously cached block is reused as a new block, + its hash metadata should be correctly reset. + """ + block_size = 16 + manager = KVCacheManager( + block_size=block_size, + num_gpu_blocks=1, + sliding_window=False, + enable_caching=True, + num_preallocate_tokens=0, + ) + + # Allocate 1 block and cache it. + num_tokens = block_size * 1 + req = make_request("0", list(range(num_tokens))) + computed_blocks = manager.get_computed_blocks(req) + assert not computed_blocks + blocks = manager.allocate_slots(req, num_tokens, computed_blocks) + assert len(blocks) == 1 + + # Deallocate the block. + manager.free(req) + + # Allocate a new block that's not full, make sure hash info on the + # block is cleared. + req = make_request("1", list(range(num_tokens - 1))) + computed_blocks = manager.get_computed_blocks(req) + assert not computed_blocks + blocks = manager.allocate_slots(req, num_tokens - 1, computed_blocks) + assert len(blocks) == 1 + + assert manager.block_pool[blocks[0].block_id].block_hash is None + + +def test_computed_blocks_not_evicted(): + """ + Test that the computed blocks are not evicted when getting new blocks + for a request if there are any other free blocks. + """ + block_size = 16 + manager = KVCacheManager( + block_size=block_size, + num_gpu_blocks=2, + sliding_window=False, + enable_caching=True, + num_preallocate_tokens=0, + ) + + # Allocate a block and cache it. + num_tokens = block_size * 1 + req0 = make_request("0", list(range(num_tokens))) + computed_blocks = manager.get_computed_blocks(req0) + assert not computed_blocks + blocks = manager.allocate_slots(req0, num_tokens, computed_blocks) + assert len(blocks) == 1 + assert blocks[0].block_id == 0 + + # Allocate another block. + req1 = make_request("1", list(range(num_tokens, num_tokens * 2))) + computed_blocks = manager.get_computed_blocks(req1) + assert not computed_blocks + blocks = manager.allocate_slots(req1, num_tokens, computed_blocks) + assert len(blocks) == 1 + assert blocks[0].block_id == 1 + + # Free the blocks. + manager.free(req0) + manager.free(req1) + + # Now if we have a cache hit on the first block, we should evict the second + # cached block rather than the first one. + req2 = make_request("2", list(range(num_tokens * 2))) + computed_blocks = manager.get_computed_blocks(req2) + assert len(computed_blocks) == 1 + assert computed_blocks[0].block_id == 0 + + blocks = manager.allocate_slots(req2, num_tokens * 2 - num_tokens, + computed_blocks) + assert len(blocks) == 1 + assert blocks[0].block_id == 1 + + +def test_basic_prefix_caching_disabled(): + """ + This tests that the prefix caching is disabled. + """ + block_size = 4 + manager = KVCacheManager( + block_size=block_size, + num_gpu_blocks=4, + sliding_window=False, + enable_caching=False, + num_preallocate_tokens=0, + ) + + req1 = make_request("1", list(range(10))) # 2 blocks and some more + + computed_blocks = manager.get_computed_blocks(req1) + assert not computed_blocks + blocks = manager.allocate_slots(req1, 10, computed_blocks) + assert len(blocks) == 3 + + # Free the blocks. + manager.free(req1) + + # No caching. + req2 = make_request("2", list(range(16))) # shared prefix + computed_blocks = manager.get_computed_blocks(req2) + assert not computed_blocks + blocks = manager.allocate_slots(req2, 16, computed_blocks) + assert len(blocks) == 4 + + # New requests should not have any blocks. + req3 = make_request("3", list(range(4))) + computed_blocks = manager.get_computed_blocks(req3) + assert not computed_blocks + blocks = manager.allocate_slots(req3, 4, computed_blocks) + assert not blocks + + +@pytest.mark.parametrize("num_preallocate_tokens", list(range(0, 8))) +@pytest.mark.parametrize("block_size", [4]) +def test_preallocate_blocks(num_preallocate_tokens: int, block_size: int): + """ + This tests that the preallocated blocks are correctly added. + """ + manager = KVCacheManager( + block_size=block_size, + num_gpu_blocks=10, + sliding_window=False, + enable_caching=True, + num_preallocate_tokens=num_preallocate_tokens, + ) + num_preallocated_blocks = cdiv(num_preallocate_tokens, block_size) + + req = make_request("0", list(range(block_size * 30))) + computed_blocks = manager.get_computed_blocks(req) + assert not computed_blocks + # Just ask for 1 block. + blocks = manager.allocate_slots(req, block_size, computed_blocks) + assert len(blocks) == 1 + num_preallocated_blocks + + # Append slots to the block. + req.num_computed_tokens = block_size * len(blocks) # Assume all used. + blocks = manager.append_slots(req, block_size) # Append 1 block. + assert len(blocks) == 1 + num_preallocated_blocks + + +def test_cache_blocks(): + """ + This is a unit test that tests the correctness of the _cache_full_blocks + function of KVCacheManager. + """ + block_size = 4 + manager = KVCacheManager( + block_size=block_size, + num_gpu_blocks=5, + sliding_window=False, + enable_caching=True, + num_preallocate_tokens=0, + ) + # Req: + # Block 0: [0, 1, 2, 3] + # Block 1: [4, 5, 6, 7] + # Block 2: [8, 9, 10, 11] + # Block 3: [12, 13] + req = make_request("0", list(range(14))) + + # Test that blocks are cached correctly for 2 full blocks from the start. + blocks = [KVCacheBlock(block_id=i) for i in range(2)] + + manager._cache_full_blocks( + request=req, + blk_start_idx=0, + full_blocks=blocks, + prev_block=None, + ) + + assert len(manager.cached_block_hash_to_block) == 2 + assert all([block.block_hash is not None for block in blocks]) + + # Test that blocks that don't start from the beginning are cached correctly. + blocks = [KVCacheBlock(block_id=2)] + manager._cache_full_blocks( + request=req, + blk_start_idx=2, + full_blocks=blocks, + prev_block=None, + ) + assert len(manager.cached_block_hash_to_block) == 3 + assert blocks[0].block_hash is not None diff --git a/tests/v1/engine/test_engine_core_client.py b/tests/v1/engine/test_engine_core_client.py index ee7c64138264c..582192196aaf9 100644 --- a/tests/v1/engine/test_engine_core_client.py +++ b/tests/v1/engine/test_engine_core_client.py @@ -81,7 +81,7 @@ def test_engine_core_client(monkeypatch, multiprocessing_mode: bool): with monkeypatch.context() as m: m.setenv("VLLM_USE_V1", "1") - engine_args = EngineArgs(model=MODEL_NAME) + engine_args = EngineArgs(model=MODEL_NAME, compilation_config=3) vllm_config = engine_args.create_engine_config( UsageContext.UNKNOWN_CONTEXT) executor_class = AsyncLLM._get_executor_cls(vllm_config) diff --git a/tools/sphinx-lint.sh b/tools/sphinx-lint.sh new file mode 100755 index 0000000000000..04f8075c5527f --- /dev/null +++ b/tools/sphinx-lint.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +sphinx-lint --disable trailing-whitespace,missing-final-newline docs diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 782dc6aed1b8c..c192c9a7b0e4d 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -19,9 +19,6 @@ except ImportError as e: logger.warning("Failed to import from vllm._C with %r", e) -if current_platform.is_rocm(): - import vllm._rocm_C # noqa: F401 - supports_moe_ops = False with contextlib.suppress(ImportError): import vllm._moe_C # noqa: F401 @@ -347,31 +344,6 @@ def _gptq_marlin_gemm_fake(a: torch.Tensor, is_zp_float: bool = False) -> torch.Tensor: return torch.empty((size_m, size_n), device=a.device, dtype=a.dtype) - @register_fake("_C::ggml_dequantize") - def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, - m: torch.SymInt, - n: torch.SymInt) -> torch.Tensor: - return torch.empty((m, n), dtype=torch.float16, device=W.device) - - @register_fake("_C::ggml_mul_mat_vec_a8") - def _ggml_mul_mat_vec_a8_fake( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: torch.SymInt, - ) -> torch.Tensor: - return torch.empty((1, row), dtype=torch.float16, device=W.device) - - @register_fake("_C::ggml_mul_mat_a8") - def _ggml_mul_mat_a8_fake( - W: torch.Tensor, - X: torch.Tensor, - quant_type: int, - row: torch.SymInt, - ) -> torch.Tensor: - batch = X.size(0) - return torch.empty((batch, row), dtype=torch.float16, device=W.device) - @register_fake("_C::marlin_qqq_gemm") def _marlin_qqq_gemm_fake(a: torch.Tensor, b_q_weight: torch.Tensor, s_tok: torch.Tensor, s_ch: torch.Tensor, @@ -471,6 +443,34 @@ def machete_prepack_B_fake( memory_format=torch.contiguous_format) +if hasattr(torch.ops._C, "ggml_dequantize"): + + @register_fake("_C::ggml_dequantize") + def _ggml_dequantize_fake(W: torch.Tensor, quant_type: int, + m: torch.SymInt, + n: torch.SymInt) -> torch.Tensor: + return torch.empty((m, n), dtype=torch.float16, device=W.device) + + @register_fake("_C::ggml_mul_mat_vec_a8") + def _ggml_mul_mat_vec_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: torch.SymInt, + ) -> torch.Tensor: + return torch.empty((1, row), dtype=torch.float16, device=W.device) + + @register_fake("_C::ggml_mul_mat_a8") + def _ggml_mul_mat_a8_fake( + W: torch.Tensor, + X: torch.Tensor, + quant_type: int, + row: torch.SymInt, + ) -> torch.Tensor: + batch = X.size(0) + return torch.empty((batch, row), dtype=torch.float16, device=W.device) + + # cutlass def cutlass_scaled_mm_supports_fp8(cuda_device_capability: int) -> bool: return torch.ops._C.cutlass_scaled_mm_supports_fp8(cuda_device_capability) diff --git a/vllm/attention/backends/abstract.py b/vllm/attention/backends/abstract.py index a504cb1f7e318..5be2d83346d00 100644 --- a/vllm/attention/backends/abstract.py +++ b/vllm/attention/backends/abstract.py @@ -1,7 +1,6 @@ from abc import ABC, abstractmethod from contextlib import contextmanager from dataclasses import dataclass, fields -from enum import Enum, auto from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Set, Tuple, Type, TypeVar) @@ -15,13 +14,19 @@ ModelRunnerInputBuilderBase) -class AttentionType(Enum): - DECODER = auto() # Decoder attention between previous layer Q/K/V - ENCODER = auto( - ) # Encoder attention between previous layer Q/K/V for encoder-decoder - ENCODER_ONLY = auto() # Encoder attention between previous layer Q/K/V - ENCODER_DECODER = auto( - ) # Attention between dec. Q and enc. K/V for encoder-decoder +class AttentionType: + """ + Attention type. + Use string to be compatible with `torch.compile`. + """ + # Decoder attention between previous layer Q/K/V + DECODER = "decoder" + # Encoder attention between previous layer Q/K/V for encoder-decoder + ENCODER = "encoder" + # Encoder attention between previous layer Q/K/V + ENCODER_ONLY = "encoder_only" + # Attention between dec. Q and enc. K/V for encoder-decoder + ENCODER_DECODER = "encoder_decoder" class AttentionBackend(ABC): @@ -241,6 +246,6 @@ def forward( attn_metadata: T, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: raise NotImplementedError diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 409a42187f46c..9e54c3b40c54e 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -87,6 +87,11 @@ def __post_init__(self): class BlocksparseFlashAttentionBackend(AttentionBackend): + @staticmethod + def get_name() -> str: + # For attention layer compatibility + return "FLASH_ATTN" + @staticmethod def get_impl_cls() -> Type["BlocksparseFlashAttentionImpl"]: return BlocksparseFlashAttentionImpl @@ -354,7 +359,7 @@ def forward( attn_metadata: BlocksparseFlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 314822b695722..32738d1043b1d 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -16,10 +16,8 @@ compute_slot_mapping_start_idx, get_num_prefill_decode_query_kv_tokens, get_seq_len_block_table_args, is_all_cross_attn_metadata_set, is_all_encoder_attn_metadata_set, is_block_tables_empty) -from vllm.forward_context import get_forward_context from vllm.multimodal import MultiModalPlaceholderMap -from vllm.utils import (async_tensor_h2d, direct_register_custom_op, - make_tensor_with_pad) +from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -639,7 +637,7 @@ def forward( attn_metadata: FlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention. @@ -668,23 +666,174 @@ def forward( "requires setting cross-attention " "metadata attributes.") - output = torch.ops.vllm.unified_flash_attention( - query, - key, - value, - self.num_heads, - self.head_size, - self.num_kv_heads, - kv_cache, - self.kv_cache_dtype, - k_scale, - v_scale, - self.scale, - attn_type.value, - self.sliding_window, - self.alibi_slopes, - self.logits_soft_cap, - ) + num_heads: int = self.num_heads + head_size: int = self.head_size + num_kv_heads: int = self.num_kv_heads + kv_cache_dtype: str = self.kv_cache_dtype + softmax_scale: float = self.scale + window_size = self.sliding_window + alibi_slopes: Optional[torch.Tensor] = self.alibi_slopes + logits_soft_cap: Optional[float] = self.logits_soft_cap + + num_tokens, hidden_size = query.shape + + # Reshape the query, key, and value tensors. + query = query.view(-1, num_heads, head_size) + if (key is not None) and (value is not None): + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + if kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + # We skip updating the KV cache under two conditions: + # a. When the Attention Type is ENCODER. In this phase, we compute + # only the encoder attention without updating the cache. + # b. When both Key and Value are None. This occurs during + # cross-attention computation in the decoding phase, where the + # KV cache is already populated with the cross-attention + # tensor. Thus, we skip cache updates during this time. + if (attn_type != AttentionType.ENCODER) and (key is not None) and ( + value is not None): + if attn_type == AttentionType.ENCODER_DECODER: + # Update cross-attention KV cache (prefill-only) + updated_slot_mapping = attn_metadata.cross_slot_mapping + else: + # Update self-attention KV cache (prefill/decode) + updated_slot_mapping = attn_metadata.slot_mapping + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory + # profiling run. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + updated_slot_mapping.flatten(), # type: ignore[union-attr] + kv_cache_dtype, + k_scale, + v_scale, + ) + + (num_prefill_query_tokens, num_prefill_kv_tokens, + num_decode_query_tokens) = \ + get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) + decode_query = query[num_prefill_query_tokens:] + # QKV for prefill. + query = query[:num_prefill_query_tokens] + assert query.shape[0] == num_prefill_query_tokens + assert decode_query.shape[0] == num_decode_query_tokens + + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ + _get_query_key_seq_metadata(prefill_meta, True, attn_type) + + key = key[:num_prefill_kv_tokens] + value = value[:num_prefill_kv_tokens] + + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=q_seq_start_loc, + cu_seqlens_k=k_seq_start_loc, + max_seqlen_q=q_seq_len, + max_seqlen_k=k_seq_len, + softmax_scale=softmax_scale, + causal=_get_causal_option(attn_type), + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + ) + else: + # prefix-enabled attention + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support prefix caching") + assert prefill_meta.seq_lens is not None + max_seq_len = max(prefill_meta.seq_lens) + prefill_output = flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=logits_soft_cap, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + # Use flash_attn_varlen_func kernel for speculative decoding + # because different queries might have different lengths. + + assert decode_meta.max_decode_query_len is not None + # use only for actual varlen decoding + if decode_meta.max_decode_query_len > 1: + assert attn_type == AttentionType.DECODER, ( + "Only decoder-only models support max_decode_query_len > 1" + ) + decode_output = flash_attn_varlen_func( + q=decode_query, + k=key_cache, + v=value_cache, + cu_seqlens_q=decode_meta.query_start_loc, + max_seqlen_q=decode_meta.max_decode_query_len, + cu_seqlens_k=decode_meta.seq_start_loc, + max_seqlen_k=decode_meta.max_decode_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + block_table=decode_meta.block_tables, + ) + else: + # Use flash_attn_with_kvcache for normal decoding. + ( + seq_lens_arg, + _, + block_tables_arg, + ) = get_seq_len_block_table_args(decode_meta, False, attn_type) + decode_output = flash_attn_with_kvcache( + q=decode_query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, + block_table=block_tables_arg, + cache_seqlens=seq_lens_arg, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + ).squeeze(1) + + if prefill_output is None: + assert decode_output is not None + return decode_output.view(num_decode_query_tokens, hidden_size) + if decode_output is None: + assert prefill_output is not None + return prefill_output.view(num_prefill_query_tokens, hidden_size) + + assert decode_meta is not None + decode_output = decode_output.squeeze(1) + output = torch.cat([prefill_output, decode_output], dim=0) + return output.view(num_tokens, hidden_size) return output @@ -692,7 +841,7 @@ def forward( def _get_query_key_seq_metadata( attn_metadata, is_prompt: bool, - attn_type: AttentionType, + attn_type: str, ) -> tuple: """ Returns sequence metadata for key and query based on the specified @@ -754,7 +903,7 @@ def _get_query_key_seq_metadata( raise AttributeError(f"Invalid attention type {str(attn_type)}") -def _get_causal_option(attn_type: AttentionType) -> bool: +def _get_causal_option(attn_type: str) -> bool: """ Determine whether the given attention type is suitable for causal attention mechanisms. @@ -770,220 +919,3 @@ def _get_causal_option(attn_type: AttentionType) -> bool: return not (attn_type == AttentionType.ENCODER or attn_type == AttentionType.ENCODER_ONLY or attn_type == AttentionType.ENCODER_DECODER) - - -def unified_flash_attention( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - head_size: int, - num_kv_heads: int, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - softmax_scale: float, - attn_type_int_val: int, - window_size: Optional[List[int]] = None, - alibi_slopes: Optional[torch.Tensor] = None, - logits_soft_cap: Optional[float] = None, -) -> torch.Tensor: - - # Convert integer attn_type to enum - try: - attn_type = AttentionType(attn_type_int_val) - except ValueError as err: - raise AttributeError( - f"Invalid attention type {str(attn_type_int_val)}") from err - - current_metadata = get_forward_context() - assert current_metadata is not None - assert isinstance(current_metadata, FlashAttentionMetadata) - attn_metadata: FlashAttentionMetadata = current_metadata - - num_tokens, hidden_size = query.shape - - # Reshape the query, key, and value tensors. - query = query.view(-1, num_heads, head_size) - if (key is not None) and (value is not None): - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - - if kv_cache.numel() > 0: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - # We skip updating the KV cache under two conditions: - # a. When the Attention Type is ENCODER. In this phase, we compute - # only the encoder attention without updating the cache. - # b. When both Key and Value are None. This occurs during - # cross-attention computation in the decoding phase, where the KV - # cache is already populated with the cross-attention tensor. - # Thus, we skip cache updates during this time. - if (attn_type != AttentionType.ENCODER) and (key is not None) and ( - value is not None): - if attn_type == AttentionType.ENCODER_DECODER: - # Update cross-attention KV cache (prefill-only) - updated_slot_mapping = attn_metadata.cross_slot_mapping - else: - # Update self-attention KV cache (prefill/decode) - updated_slot_mapping = attn_metadata.slot_mapping - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - torch.ops._C_cache_ops.reshape_and_cache_flash( - key, - value, - kv_cache[0], - kv_cache[1], - updated_slot_mapping.flatten(), # type: ignore[union-attr] - kv_cache_dtype, - k_scale, - v_scale, - ) - - (num_prefill_query_tokens, num_prefill_kv_tokens, - num_decode_query_tokens) = \ - get_num_prefill_decode_query_kv_tokens(attn_metadata, attn_type) - decode_query = query[num_prefill_query_tokens:] - # QKV for prefill. - query = query[:num_prefill_query_tokens] - assert query.shape[0] == num_prefill_query_tokens - assert decode_query.shape[0] == num_decode_query_tokens - - prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache.numel() == 0 or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - q_seq_start_loc, q_seq_len, k_seq_start_loc, k_seq_len = \ - _get_query_key_seq_metadata(prefill_meta, True, attn_type) - - key = key[:num_prefill_kv_tokens] - value = value[:num_prefill_kv_tokens] - - prefill_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=q_seq_start_loc, - cu_seqlens_k=k_seq_start_loc, - max_seqlen_q=q_seq_len, - max_seqlen_k=k_seq_len, - softmax_scale=softmax_scale, - causal=_get_causal_option(attn_type), - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - ) - else: - # prefix-enabled attention - assert attn_type == AttentionType.DECODER, ( - "Only decoder-only models support prefix caching") - assert prefill_meta.seq_lens is not None - max_seq_len = max(prefill_meta.seq_lens) - prefill_output = flash_attn_varlen_func( # noqa - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - block_table=prefill_meta.block_tables, - softcap=logits_soft_cap, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - # Use flash_attn_varlen_func kernel for speculative decoding - # because different queries might have different lengths. - - assert decode_meta.max_decode_query_len is not None - # use only for actual varlen decoding - if decode_meta.max_decode_query_len > 1: - assert attn_type == AttentionType.DECODER, ( - "Only decoder-only models support max_decode_query_len > 1") - decode_output = flash_attn_varlen_func( - q=decode_query, - k=key_cache, - v=value_cache, - cu_seqlens_q=decode_meta.query_start_loc, - max_seqlen_q=decode_meta.max_decode_query_len, - cu_seqlens_k=decode_meta.seq_start_loc, - max_seqlen_k=decode_meta.max_decode_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - block_table=decode_meta.block_tables, - ) - else: - # Use flash_attn_with_kvcache for normal decoding. - ( - seq_lens_arg, - _, - block_tables_arg, - ) = get_seq_len_block_table_args(decode_meta, False, attn_type) - decode_output = flash_attn_with_kvcache( - q=decode_query.unsqueeze(1), - k_cache=key_cache, - v_cache=value_cache, - block_table=block_tables_arg, - cache_seqlens=seq_lens_arg, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, - softcap=logits_soft_cap, - ).squeeze(1) - - if prefill_output is None: - assert decode_output is not None - return decode_output.view(num_decode_query_tokens, hidden_size) - if decode_output is None: - assert prefill_output is not None - return prefill_output.view(num_prefill_query_tokens, hidden_size) - - assert decode_meta is not None - decode_output = decode_output.squeeze(1) - output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) - - -def unified_flash_attention_fake( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - head_size: int, - num_kv_heads: int, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - softmax_scale: float, - attn_type_int_val: int, - window_size: Optional[List[int]] = None, - alibi_slopes: Optional[torch.Tensor] = None, - logits_soft_cap: Optional[float] = None, -) -> torch.Tensor: - return torch.empty_like(query) - - -direct_register_custom_op( - op_name="unified_flash_attention", - op_func=unified_flash_attention, - mutates_args=["kv_cache"], - fake_impl=unified_flash_attention_fake, -) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index 107e3bbf79666..1a2024705eb04 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -30,9 +30,8 @@ compute_slot_mapping_start_idx, is_block_tables_empty) from vllm.attention.ops.paged_attn import PagedAttention -from vllm.forward_context import get_forward_context -from vllm.utils import (async_tensor_h2d, direct_register_custom_op, - get_kv_cache_torch_dtype, make_tensor_with_pad) +from vllm.utils import (async_tensor_h2d, get_kv_cache_torch_dtype, + make_tensor_with_pad) if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, @@ -757,9 +756,8 @@ def __init__( if alibi_slopes is not None: alibi_slopes = torch.tensor(alibi_slopes, dtype=torch.float32) self.alibi_slopes = alibi_slopes - if sliding_window is not None: - raise ValueError("Sliding window is not supported in FlashInfer.") - self.sliding_window = (-1, -1) + self.sliding_window = ((sliding_window - 1, + 0) if sliding_window is not None else (-1, -1)) self.kv_cache_dtype = kv_cache_dtype self.logits_soft_cap = logits_soft_cap @@ -775,7 +773,7 @@ def forward( attn_metadata: FlashInferMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: if attn_type != AttentionType.DECODER: raise NotImplementedError("Encoder self-attention and " @@ -783,170 +781,117 @@ def forward( "are not implemented for " "FlashInferImpl") - return torch.ops.vllm.unified_flash_infer( - query, - key, - value, - self.num_heads, - self.head_size, - self.num_kv_heads, - kv_cache, - self.kv_cache_dtype, - k_scale, - v_scale, - self.scale, - self.sliding_window, - self.alibi_slopes, - self.logits_soft_cap, - ) - - -def unified_flash_infer( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - head_size: int, - num_kv_heads: int, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - softmax_scale: float, - window_size: Optional[List[int]] = None, - alibi_slopes: Optional[torch.Tensor] = None, - logits_soft_cap: Optional[float] = None, -) -> torch.Tensor: - - current_metadata = get_forward_context() - assert current_metadata is not None - assert isinstance(current_metadata, FlashInferMetadata) - attn_metadata: FlashInferMetadata = current_metadata - - num_tokens, hidden_size = query.shape - query = query.view(-1, num_heads, head_size) - key = key.view(-1, num_kv_heads, head_size) - value = value.view(-1, num_kv_heads, head_size) - - if kv_cache.numel() > 0: - # Use the same reshape and cache kernel as flash attention. - ops.reshape_and_cache_flash( - key, - value, - kv_cache[:, 0], - kv_cache[:, 1], - attn_metadata.slot_mapping.flatten(), - kv_cache_dtype, - k_scale, - v_scale, - ) - # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 - # to process the cache when the kv_cache_dtype is fp8 - if kv_cache_dtype.startswith("fp8"): - torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( - kv_cache_dtype) - kv_cache = kv_cache.view(torch_dtype) - - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - query = query.contiguous() # Flashinfer requires query to be contiguous - # Query for decode. KV is not needed because it is already cached. - # QKV for prefill. - decode_query = query[num_prefill_tokens:] - query = query[:num_prefill_tokens] - - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None - if prefill_meta := attn_metadata.prefill_metadata: - # We will use flash attention for prefill - # when kv_cache is not provided. - # This happens when vllm runs the profiling to - # determine the number of blocks. - if kv_cache.numel() == 0: - prefill_output = flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=softmax_scale, - causal=True, - window_size=window_size, - alibi_slopes=alibi_slopes, + num_heads: int = self.num_heads + head_size: int = self.head_size + num_kv_heads: int = self.num_kv_heads + kv_cache_dtype: str = self.kv_cache_dtype + softmax_scale: float = self.scale + window_size = self.sliding_window + alibi_slopes = self.alibi_slopes + logits_soft_cap = self.logits_soft_cap + + num_tokens, hidden_size = query.shape + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + if kv_cache.numel() > 0: + # Use the same reshape and cache kernel as flash attention. + ops.reshape_and_cache_flash( + key, + value, + kv_cache[:, 0], + kv_cache[:, 1], + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, ) - else: - assert prefill_meta is not None - assert prefill_meta.prefill_wrapper is not None - prefill_output = prefill_meta.prefill_wrapper.forward( - query, + # The FlashInfer api requires data to be in fp8_e4m3 or fp8_e5m2 + # to process the cache when the kv_cache_dtype is fp8 + if kv_cache_dtype.startswith("fp8"): + torch_dtype = FlashInferBackend.get_fp8_dtype_for_flashinfer( + kv_cache_dtype) + kv_cache = kv_cache.view(torch_dtype) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + query = query.contiguous( + ) # Flashinfer requires query to be contiguous + # Query for decode. KV is not needed because it is already cached. + # QKV for prefill. + decode_query = query[num_prefill_tokens:] + query = query[:num_prefill_tokens] + + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + window_left = window_size[0] if window_size is not None else -1 + + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None + if prefill_meta := attn_metadata.prefill_metadata: + # We will use flash attention for prefill + # when kv_cache is not provided. + # This happens when vllm runs the profiling to + # determine the number of blocks. + if kv_cache.numel() == 0: + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + ) + else: + assert prefill_meta is not None + assert prefill_meta.prefill_wrapper is not None + prefill_output = prefill_meta.prefill_wrapper.forward( + query, + kv_cache, + logits_soft_cap=logits_soft_cap, + causal=True, + k_scale=k_scale, + v_scale=v_scale, + window_left=window_left) + if decode_meta := attn_metadata.decode_metadata: + assert decode_meta is not None + assert decode_meta.decode_wrapper is not None + decode_output = decode_meta.decode_wrapper.forward( + decode_query, kv_cache, + sm_scale=softmax_scale, logits_soft_cap=logits_soft_cap, - causal=True, k_scale=k_scale, - v_scale=v_scale) - if decode_meta := attn_metadata.decode_metadata: - assert attn_metadata.decode_metadata is not None - assert attn_metadata.decode_metadata.decode_wrapper is not None - decode_output = attn_metadata.decode_metadata.decode_wrapper.forward( - decode_query, - kv_cache, - sm_scale=softmax_scale, - logits_soft_cap=logits_soft_cap, - k_scale=k_scale, - v_scale=v_scale) - - if prefill_output is None and decode_output is not None: - # Decode only batch. - output, num_tokens = decode_output, num_decode_tokens - elif decode_output is None and prefill_output is not None: - # Prefill only batch. - output, num_tokens = prefill_output, num_prefill_tokens - else: - # Chunked prefill batch does not work with speculative decoding in - # FlashInfer backend, so the query length for decode should be 1. - assert prefill_output is not None - assert decode_output is not None - assert decode_meta is not None - assert decode_meta.decode_query_len == 1 - decode_output = decode_output.squeeze(1) - output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) - - -def unified_flash_infer_fake( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - num_heads: int, - head_size: int, - num_kv_heads: int, - kv_cache: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, - softmax_scale: float, - window_size: Optional[List[int]] = None, - alibi_slopes: Optional[torch.Tensor] = None, - logits_soft_cap: Optional[float] = None, -) -> torch.Tensor: - return torch.empty_like(query).contiguous() - - -direct_register_custom_op( - op_name="unified_flash_infer", - op_func=unified_flash_infer, - mutates_args=["kv_cache"], - fake_impl=unified_flash_infer_fake, -) + v_scale=v_scale, + window_left=window_left) + + if prefill_output is None and decode_output is not None: + # Decode only batch. + output, num_tokens = decode_output, num_decode_tokens + elif decode_output is None and prefill_output is not None: + # Prefill only batch. + output, num_tokens = prefill_output, num_prefill_tokens + else: + # Chunked prefill batch does not work with speculative decoding in + # FlashInfer backend, so the query length for decode should be 1. + assert prefill_output is not None + assert decode_output is not None + assert decode_meta is not None + assert decode_meta.decode_query_len == 1 + decode_output = decode_output.squeeze(1) + output = torch.cat([prefill_output, decode_output], dim=0) + return output.view(num_tokens, hidden_size) diff --git a/vllm/attention/backends/hpu_attn.py b/vllm/attention/backends/hpu_attn.py index a8f4b09b67274..4a3ddd5db94e5 100644 --- a/vllm/attention/backends/hpu_attn.py +++ b/vllm/attention/backends/hpu_attn.py @@ -140,7 +140,7 @@ def forward( attn_metadata: HPUAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. diff --git a/vllm/attention/backends/ipex_attn.py b/vllm/attention/backends/ipex_attn.py index 87bdb1e0e6565..3b0d51ea4a3d8 100644 --- a/vllm/attention/backends/ipex_attn.py +++ b/vllm/attention/backends/ipex_attn.py @@ -172,7 +172,7 @@ def forward( attn_metadata: IpexAttnMetadata, # type: ignore k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with IPEX varlen_attention and PagedAttention. diff --git a/vllm/attention/backends/pallas.py b/vllm/attention/backends/pallas.py index 6fee81de14420..5988be0e6b687 100644 --- a/vllm/attention/backends/pallas.py +++ b/vllm/attention/backends/pallas.py @@ -65,6 +65,7 @@ class PallasMetadata(AttentionMetadata): # or all decoding. block_tables: Optional[torch.Tensor] = None context_lens: Optional[torch.Tensor] = None + effective_query_lens: Optional[torch.Tensor] = None @property def prefill_metadata(self) -> Optional["PallasMetadata"]: @@ -72,8 +73,6 @@ def prefill_metadata(self) -> Optional["PallasMetadata"]: return None assert self.num_decode_tokens == 0 - assert self.block_tables is None - assert self.context_lens is None return self @property @@ -151,7 +150,7 @@ def forward( attn_metadata: PallasMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with Pallas attention. @@ -186,29 +185,50 @@ def forward( query = query * self.scale if attn_metadata.num_prefills > 0: - assert seq_len % 16 == 0, ( - "Pallas FlashAttention kernel requires seq_len to be a " - f"multiple of 16 but got {seq_len}") - - # Handle GQA/MQA. - if self.num_kv_heads != self.num_heads: - key = key.repeat_interleave(self.num_queries_per_kv, dim=-2) - key = key.view(batch_size, seq_len, self.num_heads, - self.head_size) - value = value.repeat_interleave(self.num_queries_per_kv, + if attn_metadata.block_tables is None: + # Prefill without paged KV cache. + assert seq_len % 16 == 0, ( + "Pallas FlashAttention kernel requires seq_len to be a " + f"multiple of 16 but got {seq_len}") + + # Handle GQA/MQA. + if self.num_kv_heads != self.num_heads: + key = key.repeat_interleave(self.num_queries_per_kv, dim=-2) - value = value.view(batch_size, seq_len, self.num_heads, + key = key.view(batch_size, seq_len, self.num_heads, self.head_size) - # FlashAttention requires [batch_size, num_heads, seq_len, d_model] - # while the input is [batch_size, seq_len, num_heads, d_model]. - # Permute the input to match the required format. - output = torch.ops.xla.flash_attention( - query.permute(0, 2, 1, 3), - key.permute(0, 2, 1, 3), - value.permute(0, 2, 1, 3), - True, - ) - output = output.permute(0, 2, 1, 3) + value = value.repeat_interleave(self.num_queries_per_kv, + dim=-2) + value = value.view(batch_size, seq_len, self.num_heads, + self.head_size) + # FlashAttention kernel requires the input shape to be + # [batch_size, num_heads, seq_len, d_model] + # while the input is [batch_size, seq_len, num_heads, d_model]. + # Permute the input to match the required format. + output = torch.ops.xla.flash_attention( + query.permute(0, 2, 1, 3), + key.permute(0, 2, 1, 3), + value.permute(0, 2, 1, 3), + True, + ) + output = output.permute(0, 2, 1, 3) + else: + # Prefill with paged KV cache. + # TODO(woosuk): Tune the below knobs. + num_kv_pages_per_compute_block = 16 + num_queries_per_compute_block = 16 + assert seq_len % num_queries_per_compute_block == 0 + output = torch.ops.xla.multi_queries_paged_attention( + query, + key_cache, + value_cache, + attn_metadata.context_lens, + attn_metadata.block_tables, + attn_metadata.effective_query_lens, + num_kv_pages_per_compute_block, + num_queries_per_compute_block, + use_kernel=True, + ) else: # Decoding run. assert kv_cache[0].numel() > 0 diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 2bae370eaa90f..6a494f4e73cb4 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -414,7 +414,7 @@ def forward( attn_metadata: ROCmFlashAttentionMetadata, k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with FlashAttention and PagedAttention. diff --git a/vllm/attention/backends/torch_sdpa.py b/vllm/attention/backends/torch_sdpa.py index 563178d3ab60d..16e044b618c40 100644 --- a/vllm/attention/backends/torch_sdpa.py +++ b/vllm/attention/backends/torch_sdpa.py @@ -7,18 +7,14 @@ from torch.nn.functional import scaled_dot_product_attention from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl, - AttentionMetadata, AttentionType) + AttentionMetadata, + AttentionMetadataBuilder, + AttentionType) from vllm.attention.backends.utils import CommonAttentionState +from vllm.attention.ops.ipex_attn import PagedAttention from vllm.attention.ops.paged_attn import PagedAttentionMetadata -from vllm.platforms import current_platform - -if current_platform.is_cpu(): - try: - from vllm.attention.ops.ipex_attn import PagedAttention - except ImportError: - from vllm.attention.ops.paged_attn import PagedAttention -else: - from vllm.attention.ops.paged_attn import PagedAttention +from vllm.utils import make_tensor_with_pad +from vllm.worker.cpu_model_runner import ModelInputForCPUBuilder class TorchSDPABackend(AttentionBackend): @@ -39,6 +35,10 @@ def get_metadata_cls() -> Type["AttentionMetadata"]: def get_state_cls() -> Type["CommonAttentionState"]: return CommonAttentionState + @staticmethod + def get_builder_cls() -> Type["TorchSDPAMetadataBuilder"]: + return TorchSDPAMetadataBuilder + @staticmethod def get_kv_cache_shape( num_blocks: int, @@ -71,9 +71,15 @@ class TorchSDPAMetadata(AttentionMetadata, PagedAttentionMetadata): """ # Currently, input sequences can only contain all prompts # or all decoding. True if all sequences are prompts. - is_prompt: bool - slot_mapping: torch.Tensor - seq_lens: Optional[List[int]] + chunked_prefill: bool + seq_lens: Optional[List[int]] = None # For non-chunked prefill + + # For chunked prefill only + max_query_len: Optional[int] = None + max_kv_len: Optional[int] = None + query_start_loc: Optional[torch.Tensor] = None + kv_start_loc: Optional[torch.Tensor] = None + prefill_block_tables: Optional[torch.Tensor] = None # Begin encoder attn & enc/dec cross-attn fields... # Encoder sequence lengths representation @@ -123,25 +129,19 @@ def is_all_cross_attn_metadata_set(self): @property def prefill_metadata(self) -> Optional["TorchSDPAMetadata"]: - # Currently chunked prefill is not supported - if self.num_decode_tokens == 0: - assert self.num_prefills > 0 - return self - - return None + if self.num_prefill_tokens == 0: + return None + return self @property def decode_metadata(self) -> Optional["TorchSDPAMetadata"]: - # Currently chunked prefill is not supported - if self.num_prefills > 0: - assert self.num_decode_tokens == 0 + if self.num_decode_tokens == 0: return None - return self def get_seq_lens( self, - attn_type: AttentionType, + attn_type: str, ): ''' Extract appropriate sequence lengths from attention metadata @@ -174,7 +174,7 @@ def get_seq_lens( def get_attn_bias( self, - attn_type: AttentionType, + attn_type: str, ) -> Optional[List[torch.Tensor]]: ''' Extract appropriate attention bias from attention metadata @@ -203,7 +203,7 @@ def get_attn_bias( def set_attn_bias( self, attn_bias: List[torch.Tensor], - attn_type: AttentionType, + attn_type: str, ) -> None: ''' Update appropriate attention bias field of attention metadata, @@ -229,7 +229,7 @@ def set_attn_bias( def get_seq_len_block_table_args( self, - attn_type: AttentionType, + attn_type: str, ) -> tuple: ''' The particular choice of sequence-length- and block-table-related @@ -274,6 +274,105 @@ def get_seq_len_block_table_args( raise AttributeError(f"Invalid attention type {str(attn_type)}") +class TorchSDPAMetadataBuilder(AttentionMetadataBuilder[TorchSDPAMetadata]): + + def __init__(self, input_builder: ModelInputForCPUBuilder) -> None: + self.chunked_prefill = input_builder.chunked_prefill + self.input_data = input_builder.input_data + + def build(self, seq_lens: List[int], query_lens: List[int], + cuda_graph_pad_size: int, batch_size: int) -> TorchSDPAMetadata: + input_data = self.input_data + prefill_seq_lens = seq_lens[0:input_data.num_prefills] + prefill_query_lens = query_lens[0:input_data.num_prefills] + slot_mapping = torch.tensor(input_data.slot_mapping, + dtype=torch.long, + device="cpu") + + # For chunked-prefill + if self.chunked_prefill and input_data.num_prefill_tokens != 0: + prefill_block_tables = make_tensor_with_pad( + self.input_data.prefill_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + query_lens_tensor = torch.tensor(prefill_query_lens, + dtype=torch.int32, + device="cpu") + kv_lens_tensor = torch.tensor(prefill_seq_lens, + dtype=torch.int32, + device="cpu") + query_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + kv_start_loc = torch.zeros(input_data.num_prefills + 1, + dtype=torch.int32, + device="cpu") + torch.cumsum(query_lens_tensor, + dim=0, + dtype=torch.int32, + out=query_start_loc[1:]) + torch.cumsum(kv_lens_tensor, + dim=0, + dtype=torch.int32, + out=kv_start_loc[1:]) + max_query_len = max(prefill_query_lens) + max_kv_len = max(prefill_seq_lens) + else: + prefill_block_tables = None + query_start_loc = None + kv_start_loc = None + max_query_len = None + max_kv_len = None + + # For paged attention + if input_data.num_decode_tokens != 0: + seq_lens_tensor = torch.tensor( + input_data.seq_lens[input_data.num_prefills:], + dtype=torch.int32, + device="cpu", + ) + block_tables = make_tensor_with_pad( + self.input_data.decode_block_tables, + pad=0, + dtype=torch.int32, + device="cpu", + ) + else: + block_tables = torch.tensor([]) + seq_lens_tensor = torch.tensor([]) + + # For multi-modal models + placeholder_index_maps = None + if len(input_data.multi_modal_inputs_list) != 0: + placeholder_index_maps = { + modality: placeholder_map.index_map() + for modality, placeholder_map in + input_data.multi_modal_placeholder_maps.items() + } + + attn_metadata = TorchSDPAMetadata( + chunked_prefill=self.chunked_prefill, + seq_lens=prefill_seq_lens, + seq_lens_tensor=seq_lens_tensor, + max_query_len=max_query_len, + max_kv_len=max_kv_len, + query_start_loc=query_start_loc, + kv_start_loc=kv_start_loc, + max_decode_seq_len=input_data.max_decode_seq_len, + num_prefills=input_data.num_prefills, + num_prefill_tokens=input_data.num_prefill_tokens, + num_decode_tokens=input_data.num_decode_tokens, + block_tables=block_tables, + prefill_block_tables=prefill_block_tables, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=placeholder_index_maps, + ) + + return attn_metadata + + class TorchSDPABackendImpl(AttentionImpl[TorchSDPAMetadata]): def __init__( @@ -327,7 +426,7 @@ def forward( attn_metadata: TorchSDPAMetadata, # type: ignore k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with torch SDPA and PagedAttention. @@ -409,19 +508,35 @@ def forward( assert key.shape[0] == num_prefill_tokens + num_decode_tokens assert value.shape[0] == num_prefill_tokens + num_decode_tokens + output = torch.empty_like(query) if prefill_meta := attn_metadata.prefill_metadata: assert attn_metadata.seq_lens is not None - if (kv_cache.numel() == 0 - or prefill_meta.block_tables.numel() == 0): - output = self._run_sdpa_forward(query, - key, - value, - prefill_meta, - attn_type=attn_type) + if not prefill_meta.prefill_metadata.chunked_prefill: # type: ignore + self._run_sdpa_forward(output, + query, + key, + value, + prefill_meta, + attn_type=attn_type) else: # prefix-enabled attention - raise RuntimeError( - "Torch SDPA backend doesn't support prefix decoding.") + assert not self.need_mask + import intel_extension_for_pytorch.llm.modules as ipex_modules + output = torch.empty_like(query) + ipex_modules.PagedAttention.flash_attn_varlen_func( + output[:prefill_meta.num_prefill_tokens, :, :], + query[:prefill_meta.num_prefill_tokens, :, :], + key_cache, + value_cache, + prefill_meta.query_start_loc, + prefill_meta.kv_start_loc, + prefill_meta.max_query_len, + prefill_meta.max_kv_len, + self.scale, + True, + prefill_meta.prefill_block_tables, + self.alibi_slopes, + ) if decode_meta := attn_metadata.decode_metadata: assert attn_type != AttentionType.ENCODER_ONLY, ( @@ -433,8 +548,9 @@ def forward( block_tables_arg, ) = decode_meta.get_seq_len_block_table_args(attn_type) - output = PagedAttention.forward_decode( - query, + PagedAttention.forward_decode( + output[attn_metadata.num_prefill_tokens:, :, :], + query[attn_metadata.num_prefill_tokens:, :, :], key_cache, value_cache, block_tables_arg, @@ -453,12 +569,13 @@ def forward( def _run_sdpa_forward( self, + output: torch.Tensor, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, attn_metadata: TorchSDPAMetadata, - attn_type: AttentionType = AttentionType.DECODER, - ): + attn_type: str = AttentionType.DECODER, + ) -> None: if self.num_kv_heads != self.num_heads: key = key.repeat_interleave(self.num_queries_per_kv, dim=1) value = value.repeat_interleave(self.num_queries_per_kv, dim=1) @@ -479,7 +596,6 @@ def _run_sdpa_forward( attn_masks = [None] * len(seq_lens) attn_metadata.set_attn_bias(attn_masks, attn_type) - output = torch.empty_like(query) query = query.movedim(0, query.dim() - 2) key = key.movedim(0, key.dim() - 2) value = value.movedim(0, value.dim() - 2) @@ -502,7 +618,6 @@ def _run_sdpa_forward( scale=self.scale).squeeze(0).movedim(query.dim() - 2, 0) output[start_q:end_q, :, :] = sub_out start_q, start_kv = end_q, end_kv - return output def _make_alibi_bias( diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 12800668af223..56cc43430301f 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -478,7 +478,7 @@ def is_all_cross_attn_metadata_set(attn_metadata): def get_seq_len_block_table_args( attn_metadata, is_prompt: bool, - attn_type: AttentionType, + attn_type: str, ) -> tuple: ''' The particular choice of sequence-length- and block-table-related @@ -529,7 +529,7 @@ def get_seq_len_block_table_args( def get_num_prefill_decode_query_kv_tokens( attn_metadata, - attn_type: AttentionType, + attn_type: str, ) -> Tuple[int, int, int]: """ Calculate the number of prefill and decode tokens for query, key/value diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 83d03606524dc..292575a8736bc 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -284,7 +284,7 @@ def decode_metadata(self) -> Optional["XFormersMetadata"]: def _get_attn_bias( attn_metadata: XFormersMetadata, - attn_type: AttentionType, + attn_type: str, ) -> Optional[AttentionBias]: ''' Extract appropriate attention bias from attention metadata @@ -314,7 +314,7 @@ def _get_attn_bias( def _set_attn_bias( attn_metadata: XFormersMetadata, attn_bias: List[Optional[AttentionBias]], - attn_type: AttentionType, + attn_type: str, ) -> None: ''' Update appropriate attention bias field of attention metadata, @@ -416,7 +416,7 @@ def forward( attn_metadata: "XFormersMetadata", k_scale: float = 1.0, v_scale: float = 1.0, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: """Forward pass with xFormers and PagedAttention. @@ -617,7 +617,7 @@ def _run_memory_efficient_xformers_forward( key: torch.Tensor, value: torch.Tensor, attn_metadata: XFormersMetadata, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: """Attention for 1D query of multiple prompts. Multiple prompt tokens are flattened in to `query` input. diff --git a/vllm/attention/layer.py b/vllm/attention/layer.py index 33d05cbd3fe01..17157617248f7 100644 --- a/vllm/attention/layer.py +++ b/vllm/attention/layer.py @@ -4,12 +4,16 @@ import torch import torch.nn as nn +import vllm.envs as envs from vllm.attention import AttentionMetadata, AttentionType -from vllm.attention.selector import get_attn_backend -from vllm.config import CacheConfig +from vllm.attention.selector import backend_name_to_enum, get_attn_backend +from vllm.config import CacheConfig, get_current_vllm_config +from vllm.forward_context import ForwardContext, get_forward_context from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.platforms import current_platform +from vllm.utils import direct_register_custom_op class Attention(nn.Module): @@ -35,18 +39,26 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, blocksparse_params: Optional[Dict[str, Any]] = None, logits_soft_cap: Optional[float] = None, + per_layer_sliding_window: Optional[int] = None, prefix: str = "", ) -> None: super().__init__() + if per_layer_sliding_window is not None: + # per-layer sliding window + sliding_window = per_layer_sliding_window + elif cache_config is not None: + # model-level sliding window + sliding_window = cache_config.sliding_window + else: + sliding_window = None + if cache_config is not None: kv_cache_dtype = cache_config.cache_dtype block_size = cache_config.block_size - sliding_window = cache_config.sliding_window is_attention_free = cache_config.is_attention_free else: kv_cache_dtype = "auto" block_size = 16 - sliding_window = None is_attention_free = False if num_kv_heads is None: num_kv_heads = num_heads @@ -85,6 +97,19 @@ def __init__( self.impl = impl_cls(num_heads, head_size, scale, num_kv_heads, alibi_slopes, sliding_window, kv_cache_dtype, blocksparse_params, logits_soft_cap) + self.backend = backend_name_to_enum(attn_backend.get_name()) + + # For cuda-alike (CUDA and ROCM) and cpu platforms, we control how + # torch.compile works by registering the attention as one giant + # opaque custom op. For other platforms, we directly call them + # and let torch.compile handle them. + self.use_direct_call = envs.VLLM_USE_V1 or not ( + current_platform.is_cuda_alike() or current_platform.is_cpu()) + compilation_config = get_current_vllm_config().compilation_config + if prefix in compilation_config.static_forward_context: + raise ValueError(f"Duplicate layer name: {prefix}") + compilation_config.static_forward_context[prefix] = self + self.layer_name = prefix def forward( self, @@ -93,17 +118,22 @@ def forward( value: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, - attn_type: AttentionType = AttentionType.DECODER, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: - return self.impl.forward(query, - key, - value, - kv_cache, - attn_metadata, - self._k_scale, - self._v_scale, - attn_type=attn_type) + if self.use_direct_call: + return self.impl.forward(query, + key, + value, + kv_cache, + attn_metadata, + self._k_scale, + self._v_scale, + attn_type=attn_type) + else: + return torch.ops.vllm.unified_attention(query, key, value, + kv_cache, attn_type, + self.layer_name) def extra_repr(self) -> str: s = f"head_size={self.impl.head_size}" # type: ignore @@ -112,3 +142,44 @@ def extra_repr(self) -> str: s += f", scale={self.impl.scale}" # type: ignore s += f", backend={self.impl.__class__.__name__}" return s + + +def unified_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_type: str, + layer_name: str, +) -> torch.Tensor: + forward_context: ForwardContext = get_forward_context() + attn_metadata = forward_context.dynamic_forward_context + self = forward_context.static_forward_context[layer_name] + return self.impl.forward(query, + key, + value, + kv_cache, + attn_metadata, + self._k_scale, + self._v_scale, + attn_type=attn_type) + + +def unified_attention_fake( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + kv_cache: torch.Tensor, + attn_type: str, + layer_name: str, +) -> torch.Tensor: + return torch.empty_like(query).contiguous() + + +direct_register_custom_op( + op_name="unified_attention", + op_func=unified_attention, + mutates_args=["kv_cache"], + fake_impl=unified_attention_fake, + dispatch_key=current_platform.dispatch_key, +) diff --git a/vllm/attention/ops/ipex_attn.py b/vllm/attention/ops/ipex_attn.py index 8df6d4ced9dc6..cbc6c74acf09a 100644 --- a/vllm/attention/ops/ipex_attn.py +++ b/vllm/attention/ops/ipex_attn.py @@ -1,12 +1,17 @@ from typing import Dict, List, Optional, Tuple -import intel_extension_for_pytorch.llm.modules as ipex_modules +try: + import intel_extension_for_pytorch.llm.modules as ipex_modules + _use_ipex = True +except ImportError: + _use_ipex = False + import torch from vllm import _custom_ops as ops -class PagedAttention: +class _PagedAttention: @staticmethod def get_supported_head_sizes() -> List[int]: @@ -22,6 +27,105 @@ def get_kv_cache_shape( ) -> Tuple[int, ...]: return (2, num_blocks, block_size * num_kv_heads * head_size) + @staticmethod + def split_kv_cache( + kv_cache: torch.Tensor, + num_kv_heads: int, + head_size: int, + *args, + ) -> Tuple[torch.Tensor, torch.Tensor]: + x = 16 // kv_cache.element_size() + num_blocks = kv_cache.shape[1] + + key_cache = kv_cache[0] + key_cache = key_cache.view(num_blocks, num_kv_heads, head_size // x, + -1, x) + value_cache = kv_cache[1] + value_cache = value_cache.view(num_blocks, num_kv_heads, head_size, -1) + return key_cache, value_cache + + @staticmethod + def write_to_paged_cache( + key: torch.Tensor, + value: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + slot_mapping: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + *args, + ) -> None: + ops.reshape_and_cache( + key, + value, + key_cache, + value_cache, + slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + @staticmethod + def forward_decode( + output: torch.Tensor, + query: torch.Tensor, + key_cache: torch.Tensor, + value_cache: torch.Tensor, + block_tables: torch.Tensor, + context_lens: torch.Tensor, + max_context_len: int, + kv_cache_dtype: str, + num_kv_heads: int, + scale: float, + alibi_slopes: Optional[torch.Tensor], + k_scale: float, + v_scale: float, + *args, + ) -> None: + tp_rank: int = 0 + blocksparse_local_blocks: int = 0 + blocksparse_vert_stride: int = 0 + blocksparse_block_size: int = 64 + blocksparse_head_sliding_step: int = 0 + block_size = value_cache.shape[3] + + ops.paged_attention_v1( + output, + query, + key_cache, + value_cache, + num_kv_heads, + scale, + block_tables, + context_lens, + block_size, + max_context_len, + alibi_slopes, + kv_cache_dtype, + k_scale, + v_scale, + tp_rank, + blocksparse_local_blocks, + blocksparse_vert_stride, + blocksparse_block_size, + blocksparse_head_sliding_step, + ) + + @staticmethod + def copy_blocks( + kv_caches: List[torch.Tensor], + src_to_dists: Dict[int, List[int]], + *args, + ) -> None: + key_caches = [kv_cache[0] for kv_cache in kv_caches] + value_caches = [kv_cache[1] for kv_cache in kv_caches] + ops.copy_blocks(key_caches, value_caches, src_to_dists) + + +class _IPEXPagedAttention(_PagedAttention): + @staticmethod def split_kv_cache( kv_cache: torch.Tensor, @@ -55,6 +159,7 @@ def write_to_paged_cache( @staticmethod def forward_decode( + output: torch.Tensor, query: torch.Tensor, key_cache: torch.Tensor, value_cache: torch.Tensor, @@ -68,8 +173,7 @@ def forward_decode( k_scale: float, v_scale: float, *args, - ) -> torch.Tensor: - output = torch.empty_like(query) + ) -> None: block_size = value_cache.shape[2] head_mapping = torch.arange( 0, @@ -83,41 +187,5 @@ def forward_decode( scale, block_tables, context_lens, block_size, max_context_len, alibi_slopes) - return output - - @staticmethod - def forward_prefix( - query: torch.Tensor, - key: torch.Tensor, - value: torch.Tensor, - kv_cache_dtype: str, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - block_tables: torch.Tensor, - subquery_start_loc: torch.Tensor, - prompt_lens_tensor: torch.Tensor, - context_lens: torch.Tensor, - max_subquery_len: int, - alibi_slopes: Optional[torch.Tensor], - *args, - ) -> torch.Tensor: - raise NotImplementedError - - @staticmethod - def swap_blocks( - src_kv_cache: torch.Tensor, - dst_kv_cache: torch.Tensor, - src_to_dst: Dict[int, int], - *args, - ) -> None: - raise NotImplementedError - @staticmethod - def copy_blocks( - kv_caches: List[torch.Tensor], - src_to_dists: Dict[int, List[int]], - *args, - ) -> None: - key_caches = [kv_cache[0] for kv_cache in kv_caches] - value_caches = [kv_cache[1] for kv_cache in kv_caches] - ops.copy_blocks(key_caches, value_caches, src_to_dists) +PagedAttention = _IPEXPagedAttention if _use_ipex else _PagedAttention diff --git a/vllm/attention/ops/prefix_prefill.py b/vllm/attention/ops/prefix_prefill.py index a2a649c8ebcfd..9c11a8df55278 100644 --- a/vllm/attention/ops/prefix_prefill.py +++ b/vllm/attention/ops/prefix_prefill.py @@ -7,6 +7,13 @@ from vllm.platforms import current_platform +# Static kernels parameters +BASE_BLOCK = 128 if current_platform.has_device_capability(80) else 64 +NUM_WARPS = 8 + +# To check compatibility +IS_TURING = current_platform.get_device_capability() == (7, 5) + if triton.__version__ >= "2.1.0": @triton.jit @@ -50,6 +57,7 @@ def _fwd_kernel( stride_v_cache_d, stride_v_cache_bl, num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # head size BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 @@ -130,7 +138,7 @@ def _fwd_kernel( k = k_load qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) # [M,N] - qk += tl.dot(q, k) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")) qk *= sm_scale @@ -178,7 +186,7 @@ def _fwd_kernel( v = v_load p = p.to(v.dtype) - acc += tl.dot(p, v) + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) # # update m_i and l_i l_i = l_i_new m_i = m_i_new @@ -204,7 +212,7 @@ def _fwd_kernel( other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) qk *= sm_scale # apply causal mask qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, @@ -238,7 +246,7 @@ def _fwd_kernel( other=0.0) p = p.to(v.dtype) - acc += tl.dot(p, v) + acc = tl.dot(p, v, acc=acc, input_precision=IN_PRECISION) # update m_i and l_i l_i = l_i_new m_i = m_i_new @@ -485,6 +493,7 @@ def _fwd_kernel_alibi( stride_v_cache_d, stride_v_cache_bl, num_queries_per_kv: int, + IN_PRECISION: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, # head size BLOCK_DMODEL_PADDED: tl.constexpr, # head size padded to a power of 2 @@ -560,7 +569,7 @@ def _fwd_kernel_alibi( k = k_load qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k) + qk = tl.dot(q, k, acc=qk, input_precision=IN_PRECISION) qk = tl.where((start_n + offs_n[None, :]) < cur_batch_ctx_len, qk, float("-inf")) qk *= sm_scale @@ -600,7 +609,7 @@ def _fwd_kernel_alibi( v = v_load p = p.to(v.dtype) - acc += tl.dot(p, v, allow_tf32=False) + acc = tl.dot(p, v, acc=acc, input_precision='ieee') # update m_i and l_i l_i = l_i_new m_i = m_i_new @@ -635,7 +644,7 @@ def _fwd_kernel_alibi( other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) - qk += tl.dot(q, k, allow_tf32=False) + qk = tl.dot(q, k, acc=qk, input_precision='ieee') qk *= sm_scale qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) @@ -673,7 +682,7 @@ def _fwd_kernel_alibi( other=0.0) p = p.to(v.dtype) - acc += tl.dot(p, v, allow_tf32=False) + acc = tl.dot(p, v, acc=acc, input_precision='ieee') # update m_i and l_i l_i = l_i_new m_i = m_i_new @@ -709,13 +718,17 @@ def context_attention_fwd(q, alibi_slopes=None, sliding_window=None): - BLOCK = 128 if current_platform.has_device_capability(80) else 64 - NUM_WARPS = 8 - + q_dtype_is_f32 = q.dtype is torch.float32 # need to reduce num. blocks when using fp32 # due to increased use of GPU shared memory - if q.dtype is torch.float32: - BLOCK = BLOCK // 2 + # if q.dtype is torch.float32: + BLOCK = BASE_BLOCK // 2 if q_dtype_is_f32 else BASE_BLOCK + + # Turing does have tensor core for float32 multiplication + # use ieee as fallback for triton kernels work. There is also + # warning on vllm/config.py to inform users this fallback + # implementation + IN_PRECISION = 'ieee' if IS_TURING and q_dtype_is_f32 else None # Conversion of FP8 Tensor from uint8 storage to # appropriate torch.dtype for interpretation by Triton @@ -799,6 +812,7 @@ def context_attention_fwd(q, v_cache.stride( 3), #[num_blocks, num_kv_heads, head_size, block_size] num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, @@ -850,6 +864,7 @@ def context_attention_fwd(q, v_cache.stride( 3), #[num_blocks, num_kv_heads, head_size, block_size] num_queries_per_kv=num_queries_per_kv, + IN_PRECISION=IN_PRECISION, BLOCK_M=BLOCK, BLOCK_DMODEL=Lk, BLOCK_DMODEL_PADDED=Lk_padded, diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 0cf1e3a95fcba..464bc2af8fd6d 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -1,6 +1,5 @@ import copy import dataclasses -import operator from contextlib import ExitStack from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple from unittest.mock import patch @@ -11,205 +10,15 @@ import vllm.envs as envs from vllm.config import CompilationConfig from vllm.logger import init_logger -from vllm.utils import combine_fx_passes, weak_ref_tensors +from vllm.utils import weak_ref_tensors from .counter import compilation_counter -from .fusion import FusionPass -from .reshapes import RedundantReshapesPass +from .inductor_pass import InductorPass +from .pass_manager import PostGradPassManager logger = init_logger(__name__) -def fix_functionalization(graph: fx.Graph): - """ - Rewrite the graph module to replace the pattern involving - torch._higher_order_ops.auto_functionalize.auto_functionalized - with a direct call to the inplace custom op. - - # TODO: check if PyTorch nightly has fixed this issue - """ - - # debug code, if we want to see the graph before the transformation - # with open("before.py", "w") as f: - # print(graph.python_code(root_module="self", verbose=True).src, file=f) - - nodes_to_remove = [] - - for node in graph.nodes: - # Identify the auto_functionalized node - if node.op == 'call_function' and node.target == torch._higher_order_ops.auto_functionalize.auto_functionalized: # noqa - if node.args[0] == torch.ops._C.rotary_embedding.default: - # manual replace for rotary_embedding - - # Now, collect the arguments - kwargs = node.kwargs - - query = kwargs['query'] - mm_node = query.args[0].args[0] - - # Create a new call to torch.ops._C.rotary_embedding.default - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function(torch.ops._C.rotary_embedding.default, - kwargs=kwargs) - - # Remove the auto_functionalized node - # Since the node may have outputs, we need to handle its users - # Replace uses of the outputs (getitem nodes) with mm_node - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - # Remove the getitem node - for getitem_user in list(user.users): - if (getitem_user.op == 'call_function' - and getitem_user.target - == torch.ops.aten.slice_scatter.default): - # Replace the uses of slice_scatter node - # with mm_node - getitem_user.replace_all_uses_with(mm_node) - nodes_to_remove.append(getitem_user) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[0] == torch.ops._C.fused_add_rms_norm.default: - # manual replace for fused_add_rms_norm - # this is the most effective optimization for llama - # failing to do this will result in many unnecessary copies - - kwargs = node.kwargs - - input = kwargs['input'] - residual = kwargs['residual'] - - # Create a new call to torch.ops._C.rotary_embedding.default - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.fused_add_rms_norm.default, kwargs=kwargs) - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - # Remove the getitem node - if user.args[1] == 1: - replace_node = input - elif user.args[1] == 2: - replace_node = residual - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - elif (node.args[0] == - torch.ops._C.fused_add_rms_norm_static_fp8_quant.default): - # manual replace for fused_add_rms_norm_static_fp8_quant - # this is the most effective optimization for llama - # failing to do this will result in many unnecessary copies - - kwargs = node.kwargs - - result = kwargs['result'] - residual = kwargs['residual'] - - # Create a new call to - # torch.ops._C.fused_add_rms_norm_static_fp8_quant.default - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.fused_add_rms_norm_static_fp8_quant. - default, - kwargs=kwargs) - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - # Remove the getitem node - if user.args[1] == 1: - replace_node = result - elif user.args[1] == 2: - replace_node = residual - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[0] == torch.ops._C.rms_norm.default: - # manual replace for rms_norm - - kwargs = node.kwargs - - replace_node = kwargs['result'] - # Create a new call to torch.ops._C.rms_norm.default - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function(torch.ops._C.rms_norm.default, - kwargs=kwargs) - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[ - 0] == torch.ops._C.rms_norm_static_fp8_quant.default: # noqa - # manual replace for rms_norm_static_fp8_quant - - kwargs = node.kwargs - - replace_node = kwargs['result'] - # Create a new call to torch.ops._C.rms_norm_static_fp8_quant.default # noqa - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.rms_norm_static_fp8_quant.default, - kwargs=kwargs) - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - elif node.args[0] == torch.ops._C.silu_and_mul.default: - # manual replace for silu_and_mul - - kwargs = node.kwargs - - input = kwargs['input'] - out = kwargs['out'] - - # Create a new call to torch.ops._C.silu_and_mul.default - # cannot use kwargs, because we have an `out`, see https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 # noqa - with graph.inserting_before(node): - # just insert the call to the custom op - # NOTE: don't run dead code elimination, - # otherwise this op will be removed - graph.call_function( - torch.ops._C.silu_and_mul.default, - args=(out, input), - ) - replace_node = out - - for user in list(node.users): - if user.op == 'call_function' and user.target == operator.getitem: # noqa - user.replace_all_uses_with(replace_node) - nodes_to_remove.append(user) - nodes_to_remove.append(node) - - # Remove the nodes all at once - for node in nodes_to_remove: - graph.erase_node(node) - - # debug code, if we want to see the graph after the transformation - # with open("after.py", "w") as f: - # print(graph.python_code(root_module="self", verbose=True).src, file=f) - - def wrap_inductor(graph, example_inputs, additional_inductor_config, @@ -368,12 +177,8 @@ class VllmBackend: The major work of this backend is to split the graph into piecewise graphs, and pass them to the piecewise backend. - This backend also handles custom passes and adds them to Inductor config. - The order of the post-grad post-passes is: - 1. post_grad_passes (constructor parameter) - 2. config["post_grad_custom_post_pass"] - 3. fix_functionalization - This way, all passes operate on a functionalized graph. + This backend also adds the PostGradPassManager to Inductor config, + which handles the post-grad passes. """ compilation_configs: CompilationConfig @@ -402,7 +207,9 @@ def __init__( # streams, it might not be safe to share a global pool. # only investigate this when we use multiple streams self.graph_pool = global_graph_pool - self.post_grad_passes = [] + + # Passes to run on the graph post-grad. + self.post_grad_pass_manager = PostGradPassManager() self.sym_tensor_indices = [] self.input_buffers = [] @@ -412,24 +219,19 @@ def __init__( # `torch.compile` is JIT compiled, so we don't need to # do anything here - def add_passes_to_config(self): + def configure_post_pass(self): config = self.compilation_configs - passes = list(self.post_grad_passes) - - passes = passes + [RedundantReshapesPass(config)] - - if config.enable_fusion: - passes = passes + [FusionPass.instance(config)] + self.post_grad_pass_manager.configure(config.pass_config) + # Post-grad custom passes are run using the post_grad_custom_post_pass + # hook. If a pass for that hook exists, add it to the pass manager. inductor_config = config.inductor_compile_config - if "post_grad_custom_post_pass" in inductor_config: - passes = passes + [inductor_config["post_grad_custom_post_pass"]] - - # add the fix_functionalization pass last, so that all other - # passes operate on a functionalized graph - passes = passes + [fix_functionalization] - combined_pass = combine_fx_passes(passes) - inductor_config["post_grad_custom_post_pass"] = combined_pass + PASS_KEY = "post_grad_custom_post_pass" + if PASS_KEY in inductor_config: + # Config should automatically wrap all inductor passes + assert isinstance(inductor_config[PASS_KEY], InductorPass) + self.post_grad_pass_manager.add(inductor_config[PASS_KEY]) + inductor_config[PASS_KEY] = self.post_grad_pass_manager def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: @@ -444,10 +246,10 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: # we get the sizes to capture for cudagraph # from compilation context self.compilation_configs.init_during_runtime() - self.add_passes_to_config() + self.configure_post_pass() self.split_gm, self.piecewise_graphs = split_graph( - graph, self.compilation_configs.non_cudagraph_ops) + graph, self.compilation_configs.splitting_ops) from torch._dynamo.utils import lazy_format_graph_code logger.debug("%s", lazy_format_graph_code("before split", self.graph)) diff --git a/vllm/compilation/counter.py b/vllm/compilation/counter.py index 100a49aba74ac..6385f1c5dbf81 100644 --- a/vllm/compilation/counter.py +++ b/vllm/compilation/counter.py @@ -5,6 +5,7 @@ @dataclasses.dataclass class CompilationCounter: + num_models_seen: int = 0 num_graphs_seen: int = 0 # including the splitting ops num_piecewise_graphs_seen: int = 0 diff --git a/vllm/compilation/decorators.py b/vllm/compilation/decorators.py index 4b78491bc5a48..8b81a29936989 100644 --- a/vllm/compilation/decorators.py +++ b/vllm/compilation/decorators.py @@ -3,6 +3,7 @@ import torch +from vllm.compilation.counter import compilation_counter from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import CompilationLevel, VllmConfig from vllm.logger import init_logger @@ -130,6 +131,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = '', **kwargs): ] or not supports_dynamo() if self.do_not_compile: return + compilation_counter.num_models_seen += 1 TorchCompileWrapperWithCustomDispatcher.__init__( self, compilation_level=vllm_config.compilation_config.level) diff --git a/vllm/compilation/fix_functionalization.py b/vllm/compilation/fix_functionalization.py new file mode 100644 index 0000000000000..3584cc3608caf --- /dev/null +++ b/vllm/compilation/fix_functionalization.py @@ -0,0 +1,177 @@ +import operator +from typing import Dict, Iterable, List, Optional, Tuple, Union + +import torch +from torch._higher_order_ops.auto_functionalize import auto_functionalized + +from vllm.logger import init_logger + +from .vllm_inductor_pass import VllmInductorPass, is_func + +logger = init_logger(__name__) + + +class FixFunctionalizationPass(VllmInductorPass): + """ + This pass defunctionalizes certain nodes to avoid redundant tensor copies. + After this pass, DCE (dead-code elimination) should never be run, + as de-functionalized nodes may appear as dead code. + + To add new nodes to defunctionalize, add to the if-elif chain in __call__. + """ + + def __call__(self, graph: torch.fx.Graph): + self.begin() + self.dump_graph(graph, "before_fix_functionalization") + + self.nodes_to_remove: List[torch.fx.Node] = [] + count = 0 + for node in graph.nodes: + if not is_func(node, auto_functionalized): + continue # Avoid deep if-elif nesting + + kwargs = node.kwargs + at_target = node.args[0] + + if at_target == torch.ops._C.rotary_embedding.default: + query = kwargs['query'] + mm_node = query.args[0].args[0] + + # rotary_embedding is a special case: the two mutating inputs + # are query and key, which are slices of mm_node. + # While functionalized, results at[1] and at[2] are scattered + # back into mm_node. After de-functionalization, we can just + # use mm_node directly. + for idx, user in self.getitem_users(node).items(): + for user_of_getitem in user.users: + if is_func(user_of_getitem, + torch.ops.aten.slice_scatter.default): + user_of_getitem.replace_all_uses_with(mm_node) + self._remove(user_of_getitem) + self._remove(user) + + self.insert_defunctionalized(graph, node) + self._remove(node) + + # These 2 replacements avoid the most copies for LLaMa. + elif at_target == torch.ops._C.fused_add_rms_norm.default: + mutated_args = {1: 'input', 2: 'residual'} + self.defunctionalize(graph, node, mutated_args) + elif at_target == torch.ops._C.fused_add_rms_norm_static_fp8_quant.default: # noqa: E501 + mutated_args = {1: 'result', 2: 'residual'} + self.defunctionalize(graph, node, mutated_args) + + elif at_target in [ + torch.ops._C.rms_norm.default, + torch.ops._C.rms_norm_static_fp8_quant.default + ]: + mutated_args = {1: 'result'} + self.defunctionalize(graph, node, mutated_args) + + elif at_target == torch.ops._C.silu_and_mul.default: + mutated_args = {1: 'out'} + # Because we have an 'out', need to specify args directly + self.defunctionalize(graph, + node, + mutated_args, + args=('out', 'input')) + else: + continue # skip the count + + count += 1 + + self.dump_graph(graph, "before_fix_functionalization_cleanup") + + # Remove the nodes all at once + count_removed = len(self.nodes_to_remove) + for node in self.nodes_to_remove: + graph.erase_node(node) + + logger.debug("De-functionalized %s nodes, removed %s nodes", count, + count_removed) + self.dump_graph(graph, "after_fix_functionalization") + self.end_and_log() + + def _remove(self, node_or_nodes: Union[torch.fx.Node, + Iterable[torch.fx.Node]]): + """ + Stage a node (or nodes) for removal at the end of the pass. + """ + if isinstance(node_or_nodes, torch.fx.Node): + self.nodes_to_remove.append(node_or_nodes) + else: + self.nodes_to_remove.extend(node_or_nodes) + + def defunctionalize(self, + graph: torch.fx.Graph, + node: torch.fx.Node, + mutated_args: Dict[int, Union[torch.fx.Node, str]], + args: Optional[Tuple[Union[torch.fx.Node, str], + ...]] = None): + """ + De-functionalize a node by replacing it with a call to the original. + It also replaces the getitem users with the mutated arguments. + See replace_users_with_mutated_args and insert_defunctionalized. + """ + self.replace_users_with_mutated_args(node, mutated_args) + self.insert_defunctionalized(graph, node, args=args) + self._remove(node) + + def replace_users_with_mutated_args(self, node: torch.fx.Node, + mutated_args: Dict[int, + Union[torch.fx.Node, + str]]): + """ + Replace all getitem users of the auto-functionalized node with the + mutated arguments. + :param node: The auto-functionalized node + :param mutated_args: The mutated arguments, indexed by getitem index. + If the value of an arg is a string, `node.kwargs[arg]` is used. + """ + for idx, user in self.getitem_users(node).items(): + arg = mutated_args[idx] + arg = node.kwargs[arg] if isinstance(arg, str) else arg + user.replace_all_uses_with(arg) + self._remove(user) + + def getitem_users(self, node: torch.fx.Node) -> Dict[int, torch.fx.Node]: + """ + Returns the operator.getitem users of the auto-functionalized node, + indexed by the index they are getting. + """ + users = {} + for user in node.users: + if is_func(user, operator.getitem): + idx = user.args[1] + users[idx] = user + return users + + def insert_defunctionalized(self, + graph: torch.fx.Graph, + node: torch.fx.Node, + args: Optional[Tuple[Union[torch.fx.Node, str], + ...]] = None): + """ + Insert a new defunctionalized node into the graph before node. + If one of the kwargs is 'out', provide args directly, + as node.kwargs cannot be used. + See https://github.com/pytorch/pytorch/blob/a00faf440888ffb724bad413f329a49e2b6388e7/torch/_inductor/lowering.py#L351 + + :param graph: Graph to insert the defunctionalized node into + :param node: The auto-functionalized node to defunctionalize + :param args: If we cannot use kwargs, specify args directly. + If an arg is a string, `node.kwargs[arg]` is used. + """ # noqa: E501 + assert is_func(node, auto_functionalized), \ + f"node must be auto-functionalized, is {node} instead" + + # Create a new call to the original function + with graph.inserting_before(node): + function = node.args[0] + if args is None: + graph.call_function(function, kwargs=node.kwargs) + else: + # Args passed as strings refer to items in node.kwargs + args = tuple(node.kwargs[arg] if isinstance(arg, str) else arg + for arg in args) + graph.call_function(function, args=args) diff --git a/vllm/compilation/fusion.py b/vllm/compilation/fusion.py index e6a3afef85e1b..5efa410fab6a0 100644 --- a/vllm/compilation/fusion.py +++ b/vllm/compilation/fusion.py @@ -6,10 +6,11 @@ from torch._inductor.pattern_matcher import (Match, PatternMatcherPass, fwd_only, register_replacement) -from vllm.compilation.inductor_pass import InductorPass from vllm.config import CompilationConfig from vllm.logger import init_logger +from .vllm_inductor_pass import VllmInductorPass, is_func + logger = init_logger(__name__) @@ -90,8 +91,6 @@ def empty_fp32(*args, **kwargs): # Utilities for post-processing multi-output matches -def is_func(node: torch.fx.Node, target) -> bool: - return node.op == "call_function" and node.target == target # Returns the first auto_functionalized node with the given op (if it exists) @@ -127,7 +126,7 @@ def find_getitem(node: torch.fx.Node, idx: int) -> torch.fx.Node: return ret -class FusionPass(InductorPass): +class FusionPass(VllmInductorPass): """ This pass fuses a pre-defined set of custom ops into fused ops. It uses the torch pattern matcher to find the patterns and replace them. @@ -142,7 +141,7 @@ class FusionPass(InductorPass): _instance: 'Optional[FusionPass]' = None @classmethod - def instance(cls, config: CompilationConfig): + def instance(cls, config: CompilationConfig.PassConfig): """ Get the singleton instance of the FusionPass. If the instance exists, the config is updated but @@ -154,7 +153,7 @@ def instance(cls, config: CompilationConfig): cls._instance.config = config return cls._instance - def __init__(self, config: CompilationConfig): + def __init__(self, config: CompilationConfig.PassConfig): assert self.__class__._instance is None, \ "FusionPass singleton instance already exists" super().__init__(config) @@ -278,6 +277,7 @@ def process_matches(self, graph: torch.fx.Graph): for node in match.nodes) def __call__(self, graph: torch.fx.Graph): + self.begin() self.dump_graph(graph, "before_fusion") count = self.patterns.apply(graph) @@ -289,3 +289,4 @@ def __call__(self, graph: torch.fx.Graph): logger.debug("Post-processed %s matches", len(self.matches)) self.dump_graph(graph, "after_fusion") self.matches.clear() + self.end_and_log() diff --git a/vllm/compilation/inductor_pass.py b/vllm/compilation/inductor_pass.py index 8082a08b40019..f6846c08ac841 100644 --- a/vllm/compilation/inductor_pass.py +++ b/vllm/compilation/inductor_pass.py @@ -1,38 +1,84 @@ +import hashlib +import inspect +import types from abc import ABC, abstractmethod +from typing import Any, Callable, Optional, Union import torch - -from vllm.config import CompilationConfig -# yapf: disable -from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank -from vllm.distributed import ( - get_tensor_model_parallel_world_size as get_tp_world_size) -from vllm.distributed import model_parallel_is_initialized as p_is_init -# yapf: enable -from vllm.logger import init_logger - -logger = init_logger(__name__) +from torch import fx class InductorPass(ABC): + """ + General custom inductor pass interface. + TODO(torch==2.6) use torch._inductor.custom_graph_pass.CustomGraphPass + """ @abstractmethod def __call__(self, graph: torch.fx.Graph): + """ + Execute the pass on the given graph. + """ raise NotImplementedError - def __init__(self, config: CompilationConfig): - self.config = config - - def dump_graph(self, graph: torch.fx.Graph, stage: str): - if stage in self.config.dump_graph_stages: - # Make sure filename includes rank in the distributed setting - parallel = p_is_init() and get_tp_world_size() > 1 - rank = f"-{get_tp_rank()}" if parallel else "" - filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" - - logger.info("Printing graph to %s", filepath) - with open(filepath, "w") as f: - src = graph.python_code(root_module="self", verbose=True).src - # Add imports so it's not full of errors - print("import torch; from torch import device", file=f) - print(src, file=f) + def uuid(self) -> Any: + """ + Provide a unique identifier for the pass, used in Inductor code cache. + This should depend on the pass implementation, so that changes to the + pass result in recompilation. + By default, the object source is hashed. + """ + return InductorPass.hash_source(self) + + @staticmethod + def hash_source(*srcs: Union[str, Any]): + """ + Utility method to hash the sources of functions or objects. + :param srcs: strings or objects to add to the hash. + Objects and functions have their source inspected. + :return: + """ + hasher = hashlib.sha256() + for src in srcs: + if isinstance(src, str): + src_str = src + elif isinstance(src, types.FunctionType): + src_str = inspect.getsource(src) + else: + src_str = inspect.getsource(src.__class__) + hasher.update(src_str.encode("utf-8")) + return hasher.digest() + + +class CallableInductorPass(InductorPass): + """ + This class is a wrapper for a callable that automatically provides an + implementation of the UUID. + """ + + def __init__(self, + callable: Callable[[fx.Graph], None], + uuid: Optional[Any] = None): + self.callable = callable + if uuid is None: + uuid = InductorPass.hash_source(callable) + self._uuid = uuid + + def __call__(self, graph: torch.fx.Graph): + self.callable(graph) + + def uuid(self) -> Any: + return self._uuid + + def __getstate__(self): + """ + Pickling occurs in the Inductor code cache if a pass is not given to + the pass manager but is instead directly added to config as a pass. + See PostGradPassManager for more. + + TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead. + """ + return self._uuid + + def __setstate__(self, state): + raise ValueError("Cannot unpickle CallableInductorPass") diff --git a/vllm/compilation/pass_manager.py b/vllm/compilation/pass_manager.py new file mode 100644 index 0000000000000..fb522ae053e97 --- /dev/null +++ b/vllm/compilation/pass_manager.py @@ -0,0 +1,77 @@ +from typing import List + +from torch import fx as fx + +from vllm.config import CompilationConfig +from vllm.logger import init_logger + +from .fix_functionalization import FixFunctionalizationPass +from .fusion import FusionPass +from .inductor_pass import InductorPass +from .reshapes import RedundantReshapesPass + +logger = init_logger(__name__) + + +class PostGradPassManager: + """ + The pass manager for post-grad passes. + It handles configuration, adding custom passes, and running passes. + It also supports pickling, which is used by the Inductor code cache. + TODO(torch==2.6), use CustomGraphPass + (torch._inductor.custom_graph_pass.CustomGraphPass) + + The order of the post-grad post-passes is: + 1. passes (constructor parameter) + 2. default passes (RedundantReshapesPass, FusionPass) + 3. config["post_grad_custom_post_pass"] (if it exists) + 4. fix_functionalization + This way, all passes operate on a functionalized graph. + """ + + def __init__(self): + self.passes: List[InductorPass] = [] + + def __call__(self, graph: fx.Graph): + for pass_ in self.passes: + pass_(graph) + + # always run fix_functionalization last + self.fix_functionalization(graph) + + def configure(self, pass_config: CompilationConfig.PassConfig): + self.pass_config = pass_config + if pass_config.enable_reshape: + self.passes += [RedundantReshapesPass(pass_config)] + + if pass_config.enable_fusion: + self.passes += [FusionPass.instance(pass_config)] + + self.fix_functionalization = FixFunctionalizationPass(pass_config) + + def add(self, pass_: InductorPass): + assert isinstance(pass_, InductorPass) + self.passes.append(pass_) + + def __getstate__(self): + """ + Custom pickling for the pass manager, as some passes cannot be pickled. + Pickling occurs because the pass manager is set as the value of + `config["post_grad_custom_post_pass"]` in the Inductor config. + The config is pickled to act as a key in the Inductor code cache. + Any other passes in the config are pickled as well. + + TODO(torch==2.6), use the `uuid` method in CustomGraphPass instead. + """ + state = {"pass_config": self.pass_config.uuid(), "passes": []} + for pass_ in self.passes: + state["passes"].append(pass_.uuid()) + state["passes"].append(self.fix_functionalization.uuid()) + return state + + def __setstate__(self, state): + """ + Do not allow unpickling of the pass manager. + If this is needed in the future, it should properly pickle the passes. + """ + raise ValueError("Cannot unpickle PostGradPassManager") diff --git a/vllm/compilation/reshapes.py b/vllm/compilation/reshapes.py index 36597e119d2e1..63a369fe8d966 100644 --- a/vllm/compilation/reshapes.py +++ b/vllm/compilation/reshapes.py @@ -3,14 +3,14 @@ import torch.fx from torch import SymInt -from vllm.compilation.fusion import is_func -from vllm.compilation.inductor_pass import InductorPass from vllm.logger import init_logger +from .vllm_inductor_pass import VllmInductorPass, is_func + logger = init_logger(__name__) -class RedundantReshapesPass(InductorPass): +class RedundantReshapesPass(VllmInductorPass): """ This is an inductor pass that removes redundant reshape operations. It is required for RMSNorm-quant fusion to work properly. @@ -31,6 +31,7 @@ class RedundantReshapesPass(InductorPass): """ def __call__(self, graph: torch.fx.Graph): + self.begin() self.dump_graph(graph, "before_reshapes") count = 0 # Remove no-op reshapes/views: @@ -56,6 +57,7 @@ def __call__(self, graph: torch.fx.Graph): logger.debug("Removed %s no-op reshapes", count) self.dump_graph(graph, "after_reshapes") + self.end_and_log() def dims_equivalent(self, dim: Union[int, torch.fx.Node], i_dim: Union[int, SymInt]) -> bool: diff --git a/vllm/compilation/vllm_inductor_pass.py b/vllm/compilation/vllm_inductor_pass.py new file mode 100644 index 0000000000000..dbf6b8f7789e1 --- /dev/null +++ b/vllm/compilation/vllm_inductor_pass.py @@ -0,0 +1,53 @@ +import time + +import torch + +from vllm.config import CompilationConfig +# yapf: disable +from vllm.distributed import get_tensor_model_parallel_rank as get_tp_rank +from vllm.distributed import ( + get_tensor_model_parallel_world_size as get_tp_world_size) +from vllm.distributed import model_parallel_is_initialized as p_is_init +# yapf: enable +from vllm.logger import init_logger + +from .inductor_pass import InductorPass + +logger = init_logger(__name__) + + +def is_func(node: torch.fx.Node, target) -> bool: + return node.op == "call_function" and node.target == target + + +class VllmInductorPass(InductorPass): + """ + An inductor pass with access to vLLM PassConfig. + It provides timing, logging, and dumping utilities. + """ + + def __init__(self, config: CompilationConfig.PassConfig): + self.config = config + self.pass_name = self.__class__.__name__ + + def dump_graph(self, graph: torch.fx.Graph, stage: str): + if stage in self.config.dump_graph_stages: + # Make sure filename includes rank in the distributed setting + parallel = p_is_init() and get_tp_world_size() > 1 + rank = f"-{get_tp_rank()}" if parallel else "" + filepath = self.config.dump_graph_dir / f"{stage}{rank}.py" + + logger.info("%s printing graph to %s", self.pass_name, filepath) + with open(filepath, "w") as f: + src = graph.python_code(root_module="self", verbose=True).src + # Add imports so it's not full of errors + print("import torch; from torch import device", file=f) + print(src, file=f) + + def begin(self): + self._start_time = time.perf_counter_ns() + + def end_and_log(self): + self._end_time = time.perf_counter_ns() + duration_ms = float(self._end_time - self._start_time) / 1.0e6 + logger.debug("%s completed in %.1f ms", self.pass_name, duration_ms) diff --git a/vllm/compilation/wrapper.py b/vllm/compilation/wrapper.py index 0143d0301ca1a..bc4d292fef402 100644 --- a/vllm/compilation/wrapper.py +++ b/vllm/compilation/wrapper.py @@ -8,7 +8,7 @@ import torch import vllm.envs as envs -from vllm.config import CompilationLevel +from vllm.config import CompilationLevel, get_current_vllm_config class TorchCompileWrapperWithCustomDispatcher: @@ -32,7 +32,6 @@ def __init__(self, # default compilation settings # compiling the forward method - from vllm.plugins import get_current_vllm_config backend = get_current_vllm_config( ).compilation_config.init_backend() diff --git a/vllm/config.py b/vllm/config.py index e69cbd3eb402a..c87feaec3e5f6 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1,7 +1,9 @@ import copy import enum +import hashlib import json import warnings +from contextlib import contextmanager from dataclasses import dataclass, field, replace from pathlib import Path from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Counter, Dict, @@ -13,8 +15,10 @@ from transformers import PretrainedConfig import vllm.envs as envs +from vllm.compilation.inductor_pass import CallableInductorPass, InductorPass from vllm.logger import init_logger -from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS +from vllm.model_executor.layers.quantization import (QUANTIZATION_METHODS, + get_quantization_config) from vllm.model_executor.models import ModelRegistry from vllm.platforms import current_platform from vllm.tracing import is_otel_available, otel_import_error_traceback @@ -23,7 +27,7 @@ get_hf_text_config, get_pooling_config, get_sentence_transformer_tokenizer_config, is_encoder_decoder, uses_mrope) from vllm.utils import (GiB_bytes, cuda_device_count_stateless, get_cpu_memory, - identity, print_warning_once, resolve_obj_by_qualname) + print_warning_once, resolve_obj_by_qualname) if TYPE_CHECKING: from ray.util.placement_group import PlacementGroup @@ -179,7 +183,7 @@ def __init__( hf_overrides_fn = hf_overrides else: hf_overrides_kw = hf_overrides - hf_overrides_fn = identity + hf_overrides_fn = None if rope_scaling is not None: hf_override: Dict[str, Any] = {"rope_scaling": rope_scaling} @@ -208,8 +212,15 @@ def __init__( self.skip_tokenizer_init = skip_tokenizer_init hf_config = get_config(self.model, trust_remote_code, revision, - code_revision, config_format, **hf_overrides_kw) - hf_config = hf_overrides_fn(hf_config) + code_revision, config_format) + + if hf_overrides_kw: + logger.info("Overriding HF config with %s", hf_overrides_kw) + hf_config.update(hf_overrides_kw) + if hf_overrides_fn: + logger.info("Overriding HF config with %s", hf_overrides_fn) + hf_config = hf_overrides_fn(hf_config) + self.hf_config = hf_config self.hf_text_config = get_hf_text_config(self.hf_config) @@ -230,15 +241,26 @@ def __init__( (self.hf_text_config.model_type in ["gemma2"])) if (not self.disable_sliding_window and has_interleaved_attention): - sliding_window_len_min = get_min_sliding_window( - self.hf_text_config.sliding_window) - - print_warning_once( - f"{self.hf_text_config.model_type} has interleaved attention, " - "which is currently not supported by vLLM. Disabling sliding " - "window and capping the max length to the sliding window size " - f"({sliding_window_len_min}).") - self.disable_sliding_window = True + if envs.VLLM_ATTENTION_BACKEND == "XFORMERS": + sliding_window_len_min = get_min_sliding_window( + self.hf_text_config.sliding_window) + + print_warning_once( + f"{self.hf_text_config.model_type} has interleaved " + "attention, which is currently not supported by the " + "XFORMERS backend. Disabling sliding window and capping " + "the max length to the sliding window size " + f"({sliding_window_len_min}).") + self.disable_sliding_window = True + else: + # for a model with interleaved attention, + # the scheduler and the model treat it as full attention + # (i.e., not dropping any tokens outside the window). + # only the attention layer itself is aware of the sliding + # window, and use the window size to compute the attention. + self.hf_text_config.interleaved_sliding_window = sliding_window + delattr(self.hf_text_config, "sliding_window") + sliding_window = None self.max_model_len = _get_and_verify_max_len( hf_config=self.hf_text_config, @@ -370,10 +392,10 @@ def _parse_quant_hf_config(self): return quant_cfg def _verify_quantization(self) -> None: - supported_quantization = [*QUANTIZATION_METHODS] + supported_quantization = QUANTIZATION_METHODS rocm_supported_quantization = [ "awq", "gptq", "fp8", "compressed_tensors", "compressed-tensors", - "fbgemm_fp8" + "fbgemm_fp8", "gguf" ] optimized_quantization_methods = [ "fp8", "marlin", "modelopt", "gptq_marlin_24", "gptq_marlin", @@ -392,7 +414,8 @@ def _verify_quantization(self) -> None: quant_method = quant_cfg.get("quant_method", "").lower() # Detect which checkpoint is it - for _, method in QUANTIZATION_METHODS.items(): + for name in QUANTIZATION_METHODS: + method = get_quantization_config(name) quantization_override = method.override_quantization_method( quant_cfg, self.quantization) if quantization_override: @@ -697,6 +720,11 @@ def uses_mrope(self) -> bool: def is_multimodal_model(self) -> bool: return self.multimodal_config is not None + @property + def is_cross_encoder(self) -> bool: + architectures = getattr(self.hf_config, "architectures", []) + return ModelRegistry.is_cross_encoder_model(architectures) + class CacheConfig: """Configuration for the KV cache. @@ -922,76 +950,71 @@ def _verify_load_format(self) -> None: f"{rocm_supported_load_format}") +@dataclass class ParallelConfig: - """Configuration for the distributed execution. + """Configuration for the distributed execution.""" - Args: - pipeline_parallel_size: Number of pipeline parallel groups. - tensor_parallel_size: Number of tensor parallel groups. - worker_use_ray: Deprecated, use distributed_executor_backend instead. - max_parallel_loading_workers: Maximum number of multiple batches - when load model sequentially. To avoid RAM OOM when using tensor - parallel and large models. - disable_custom_all_reduce: Disable the custom all-reduce kernel and - fall back to NCCL. - tokenizer_pool_config: Config for the tokenizer pool. - If None, will use synchronous tokenization. - ray_workers_use_nsight: Whether to profile Ray workers with nsight, see - https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. - placement_group: ray distributed model workers placement group. - distributed_executor_backend: Backend to use for distributed model - workers, either "ray" or "mp" (multiprocessing). If the product - of pipeline_parallel_size and tensor_parallel_size is less than - or equal to the number of GPUs available, "mp" will be used to - keep processing on a single host. Otherwise, this will default - to "ray" if Ray is installed and fail otherwise. Note that tpu - and hpu only support Ray for distributed inference. - """ + pipeline_parallel_size: int = 1 # Number of pipeline parallel groups. + tensor_parallel_size: int = 1 # Number of tensor parallel groups. - def __init__( - self, - pipeline_parallel_size: int, - tensor_parallel_size: int, - worker_use_ray: Optional[bool] = None, - max_parallel_loading_workers: Optional[int] = None, - disable_custom_all_reduce: bool = False, - tokenizer_pool_config: Optional[TokenizerPoolConfig] = None, - ray_workers_use_nsight: bool = False, - placement_group: Optional["PlacementGroup"] = None, - distributed_executor_backend: Optional[Union[ - str, Type["ExecutorBase"]]] = None, - ) -> None: - self.pipeline_parallel_size = pipeline_parallel_size - self.tensor_parallel_size = tensor_parallel_size - self.distributed_executor_backend = distributed_executor_backend - self.max_parallel_loading_workers = max_parallel_loading_workers - self.disable_custom_all_reduce = disable_custom_all_reduce - self.tokenizer_pool_config = tokenizer_pool_config - self.ray_workers_use_nsight = ray_workers_use_nsight - self.placement_group = placement_group - self.world_size = pipeline_parallel_size * self.tensor_parallel_size - - if worker_use_ray: + # Deprecated, use distributed_executor_backend instead. + worker_use_ray: Optional[bool] = None + + # Maximum number of multiple batches + # when load model sequentially. To avoid RAM OOM when using tensor + # parallel and large models. + max_parallel_loading_workers: Optional[int] = None + + # Disable the custom all-reduce kernel and fall back to NCCL. + disable_custom_all_reduce: bool = False + + # Config for the tokenizer pool. If None, will use synchronous tokenization. + tokenizer_pool_config: Optional[TokenizerPoolConfig] = None + + # Whether to profile Ray workers with nsight, see https://docs.ray.io/en/latest/ray-observability/user-guides/profiling.html#profiling-nsight-profiler. + ray_workers_use_nsight: bool = False + + # ray distributed model workers placement group. + placement_group: Optional["PlacementGroup"] = None + + # Backend to use for distributed model + # workers, either "ray" or "mp" (multiprocessing). If the product + # of pipeline_parallel_size and tensor_parallel_size is less than + # or equal to the number of GPUs available, "mp" will be used to + # keep processing on a single host. Otherwise, this will default + # to "ray" if Ray is installed and fail otherwise. Note that tpu + # and hpu only support Ray for distributed inference. + distributed_executor_backend: Optional[Union[str, + Type["ExecutorBase"]]] = None + + # the full name of the worker class to use. If "auto", the worker class + # will be determined based on the platform. + worker_cls: str = "auto" + + world_size: int = field(init=False) + + rank: int = 0 + + def __post_init__(self) -> None: + self.world_size = self.pipeline_parallel_size * \ + self.tensor_parallel_size + + if self.worker_use_ray: if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" elif not self.use_ray: raise ValueError(f"worker-use-ray can't be used with " f"distributed executor backend " f"'{self.distributed_executor_backend}'.") - - if current_platform.is_tpu() and self.world_size > 1: - if self.distributed_executor_backend is None: - self.distributed_executor_backend = "ray" - if self.distributed_executor_backend != "ray": - raise ValueError( - "TPU backend only supports Ray for distributed inference.") - - if current_platform.is_hpu() and self.world_size > 1: + ray_only_devices = ["tpu", "hpu"] + if (current_platform.device_type in ray_only_devices + and self.world_size > 1): if self.distributed_executor_backend is None: self.distributed_executor_backend = "ray" if self.distributed_executor_backend != "ray": raise ValueError( - "HPU backend only supports Ray for distributed inference.") + f"{current_platform.device_type.upper()} backend only " + "supports Ray for distributed inference.") if self.distributed_executor_backend is None and self.world_size > 1: # We use multiprocessing by default if world_size fits on the @@ -1022,7 +1045,6 @@ def __init__( backend) self._verify_args() - self.rank: int = 0 @property def use_ray(self) -> bool: @@ -1055,100 +1077,97 @@ def _verify_args(self) -> None: "run with Ray.") +@dataclass class SchedulerConfig: - """Scheduler configuration. + """Scheduler configuration.""" - Args: - task: The task to use the model for. - max_num_batched_tokens: Maximum number of tokens to be processed in - a single iteration. - max_num_seqs: Maximum number of sequences to be processed in a single - iteration. - max_model_len: Maximum length of a sequence (including prompt - and generated text). - num_lookahead_slots: The number of slots to allocate per sequence per - step, beyond the known token ids. This is used in speculative - decoding to store KV activations of tokens which may or may not be - accepted. - delay_factor: Apply a delay (of delay factor multiplied by previous - prompt latency) before scheduling next prompt. - enable_chunked_prefill: If True, prefill requests can be chunked based - on the remaining max_num_batched_tokens. - preemption_mode: Whether to perform preemption by swapping or - recomputation. If not specified, we determine the mode as follows: - We use recomputation by default since it incurs lower overhead than - swapping. However, when the sequence group has multiple sequences - (e.g., beam search), recomputation is not currently supported. In - such a case, we use swapping instead. - send_delta_data: Private API. If used, scheduler sends delta data to - workers instead of an entire data. It should be enabled only - when SPMD worker architecture is enabled. I.e., - VLLM_USE_RAY_SPMD_WORKER=1 - policy: The scheduling policy to use. "fcfs" (default) or "priority". - """ + task: str = "generate" # The task to use the model for. + + # Maximum number of tokens to be processed in a single iteration. + max_num_batched_tokens: int = field(default=None) # type: ignore + + # Maximum number of sequences to be processed in a single iteration. + max_num_seqs: int = 128 + + # Maximum length of a sequence (including prompt and generated text). + max_model_len: int = 8192 + + # The number of slots to allocate per sequence per + # step, beyond the known token ids. This is used in speculative + # decoding to store KV activations of tokens which may or may not be + # accepted. + num_lookahead_slots: int = 0 + + # Apply a delay (of delay factor multiplied by previous + # prompt latency) before scheduling next prompt. + delay_factor: float = 0.0 + + # If True, prefill requests can be chunked based + # on the remaining max_num_batched_tokens. + enable_chunked_prefill: bool = False + + is_multimodal_model: bool = False + + # Whether to perform preemption by swapping or + # recomputation. If not specified, we determine the mode as follows: + # We use recomputation by default since it incurs lower overhead than + # swapping. However, when the sequence group has multiple sequences + # (e.g., beam search), recomputation is not currently supported. In + # such a case, we use swapping instead. + preemption_mode: Optional[str] = None - def __init__(self, - task: _Task, - max_num_batched_tokens: Optional[int], - max_num_seqs: int, - max_model_len: int, - num_lookahead_slots: int = 0, - delay_factor: float = 0.0, - enable_chunked_prefill: bool = False, - is_multimodal_model: bool = False, - preemption_mode: Optional[str] = None, - num_scheduler_steps: int = 1, - multi_step_stream_outputs: bool = False, - send_delta_data: bool = False, - policy: str = "fcfs") -> None: - if max_num_batched_tokens is None: - if enable_chunked_prefill: - if num_scheduler_steps > 1: + num_scheduler_steps: int = 1 + + multi_step_stream_outputs: bool = False + + # Private API. If used, scheduler sends delta data to + # workers instead of an entire data. It should be enabled only + # when SPMD worker architecture is enabled. I.e., + # VLLM_USE_RAY_SPMD_WORKER=1 + send_delta_data: bool = False + + # The scheduling policy to use. "fcfs" (default) or "priority". + policy: str = "fcfs" + + chunked_prefill_enabled: bool = field(init=False) + + def __post_init__(self) -> None: + if self.max_num_batched_tokens is None: + if self.enable_chunked_prefill: + if self.num_scheduler_steps > 1: # Multi-step Chunked-Prefill doesn't allow prompt-chunking # for now. Have max_num_batched_tokens set to max_model_len # so we don't reject sequences on account of a short # max_num_batched_tokens. - max_num_batched_tokens = max(max_model_len, 2048) + self.max_num_batched_tokens = max(self.max_model_len, 2048) else: - # It is the values that have the best balance between ITL - # and TTFT on A100. Note it is not optimized for throughput. - max_num_batched_tokens = 512 + # This value is chosen to have a balance between ITL + # and TTFT. Note it is not optimized for throughput. + self.max_num_batched_tokens = 2048 else: # If max_model_len is too short, use 2048 as the default value # for higher throughput. - max_num_batched_tokens = max(max_model_len, 2048) + self.max_num_batched_tokens = max(self.max_model_len, 2048) - if task == "embedding": + if self.task == "embedding": # For embedding, choose specific value for higher throughput - max_num_batched_tokens = max( - max_num_batched_tokens, + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS, ) - if is_multimodal_model: + if self.is_multimodal_model: # The value needs to be at least the number of multimodal tokens - max_num_batched_tokens = max( - max_num_batched_tokens, + self.max_num_batched_tokens = max( + self.max_num_batched_tokens, _MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS, ) - self.max_num_batched_tokens = max_num_batched_tokens - - if enable_chunked_prefill: + if self.enable_chunked_prefill: logger.info( "Chunked prefill is enabled with max_num_batched_tokens=%d.", self.max_num_batched_tokens) - self.task: Final = task - self.max_num_seqs = max_num_seqs - self.max_model_len = max_model_len - self.num_lookahead_slots = num_lookahead_slots - self.delay_factor = delay_factor - self.chunked_prefill_enabled = enable_chunked_prefill - self.preemption_mode = preemption_mode - self.num_scheduler_steps = num_scheduler_steps - self.multi_step_stream_outputs = multi_step_stream_outputs - self.send_delta_data = send_delta_data - self.policy = policy + self.chunked_prefill_enabled = self.enable_chunked_prefill self._verify_args() def _verify_args(self) -> None: @@ -1187,25 +1206,13 @@ def is_multi_step(self) -> bool: class DeviceConfig: device: Optional[torch.device] + device_type: str def __init__(self, device: str = "auto") -> None: if device == "auto": # Automated device type detection - if current_platform.is_cuda_alike(): - self.device_type = "cuda" - elif current_platform.is_neuron(): - self.device_type = "neuron" - elif current_platform.is_hpu(): - self.device_type = "hpu" - elif current_platform.is_openvino(): - self.device_type = "openvino" - elif current_platform.is_tpu(): - self.device_type = "tpu" - elif current_platform.is_cpu(): - self.device_type = "cpu" - elif current_platform.is_xpu(): - self.device_type = "xpu" - else: + self.device_type = current_platform.device_type + if not self.device_type: raise RuntimeError("Failed to infer device type") else: # Device type is assigned explicitly @@ -2089,13 +2096,15 @@ class CompilationConfig(BaseModel): - 'none,+op1,+op2' to enable only op1 and op2 By default, all custom ops are enabled when running without Inductor and disabled when running with Inductor (compile_level >= Inductor). + - splitting_ops: a list of ops to split the full graph into subgraphs, used in piecewise compilation. - CudaGraph capture: - use_cudagraph: whether to use cudagraph inside compilation. - False: cudagraph inside compilation is not used. - True: cudagraph inside compilation is used. It requires - that all input buffers have fixed addresses. - Note that this is orthogonal to the cudagraph capture out - side of compilation. + that all input buffers have fixed addresses, and all + splitting ops write their outputs to input buffers. + Note that this is orthogonal to the cudagraph capture logic + outside of compilation. TODO: move outside cudagraph logic into compilation. torch.compile will handle cudagraph capture logic in the future. - cudagraph_capture_sizes: sizes to capture cudagraph. @@ -2129,12 +2138,7 @@ class CompilationConfig(BaseModel): name because the config uses json format. If we pass the config from Python, functions can also be passed directly via Python object constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - - custom inductor passes: - - dump_graph_stages: list of stages for which we want to dump the graph. - Each pass defines its own stages (before, after, maybe in-between). - - dump_graph_dir: directory to dump the graph. Default is . - - enable_fusion: whether to enable the custom fusion pass. - TODO better pass enabling system. + - custom inductor passes: see PassConfig for more details Why we have different sizes for cudagraph and inductor: - cudagraph: a cudagraph captured for a specific size can only be used @@ -2149,6 +2153,10 @@ class CompilationConfig(BaseModel): level: int = 0 backend: str = "" custom_ops: List[str] = Field(default_factory=list) + splitting_ops: List[str] = Field(default_factory=lambda: [ + "vllm.unified_attention", + "vllm.unified_v1_flash_attention", + ]) use_inductor: bool = True inductor_specialize_for_cudagraph_no_more_than: Optional[int] = None @@ -2157,14 +2165,47 @@ class CompilationConfig(BaseModel): inductor_passes: Dict[str, str] = Field(default_factory=dict) use_cudagraph: bool = False - non_cudagraph_ops: List[str] = Field(default_factory=list) cudagraph_num_of_warmups: int = 0 cudagraph_capture_sizes: Optional[List[int]] = None cudagraph_copy_inputs: bool = False - dump_graph_stages: List[str] = Field(default_factory=list) - dump_graph_dir: Path = Field(default=Path(".")) - enable_fusion: bool = True + class PassConfig(BaseModel): + """ + Configuration for custom Inductor passes. + This is separate from general CompilationConfig so that inductor passes + don't all have access to full configuration - that would create a cycle + as the PassManager is set as a property of config. + - dump_graph_stages: list of stages for which we want to dump the graph. + Each pass defines its own stages (before, after, maybe in-between). + - dump_graph_dir: directory to dump the graphs. Default is . + - enable_fusion: whether to enable the custom fusion pass. + - enable_reshape: whether to enable the custom reshape elimination pass. + TODO better pass enabling system. + """ + dump_graph_stages: List[str] = Field(default_factory=list) + dump_graph_dir: Path = Field(default=Path(".")) + enable_fusion: bool = True + enable_reshape: bool = True + + def uuid(self): + """ + Produces a hash unique to the pass configuration. + Any new fields that affect compilation should be added to the hash. + Do not include dump_graph_* in the hash - they don't affect + compilation. + """ + dict_ = self.model_dump( + include={"enable_fusion", "enable_reshape"}) + encoded = json.dumps(dict_, sort_keys=True).encode("utf-8") + return hashlib.sha256(encoded).digest() + + def model_post_init(self, __context: Any) -> None: + if not self.enable_reshape and self.enable_fusion: + print_warning_once( + "Fusion enabled but reshape elimination disabled." + "RMSNorm + quant (fp8) fusion might not work") + + pass_config: PassConfig = Field(default_factory=PassConfig) # not configurable, computed after init compile_sizes: List[int] = PrivateAttr @@ -2174,6 +2215,11 @@ class CompilationConfig(BaseModel): enabled_custom_ops: Counter[str] = PrivateAttr disabled_custom_ops: Counter[str] = PrivateAttr + # Per-model forward context + # Mainly used to store attention cls + # Map from layer name to the attention cls + static_forward_context: Dict[str, Any] = PrivateAttr + @classmethod def from_cli(cls, cli_value: str) -> "CompilationConfig": """Parse the CLI value for the compilation config.""" @@ -2190,8 +2236,9 @@ def model_post_init(self, __context: Any) -> None: for k, v in self.inductor_passes.items(): if not isinstance(v, str): assert callable(v), ( - f"pass {k} should be a function or a qualified name") - self.inductor_compile_config[k] = v + f"pass {k} should be callable or a qualified name") + self.inductor_compile_config[k] = v if isinstance( + v, InductorPass) else CallableInductorPass(v) continue # resolve function from qualified name @@ -2199,10 +2246,12 @@ def model_post_init(self, __context: Any) -> None: module = ".".join(names[:-1]) func_name = names[-1] func = __import__(module).__dict__[func_name] - self.inductor_compile_config[k] = func + self.inductor_compile_config[k] = func if isinstance( + func, InductorPass) else CallableInductorPass(func) self.enabled_custom_ops = Counter() self.disabled_custom_ops = Counter() + self.static_forward_context = {} def init_backend(self) -> Union[str, Callable]: if self.level == CompilationLevel.NO_COMPILATION: @@ -2264,10 +2313,10 @@ class VllmConfig: model_config: ModelConfig = field(default=None, init=True) # type: ignore cache_config: CacheConfig = field(default=None, init=True) # type: ignore - parallel_config: ParallelConfig = field(default=None, - init=True) # type: ignore - scheduler_config: SchedulerConfig = field(default=None, - init=True) # type: ignore + parallel_config: ParallelConfig = field(default_factory=ParallelConfig, + init=True) + scheduler_config: SchedulerConfig = field(default_factory=SchedulerConfig, + init=True) device_config: DeviceConfig = field(default=None, init=True) # type: ignore load_config: LoadConfig = field(default=None, init=True) # type: ignore @@ -2339,20 +2388,43 @@ def __post_init__(self): self.quant_config = VllmConfig._get_quantization_config( self.model_config, self.load_config) + if self.scheduler_config is not None and \ + self.model_config is not None and \ + self.scheduler_config.chunked_prefill_enabled and \ + self.model_config.dtype == torch.float32 and \ + current_platform.get_device_capability() == (7, 5): + print_warning_once( + "Turing devices tensor cores do not support float32 matmul. " + "To workaround this limitation, vLLM will set 'ieee' input " + "precision for chunked prefill triton kernels.") + if self.compilation_config is None: self.compilation_config = CompilationConfig() - if envs.VLLM_USE_V1: + if envs.VLLM_USE_V1 and not self.model_config.enforce_eager: # NOTE(woosuk): Currently, we use inductor because the piecewise # CUDA graphs do not work properly with the custom CUDA kernels. # FIXME(woosuk): Disable inductor to reduce the compilation time # and avoid any potential issues with the inductor. self.compilation_config.custom_ops = ["none"] self.compilation_config.use_cudagraph = True - self.compilation_config.non_cudagraph_ops = [ - "vllm.unified_v1_flash_attention" - ] self.compilation_config.use_inductor = True - self.compilation_config.enable_fusion = False + self.compilation_config.pass_config.enable_fusion = False + self.compilation_config.pass_config.enable_reshape = False + self.compilation_config.level = CompilationLevel.PIECEWISE + + if self.cache_config is not None and \ + self.cache_config.cpu_offload_gb > 0 and \ + self.compilation_config.level != CompilationLevel.NO_COMPILATION: + logger.warning( + "CPU offload is not supported with `torch.compile` yet." + " Disabling `torch.compile`.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION + + if self.lora_config is not None and self.compilation_config.level !=\ + CompilationLevel.NO_COMPILATION: + logger.warning("LoRA is not supported with `torch.compile` yet. " + "Disabling `torch.compile`.") + self.compilation_config.level = CompilationLevel.NO_COMPILATION current_platform.check_and_update_config(self) @@ -2396,3 +2468,53 @@ def __str__(self): self.cache_config.enable_prefix_caching, self.model_config.use_async_output_proc, self.model_config.mm_processor_kwargs) + + +_current_vllm_config: Optional[VllmConfig] = None + + +@contextmanager +def set_current_vllm_config(vllm_config: VllmConfig): + """ + Temporarily set the current VLLM config. + Used during model initialization. + We save the current VLLM config in a global variable, + so that all modules can access it, e.g. custom ops + can access the VLLM config to determine how to dispatch. + """ + global _current_vllm_config + old_vllm_config = _current_vllm_config + from vllm.compilation.counter import compilation_counter + num_models_seen = compilation_counter.num_models_seen + try: + _current_vllm_config = vllm_config + yield + finally: + logger.debug("enabled custom ops: %s", + vllm_config.compilation_config.enabled_custom_ops) + logger.debug("disabled custom ops: %s", + vllm_config.compilation_config.disabled_custom_ops) + if vllm_config.compilation_config.level == CompilationLevel.PIECEWISE \ + and compilation_counter.num_models_seen == num_models_seen: + # If the model supports compilation, + # compilation_counter.num_models_seen should be increased + # by at least 1. + # If it is not increased, it means the model does not support + # compilation (does not have @support_torch_compile decorator). + logger.warning( + "`torch.compile` is turned on, but the model %s" + " does not support it. Please open an issue on GitHub" + "if you want it to be supported.", + vllm_config.model_config.model) + _current_vllm_config = old_vllm_config + + +def get_current_vllm_config() -> VllmConfig: + if _current_vllm_config is None: + # in ci, usually when we test custom ops/modules directly, + # we don't set the vllm config. In that case, we set a default + # config. + logger.warning("Current VLLM config is not set.") + from vllm.config import VllmConfig + return VllmConfig() + return _current_vllm_config diff --git a/vllm/core/block/cpu_gpu_block_allocator.py b/vllm/core/block/cpu_gpu_block_allocator.py index 9727f6e19b84e..3197af3c2b7a4 100644 --- a/vllm/core/block/cpu_gpu_block_allocator.py +++ b/vllm/core/block/cpu_gpu_block_allocator.py @@ -306,14 +306,6 @@ def mark_blocks_as_computed(self, block_ids: List[int]) -> None: device = Device.GPU return self._allocators[device].mark_blocks_as_computed(block_ids) - def get_computed_block_ids(self, prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool) -> List[int]: - # Prefix caching only supported on GPU. - device = Device.GPU - return self._allocators[device].get_computed_block_ids( - prev_computed_block_ids, block_ids, skip_last_block_id) - def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: # Prefix caching only supported on GPU. @@ -342,6 +334,13 @@ def get_and_reset_swaps(self) -> List[Tuple[int, int]]: self._swap_mapping.clear() return list(mapping.items()) + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + device: Device = Device.GPU, + ) -> List[int]: + return self._allocators[device].find_cached_blocks_prefix(block_hashes) + class NullBlock(Block): """ diff --git a/vllm/core/block/interfaces.py b/vllm/core/block/interfaces.py index 72bbab1dcea5d..06f4851af3466 100644 --- a/vllm/core/block/interfaces.py +++ b/vllm/core/block/interfaces.py @@ -159,12 +159,6 @@ def mark_blocks_as_accessed(self, block_ids: List[int], def mark_blocks_as_computed(self, block_ids: List[int]) -> None: pass - @abstractmethod - def get_computed_block_ids(self, prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool) -> List[int]: - pass - @abstractmethod def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: @@ -192,6 +186,13 @@ def get_prefix_cache_hit_rate(self) -> float: class NoFreeBlocksError(ValueError): pass + @abstractmethod + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + ) -> List[int]: + pass + class DeviceAwareBlockAllocator(ABC): @@ -207,9 +208,12 @@ def allocate_immutable_block(self, prev_block: Optional[Block], pass @abstractmethod - def allocate_immutable_blocks(self, prev_block: Optional[Block], - block_token_ids: List[List[int]], - device: Device) -> List[Block]: + def allocate_immutable_blocks( + self, + prev_block: Optional[Block], + block_token_ids: List[List[int]], + device: Device, + ) -> List[Block]: pass @abstractmethod @@ -246,12 +250,6 @@ def mark_blocks_as_accessed(self, block_ids: List[int], def mark_blocks_as_computed(self, block_ids: List[int]) -> None: pass - @abstractmethod - def get_computed_block_ids(self, prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool) -> List[int]: - pass - @abstractmethod def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: @@ -284,3 +282,11 @@ def allocate_or_get_null_block(self) -> Block: def get_prefix_cache_hit_rate(self, device: Device) -> float: """Prefix cache hit rate. -1 means not supported or disabled.""" pass + + @abstractmethod + def find_cached_blocks_prefix( + self, + block_hashes: List[int], + device: Device = Device.GPU, + ) -> List[int]: + pass diff --git a/vllm/core/block/naive_block.py b/vllm/core/block/naive_block.py index 9341a518d11c6..a2af5ad6362c1 100644 --- a/vllm/core/block/naive_block.py +++ b/vllm/core/block/naive_block.py @@ -262,13 +262,6 @@ def mark_blocks_as_computed(self, block_ids: List[int]) -> None: """ pass - def get_computed_block_ids(self, prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool) -> List[int]: - """No prefix caching here => return empty list - """ - return [] - def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: """Determine blocks that can be skipped in prefill. @@ -329,6 +322,10 @@ def swap_in(self, blocks: List[Block]) -> None: def get_prefix_cache_hit_rate(self) -> float: return -1 + def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: + # Not applicable for naive block allocator. + return [] + class NaiveBlock(Block): """An implementation of the Block class that does not support prefix diff --git a/vllm/core/block/prefix_caching_block.py b/vllm/core/block/prefix_caching_block.py index 57527e39b9bdd..b736167f6ceb4 100644 --- a/vllm/core/block/prefix_caching_block.py +++ b/vllm/core/block/prefix_caching_block.py @@ -1,13 +1,18 @@ """Token blocks.""" +import sys +from bisect import bisect_left from os.path import commonprefix -from typing import Dict, FrozenSet, Iterable, List, Optional, Set, Tuple +from typing import (Callable, Dict, FrozenSet, Iterable, List, Optional, Set, + Tuple) from vllm.core.block.common import (CacheMetricData, CopyOnWriteTracker, get_all_blocks_recursively) -from vllm.core.block.interfaces import Block, BlockAllocator, BlockId, Device +from vllm.core.block.interfaces import (Block, BlockAllocator, BlockId, Device, + DeviceAwareBlockAllocator) from vllm.core.block.naive_block import (BlockPool, NaiveBlock, NaiveBlockAllocator) from vllm.core.evictor import EvictionPolicy, Evictor, make_evictor +from vllm.sequence import Sequence PrefixHash = int @@ -534,26 +539,6 @@ def block_is_computed(self, block_id: int) -> bool: else: return block_id in self.evictor - def get_computed_block_ids(self, - prev_computed_block_ids: List[int], - block_ids: List[int], - skip_last_block_id: bool = True) -> List[int]: - prev_prefix_size = len(prev_computed_block_ids) - cur_size = len(block_ids) - if skip_last_block_id: - cur_size -= 1 - - # Sanity checks - assert cur_size >= 0 - assert prev_prefix_size <= cur_size - - ret = prev_computed_block_ids - for i in range(prev_prefix_size, cur_size): - block_id = block_ids[i] - if self.block_is_computed(block_id): - ret.append(block_id) - return ret - def get_common_computed_block_ids( self, computed_seq_block_ids: List[List[int]]) -> List[int]: """Return the block ids that are common for a given sequence group. @@ -634,6 +619,47 @@ def swap_in(self, blocks: List[Block]) -> None: block.block_id = block_id # Assign block_id + def find_cached_blocks_prefix(self, block_hashes: List[int]) -> List[int]: + """ + Given a list of block hashes, return the prefix of the block hashes that + are all cached. + + Since a block's block hash includes the hashes of all previous blocks, + and we only allocate/deallocate blocks in the entire sequence, so if a + block is cached, then all previous blocks are also cached. With this + property, we can use binary search to find the prefix of cached blocks. + + Args: + block_hashes (List[int]): The list of block hashes. + + Returns: + List[int]: The prefix of the `block_hashes` that are cached. + """ + + def _block_is_cached(block_hash: PrefixHash) -> bool: + if block_hash not in self._cached_blocks: + return False + + cached_block_id = self._cached_blocks[block_hash] + # We only consider the blocks that are marked as computed. + return self.block_is_computed(cached_block_id) + + def _bisect_left(a, x, key: Callable[[PrefixHash], bool]) -> int: + + # python <= 3.10 don't have the key argument + if sys.version_info < (3, 10): + a = [key(e) for e in a] + return bisect_left(a, x) + else: + return bisect_left(a, x, key=key) + + # Look for the first block that's not cached, and returns the prefix + # i.e. blocks that are cached. + idx = _bisect_left(block_hashes, + True, + key=lambda x: not _block_is_cached(x)) + return block_hashes[:idx] + class PrefixCachingBlock(Block): """A block implementation that supports prefix caching. @@ -843,86 +869,126 @@ def hash_block_tokens(is_first_block: bool, prev_block_hash: Optional[int], class ComputedBlocksTracker: - """Handles caching of per-sequence computed block ids. - When a sequence appears for the first time, it traverses all of the - blocks and detects the prefix of blocks that is computed. On the - subsequent times, it only traverses the new blocks that were added - and updates the already recorded prefix of blocks with the newly - computed blocks. - - To avoid redundant traversals, the algorithm also detects when there - is a "gap" in the computed prefix. For example, if we have blocks = - [1,2,3,4,5], and we have detected [1,2,3] as the computed prefix, then - we won't try to add more computed blocks to [1,2,3] in this sequence - iteration, and will add more computed blocks only after the sequence is - freed and reused again. - - Note that currently, for a given sequence, we also skip the last - block id for caching purposes, to avoid caching of a full sequence """ + Tracks the computed blocks for each sequence. - def __init__(self, allocator): - self._allocator = allocator - self._cached_computed_seq_blocks: Dict[int, Tuple[List[int], - bool]] = {} + Internally, it maintains a map from sequence id to the list of block hashes + for the sequence. We cache the hashes of the full blocks for each sequence, + and make sure the hash is calculated in the same way as the allocator. + When a sequence is being decoded, we also update the sequence's hash + accordingly and incrementally. - def add_seq(self, seq_id: int) -> None: - """Start tracking seq_id - """ - assert seq_id not in self._cached_computed_seq_blocks - self._cached_computed_seq_blocks[seq_id] = ([], False) - - def remove_seq(self, seq_id: int) -> None: - """Stop tracking seq_id - """ - assert seq_id in self._cached_computed_seq_blocks - del self._cached_computed_seq_blocks[seq_id] - - def get_cached_computed_blocks_and_update( - self, seq_id: int, block_ids: List[int]) -> List[int]: - """ Look at the class documentation for details - """ - # Ensure seq_id is already tracked - assert seq_id in self._cached_computed_seq_blocks - - # Get cached data (may be empty on the first time) - prev_computed_block_ids, has_gap = self._cached_computed_seq_blocks[ - seq_id] - - if has_gap: - # When gap is detected, we do not add more computed blocks at this - # sequence iteration - return prev_computed_block_ids - - # We do not consider the last block id for caching purposes. - num_cur_blocks = len(block_ids) - 1 - assert num_cur_blocks >= 0 - - if len(prev_computed_block_ids) >= num_cur_blocks: - # Cache HIT - assert len(prev_computed_block_ids) == num_cur_blocks - return prev_computed_block_ids - - # If here, then we may possibly add more computed blocks. As a result, - # traverse the additional blocks after prev_computed_block_ids to - # detect more computed blocks and add them. - - # Incremental init for seq_id => Look only at the new blocks - computed_block_ids = self._allocator.get_computed_block_ids( # noqa: E501 - prev_computed_block_ids, - block_ids, - skip_last_block_id= - True, # We skip last block id to avoid caching of full seq - ) + From the sequence hash, with prefix caching enabled, we could also calculate + the number of cached tokens for the sequence by looking up the number of + cached block hashes in the allocator. + """ - # Detect if there is a "gap" - has_gap = len(computed_block_ids) < num_cur_blocks + def __init__( + self, + allocator: DeviceAwareBlockAllocator, + block_size: int, + enable_caching: bool, + ): + self._allocator = allocator + self._block_size = block_size + self._enable_caching = enable_caching + + # A map from seq_id to the list of block hashes for the + # sequence. This is so that we don't have to recompute the block hashes + # for the sequence when we need to check if the sequence is cached. + # Note a block that's not full will not have its hash calculated and + # recorded. + self._seq_id_to_blocks_hashes: Dict[int, List[int]] = {} + + # A map from seq_id to the number of tokens that are cached for the + # sequence. + # We need this so that a sequence in continuous prefill doesn't + # accidentally see its cached token count change. See comments in + # `get_num_cached_tokens` for more details. + self._seq_id_to_num_tokens_computed: Dict[int, int] = {} + + def _update_seq_hashes(self, seq: Sequence) -> None: + """Incrementally update the sequence's block hashes and record them.""" + assert self._enable_caching + + block_hashes_recorded = self._seq_id_to_blocks_hashes.get( + seq.seq_id, []) + cur_num_blocks_recorded = len(block_hashes_recorded) + token_ids = seq.get_token_ids() + assert len(token_ids) >= cur_num_blocks_recorded * self._block_size, ( + f"The sequence has {len(token_ids)} tokens, but" + f" already recorded {cur_num_blocks_recorded} blocks. " + "This should not happen since we assume blocks are " + "only appended other than recomputation. When the sequence is " + "recomputed, we should have removed the info of the old blocks.") + # Update the computed block hashes for the sequence. Since only full + # blocks are considered as "computed", we take floor here. + num_computed_blocks = len(token_ids) // self._block_size + + # We need to know the hash of the previous block to compute the hash of + # the current block so that blocks could be uniquely identified across + # sequences of prefixes. + prev_block_hash = (None if cur_num_blocks_recorded == 0 else + block_hashes_recorded[-1]) + # Only update the computed block hashes for the new blocks + for i in range(cur_num_blocks_recorded, num_computed_blocks): + assert len(token_ids) >= (i + 1) * self._block_size + block_token_ids = token_ids[i * self._block_size:(i + 1) * + self._block_size] + # This has to be kept in sync with the allocator's hash + # calculation. + block_hash = PrefixCachingBlock.hash_block_tokens( + is_first_block=prev_block_hash is None, + prev_block_hash=prev_block_hash, + cur_block_token_ids=block_token_ids, + ) + block_hashes_recorded.append(block_hash) + prev_block_hash = block_hash + + self._seq_id_to_blocks_hashes[seq.seq_id] = block_hashes_recorded + + def get_num_cached_tokens(self, seq: Sequence) -> int: + if not self._enable_caching: + return 0 + + # We always try to update the sequence hashes on the fly. + # This is to ensure that we don't miss any cached tokens for the + # sequence during decode. + # This routine should only update hash for any new blocks too. + self._update_seq_hashes(seq) + + num_computed_tokens_prev = self._seq_id_to_num_tokens_computed.get( + seq.seq_id, None) + + # TODO(rickyx): This hack could be removed once we mark blocks as + # computed correctly with chunked prefills. + if num_computed_tokens_prev is not None and seq.is_prefill(): + # For a sequence that is still in prefill, we don't + # recompute the number of cached tokens. + # This also handles correctly chunked prefill since currently + # we mark blocks as computed even if the sequence is still partially + # prefilled. So a continuously prefilled sequence should not + # see its cached token count change while running. + return num_computed_tokens_prev + + block_hashes = self._seq_id_to_blocks_hashes[seq.seq_id] + + # This is O(logN), where N is the number of blocks. + num_cached_blocks = len( + self._allocator.find_cached_blocks_prefix(block_hashes)) + num_cached_tokens = num_cached_blocks * self._block_size + self._seq_id_to_num_tokens_computed[seq.seq_id] = num_cached_tokens + return num_cached_tokens - # Record - self._cached_computed_seq_blocks[seq_id] = (computed_block_ids, - has_gap) + def remove_seq(self, seq_id: int) -> None: + """Stop tracking the sequence.""" + if not self._enable_caching: + return + assert seq_id in self._seq_id_to_blocks_hashes + del self._seq_id_to_blocks_hashes[seq_id] - return computed_block_ids + assert seq_id in self._seq_id_to_num_tokens_computed + del self._seq_id_to_num_tokens_computed[seq_id] class LastAccessBlocksTracker: diff --git a/vllm/core/block_manager.py b/vllm/core/block_manager.py index 21f4c63b6572d..209487c6b4f9e 100644 --- a/vllm/core/block_manager.py +++ b/vllm/core/block_manager.py @@ -101,7 +101,7 @@ def __init__( self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} self._computed_blocks_tracker = ComputedBlocksTracker( - self.block_allocator) + self.block_allocator, self.block_size, self.enable_caching) self._last_access_blocks_tracker = LastAccessBlocksTracker( self.block_allocator) @@ -170,7 +170,6 @@ def allocate(self, seq_group: SequenceGroup) -> None: self.block_tables[seq.seq_id] = block_table # Track seq - self._computed_blocks_tracker.add_seq(seq.seq_id) self._last_access_blocks_tracker.add_seq(seq.seq_id) # Assign the block table for each sequence. @@ -178,7 +177,6 @@ def allocate(self, seq_group: SequenceGroup) -> None: self.block_tables[seq.seq_id] = block_table.fork() # Track seq - self._computed_blocks_tracker.add_seq(seq.seq_id) self._last_access_blocks_tracker.add_seq(seq.seq_id) # Allocate cross-attention block table for encoder sequence @@ -314,11 +312,13 @@ def get_common_computed_block_ids( """ computed_seq_block_ids = [] for seq in seqs: - computed_seq_block_ids.append( - self._computed_blocks_tracker. - get_cached_computed_blocks_and_update( - seq.seq_id, - self.block_tables[seq.seq_id].physical_block_ids)) + all_blocks = self.block_tables[seq.seq_id].physical_block_ids + num_cached_tokens = ( + self._computed_blocks_tracker.get_num_cached_tokens(seq)) + assert num_cached_tokens % self.block_size == 0 + num_cached_blocks = num_cached_tokens // self.block_size + computed_block_ids = all_blocks[:num_cached_blocks] + computed_seq_block_ids.append(computed_block_ids) # NOTE(sang): This assumes seq_block_ids doesn't contain any None. return self.block_allocator.get_common_computed_block_ids( @@ -332,7 +332,6 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: self.block_tables[child_seq.seq_id] = src_block_table.fork() # Track child seq - self._computed_blocks_tracker.add_seq(child_seq.seq_id) self._last_access_blocks_tracker.add_seq(child_seq.seq_id) def can_swap_in(self, seq_group: SequenceGroup, @@ -503,3 +502,9 @@ def _can_swap(self, return AllocStatus.OK else: return AllocStatus.LATER + + def get_num_cached_tokens(self, seq: Sequence) -> int: + """Get the number of tokens in blocks that are already computed and + cached in the block manager for the sequence. + """ + return self._computed_blocks_tracker.get_num_cached_tokens(seq) diff --git a/vllm/core/interfaces.py b/vllm/core/interfaces.py index 9501a516bf020..b10b8d3f4a5bf 100644 --- a/vllm/core/interfaces.py +++ b/vllm/core/interfaces.py @@ -121,3 +121,7 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup, def get_prefix_cache_hit_rate(self, device: Device) -> float: """Prefix cache hit rate. -1 means not supported or disabled.""" pass + + @abstractmethod + def get_num_cached_tokens(self, seq: Sequence) -> int: + pass diff --git a/vllm/core/placeholder_block_space_manager.py b/vllm/core/placeholder_block_space_manager.py index a337392bbed53..26d42b7f1790e 100644 --- a/vllm/core/placeholder_block_space_manager.py +++ b/vllm/core/placeholder_block_space_manager.py @@ -89,3 +89,6 @@ def mark_blocks_as_computed(self, seq_group: SequenceGroup, def get_prefix_cache_hit_rate(self, device: Device) -> float: return -1 + + def get_num_cached_tokens(self, seq: Sequence) -> int: + return 0 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index af4671ec29be9..530cbdc3a9190 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -56,11 +56,16 @@ class SchedulingBudget: max_num_seqs: int _request_ids_num_batched_tokens: Set[str] = field(default_factory=set) _request_ids_num_curr_seqs: Set[str] = field(default_factory=set) + # Number of cached tokens in the batch. + _num_cached_tokens: int = 0 + # Number of actual non-cached tokens in the batch. _num_batched_tokens: int = 0 _num_curr_seqs: int = 0 def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): - assert num_new_tokens != 0 + # We allow num_new_tokens to be 0 when the entire sequence has + # been cached. + assert num_new_tokens >= 0 assert num_new_seqs != 0 return (self.num_batched_tokens + num_new_tokens <= self.token_budget and self.num_curr_seqs + num_new_seqs <= self.max_num_seqs) @@ -68,12 +73,18 @@ def can_schedule(self, *, num_new_tokens: int, num_new_seqs: int): def remaining_token_budget(self): return self.token_budget - self.num_batched_tokens - def add_num_batched_tokens(self, req_id: str, num_batched_tokens: int): + def add_num_batched_tokens(self, + req_id: str, + num_batched_tokens: int, + num_cached_tokens: int = 0): if req_id in self._request_ids_num_batched_tokens: return + assert num_cached_tokens >= 0 + assert num_batched_tokens >= 0 self._request_ids_num_batched_tokens.add(req_id) self._num_batched_tokens += num_batched_tokens + self._num_cached_tokens += num_cached_tokens def subtract_num_batched_tokens(self, req_id: str, num_batched_tokens: int): @@ -101,6 +112,10 @@ def num_batched_tokens(self): def num_curr_seqs(self): return self._num_curr_seqs + @property + def num_cached_tokens(self): + return self._num_cached_tokens + @dataclass class ScheduledSequenceGroup: @@ -541,9 +556,19 @@ def _schedule_running( assert len(self._async_stopped) == 0 while running_queue: seq_group = running_queue[0] - num_running_tokens = self._get_num_new_tokens( - seq_group, SequenceStatus.RUNNING, enable_chunking, budget) - + # We discard the cached tokens info here because we don't need it + # for running sequence: + # 1. If a sequence is running with chunked prefill, the cached + # tokens info was already used for the first prefill. + # 2. If a sequence is running with non-chunked prefill, then + # there it's a decoding sequence, and the cached tokens info is + # irrelevant. + num_uncached_new_tokens, _ = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.RUNNING, enable_chunking, + budget)) + + num_running_tokens = num_uncached_new_tokens if num_running_tokens == 0: # No budget => Stop break @@ -715,13 +740,15 @@ def _schedule_swapped( # The total number of sequences in the RUNNING state should not # exceed the maximum number of sequences. num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.SWAPPED, - enable_chunking, budget) - - if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.SWAPPED, enable_chunking, + budget)) + + if num_new_tokens_uncached == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + ): break if lora_int_id > 0 and curr_loras is not None: @@ -732,12 +759,19 @@ def _schedule_swapped( is_prefill = seq_group.is_prefill() if is_prefill: prefill_seq_groups.append( - ScheduledSequenceGroup(seq_group, - token_chunk_size=num_new_tokens)) + ScheduledSequenceGroup( + seq_group, + token_chunk_size=num_new_tokens_uncached + + num_new_tokens_cached, + )) else: decode_seq_groups.append( ScheduledSequenceGroup(seq_group, token_chunk_size=1)) - budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, + ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) swapped_queue.extendleft(leftover_swapped) @@ -803,26 +837,30 @@ def _schedule_priority_preemption( if waiting_queue: seq_group = waiting_queue.popleft() num_new_seqs = seq_group.get_max_num_running_seqs() - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.WAITING, - False, budget) + num_new_tokens_uncached, _ = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.WAITING, False, budget)) #Only preempt if priority inversion exists while running_queue and self._get_priority( running_queue[-1]) > self._get_priority(seq_group): #Only preempt if waiting sequence cannot be allocated can_allocate = self.block_manager.can_allocate(seq_group) - if (num_new_tokens and can_allocate == AllocStatus.OK - and budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + if (num_new_tokens_uncached > 0 + and can_allocate == AllocStatus.OK + and budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + )): break #Adjust budget to remove the victim sequence group vseq_group = running_queue.pop() - num_running_tokens = self._get_num_new_tokens( - vseq_group, SequenceStatus.RUNNING, False, budget) - budget.subtract_num_batched_tokens(vseq_group.request_id, - num_running_tokens) + num_running_tokens_uncached, _ = ( + self._get_num_new_uncached_and_cached_tokens( + vseq_group, SequenceStatus.RUNNING, False, budget)) + budget.subtract_num_batched_tokens( + vseq_group.request_id, num_running_tokens_uncached) num_running_seqs = vseq_group.get_max_num_running_seqs() budget.subtract_num_seqs(vseq_group.request_id, num_running_seqs) @@ -882,9 +920,12 @@ def _schedule_prefills( assert len(waiting_seqs) == 1, ( "Waiting sequence group should have only one prompt " "sequence.") - num_new_tokens = self._get_num_new_tokens(seq_group, - SequenceStatus.WAITING, - enable_chunking, budget) + num_new_tokens_uncached, num_new_tokens_cached = ( + self._get_num_new_uncached_and_cached_tokens( + seq_group, SequenceStatus.WAITING, enable_chunking, + budget)) + num_new_tokens = num_new_tokens_uncached + num_new_tokens_cached + if not enable_chunking: num_prompt_tokens = waiting_seqs[0].get_len() assert num_new_tokens == num_prompt_tokens @@ -935,10 +976,18 @@ def _schedule_prefills( waiting_queue.popleft() continue + if (budget.num_batched_tokens >= + self.scheduler_config.max_num_batched_tokens): + # We've reached the budget limit - since there might be + # continuous prefills in the running queue, we should break + # to avoid scheduling any new prefills. + break + num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_new_tokens == 0 - or not budget.can_schedule(num_new_tokens=num_new_tokens, - num_new_seqs=num_new_seqs)): + if num_new_tokens_uncached == 0 or not budget.can_schedule( + num_new_tokens=num_new_tokens_uncached, + num_new_seqs=num_new_seqs, + ): break # Can schedule this request. @@ -967,7 +1016,11 @@ def _schedule_prefills( seq_groups.append( ScheduledSequenceGroup(seq_group=seq_group, token_chunk_size=num_new_tokens)) - budget.add_num_batched_tokens(seq_group.request_id, num_new_tokens) + budget.add_num_batched_tokens( + seq_group.request_id, + num_batched_tokens=num_new_tokens_uncached, + num_cached_tokens=num_new_tokens_cached, + ) budget.add_num_seqs(seq_group.request_id, num_new_seqs) # Queue requests that couldn't be scheduled. @@ -1075,7 +1128,8 @@ def _schedule_default(self) -> SchedulerOutputs: return SchedulerOutputs( scheduled_seq_groups=scheduled_seq_groups, num_prefill_groups=num_prefill_groups, - num_batched_tokens=budget.num_batched_tokens, + num_batched_tokens=budget.num_batched_tokens + + budget.num_cached_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_copy=blocks_to_copy, @@ -1119,7 +1173,6 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: running_scheduled.swapped_out) == 0: swapped_in = self._schedule_swapped(budget, curr_loras) - # Schedule new prefills. prefills = self._schedule_prefills(budget, curr_loras, enable_chunking=True) @@ -1157,7 +1210,8 @@ def _schedule_chunked_prefill(self) -> SchedulerOutputs: num_prefill_groups=(len(prefills.seq_groups) + len(swapped_in.prefill_seq_groups) + len(running_scheduled.prefill_seq_groups)), - num_batched_tokens=budget.num_batched_tokens, + num_batched_tokens=budget.num_batched_tokens + + budget.num_cached_tokens, blocks_to_swap_in=swapped_in.blocks_to_swap_in, blocks_to_swap_out=running_scheduled.blocks_to_swap_out, blocks_to_copy=running_scheduled.blocks_to_copy + @@ -1303,6 +1357,7 @@ def schedule( encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, state=seq_group.state, + token_type_ids=seq_group.token_type_ids, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. # the subsequent comms can still use delta, but @@ -1584,64 +1639,178 @@ def _get_num_lookahead_slots(self, is_prefill: bool, return self.scheduler_config.num_lookahead_slots - def _get_num_new_tokens(self, seq_group: SequenceGroup, - status: SequenceStatus, enable_chunking: bool, - budget: SchedulingBudget) -> int: - """Get the next new tokens to compute for a given sequence group - that's in a given `status`. + def _get_num_new_uncached_and_cached_tokens( + self, + seq_group: SequenceGroup, + status: SequenceStatus, + enable_chunking: bool, + budget: SchedulingBudget, + ) -> Tuple[int, int]: + """ + Returns the number of new uncached and cached tokens to schedule for a + given sequence group that's in a given `status`. The API could chunk the number of tokens to compute based on `budget` if `enable_chunking` is True. If a sequence group has multiple sequences (e.g., running beam search), it means it is in decoding phase, so chunking doesn't happen. - Returns 0 if the new token cannot be computed due to token budget. + Returns (0, 0) if the new token cannot be computed due to token budget. + + The cached tokens's blocks are already computed, and the attention + backend will reuse the cached blocks rather than recomputing them. So + the scheduler could schedule these cached tokens "for free". + + Args: + seq_group: The sequence group to get the number of new tokens to + schedule. + status: The status of the sequences to get the number of new tokens + to schedule. + enable_chunking: Whether to chunk the number of tokens to compute. + budget: The budget to chunk the number of tokens to compute. + + + Returns: + A tuple of two ints. The first int is the number of new uncached + tokens to schedule. The second int is the number of cached tokens. + If no more new tokens can be scheduled, returns (0, 0). """ - num_new_tokens = 0 + num_cached_new_tokens = 0 + num_uncached_new_tokens = 0 + seqs = seq_group.get_seqs(status=status) + # Compute the number of new uncached and cached tokens for + # each sequence. for seq in seqs: - num_new_tokens += seq.get_num_new_tokens() - assert num_new_tokens > 0 - # Chunk if a running request cannot fit in the given budget. - # If number of seq > 1, it means it is doing beam search - # in a decode phase. Do not chunk. + if not seq.is_prefill(): + # Decode sequences should always just have 1 uncached token + # TODO(rickyx): Actually is this still correct for multi-step? + num_uncached_new_tokens += 1 + continue + + num_computed_tokens_seq = seq.get_num_computed_tokens() + all_num_new_tokens_seq = seq.get_len() - num_computed_tokens_seq + if not self.cache_config.enable_prefix_caching: + # If prefix caching is not enabled, all new tokens are uncached. + num_uncached_new_tokens += all_num_new_tokens_seq + continue + + # NOTE: the cache token might be currently in a block that's in an + # evictor meaning that it's not yet allocated. However, we don't + # exclude such tokens in the cache count because it will be + # guaranteed to be allocated later if the sequence can be allocated. + num_cached_tokens_seq = self.block_manager.get_num_cached_tokens( + seq) + + # Sanity check. + if num_cached_tokens_seq < num_computed_tokens_seq: + # This should only happen with chunked prefill, and + # the seq is still in prefill. The `num_cached_tokens_seq` + # is the value we calculated on scheduling the first prefill. + # For subsequent continuous prefill steps, we cached the + # number of cache tokens for the sequence so the cached token + # count could be less than the number of computed tokens. + # See comments on `ComputedBlocksTracker` for more details. + assert ( + seq.is_prefill() and seq.status == SequenceStatus.RUNNING + and self.scheduler_config.chunked_prefill_enabled + ), ("Number of cached tokens should not be less than the " + "number of computed tokens for a sequence that's still " + f"in prefill. But there are {num_cached_tokens_seq} cached " + f"tokens and {num_computed_tokens_seq} computed tokens " + f"for sequence {seq.seq_id}.") + + num_cached_new_tokens_seq = max( + 0, num_cached_tokens_seq - num_computed_tokens_seq) + num_uncached_new_tokens_seq = (all_num_new_tokens_seq - + num_cached_new_tokens_seq) + + num_uncached_new_tokens += num_uncached_new_tokens_seq + num_cached_new_tokens += num_cached_new_tokens_seq + + if num_uncached_new_tokens == 0 and num_cached_new_tokens > 0: + # For a fully cached hit sequence, we actually need to recompute the + # last token. So we need at least 1 uncached token to schedule. + # See ModelRunner._compute_for_prefix_cache_hit for more details. + num_uncached_new_tokens = 1 + num_cached_new_tokens -= 1 + if enable_chunking and len(seqs) == 1: - remaining_token_budget = budget.remaining_token_budget() - if self.scheduler_config.is_multi_step: - # The current multi-step + chunked prefill capability does - # not actually support chunking prompts. - # - # Therefore, `num_new_tokens` is computed in the same fashion - # for both multi-step+chunked-prefill & - # multi-step+chunked-prefill+APC - # - # Prompts with more tokens than the current remaining budget - # are postponed to future scheduler steps - if num_new_tokens > self._get_prompt_limit(seq_group): - # If the seq_group is in prompt-stage, pass the - # num_new_tokens as-is so the caller can ignore - # the sequence. - pass - else: - num_new_tokens = 0 \ - if num_new_tokens > remaining_token_budget \ - else num_new_tokens - elif self.cache_config.enable_prefix_caching: - # When prefix caching is enabled, we always allocate - # the number of new tokens that is dividable by the block - # size to avoid partial block matching. - block_size = self.cache_config.block_size - remainder = budget.token_budget % block_size - if remainder != 0: - raise ValueError("When enabling chunked prefill and " - "prefix caching, max_num_batched_tokens " - "(chunk size) must be dividable by " - "block size, but got chunk_size " - f"({budget.token_budget}) % block_size " - f"({block_size}) = {remainder}") - if remaining_token_budget < num_new_tokens: - num_new_tokens = (remaining_token_budget // - block_size) * block_size - else: - num_new_tokens = min(num_new_tokens, remaining_token_budget) + # Chunk if a running request cannot fit in the given budget. + # If number of seq > 1, it means it is doing beam search + # in a decode phase. Do not chunk. + num_uncached_new_tokens = self._chunk_new_tokens_to_schedule( + self.scheduler_config, + self.cache_config, + budget, + self._get_prompt_limit(seq_group), + num_uncached_new_tokens, + ) + + return num_uncached_new_tokens, num_cached_new_tokens + + @staticmethod + def _chunk_new_tokens_to_schedule( + scheduler_config: SchedulerConfig, + cache_config: CacheConfig, + budget: SchedulingBudget, + prompt_limit: int, + num_new_tokens: int, + ) -> int: + """ + Chunks the number of new tokens to schedule based on the budget when + chunked prefill is enabled. + + Args: + scheduler_config: The scheduler config. + cache_config: The cache config. + budget: The budget to chunk the number of tokens to compute. + prompt_limit: The maximum number of tokens allowed in a prompt. + num_new_tokens: The number of new tokens to schedule. + + Returns: + The number of new tokens to schedule after chunking. + """ + remaining_token_budget = budget.remaining_token_budget() + if scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps + if num_new_tokens > prompt_limit: + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + return num_new_tokens + + return (0 if num_new_tokens > remaining_token_budget else + num_new_tokens) + + if cache_config.enable_prefix_caching: + # Adjust the remaining token budget to be divisible by the block + # size when prefix caching is enabled. + + # When prefix caching is enabled, we always allocate + # the number of new tokens that is dividable by the block + # size to avoid partial block matching. + block_size = cache_config.block_size + remainder = budget.token_budget % block_size + if remainder != 0: + raise ValueError("When enabling chunked prefill and " + "prefix caching, max_num_batched_tokens " + "(chunk size) must be dividable by " + "block size, but got chunk_size " + f"({budget.token_budget}) % block_size " + f"({block_size}) = {remainder}") + # Round down to block size. + remaining_token_budget = (remaining_token_budget // block_size * + block_size) + + num_new_tokens = min(num_new_tokens, remaining_token_budget) + return num_new_tokens diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index 7c6f48e88637b..7411304eb18fa 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -131,6 +131,48 @@ def all_reduce(self, ncclRedOpTypeEnum.from_torch(op), self.comm, cudaStream_t(stream.cuda_stream)) + def all_gather(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = self.stream + self.nccl.ncclAllGather( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), input_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), self.comm, + cudaStream_t(stream.cuda_stream)) + + def reduce_scatter(self, + output_tensor: torch.Tensor, + input_tensor: torch.Tensor, + op: ReduceOp = ReduceOp.SUM, + stream=None): + if self.disabled: + return + # nccl communicator created on a specific device + # will only work on tensors on the same device + # otherwise it will cause "illegal memory access" + assert input_tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {input_tensor.device}") + if stream is None: + stream = self.stream + self.nccl.ncclReduceScatter( + buffer_type(input_tensor.data_ptr()), + buffer_type(output_tensor.data_ptr()), output_tensor.numel(), + ncclDataTypeEnum.from_torch(input_tensor.dtype), + ncclRedOpTypeEnum.from_torch(op), self.comm, + cudaStream_t(stream.cuda_stream)) + def send(self, tensor: torch.Tensor, dst: int, stream=None): if self.disabled: return diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index 7619c98f22148..ff88f72470b27 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -151,6 +151,28 @@ class NCCLLibrary: ncclRedOp_t, ncclComm_t, cudaStream_t ]), + # ncclResult_t ncclAllGather( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclAllGather", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclComm_t, cudaStream_t + ]), + + # ncclResult_t ncclReduceScatter( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, ncclRedOp_t op, ncclComm_t comm, + # cudaStream_t stream); + # note that cudaStream_t is a pointer type, so the last argument + # is a pointer + Function("ncclReduceScatter", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ncclRedOp_t, ncclComm_t, cudaStream_t + ]), + # ncclResult_t ncclSend( # const void* sendbuff, size_t count, ncclDataType_t datatype, # int dest, ncclComm_t comm, cudaStream_t stream); @@ -258,6 +280,28 @@ def ncclAllReduce(self, sendbuff: buffer_type, recvbuff: buffer_type, datatype, op, comm, stream)) + def ncclReduceScatter(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, op: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # and `op` should be `ncclRedOp_t` + # both are aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclReduceScatter"](sendbuff, recvbuff, + count, datatype, op, + comm, stream)) + + def ncclAllGather(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + # `datatype` actually should be `ncclDataType_t` + # which is an aliases of `ctypes.c_int` + # when we pass int to a function, it will be converted to `ctypes.c_int` + # by ctypes automatically + self.NCCL_CHECK(self._funcs["ncclAllGather"](sendbuff, recvbuff, count, + datatype, comm, stream)) + def ncclSend(self, sendbuff: buffer_type, count: int, datatype: int, dest: int, comm: ncclComm_t, stream: cudaStream_t) -> None: self.NCCL_CHECK(self._funcs["ncclSend"](sendbuff, count, datatype, diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 002b67e635bf4..60ad5ee54a2f2 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -192,6 +192,7 @@ class EngineArgs: override_neuron_config: Optional[Dict[str, Any]] = None override_pooler_config: Optional[PoolerConfig] = None compilation_config: Optional[CompilationConfig] = None + worker_cls: str = "auto" def __post_init__(self): if not self.tokenizer: @@ -202,6 +203,13 @@ def __post_init__(self): if self.enable_prefix_caching is None: self.enable_prefix_caching = bool(envs.VLLM_USE_V1) + # support `EngineArgs(compilation_config={...})` + # without having to manually construct a + # CompilationConfig object + if isinstance(self.compilation_config, (int, dict)): + self.compilation_config = CompilationConfig.from_cli( + json.dumps(self.compilation_config)) + # Setup plugins from vllm.plugins import load_general_plugins load_general_plugins() @@ -888,7 +896,16 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: 'testing only. level 3 is the recommended level ' 'for production.\n' 'To specify the full compilation config, ' - 'use a JSON string.') + 'use a JSON string.\n' + 'Following the convention of traditional ' + 'compilers, using -O without space is also ' + 'supported. -O3 is equivalent to -O 3.') + + parser.add_argument( + '--worker-cls', + type=str, + default="auto", + help='The worker class to use for distributed execution.') return parser @@ -1007,7 +1024,9 @@ def create_engine_config(self, self.tokenizer_pool_extra_config, ), ray_workers_use_nsight=self.ray_workers_use_nsight, - distributed_executor_backend=self.distributed_executor_backend) + distributed_executor_backend=self.distributed_executor_backend, + worker_cls=self.worker_cls, + ) max_model_len = model_config.max_model_len use_long_context = max_model_len > 32768 @@ -1025,7 +1044,8 @@ def create_engine_config(self, use_spec_decode = self.speculative_model is not None if (is_gpu and not use_sliding_window and not use_spec_decode and not self.enable_lora - and not self.enable_prompt_adapter): + and not self.enable_prompt_adapter + and model_config.task != "embedding"): self.enable_chunked_prefill = True logger.warning( "Chunked prefill is enabled by default for models with " @@ -1042,6 +1062,10 @@ def create_engine_config(self, "errors during the initial memory profiling phase, or result " "in low performance due to small KV cache space. Consider " "setting --max-model-len to a smaller value.", max_model_len) + elif self.enable_chunked_prefill and model_config.task == "embedding": + msg = "Chunked prefill is not supported for embedding models" + raise ValueError(msg) + speculative_config = SpeculativeConfig.maybe_create_spec_config( target_model_config=model_config, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 8cf5c4c308c47..a4975cece9a81 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -231,19 +231,18 @@ def __init__( use_cached_outputs: bool = False, ) -> None: - # TODO: remove the local variables and use self.* throughout the class. - model_config = self.model_config = vllm_config.model_config - cache_config = self.cache_config = vllm_config.cache_config - lora_config = self.lora_config = vllm_config.lora_config - parallel_config = self.parallel_config = vllm_config.parallel_config - scheduler_config = self.scheduler_config = vllm_config.scheduler_config - device_config = self.device_config = vllm_config.device_config - speculative_config = self.speculative_config = vllm_config.speculative_config # noqa - load_config = self.load_config = vllm_config.load_config - decoding_config = self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa + self.model_config = vllm_config.model_config + self.cache_config = vllm_config.cache_config + self.lora_config = vllm_config.lora_config + self.parallel_config = vllm_config.parallel_config + self.scheduler_config = vllm_config.scheduler_config + self.device_config = vllm_config.device_config + self.speculative_config = vllm_config.speculative_config # noqa + self.load_config = vllm_config.load_config + self.decoding_config = vllm_config.decoding_config or DecodingConfig( # noqa ) - prompt_adapter_config = self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa - observability_config = self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa + self.prompt_adapter_config = vllm_config.prompt_adapter_config # noqa + self.observability_config = vllm_config.observability_config or ObservabilityConfig( # noqa ) logger.info( @@ -265,54 +264,43 @@ def __init__( "mm_processor_kwargs=%s, pooler_config=%r," "compilation_config=%r", VLLM_VERSION, - model_config.model, - speculative_config, - model_config.tokenizer, - model_config.skip_tokenizer_init, - model_config.tokenizer_mode, - model_config.revision, - model_config.override_neuron_config, - model_config.tokenizer_revision, - model_config.trust_remote_code, - model_config.dtype, - model_config.max_model_len, - load_config.download_dir, - load_config.load_format, - parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size, - parallel_config.disable_custom_all_reduce, - model_config.quantization, - model_config.enforce_eager, - cache_config.cache_dtype, - model_config.quantization_param_path, - device_config.device, - decoding_config, - observability_config, - model_config.seed, - model_config.served_model_name, - scheduler_config.num_scheduler_steps, - scheduler_config.chunked_prefill_enabled, - scheduler_config.multi_step_stream_outputs, - cache_config.enable_prefix_caching, - model_config.use_async_output_proc, + self.model_config.model, + self.speculative_config, + self.model_config.tokenizer, + self.model_config.skip_tokenizer_init, + self.model_config.tokenizer_mode, + self.model_config.revision, + self.model_config.override_neuron_config, + self.model_config.tokenizer_revision, + self.model_config.trust_remote_code, + self.model_config.dtype, + self.model_config.max_model_len, + self.load_config.download_dir, + self.load_config.load_format, + self.parallel_config.tensor_parallel_size, + self.parallel_config.pipeline_parallel_size, + self.parallel_config.disable_custom_all_reduce, + self.model_config.quantization, + self.model_config.enforce_eager, + self.cache_config.cache_dtype, + self.model_config.quantization_param_path, + self.device_config.device, + self.decoding_config, + self.observability_config, + self.model_config.seed, + self.model_config.served_model_name, + self.scheduler_config.num_scheduler_steps, + self.scheduler_config.chunked_prefill_enabled, + self.scheduler_config.multi_step_stream_outputs, + self.cache_config.enable_prefix_caching, + self.model_config.use_async_output_proc, use_cached_outputs, - model_config.mm_processor_kwargs, - model_config.pooler_config, + self.model_config.mm_processor_kwargs, + self.model_config.pooler_config, vllm_config.compilation_config, ) # TODO(woosuk): Print more configs in debug mode. - self.model_config = model_config - self.cache_config = cache_config - self.lora_config = lora_config - self.parallel_config = parallel_config - self.scheduler_config = scheduler_config - self.device_config = device_config - self.speculative_config = speculative_config - self.load_config = load_config - self.decoding_config = decoding_config or DecodingConfig() - self.prompt_adapter_config = prompt_adapter_config - self.observability_config = observability_config or ObservabilityConfig( - ) + self.log_stats = log_stats self.use_cached_outputs = use_cached_outputs @@ -334,15 +322,15 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.seq_counter = Counter() self.generation_config_fields = _load_generation_config_dict( - model_config) + self.model_config) - self.input_preprocessor = InputPreprocessor(model_config, + self.input_preprocessor = InputPreprocessor(self.model_config, self.tokenizer, mm_registry) self.input_registry = input_registry self.input_processor = input_registry.create_input_processor( - model_config) + self.model_config) self.model_executor = executor_class(vllm_config=vllm_config, ) @@ -354,36 +342,36 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: from vllm.model_executor.model_loader import ( get_architecture_class_name) usage_message.report_usage( - get_architecture_class_name(model_config), + get_architecture_class_name(self.model_config), usage_context, extra_kvs={ # Common configuration "dtype": - str(model_config.dtype), + str(self.model_config.dtype), "tensor_parallel_size": - parallel_config.tensor_parallel_size, + self.parallel_config.tensor_parallel_size, "block_size": - cache_config.block_size, + self.cache_config.block_size, "gpu_memory_utilization": - cache_config.gpu_memory_utilization, + self.cache_config.gpu_memory_utilization, # Quantization "quantization": - model_config.quantization, + self.model_config.quantization, "kv_cache_dtype": - str(cache_config.cache_dtype), + str(self.cache_config.cache_dtype), # Feature flags "enable_lora": - bool(lora_config), + bool(self.lora_config), "enable_prompt_adapter": - bool(prompt_adapter_config), + bool(self.prompt_adapter_config), "enable_prefix_caching": - cache_config.enable_prefix_caching, + self.cache_config.enable_prefix_caching, "enforce_eager": - model_config.enforce_eager, + self.model_config.enforce_eager, "disable_custom_all_reduce": - parallel_config.disable_custom_all_reduce, + self.parallel_config.disable_custom_all_reduce, }) if self.tokenizer: @@ -402,7 +390,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: for _ in range(self.parallel_config.pipeline_parallel_size) ] - if model_config.use_async_output_proc: + if self.model_config.use_async_output_proc: process_model_outputs = weak_bind(self._process_model_outputs) self.async_callbacks = [ @@ -422,11 +410,11 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: # GPU and CPU blocks, which are profiled in the distributed executor. self.scheduler = [ Scheduler( - scheduler_config, cache_config, lora_config, - parallel_config.pipeline_parallel_size, + self.scheduler_config, self.cache_config, self.lora_config, + self.parallel_config.pipeline_parallel_size, self.async_callbacks[v_id] - if model_config.use_async_output_proc else None) - for v_id in range(parallel_config.pipeline_parallel_size) + if self.model_config.use_async_output_proc else None) + for v_id in range(self.parallel_config.pipeline_parallel_size) ] # Metric Logging. @@ -448,7 +436,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: "prometheus": PrometheusStatLogger( local_interval=_LOCAL_LOGGING_INTERVAL_SEC, - labels=dict(model_name=model_config.served_model_name), + labels=dict( + model_name=self.model_config.served_model_name), max_model_len=self.model_config.max_model_len), } self.stat_loggers["prometheus"].info("cache_config", diff --git a/vllm/entrypoints/chat_utils.py b/vllm/entrypoints/chat_utils.py index abee5ac46391c..c2054dcbfce0e 100644 --- a/vllm/entrypoints/chat_utils.py +++ b/vllm/entrypoints/chat_utils.py @@ -412,6 +412,8 @@ def _placeholder_str(self, modality: ModalityStr, return "" if model_type == "idefics3": return "" + if model_type == "aria": + return "<|fim_prefix|><|img|><|fim_suffix|>" raise TypeError(f"Unknown {modality} model type: {model_type}") elif modality == "audio": diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 86b0b6893f1d9..e07f4c04abd84 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,4 +1,5 @@ import itertools +import json import warnings from contextlib import contextmanager from typing import (Any, ClassVar, Dict, List, Optional, Sequence, Tuple, Type, @@ -9,6 +10,7 @@ from vllm import envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, BeamSearchSequence, get_beam_search_score) +from vllm.config import CompilationConfig from vllm.engine.arg_utils import (EngineArgs, HfOverrides, PoolerConfig, TaskOption) from vllm.engine.llm_engine import LLMEngine @@ -18,7 +20,7 @@ apply_mistral_chat_template, parse_chat_messages, resolve_chat_template_content_format) -from vllm.inputs import PromptType, TextPrompt, TokensPrompt +from vllm.inputs import PromptType, SingletonPrompt, TextPrompt, TokensPrompt from vllm.inputs.parse import parse_and_batch_prompt from vllm.logger import init_logger from vllm.lora.request import LoRARequest @@ -107,6 +109,9 @@ class LLM: hf_overrides: If a dictionary, contains arguments to be forwarded to the HuggingFace config. If a callable, it is called to update the HuggingFace config. + compilation_config: Either an integer or a dictionary. If it is an + integer, it is used as the level of compilation optimization. If it + is a dictionary, it can specify the full compilation configuration. **kwargs: Arguments for :class:`~vllm.EngineArgs`. (See :ref:`engine_args`) @@ -166,6 +171,7 @@ def __init__( # After positional args are removed, move this right below `model` task: TaskOption = "auto", override_pooler_config: Optional[PoolerConfig] = None, + compilation_config: Optional[Union[int, Dict[str, Any]]] = None, **kwargs, ) -> None: ''' @@ -178,6 +184,12 @@ def __init__( if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True + if compilation_config is not None: + compilation_config_instance = CompilationConfig.from_cli( + json.dumps(compilation_config)) + else: + compilation_config_instance = None + engine_args = EngineArgs( model=model, task=task, @@ -202,6 +214,7 @@ def __init__( hf_overrides=hf_overrides, mm_processor_kwargs=mm_processor_kwargs, override_pooler_config=override_pooler_config, + compilation_config=compilation_config_instance, **kwargs, ) # Logic to switch between engines is done at runtime instead of import @@ -804,6 +817,128 @@ def encode( return self.engine_class.validate_outputs(outputs, EmbeddingRequestOutput) + def score( + self, + text_1: Union[SingletonPrompt, Sequence[SingletonPrompt]], + text_2: Union[SingletonPrompt, Sequence[SingletonPrompt]], + /, + truncate_prompt_tokens: Optional[int] = None, + use_tqdm: bool = True, + lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None, + prompt_adapter_request: Optional[PromptAdapterRequest] = None, + ) -> List[EmbeddingRequestOutput]: + """Generates similarity scores for all pairs . + + The inputs can be 1 -> 1, 1 -> N or N -> N. In the 1 - N case + the text_1 sentence will be replicated N times to pair with the text_2 + sentences. The input pairs are used to build a list of prompts for the + cross encoder model. This class automatically batches the prompts, + considering the memory constraint. For the best performance, put all + of your texts into a single list and pass it to this method. + + Args: + text_1: can be a single prompt or a list of prompts, in which + case it has to have the same length as the text_2 list + text_2: The texts to pair with the query to form the input + to the LLM. See :class:`~vllm.inputs.PromptType` for + more details about the format of each prompts. + use_tqdm: Whether to use tqdm to display the progress bar. + lora_request: LoRA request to use for generation, if any. + prompt_adapter_request: Prompt Adapter request to use for + generation, if any. + + Returns: + A list of ``EmbeddingRequestOutput`` objects containing the + generated scores in the same order as the input prompts. + """ + task = self.llm_engine.model_config.task + if task != "embedding": + messages = ["LLM.score() is only supported for embedding models."] + + supported_tasks = self.llm_engine.model_config.supported_tasks + if "embedding" in supported_tasks: + messages.append( + "Your model supports the 'embedding' task, but is " + f"currently initialized for the '{task}' task. Please " + "initialize the model using `--task embedding`.") + + raise ValueError(" ".join(messages)) + + if not self.llm_engine.model_config.is_cross_encoder: + raise ValueError("Your model does not support the cross encoding") + + tokenizer = self.llm_engine.get_tokenizer() + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + # the tokenizer for models such as + # "cross-encoder/ms-marco-MiniLM-L-6-v2" doesn't support passing + # lists of tokens to the `text` and `text_pair` kwargs + def ensure_str(prompt: SingletonPrompt): + if isinstance(prompt, dict): + if "multi_modal_data" in prompt: + raise ValueError("Multi-modal prompt is not " + "supported for cross encoding") + elif "prompt_token_ids" in prompt: + prompt = tokenizer.decode( + cast(TokensPrompt, prompt)["prompt_token_ids"]) + elif "prompt" in prompt: + prompt = cast(TextPrompt, prompt)["prompt"] + assert type(prompt) is str + return prompt + + if isinstance(text_1, (str, dict)): + # Convert a single prompt to a list. + text_1 = [text_1] + text_1 = [ensure_str(t) for t in text_1] + + if isinstance(text_2, (str, dict)): + # Convert a single prompt to a list. + text_2 = [text_2] + text_2 = [ensure_str(t) for t in text_2] + + if len(text_1) > 1 and len(text_1) != len(text_2): + raise ValueError("Input lengths must be either 1:1, 1:N or N:N") + if len(text_1) == 0: + raise ValueError("At least one text element must be given") + if len(text_2) == 0: + raise ValueError("At least one text_pair element must be given") + + if len(text_1) == 1: + text_1 = text_1 * len(text_2) + + input_pairs = [(t1, t2) for t1, t2 in zip(text_1, text_2)] + pooling_params = PoolingParams() + + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + parsed_prompts = [] + + for q, t in input_pairs: + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + parsed_prompts.append(engine_prompt) + + self._validate_and_add_requests( + prompts=parsed_prompts, + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request, + ) + + outputs = self._run_engine(use_tqdm=use_tqdm) + return self.engine_class.validate_outputs(outputs, + EmbeddingRequestOutput) + def start_profile(self) -> None: self.llm_engine.start_profile() diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 0751a60f524a3..6bc31ef83ded4 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -45,6 +45,7 @@ EmbeddingRequest, EmbeddingResponse, ErrorResponse, LoadLoraAdapterRequest, + ScoreRequest, ScoreResponse, TokenizeRequest, TokenizeResponse, UnloadLoraAdapterRequest) @@ -53,6 +54,7 @@ from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.entrypoints.openai.serving_score import OpenAIServingScores from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) from vllm.entrypoints.openai.tool_parsers import ToolParserManager @@ -280,6 +282,10 @@ def embedding(request: Request) -> Optional[OpenAIServingEmbedding]: return request.app.state.openai_serving_embedding +def score(request: Request) -> Optional[OpenAIServingScores]: + return request.app.state.openai_serving_scores + + def tokenization(request: Request) -> OpenAIServingTokenization: return request.app.state.openai_serving_tokenization @@ -391,6 +397,23 @@ async def create_embedding(request: EmbeddingRequest, raw_request: Request): assert_never(generator) +@router.post("/v1/score") +async def create_score(request: ScoreRequest, raw_request: Request): + handler = score(raw_request) + if handler is None: + return base(raw_request).create_error_response( + message="The model does not support Score API") + + generator = await handler.create_score(request, raw_request) + if isinstance(generator, ErrorResponse): + return JSONResponse(content=generator.model_dump(), + status_code=generator.code) + elif isinstance(generator, ScoreResponse): + return JSONResponse(content=generator.model_dump()) + + assert_never(generator) + + if envs.VLLM_TORCH_PROFILER_DIR: logger.warning( "Torch Profiler is enabled in the API server. This should ONLY be " @@ -466,8 +489,9 @@ def build_app(args: Namespace) -> FastAPI: @app.exception_handler(RequestValidationError) async def validation_exception_handler(_, exc): - chat = app.state.openai_serving_chat - err = chat.create_error_response(message=str(exc)) + err = ErrorResponse(message=str(exc), + type="BadRequestError", + code=HTTPStatus.BAD_REQUEST) return JSONResponse(err.model_dump(), status_code=HTTPStatus.BAD_REQUEST) @@ -475,10 +499,12 @@ async def validation_exception_handler(_, exc): @app.middleware("http") async def authentication(request: Request, call_next): - root_path = "" if args.root_path is None else args.root_path if request.method == "OPTIONS": return await call_next(request) - if not request.url.path.startswith(f"{root_path}/v1"): + url_path = request.url.path + if app.root_path and url_path.startswith(app.root_path): + url_path = url_path[len(app.root_path):] + if not url_path.startswith("/v1"): return await call_next(request) if request.headers.get("Authorization") != "Bearer " + token: return JSONResponse(content={"error": "Unauthorized"}, @@ -565,6 +591,13 @@ def init_app_state( chat_template=resolved_chat_template, chat_template_content_format=args.chat_template_content_format, ) if model_config.task == "embedding" else None + state.openai_serving_scores = OpenAIServingScores( + engine_client, + model_config, + base_model_paths, + request_logger=request_logger + ) if (model_config.task == "embedding" \ + and model_config.is_cross_encoder) else None state.openai_serving_tokenization = OpenAIServingTokenization( engine_client, model_config, diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index b7b064ae01f05..ee94a9413f098 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -9,12 +9,15 @@ from typing_extensions import Annotated from vllm.entrypoints.chat_utils import ChatCompletionMessageParam +from vllm.logger import init_logger from vllm.pooling_params import PoolingParams from vllm.sampling_params import (BeamSearchParams, GuidedDecodingParams, RequestOutputKind, SamplingParams) from vllm.sequence import Logprob from vllm.utils import random_uuid +logger = init_logger(__name__) + # torch is mocked during docs generation, # so we have to provide the values as literals _MOCK_LONG_INFO = Namespace(min=-9223372036854775808, max=9223372036854775807) @@ -35,8 +38,19 @@ class OpenAIBaseModel(BaseModel): - # OpenAI API does not allow extra fields - model_config = ConfigDict(extra="forbid") + # OpenAI API does allow extra fields + model_config = ConfigDict(extra="allow") + + @model_validator(mode="before") + @classmethod + def __log_extra_fields__(cls, data): + if isinstance(data, dict): + extra_fields = data.keys() - cls.model_fields.keys() + if extra_fields: + logger.warning( + "The following fields were present in the request " + "but ignored: %s", extra_fields) + return data class ErrorResponse(OpenAIBaseModel): @@ -464,17 +478,17 @@ def check_tool_usage(cls, data): # it matches a valid tool if isinstance(data["tool_choice"], dict): valid_tool = False - specified_function = data["tool_choice"]["function"] + specified_function = data["tool_choice"].get("function") if not specified_function: raise ValueError( - "Incorrectly formatted `tool_choice`. Should be like " - "`{\"type\": \"function\"," + "Expected field `function` in `tool_choice`." + " Correct usage: `{\"type\": \"function\"," " \"function\": {\"name\": \"my_function\"}}`") - specified_function_name = specified_function["name"] + specified_function_name = specified_function.get("name") if not specified_function_name: raise ValueError( - "Incorrectly formatted `tool_choice`. Should be like " - "`{\"type\": \"function\", " + "Expected field `name` in `function` in `tool_choice`." + "Correct usage: `{\"type\": \"function\", " "\"function\": {\"name\": \"my_function\"}}`") for tool in data["tools"]: if tool["function"]["name"] == specified_function_name: @@ -746,22 +760,6 @@ class EmbeddingChatRequest(OpenAIBaseModel): # doc: end-chat-embedding-pooling-params # doc: begin-chat-embedding-extra-params - add_generation_prompt: bool = Field( - default=True, - description= - ("If true, the generation prompt will be added to the chat template. " - "This is a parameter used by chat template in tokenizer config of the " - "model."), - ) - continue_final_message: bool = Field( - default=False, - description= - ("If this is set, the chat will be formatted so that the final " - "message in the chat is open-ended, without any EOS tokens. The " - "model will continue this message rather than starting a new one. " - "This allows you to \"prefill\" part of the model's response for it. " - "Cannot be used at the same time as `add_generation_prompt`."), - ) add_special_tokens: bool = Field( default=False, description=( @@ -808,6 +806,27 @@ def to_pooling_params(self): EmbeddingRequest = Union[EmbeddingCompletionRequest, EmbeddingChatRequest] +class ScoreRequest(OpenAIBaseModel): + model: str + text_1: Union[List[str], str] + text_2: Union[List[str], str] + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None + + # doc: begin-chat-embedding-pooling-params + additional_data: Optional[Any] = None + # doc: end-chat-embedding-pooling-params + + priority: int = Field( + default=0, + description=( + "The priority of the request (lower means earlier handling; " + "default: 0). Any priority other than 0 will raise an error " + "if the served model does not use priority scheduling.")) + + def to_pooling_params(self): + return PoolingParams(additional_data=self.additional_data) + + class CompletionLogProbs(OpenAIBaseModel): text_offset: List[int] = Field(default_factory=list) token_logprobs: List[Optional[float]] = Field(default_factory=list) @@ -878,6 +897,21 @@ class EmbeddingResponse(OpenAIBaseModel): usage: UsageInfo +class ScoreResponseData(OpenAIBaseModel): + index: int + object: str = "score" + score: Union[List[float], str] + + +class ScoreResponse(OpenAIBaseModel): + id: str = Field(default_factory=lambda: f"embd-{random_uuid()}") + object: str = "list" + created: int = Field(default_factory=lambda: int(time.time())) + model: str + data: List[ScoreResponseData] + usage: UsageInfo + + class FunctionCall(OpenAIBaseModel): name: str arguments: str diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 2eef909eb9319..54ca0463bcab1 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -361,7 +361,7 @@ async def chat_completion_stream_generator( # Send response to echo the input portion of the # last message - if request.echo or request.continue_final_message: + if request.echo: last_msg_content: Union[str, List[Dict[str, str]]] = "" if conversation and "content" in conversation[ -1] and conversation[-1].get("role") == role: @@ -706,7 +706,7 @@ async def chat_completion_full_generator( stop_reason=output.stop_reason) choices.append(choice_data) - if request.echo or request.continue_final_message: + if request.echo: last_msg_content: Union[str, List[Dict[str, str]]] = "" if conversation and "content" in conversation[-1] and conversation[ -1].get("role") == role: diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 74ad7389784fc..c84a7d2d8e13e 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -148,8 +148,10 @@ async def create_embedding( chat_template=request.chat_template or self.chat_template, chat_template_content_format=self. chat_template_content_format, - add_generation_prompt=request.add_generation_prompt, - continue_final_message=request.continue_final_message, + # In embedding requests, we are not generating tokens, + # so there is no need to append extra tokens to the input + add_generation_prompt=False, + continue_final_message=False, truncate_prompt_tokens=truncate_prompt_tokens, add_special_tokens=request.add_special_tokens, ) diff --git a/vllm/entrypoints/openai/serving_score.py b/vllm/entrypoints/openai/serving_score.py new file mode 100644 index 0000000000000..156fea6f47982 --- /dev/null +++ b/vllm/entrypoints/openai/serving_score.py @@ -0,0 +1,215 @@ +import asyncio +import time +from typing import Any, AsyncGenerator, Dict, List, Optional, Union, cast + +from fastapi import Request + +from vllm.config import ModelConfig +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.logger import RequestLogger +from vllm.entrypoints.openai.protocol import (ErrorResponse, ScoreRequest, + ScoreResponse, ScoreResponseData, + UsageInfo) +from vllm.entrypoints.openai.serving_engine import BaseModelPath, OpenAIServing +from vllm.inputs.data import TokensPrompt +from vllm.logger import init_logger +from vllm.outputs import EmbeddingRequestOutput +from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer +from vllm.utils import merge_async_iterators, random_uuid + +logger = init_logger(__name__) + + +def request_output_to_score_response( + final_res_batch: List[EmbeddingRequestOutput], request_id: str, + created_time: int, model_name: str) -> ScoreResponse: + data: List[ScoreResponseData] = [] + score = None + num_prompt_tokens = 0 + for idx, final_res in enumerate(final_res_batch): + if final_res is not None: + score = final_res.outputs.embedding + score_data = ScoreResponseData(index=idx, score=score) + data.append(score_data) + + usage = UsageInfo( + prompt_tokens=num_prompt_tokens, + total_tokens=num_prompt_tokens, + ) + + return ScoreResponse( + id=request_id, + created=created_time, + model=model_name, + data=data, + usage=usage, + ) + + +def make_pairs(text_1: Union[List[str], str], text_2: Union[List[str], + str]) -> List: + if isinstance(text_1, (str, dict)): + # Convert a single prompt to a list. + text_1 = [text_1] + text_1 = [t for t in text_1] + + if isinstance(text_2, (str, dict)): + # Convert a single prompt to a list. + text_2 = [text_2] + text_2 = [t for t in text_2] + if len(text_1) > 1 and len(text_1) != len(text_2): + raise ValueError("Input lengths must be either 1:1, 1:N or N:N") + if len(text_1) == 0: + raise ValueError("At least one text element must be given") + if len(text_2) == 0: + raise ValueError("At least one text_pair element must be given") + + if len(text_1) == 1: + text_1 = text_1 * len(text_2) + + return [(t1, t2) for t1, t2 in zip(text_1, text_2)] + + +class OpenAIServingScores(OpenAIServing): + + def __init__( + self, + engine_client: EngineClient, + model_config: ModelConfig, + base_model_paths: List[BaseModelPath], + *, + request_logger: Optional[RequestLogger], + ) -> None: + super().__init__(engine_client=engine_client, + model_config=model_config, + base_model_paths=base_model_paths, + lora_modules=None, + prompt_adapters=None, + request_logger=request_logger) + + async def create_score( + self, + request: ScoreRequest, + raw_request: Optional[Request] = None, + ) -> Union[ScoreResponse, ErrorResponse]: + """ + Score API similar to Sentence Transformers cross encoder + + See https://sbert.net/docs/package_reference/cross_encoder + """ + error_check_ret = await self._check_model(request) + if error_check_ret is not None: + return error_check_ret + + model_name = request.model + request_id = f"score-{random_uuid()}" + created_time = int(time.monotonic()) + truncate_prompt_tokens = request.truncate_prompt_tokens + + request_prompts = [] + engine_prompts = [] + + try: + ( + lora_request, + prompt_adapter_request, + ) = self._maybe_get_adapters(request) + + tokenizer = await self.engine_client.get_tokenizer(lora_request) + + if prompt_adapter_request is not None: + raise NotImplementedError("Prompt adapter is not supported " + "for embedding models") + + if isinstance(tokenizer, MistralTokenizer): + raise ValueError( + "MistralTokenizer not supported for cross-encoding") + + if not self.model_config.is_cross_encoder: + raise ValueError("Model is not cross encoder.") + + except ValueError as e: + logger.exception("Error in preprocessing prompt inputs") + return self.create_error_response(str(e)) + + # Schedule the request and get the result generator. + generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] + + input_pairs = make_pairs(request.text_1, request.text_2) + + for q, t in input_pairs: + request_prompt = f"{q}{tokenizer.sep_token}{t}" + + tokenization_kwargs: Dict[str, Any] = {} + if truncate_prompt_tokens is not None: + tokenization_kwargs["truncation"] = True + tokenization_kwargs["max_length"] = truncate_prompt_tokens + + prompt_inputs = tokenizer(text=q, + text_pair=t, + **tokenization_kwargs) + engine_prompt = TokensPrompt( + prompt_token_ids=prompt_inputs["input_ids"], + token_type_ids=prompt_inputs.get("token_type_ids")) + + request_prompts.append(request_prompt) + engine_prompts.append(engine_prompt) + + try: + pooling_params = request.to_pooling_params() + + for i, engine_prompt in enumerate(engine_prompts): + request_id_item = f"{request_id}-{i}" + + self._log_inputs(request_id_item, + request_prompts[i], + params=pooling_params, + lora_request=lora_request, + prompt_adapter_request=prompt_adapter_request) + + trace_headers = (None if raw_request is None else await + self._get_trace_headers(raw_request.headers)) + + generator = self.engine_client.encode( + engine_prompt, + pooling_params, + request_id_item, + lora_request=lora_request, + trace_headers=trace_headers, + priority=request.priority, + ) + + generators.append(generator) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + result_generator = merge_async_iterators( + *generators, + is_cancelled=raw_request.is_disconnected if raw_request else None, + ) + + num_prompts = len(engine_prompts) + + # Non-streaming response + final_res_batch: List[Optional[EmbeddingRequestOutput]] + final_res_batch = [None] * num_prompts + + try: + async for i, res in result_generator: + final_res_batch[i] = res + + assert all(final_res is not None for final_res in final_res_batch) + + final_res_batch_checked = cast(List[EmbeddingRequestOutput], + final_res_batch) + + response = request_output_to_score_response( + final_res_batch_checked, request_id, created_time, model_name) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) + + return response diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index a5f44d69e5fd2..1856308b88cfa 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -29,7 +29,8 @@ class Llama3JsonToolParser(ToolParser): Tool call parser for Llama 3.1 models intended for use with the examples/tool_chat_template_llama.jinja template. - Used when --enable-auto-tool-choice --tool-call-parser mistral are all set + Used when --enable-auto-tool-choice --tool-call-parser llama3_json + are all set """ def __init__(self, tokenizer: PreTrainedTokenizerBase): diff --git a/vllm/envs.py b/vllm/envs.py index 853c49bc4dbc1..14c1617f1be19 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -49,7 +49,7 @@ VLLM_WORKER_MULTIPROC_METHOD: str = "fork" VLLM_ASSETS_CACHE: str = os.path.join(VLLM_CACHE_ROOT, "assets") VLLM_IMAGE_FETCH_TIMEOUT: int = 5 - VLLM_VIDEO_FETCH_TIMEOUT: int = 15 + VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_AUDIO_FETCH_TIMEOUT: int = 10 VLLM_TARGET_DEVICE: str = "cuda" MAX_JOBS: Optional[str] = None diff --git a/vllm/executor/cpu_executor.py b/vllm/executor/cpu_executor.py index 1542a2ae367eb..336f9bc8efb20 100644 --- a/vllm/executor/cpu_executor.py +++ b/vllm/executor/cpu_executor.py @@ -115,13 +115,8 @@ def _create_worker( local_rank: int = 0, rank: int = 0, ): - worker_module_name = "vllm.worker.cpu_worker" - worker_class_name = "CPUWorker" - wrapper = WorkerWrapperBase( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - ) + wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) assert self.distributed_init_method is not None diff --git a/vllm/executor/gpu_executor.py b/vllm/executor/gpu_executor.py index c65d0836e5ff7..7fa34456028dd 100644 --- a/vllm/executor/gpu_executor.py +++ b/vllm/executor/gpu_executor.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Union from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger @@ -8,19 +8,14 @@ from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) -from vllm.worker.worker_base import WorkerBase, WorkerWrapperBase +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) -def create_worker(worker_module_name: str, worker_class_name: str, - worker_class_fn: Optional[Callable[[], Type[WorkerBase]]], - **kwargs): - wrapper = WorkerWrapperBase( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - worker_class_fn=worker_class_fn, - ) +def create_worker(**kwargs): + vllm_config = kwargs.get("vllm_config") + wrapper = WorkerWrapperBase(vllm_config=vllm_config) wrapper.init_worker(**kwargs) return wrapper.worker @@ -57,43 +52,11 @@ def _get_worker_kwargs( or (rank % self.parallel_config.tensor_parallel_size == 0), ) - def _get_worker_module_and_class( - self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: - worker_class_fn = None - if self.scheduler_config.is_multi_step: - worker_module_name = "vllm.worker.multi_step_worker" - worker_class_name = "MultiStepWorker" - elif self.speculative_config: - worker_module_name = "vllm.spec_decode.spec_decode_worker" - worker_class_name = "create_spec_worker" - else: - worker_module_name = "vllm.worker.worker" - worker_class_name = "Worker" - return (worker_module_name, worker_class_name, worker_class_fn) - - def _get_create_worker_kwargs( - self, - local_rank: int = 0, - rank: int = 0, - distributed_init_method: Optional[str] = None) -> Dict: - worker_kwargs = self._get_worker_kwargs(local_rank, rank, - distributed_init_method) - - (worker_module_name, worker_class_name, - worker_class_fn) = self._get_worker_module_and_class() - worker_kwargs.update( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - worker_class_fn=worker_class_fn, - ) - - return worker_kwargs - def _create_worker(self, local_rank: int = 0, rank: int = 0, distributed_init_method: Optional[str] = None): - return create_worker(**self._get_create_worker_kwargs( + return create_worker(**self._get_worker_kwargs( local_rank=local_rank, rank=rank, distributed_init_method=distributed_init_method)) diff --git a/vllm/executor/hpu_executor.py b/vllm/executor/hpu_executor.py index 220e9eee87bb3..c9b7bfa71edfa 100644 --- a/vllm/executor/hpu_executor.py +++ b/vllm/executor/hpu_executor.py @@ -48,10 +48,7 @@ def _create_worker(self, local_rank: int = 0, rank: int = 0, distributed_init_method: Optional[str] = None): - wrapper = WorkerWrapperBase( - worker_module_name="vllm.worker.hpu_worker", - worker_class_name="HPUWorker", - ) + wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, distributed_init_method)) return wrapper.worker diff --git a/vllm/executor/multiproc_gpu_executor.py b/vllm/executor/multiproc_gpu_executor.py index 3eb14fb931925..a6c05a71d2b6f 100644 --- a/vllm/executor/multiproc_gpu_executor.py +++ b/vllm/executor/multiproc_gpu_executor.py @@ -90,7 +90,7 @@ def _init_executor(self) -> None: result_handler, partial( create_worker, - **self._get_create_worker_kwargs( + **self._get_worker_kwargs( rank=rank, local_rank=rank, distributed_init_method=distributed_init_method, diff --git a/vllm/executor/neuron_executor.py b/vllm/executor/neuron_executor.py index 02d37cd7fbf23..31e6fdc3ab1bb 100644 --- a/vllm/executor/neuron_executor.py +++ b/vllm/executor/neuron_executor.py @@ -7,6 +7,7 @@ from vllm.sequence import ExecuteModelRequest from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, make_async) +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -25,10 +26,10 @@ def _init_executor(self) -> None: self._init_worker() def _init_worker(self): - from vllm.worker.neuron_worker import NeuronWorker + wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - self.driver_worker = NeuronWorker( + self.driver_worker = wrapper.init_worker( vllm_config=self.vllm_config, local_rank=0, rank=0, diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index d06b0ccb7906e..db0070ce510ee 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -1,19 +1,17 @@ from typing import List, Set, Tuple import openvino as ov -import openvino.properties.hint as hints -import torch import vllm.envs as envs -from vllm.config import CacheConfig, ModelConfig from vllm.executor.executor_base import ExecutorAsyncBase, ExecutorBase from vllm.logger import init_logger from vllm.lora.request import LoRARequest from vllm.model_executor.layers.sampler import SamplerOutput from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest -from vllm.utils import (GiB_bytes, get_distributed_init_method, get_ip, - get_open_port, make_async) +from vllm.utils import (get_distributed_init_method, get_ip, get_open_port, + make_async) +from vllm.worker.worker_base import WorkerWrapperBase logger = init_logger(__name__) @@ -29,25 +27,17 @@ def _init_executor(self) -> None: current_platform.is_openvino_gpu(), \ "OpenVINO backend supports only CPU and GPU devices" - self.ov_core = ov.Core() - self.model_config = _verify_and_get_model_config(self.model_config) - self.cache_config = _verify_and_get_cache_config( - self.ov_core, self.cache_config) - # Instantiate the worker and load the model to CPU. self._init_worker() def _init_worker(self): - from vllm.worker.openvino_worker import OpenVINOWorker - assert ( - self.parallel_config.world_size == 1 - ), "OpenVINOExecutor only supports single CPU socket currently." + wrapper = WorkerWrapperBase(vllm_config=self.vllm_config) distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) - self.driver_worker = OpenVINOWorker( - ov_core=self.ov_core, + self.driver_worker = wrapper.init_worker( + ov_core=ov.Core(), vllm_config=self.vllm_config, local_rank=0, rank=0, @@ -132,70 +122,3 @@ async def check_health_async(self) -> None: # OpenVINOExecutor will always be healthy as long as # it's running. return - - -def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: - if config.dtype != torch.float32: - logger.warning( - f"Only float32 dtype is supported on OpenVINO, casting from {config.dtype}." # noqa: G004, E501 - ) - config.dtype = torch.float32 - if not config.enforce_eager: - logger.warning( - "CUDA graph is not supported on OpenVINO backend, fallback to the " - "eager mode.") - config.enforce_eager = True - return config - - -def _verify_and_get_cache_config(ov_core: ov.Core, - config: CacheConfig) -> CacheConfig: - if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8": - if not current_platform.is_openvino_cpu(): - logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is" - "ignored for GPU, f16 data type will be used.") - config.cache_dtype = ov.Type.f16 - else: - logger.info("KV cache type is overridden to u8 via " - "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.") - config.cache_dtype = ov.Type.u8 - else: - if current_platform.is_openvino_cpu(): - ov_device = envs.VLLM_OPENVINO_DEVICE - inference_precision = ov_core.get_property( - ov_device, hints.inference_precision) - if inference_precision == ov.Type.bf16: - config.cache_dtype = ov.Type.bf16 - else: - config.cache_dtype = ov.Type.f16 - else: - config.cache_dtype = ov.Type.f16 - - if current_platform.is_openvino_cpu(): - if config.block_size != 32: - logger.info( - f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501 - ) - config.block_size = 32 - else: - if config.block_size != 16: - logger.info( - f"OpenVINO GPU optimal block size is 16, overriding currently set {config.block_size}" # noqa: G004, E501 - ) - config.block_size = 16 - - kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE - if kv_cache_space >= 0: - if kv_cache_space == 0 and current_platform.is_openvino_cpu(): - config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore - logger.warning( - "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) " - "for OpenVINO backend is not set, using 4 by default.") - else: - config.openvino_kvcache_space_bytes = kv_cache_space * GiB_bytes # type: ignore - else: - raise RuntimeError( - "Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE" - f" {kv_cache_space}, expect a positive integer value.") - - return config diff --git a/vllm/executor/ray_gpu_executor.py b/vllm/executor/ray_gpu_executor.py index 66bab2c686c67..810b0f06ff7b2 100644 --- a/vllm/executor/ray_gpu_executor.py +++ b/vllm/executor/ray_gpu_executor.py @@ -91,17 +91,6 @@ def _configure_ray_workers_use_nsight(self, return ray_remote_kwargs - def _get_worker_wrapper_args(self) -> Dict[str, Any]: - (worker_module_name, worker_class_name, - worker_class_fn) = self._get_worker_module_and_class() - - return dict( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - worker_class_fn=worker_class_fn, - trust_remote_code=self.model_config.trust_remote_code, - ) - # child class could overwrite this to return actual env vars. def _get_env_vars_to_be_updated(self): return self._env_vars_for_all_workers @@ -135,7 +124,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Create the workers. driver_ip = get_ip() - worker_wrapper_kwargs = self._get_worker_wrapper_args() for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get("GPU", 0): continue @@ -150,7 +138,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", num_gpus=num_gpus, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) if self.use_ray_spmd_worker: self.workers.append(worker) @@ -161,7 +149,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - **worker_wrapper_kwargs) + vllm_config=self.vllm_config) else: # Else, added to the list of workers. self.workers.append(worker) diff --git a/vllm/executor/ray_hpu_executor.py b/vllm/executor/ray_hpu_executor.py index a24bab6df370e..6fe8c6c403358 100644 --- a/vllm/executor/ray_hpu_executor.py +++ b/vllm/executor/ray_hpu_executor.py @@ -2,8 +2,7 @@ import os from collections import defaultdict from itertools import islice, repeat -from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, - Type) +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple import msgspec @@ -18,7 +17,6 @@ from vllm.utils import (_run_task_with_lock, get_distributed_init_method, get_ip, get_open_port, get_vllm_instance_id, make_async) -from vllm.worker.worker_base import WorkerBase if ray is not None: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -81,33 +79,6 @@ def shutdown(self) -> None: def finish_measurements(self): self._run_workers("finish_measurements") - def _get_worker_module_and_class( - self - ) -> Tuple[str, str, Optional[Callable[[], - Type[WorkerBase]]]]: # noqa: F821 - worker_class_fn = None - if self.scheduler_config.is_multi_step: - raise NotImplementedError( - "Multi-step execution is not implemented for HPU") - elif self.speculative_config: - raise NotImplementedError( - "Speculative decoding is not implemented for HPU") - else: - worker_module_name = "vllm.worker.hpu_worker" - worker_class_name = "HPUWorker" - return (worker_module_name, worker_class_name, worker_class_fn) - - def _get_worker_wrapper_args(self) -> Dict[str, Any]: - (worker_module_name, worker_class_name, - worker_class_fn) = self._get_worker_module_and_class() - - return dict( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - worker_class_fn=worker_class_fn, - trust_remote_code=self.model_config.trust_remote_code, - ) - def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): # Otherwise, the ray workers are allocated with a full GPU. @@ -128,7 +99,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # Create the workers. driver_ip = get_ip() - worker_wrapper_kwargs = self._get_worker_wrapper_args() for bundle_id, bundle in enumerate(placement_group.bundle_specs): if not bundle.get("HPU", 0): continue @@ -144,7 +114,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", resources={'HPU': num_gpus}, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote(**worker_wrapper_kwargs) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) if self.use_ray_spmd_worker: self.workers.append(worker) @@ -155,7 +125,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - **worker_wrapper_kwargs) + vllm_config=self.vllm_config) else: # Else, added to the list of workers. self.workers.append(worker) diff --git a/vllm/executor/ray_tpu_executor.py b/vllm/executor/ray_tpu_executor.py index d02fecb46f007..c227b5e283c68 100644 --- a/vllm/executor/ray_tpu_executor.py +++ b/vllm/executor/ray_tpu_executor.py @@ -69,14 +69,6 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", placement_group_bundle_index=bundle_id, ) - assert self.speculative_config is None - if self.scheduler_config.is_multi_step: - worker_module_name = "vllm.worker.multi_step_tpu_worker" - worker_class_name = "MultiStepTPUWorker" - else: - worker_module_name = "vllm.worker.tpu_worker" - worker_class_name = "TPUWorker" - # GKE does not fetch environment information from metadata server # and instead sets these from within the Ray process. Therefore we # need to override the Ray environment variables manually. @@ -95,11 +87,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", resources={"TPU": 1}, scheduling_strategy=scheduling_strategy, **ray_remote_kwargs, - )(RayWorkerWrapper).remote( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - trust_remote_code=self.model_config.trust_remote_code, - ) + )(RayWorkerWrapper).remote(vllm_config=self.vllm_config) if override_env: worker.override_env_vars.remote(override_env) @@ -109,10 +97,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", # as the resource holder for the driver process. self.driver_dummy_worker = worker self.driver_worker = RayWorkerWrapper( - worker_module_name=worker_module_name, - worker_class_name=worker_class_name, - trust_remote_code=self.model_config.trust_remote_code, - ) + vllm_config=self.vllm_config) else: # Else, added to the list of workers. self.workers.append(worker) diff --git a/vllm/executor/xpu_executor.py b/vllm/executor/xpu_executor.py index 36b7e2265efab..722b86a95ff8a 100644 --- a/vllm/executor/xpu_executor.py +++ b/vllm/executor/xpu_executor.py @@ -1,15 +1,11 @@ -from typing import Callable, List, Optional, Tuple, Type, Union +from typing import List, Optional, Union -import torch - -from vllm.config import ModelConfig, ParallelConfig from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutor from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput from vllm.sequence import ExecuteModelRequest, PoolerOutput from vllm.utils import make_async -from vllm.worker.worker_base import WorkerBase logger = init_logger(__name__) @@ -23,20 +19,8 @@ def _init_executor(self) -> None: assert self.speculative_config is None, ( "Speculative decoding not yet supported for XPU backend") - self.model_config = _verify_and_get_model_config(self.model_config) GPUExecutor._init_executor(self) - def _get_worker_module_and_class( - self) -> Tuple[str, str, Optional[Callable[[], Type[WorkerBase]]]]: - worker_class_fn = None - if self.speculative_config is not None: - raise NotImplementedError( - "XPU does not support speculative decoding") - else: - worker_module_name = "vllm.worker.xpu_worker" - worker_class_name = "XPUWorker" - return (worker_module_name, worker_class_name, worker_class_fn) - def execute_model( self, execute_model_req: ExecuteModelRequest ) -> Optional[List[Union[SamplerOutput, PoolerOutput]]]: @@ -53,26 +37,3 @@ async def execute_model_async( output = await make_async(self.driver_worker.execute_model )(execute_model_req=execute_model_req) return output - - -def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: - if config.dtype == torch.bfloat16: - logger.warning( - "bfloat16 is not fully supported on XPU, casting to float16.") - config.dtype = torch.float16 - if not config.enforce_eager: - logger.warning( - "CUDA graph is not supported on XPU, fallback to the eager " - "mode.") - config.enforce_eager = True - return config - - -def _verify_and_get_parallel_config(config: ParallelConfig) -> ParallelConfig: - if (config.distributed_executor_backend is not None - and config.distributed_executor_backend != "ray"): - logger.warning( - "%s is not supported on XPU, fallback to ray distributed executor " - "backend.", config.distributed_executor_backend) - config.distributed_executor_backend = "ray" - return config diff --git a/vllm/forward_context.py b/vllm/forward_context.py index 777747505e14a..aaa3e4bb3a1e8 100644 --- a/vllm/forward_context.py +++ b/vllm/forward_context.py @@ -1,21 +1,38 @@ from contextlib import contextmanager -from typing import Any +from dataclasses import dataclass +from typing import Any, Dict, Optional -_forward_context: Any = None +from vllm.config import VllmConfig -def get_forward_context() -> Any: +@dataclass +class ForwardContext: + static_forward_context: Dict[str, Any] + # TODO: extend to support per-layer dynamic forward context + dynamic_forward_context: Any + + +_forward_context: Optional[ForwardContext] = None + + +def get_forward_context() -> ForwardContext: """Get the current forward context.""" + assert _forward_context is not None, ( + "Forward context is not set. " + "Please use `set_forward_context` to set the forward context.") return _forward_context @contextmanager -def set_forward_context(context: Any): +def set_forward_context(context: Any, vllm_config: VllmConfig): """A context manager that stores the current forward context, can be attention metadata, etc.""" global _forward_context prev_context = _forward_context - _forward_context = context + _forward_context = ForwardContext( + static_forward_context=vllm_config.compilation_config. + static_forward_context, + dynamic_forward_context=context) try: yield finally: diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index 07ff9faa50f13..fb7dbbebd7b90 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -38,6 +38,9 @@ class TokensPrompt(TypedDict): prompt_token_ids: List[int] """A list of token IDs to pass to the model.""" + token_type_ids: NotRequired[List[int]] + """A list of token type IDs to pass to the cross encoder model.""" + multi_modal_data: NotRequired["MultiModalDataDict"] """ DEPRECATED: Optional multi-modal data to pass to the model, @@ -133,6 +136,9 @@ class TokenInputs(TypedDict): prompt_token_ids: List[int] """The token IDs of the prompt.""" + token_type_ids: NotRequired[List[int]] + """The token type IDs of the prompt.""" + prompt: NotRequired[str] """ The original prompt text corresponding to the token IDs, if available. @@ -160,6 +166,7 @@ class TokenInputs(TypedDict): def token_inputs( prompt_token_ids: List[int], + token_type_ids: Optional[List[int]] = None, prompt: Optional[str] = None, multi_modal_data: Optional["MultiModalDataDict"] = None, multi_modal_placeholders: Optional["MultiModalPlaceholderDict"] = None, @@ -170,6 +177,8 @@ def token_inputs( if prompt is not None: inputs["prompt"] = prompt + if token_type_ids is not None: + inputs["token_type_ids"] = token_type_ids if multi_modal_data is not None: inputs["multi_modal_data"] = multi_modal_data if multi_modal_placeholders is not None: @@ -234,6 +243,15 @@ def prompt_token_ids(self) -> List[int]: assert_never(inputs) + @cached_property + def token_type_ids(self) -> List[int]: + inputs = self.inputs + + if inputs["type"] == "token" or inputs["type"] == "multimodal": + return inputs.get("token_type_ids", []) + + assert_never(inputs) + @cached_property def prompt_embeds(self) -> Optional[torch.Tensor]: inputs = self.inputs diff --git a/vllm/inputs/preprocess.py b/vllm/inputs/preprocess.py index aacff87df6d79..3d606817e90aa 100644 --- a/vllm/inputs/preprocess.py +++ b/vllm/inputs/preprocess.py @@ -10,7 +10,7 @@ from vllm.multimodal.processing import MultiModalDataDict, MultiModalInputsV2 from vllm.prompt_adapter.request import PromptAdapterRequest from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup -from vllm.utils import print_warning_once +from vllm.utils import print_info_once, print_warning_once from .data import (DecoderOnlyInputs, EncoderDecoderInputs, ProcessorInputs, PromptType, SingletonInputs, SingletonPrompt, token_inputs) @@ -212,7 +212,7 @@ def _can_process_multimodal(self) -> bool: # updated to use the new multi-modal processor can_process_multimodal = self.mm_registry.has_processor(model_config) if not can_process_multimodal: - logger.info( + print_info_once( "Your model uses the legacy input pipeline instead of the new " "multi-modal processor. Please note that the legacy pipeline " "will be removed in a future release. For more details, see: " @@ -305,6 +305,7 @@ def _prompt_to_llm_inputs( tokens_content = parsed["content"] prompt_token_ids = tokens_content["prompt_token_ids"] + token_type_ids = tokens_content.get("token_type_ids") multi_modal_data = tokens_content.get("multi_modal_data") mm_processor_kwargs = tokens_content.get("mm_processor_kwargs") @@ -318,6 +319,7 @@ def _prompt_to_llm_inputs( return token_inputs( prompt_token_ids=prompt_token_ids, + token_type_ids=token_type_ids, multi_modal_data=multi_modal_data, mm_processor_kwargs=mm_processor_kwargs, ) diff --git a/vllm/logger.py b/vllm/logger.py index 9e16e591315ba..538db0dcf19aa 100644 --- a/vllm/logger.py +++ b/vllm/logger.py @@ -50,7 +50,7 @@ def _configure_vllm_root_logger() -> None: - logging_config: Optional[Dict] = None + logging_config: Dict = {} if not VLLM_CONFIGURE_LOGGING and VLLM_LOGGING_CONFIG_PATH: raise RuntimeError( @@ -75,6 +75,11 @@ def _configure_vllm_root_logger() -> None: type(custom_config).__name__) logging_config = custom_config + for formatter in logging_config.get("formatters", {}).values(): + # This provides backwards compatibility after #10134. + if formatter.get("class") == "vllm.logging.NewLineFormatter": + formatter["class"] = "vllm.logging_utils.NewLineFormatter" + if logging_config: dictConfig(logging_config) diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index 3443c3feb4d2a..f5c2eced9d2bb 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -44,6 +44,11 @@ class ColumnParallelLinearWithShardedLoRA(ColumnParallelLinearWithLoRA): Based on S-LoRA, slicing happens along the rank dim. """ + # For all LoRA layers where the `base_layer` is `ColumnParallelLinear`, + # their `lora_a` and `lora_b` have different sharding patterns. After + # completing the `lora_a` GEMM , a gather operation is performed. + # Therefore, the sharding of `lora_a` only needs to correspond with the + # gather operation. def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: tp_rank = get_tensor_model_parallel_rank() shard_size = self.lora_a_stacked.shape[2] diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 6afe80219fe07..3701988ff692f 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -451,6 +451,12 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA): def __init__(self, base_layer: ColumnParallelLinear) -> None: super().__init__() + # The base_layer type is ColumnParallelLinear or + # MergedColumnParallelLinear, their weight sharding logic is + # inconsistent when TP is greater than 1. + self.is_merged_col_linear = type( + base_layer) is MergedColumnParallelLinear + self.base_layer = base_layer self.tp_size = get_tensor_model_parallel_world_size() self.input_size = self.base_layer.input_size @@ -508,14 +514,30 @@ def slice_lora_a(self, lora_a: torch.Tensor) -> torch.Tensor: return lora_a def slice_lora_b(self, lora_b: torch.Tensor) -> torch.Tensor: - tensor_model_parallel_rank = get_tensor_model_parallel_rank() - shard_size = self.output_dim - start_idx = tensor_model_parallel_rank * shard_size - end_idx = (tensor_model_parallel_rank + 1) * shard_size - lora_b = lora_b[:, start_idx:end_idx] + # Applicable to cases where the base_layer is + # MergedColumnParallelLinear. + if self.is_merged_col_linear: + tp_rank = get_tensor_model_parallel_rank() + shard_size = self.output_size // 2 + offset = lora_b.shape[-1] // 2 + + left_weight = lora_b[:, tp_rank * shard_size:(tp_rank + 1) * + shard_size] + right_weight = lora_b[:, offset + tp_rank * shard_size:offset + + (tp_rank + 1) * shard_size] + lora_b = torch.cat([left_weight, right_weight], dim=1) + # Applicable to cases where the base_layer is + # ColumnParallelLinear. + else: + tensor_model_parallel_rank = get_tensor_model_parallel_rank() + shard_size = self.output_dim + start_idx = tensor_model_parallel_rank * shard_size + end_idx = (tensor_model_parallel_rank + 1) * shard_size + lora_b = lora_b[:, start_idx:end_idx] return lora_b def slice_bias(self, bias: torch.Tensor) -> torch.Tensor: + # TODO: Fix the slicing logic of bias. if bias is None: return bias tensor_model_parallel_rank = get_tensor_model_parallel_rank() @@ -779,7 +801,7 @@ def can_replace_layer( class QKVParallelLinearWithLora(ColumnParallelLinearWithLoRA): """ ColumnParallelLinear layer that is specifically designed for - qkv_proj. Certain models, such as chtglm3 and baichuan-7b, + qkv_proj. Certain models, such as chatglm3 and baichuan-7b, only contains a single LoRA within their qkv_proj layer. During inference with Tensor Parallel, the weights of lora_b diff --git a/vllm/lora/ops/bgmv_expand.py b/vllm/lora/ops/bgmv_expand.py index f176259fddc78..42adb191b8ead 100644 --- a/vllm/lora/ops/bgmv_expand.py +++ b/vllm/lora/ops/bgmv_expand.py @@ -9,6 +9,8 @@ import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + from .utils import get_lora_op_configs @@ -162,9 +164,24 @@ def _bgmv_expand( return +def bgmv_expand_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + add_inputs: bool = True, +) -> None: + return + + try: - bgmv_expand = torch.library.custom_op("lora::bgmv_expand", - _bgmv_expand, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="bgmv_expand", + op_func=_bgmv_expand, + mutates_args=["output_tensor"], + fake_impl=bgmv_expand_fake, + ) + bgmv_expand = torch.ops.vllm.bgmv_expand + except AttributeError: bgmv_expand = _bgmv_expand diff --git a/vllm/lora/ops/bgmv_expand_slice.py b/vllm/lora/ops/bgmv_expand_slice.py index 2c6ed96c253f0..f397d752a3ea9 100644 --- a/vllm/lora/ops/bgmv_expand_slice.py +++ b/vllm/lora/ops/bgmv_expand_slice.py @@ -9,6 +9,8 @@ import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + from .utils import get_lora_op_configs @@ -179,9 +181,26 @@ def _bgmv_expand_slice( return +def bgmv_expand_slice_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + slice_offset: int, + slice_size: int, + add_inputs: bool = True, +) -> None: + return + + try: - bgmv_expand_slice = torch.library.custom_op("lora::bgmv_expand_slice", - _bgmv_expand_slice, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="bgmv_expand_slice", + op_func=_bgmv_expand_slice, + mutates_args=["output_tensor"], + fake_impl=bgmv_expand_slice_fake, + ) + bgmv_expand_slice = torch.ops.vllm.bgmv_expand_slice + except AttributeError: bgmv_expand_slice = _bgmv_expand_slice diff --git a/vllm/lora/ops/bgmv_shrink.py b/vllm/lora/ops/bgmv_shrink.py index 0846ff36b1692..f3ef01d39e776 100644 --- a/vllm/lora/ops/bgmv_shrink.py +++ b/vllm/lora/ops/bgmv_shrink.py @@ -9,6 +9,8 @@ import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + from .utils import get_lora_op_configs @@ -142,9 +144,24 @@ def _bgmv_shrink( return +def bgmv_shrink_fake( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + scaling: float = 1.0, +) -> None: + return + + try: - bgmv_shrink = torch.library.custom_op("lora::bgmv_shrink", - _bgmv_shrink, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="bgmv_shrink", + op_func=_bgmv_shrink, + mutates_args=["output_tensor"], + fake_impl=bgmv_shrink_fake, + ) + bgmv_shrink = torch.ops.vllm.bgmv_shrink + except AttributeError: bgmv_shrink = _bgmv_shrink diff --git a/vllm/lora/ops/sgmv_expand.py b/vllm/lora/ops/sgmv_expand.py index ee2cd2e05e2ee..77c5178493c44 100644 --- a/vllm/lora/ops/sgmv_expand.py +++ b/vllm/lora/ops/sgmv_expand.py @@ -9,6 +9,8 @@ import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + @triton.jit def _sgmv_expand_kernel( @@ -196,9 +198,30 @@ def _sgmv_expand( return +def sgmv_expand_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + add_inputs: bool = False, +) -> None: + return + + try: - sgmv_expand = torch.library.custom_op("lora::sgmv_expand", - _sgmv_expand, - mutates_args=["output_tensor"]) + + direct_register_custom_op( + op_name="sgmv_expand", + op_func=_sgmv_expand, + mutates_args=["output_tensor"], + fake_impl=sgmv_expand_fake, + ) + sgmv_expand = torch.ops.vllm.sgmv_expand + except AttributeError: sgmv_expand = _sgmv_expand diff --git a/vllm/lora/ops/sgmv_expand_slice.py b/vllm/lora/ops/sgmv_expand_slice.py index 5244fa14913a4..55c4fb68ed128 100644 --- a/vllm/lora/ops/sgmv_expand_slice.py +++ b/vllm/lora/ops/sgmv_expand_slice.py @@ -9,6 +9,8 @@ import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + @triton.jit def _sgmv_expand_slice_kernel( @@ -209,9 +211,31 @@ def _sgmv_expand_slice( return +def sgmv_expand_slice_fake( + inputs: torch.Tensor, + lora_b_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + slice_offset: int, + slice_size: int, + add_inputs: bool = False, +) -> None: + return + + try: - sgmv_expand_slice = torch.library.custom_op("lora::sgmv_expand_slice", - _sgmv_expand_slice, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="sgmv_expand_slice", + op_func=_sgmv_expand_slice, + mutates_args=["output_tensor"], + fake_impl=sgmv_expand_slice_fake, + ) + sgmv_expand_slice = torch.ops.vllm.sgmv_expand_slice + except AttributeError: sgmv_expand_slice = _sgmv_expand_slice diff --git a/vllm/lora/ops/sgmv_shrink.py b/vllm/lora/ops/sgmv_shrink.py index b4d893047b06b..37d1dc84eebca 100644 --- a/vllm/lora/ops/sgmv_shrink.py +++ b/vllm/lora/ops/sgmv_shrink.py @@ -9,6 +9,8 @@ import triton import triton.language as tl +from vllm.utils import direct_register_custom_op + @triton.jit def _sgmv_shrink_kernel( @@ -190,9 +192,29 @@ def _sgmv_shrink( return +def sgmv_shrink_fake( + inputs: torch.Tensor, + lora_a_weights: torch.Tensor, + output_tensor: torch.Tensor, + b_seq_start_loc: torch.Tensor, + seq_len_tensor: torch.Tensor, + lora_indices_tensor: torch.Tensor, + batches: int, + max_seq_length: int, + token_nums: int, + scaling: float, +) -> None: + return + + try: - sgmv_shrink = torch.library.custom_op("lora::sgmv_shrink", - _sgmv_shrink, - mutates_args=["output_tensor"]) + direct_register_custom_op( + op_name="sgmv_shrink", + op_func=_sgmv_shrink, + mutates_args=["output_tensor"], + fake_impl=sgmv_shrink_fake, + ) + sgmv_shrink = torch.ops.vllm.sgmv_shrink + except AttributeError: sgmv_shrink = _sgmv_shrink diff --git a/vllm/model_executor/custom_op.py b/vllm/model_executor/custom_op.py index b07966f2ab7d0..fddc8bad09ef5 100644 --- a/vllm/model_executor/custom_op.py +++ b/vllm/model_executor/custom_op.py @@ -2,9 +2,9 @@ import torch.nn as nn +from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.platforms import current_platform -from vllm.plugins import get_current_vllm_config from vllm.utils import print_warning_once logger = init_logger(__name__) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 2471c160d66b7..46ef11e7d02c6 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -1,3 +1,4 @@ +import itertools from abc import abstractmethod from typing import Dict, List, Optional, Tuple @@ -41,12 +42,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset): def adjust_bitsandbytes_4bit_shard(param: Parameter, - qkv_offsets: Dict[str, Tuple[int, int]], + shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str) -> Tuple[int, int]: """Adjust the quantization offsets and sizes for BitsAndBytes sharding.""" - total, _ = qkv_offsets["total"] - orig_offset, orig_size = qkv_offsets[loaded_shard_id] + total, _ = shard_offsets["total"] + orig_offset, orig_size = shard_offsets[loaded_shard_id] quantized_total = param.data.shape[0] quantized_offset = orig_offset * quantized_total // total @@ -499,9 +500,17 @@ def weight_loader(self, # Special case for Marlin. shard_size, shard_offset = adjust_marlin_shard( param, shard_size, shard_offset) + if use_bitsandbytes_4bit: - shard_size = loaded_weight.shape[output_dim] // 2 - shard_offset = shard_size * shard_id + index = list(itertools.accumulate([0] + self.output_sizes)) + orig_offsets = { + str(i): (index[i], size) + for i, size in enumerate(self.output_sizes) + } + orig_offsets["total"] = (self.output_size, 0) + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_offsets, str(shard_id)) + loaded_weight_shard = loaded_weight.narrow( output_dim, shard_offset, shard_size) self.weight_loader(param, loaded_weight_shard, shard_id) diff --git a/vllm/model_executor/layers/pooler.py b/vllm/model_executor/layers/pooler.py index bfe2d7d0f382e..f9437b4112ceb 100644 --- a/vllm/model_executor/layers/pooler.py +++ b/vllm/model_executor/layers/pooler.py @@ -3,11 +3,14 @@ import torch import torch.nn as nn +from transformers import PretrainedConfig from vllm.config import PoolerConfig from vllm.model_executor.pooling_metadata import (PoolingMetadata, PoolingTensors) from vllm.sequence import EmbeddingSequenceGroupOutput, PoolerOutput +from vllm.transformers_utils.config import ( + get_cross_encoder_activation_function) class PoolingType(IntEnum): @@ -94,14 +97,10 @@ def forward( pooled_data = hidden_states[last_token_flat_indices] elif self.pooling_type == PoolingType.ALL: offset = 0 - pooled_data_lst = [] + pooled_data = [] for prompt_len in prompt_lens: - pooled_data_i = hidden_states[offset:offset + prompt_len] - - pooled_data_lst.append(pooled_data_i) + pooled_data.append(hidden_states[offset:offset + prompt_len]) offset += prompt_len - - pooled_data = torch.stack(pooled_data_lst) elif self.pooling_type == PoolingType.MEAN: # Calculate mean pooling cumsum = torch.cumsum(hidden_states, dim=0) @@ -121,7 +120,7 @@ def forward( step_tag_id = self.step_tag_id offset = 0 - pooled_data_lst = [] + pooled_data = [] for prompt_len, seq_data_i in zip( prompt_lens, pooling_metadata.seq_data.values()): pooled_data_i = hidden_states[offset:offset + prompt_len] @@ -130,20 +129,90 @@ def forward( pooled_data_i = pooled_data_i[token_ids == step_tag_id] offset += prompt_len - pooled_data_lst.append(pooled_data_i) - - pooled_data = torch.stack(pooled_data_lst) + pooled_data.append(pooled_data_i) else: raise ValueError(f"Invalid pooling type: {self.pooling_type}") if self.normalize: - pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) + if isinstance(pooled_data, list): + pooled_data = [ + nn.functional.normalize(data, p=2, dim=1) + for data in pooled_data + ] + else: + pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1) if self.softmax: - pooled_data = nn.functional.softmax(pooled_data, dim=-1) + if isinstance(pooled_data, list): + pooled_data = [ + nn.functional.softmax(data, dim=-1) for data in pooled_data + ] + else: + pooled_data = nn.functional.softmax(pooled_data, dim=-1) pooled_outputs = [ EmbeddingSequenceGroupOutput(data.tolist()) for data in pooled_data ] return PoolerOutput(outputs=pooled_outputs) + + +class CrossEncodingPooler(nn.Module): + """A layer that pools specific information from hidden states. + + This layer does the following: + 1. Extracts specific tokens or aggregates data based on pooling method. + 2. Normalizes output if specified. + 3. Returns structured results as `PoolerOutput`. + + Attributes: + pooling_type: The type of pooling to use. + normalize: Whether to normalize the pooled data. + """ + + def __init__( + self, + config: PretrainedConfig, + classifier: nn.Module, + pooler: Optional[nn.Module] = None, + ): + super().__init__() + self.classifier = classifier + self.pooler = pooler + self.default_activation_function = \ + get_cross_encoder_activation_function(config) + + def forward( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> PoolerOutput: + """Pools sentence pair scores from the hidden_states.""" + + prompt_lens = PoolingTensors.from_pooling_metadata( + pooling_metadata, hidden_states.device).prompt_lens + + offset = 0 + pooled_data_lst = [] + for prompt_len in prompt_lens: + pooled_data_i = hidden_states[offset:offset + prompt_len] + + if self.pooler is not None: + final_shape_tensor = self.pooler(pooled_data_i) + else: + final_shape_tensor = self.classifier(pooled_data_i) + + pooled_data_lst.append(final_shape_tensor) + offset += prompt_len + + pooled_output = torch.stack(pooled_data_lst) + + if self.pooler is not None: + # apply classifier once on the full batch if possible + pooled_output = self.classifier(pooled_output) + logits = self.default_activation_function(pooled_output) + + pooled_outputs = [ + EmbeddingSequenceGroupOutput(data.tolist()) for data in logits + ] + return PoolerOutput(outputs=pooled_outputs) diff --git a/vllm/model_executor/layers/quantization/__init__.py b/vllm/model_executor/layers/quantization/__init__.py index ff342c4f9479e..dd10c434f0752 100644 --- a/vllm/model_executor/layers/quantization/__init__.py +++ b/vllm/model_executor/layers/quantization/__init__.py @@ -1,65 +1,87 @@ -from typing import Dict, Type +from typing import Dict, List, Type -from vllm.model_executor.layers.quantization.aqlm import AQLMConfig -from vllm.model_executor.layers.quantization.awq import AWQConfig -from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig) -from vllm.model_executor.layers.quantization.bitsandbytes import ( - BitsAndBytesConfig) -from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 - CompressedTensorsConfig) -from vllm.model_executor.layers.quantization.deepspeedfp import ( - DeepSpeedFPConfig) -from vllm.model_executor.layers.quantization.experts_int8 import ( - ExpertsInt8Config) -from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config -from vllm.model_executor.layers.quantization.fp8 import Fp8Config -from vllm.model_executor.layers.quantization.gguf import GGUFConfig -from vllm.model_executor.layers.quantization.gptq import GPTQConfig -from vllm.model_executor.layers.quantization.gptq_marlin import ( - GPTQMarlinConfig) -from vllm.model_executor.layers.quantization.gptq_marlin_24 import ( - GPTQMarlin24Config) -from vllm.model_executor.layers.quantization.hqq_marlin import HQQMarlinConfig -from vllm.model_executor.layers.quantization.ipex_quant import IPEXConfig -from vllm.model_executor.layers.quantization.marlin import MarlinConfig -from vllm.model_executor.layers.quantization.modelopt import ModelOptFp8Config -from vllm.model_executor.layers.quantization.neuron_quant import ( - NeuronQuantConfig) -from vllm.model_executor.layers.quantization.qqq import QQQConfig -from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig -QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = { - "aqlm": AQLMConfig, - "awq": AWQConfig, - "deepspeedfp": DeepSpeedFPConfig, - "tpu_int8": Int8TpuConfig, - "fp8": Fp8Config, - "fbgemm_fp8": FBGEMMFp8Config, - "modelopt": ModelOptFp8Config, +QUANTIZATION_METHODS: List[str] = [ + "aqlm", + "awq", + "deepspeedfp", + "tpu_int8", + "fp8", + "fbgemm_fp8", + "modelopt", # The order of gptq methods is important for config.py iteration over # override_quantization_method(..) - "marlin": MarlinConfig, - "gguf": GGUFConfig, - "gptq_marlin_24": GPTQMarlin24Config, - "gptq_marlin": GPTQMarlinConfig, - "awq_marlin": AWQMarlinConfig, - "gptq": GPTQConfig, - "compressed-tensors": CompressedTensorsConfig, - "bitsandbytes": BitsAndBytesConfig, - "qqq": QQQConfig, - "hqq": HQQMarlinConfig, - "experts_int8": ExpertsInt8Config, - "neuron_quant": NeuronQuantConfig, - "ipex": IPEXConfig, -} + "marlin", + "gguf", + "gptq_marlin_24", + "gptq_marlin", + "awq_marlin", + "gptq", + "compressed-tensors", + "bitsandbytes", + "qqq", + "hqq", + "experts_int8", + "neuron_quant", + "ipex", +] def get_quantization_config(quantization: str) -> Type[QuantizationConfig]: if quantization not in QUANTIZATION_METHODS: raise ValueError(f"Invalid quantization method: {quantization}") - return QUANTIZATION_METHODS[quantization] + + # lazy import to avoid triggering `torch.compile` too early + from .aqlm import AQLMConfig + from .awq import AWQConfig + from .awq_marlin import AWQMarlinConfig + from .bitsandbytes import BitsAndBytesConfig + from .compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig) + from .deepspeedfp import DeepSpeedFPConfig + from .experts_int8 import ExpertsInt8Config + from .fbgemm_fp8 import FBGEMMFp8Config + from .fp8 import Fp8Config + from .gguf import GGUFConfig + from .gptq import GPTQConfig + from .gptq_marlin import GPTQMarlinConfig + from .gptq_marlin_24 import GPTQMarlin24Config + from .hqq_marlin import HQQMarlinConfig + from .ipex_quant import IPEXConfig + from .marlin import MarlinConfig + from .modelopt import ModelOptFp8Config + from .neuron_quant import NeuronQuantConfig + from .qqq import QQQConfig + from .tpu_int8 import Int8TpuConfig + + method_to_config: Dict[str, Type[QuantizationConfig]] = { + "aqlm": AQLMConfig, + "awq": AWQConfig, + "deepspeedfp": DeepSpeedFPConfig, + "tpu_int8": Int8TpuConfig, + "fp8": Fp8Config, + "fbgemm_fp8": FBGEMMFp8Config, + "modelopt": ModelOptFp8Config, + # The order of gptq methods is important for config.py iteration over + # override_quantization_method(..) + "marlin": MarlinConfig, + "gguf": GGUFConfig, + "gptq_marlin_24": GPTQMarlin24Config, + "gptq_marlin": GPTQMarlinConfig, + "awq_marlin": AWQMarlinConfig, + "gptq": GPTQConfig, + "compressed-tensors": CompressedTensorsConfig, + "bitsandbytes": BitsAndBytesConfig, + "qqq": QQQConfig, + "hqq": HQQMarlinConfig, + "experts_int8": ExpertsInt8Config, + "neuron_quant": NeuronQuantConfig, + "ipex": IPEXConfig, + } + + return method_to_config[quantization] __all__ = [ diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 936c2fe415375..441dd409b4f9d 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -5,6 +5,7 @@ import fnmatch import glob import inspect +import itertools import json import math import os @@ -22,12 +23,14 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import (LoadConfig, LoadFormat, ModelConfig, ParallelConfig, - VllmConfig) + VllmConfig, set_current_vllm_config) from vllm.distributed import (get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.envs import VLLM_USE_MODELSCOPE from vllm.logger import init_logger -from vllm.model_executor.layers.linear import (ReplicatedLinear, +from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, + QKVParallelLinear, + ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.quantization.base_config import ( QuantizeMethodBase) @@ -44,7 +47,6 @@ safetensors_weights_iterator) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform -from vllm.plugins import set_current_vllm_config from vllm.utils import is_pin_memory_available @@ -936,6 +938,34 @@ def _unquantized_generator(self, hf_weights_files, use_safetensors, end_index = total_size // tp_size * (tp_rank + 1) weight_sub_tensor = weight_tensor[..., start_index:end_index] + # Weights have fused on disk. In this case, we assume that the + # weight and module use same name. + elif any( + weight_name.startswith(module) + for module in self.maybe_fused_weights_modules): + # special case for fused weights + # get the size of each shard weight tensor + total_shard_sizes = next( + (sizes for module, sizes in + self.maybe_fused_weights_modules.items() + if weight_name.startswith(module))) + total_size = weight_tensor.size(0) + assert total_size == sum(total_shard_sizes) + # get the start/end index of each shard weight tensor + total_start_index = list( + itertools.accumulate([0] + total_shard_sizes))[:-1] + shard_weights_index = [ + (idx + size // tp_size * tp_rank, + idx + size // tp_size * (tp_rank + 1)) + for idx, size in zip(total_start_index, + total_shard_sizes) + ] + # slice and reorder the weight tensor + weight_tensor = [ + weight_tensor[start_index:end_index, ...] + for start_index, end_index in shard_weights_index + ] + weight_sub_tensor = torch.cat(weight_tensor, dim=0) # Shard by row else: total_size = weight_tensor.size(0) @@ -985,12 +1015,22 @@ def _load_weights(self, model_config: ModelConfig, else: self.target_modules = self.default_target_modules + # Modules whose weights might have fused on disk + # we need their output_sizes to make shard in flight correctly with TP + self.maybe_fused_weights_modules: Dict[str, List[int]] = {} + for name, module in model.named_modules(): # Some modules like `ReplicatedLinear` should not have their weights # sharded. The reason for implementing it this way is to avoid new # static variable in the model implementation. if isinstance(module, (ReplicatedLinear, )): self.unsharded_weights_modules.append(name) + # `QKVParallelLinear` and `MergedColumnParallelLinear` might have + # fused weights on disk. We need to use the output sizes of these + # modules to shard the weights correctly. + elif isinstance(module, + (QKVParallelLinear, MergedColumnParallelLinear)): + self.maybe_fused_weights_modules[name] = module.output_sizes # In TP, these weights are partitioned along the column # dimension (dim=-1) elif isinstance(module, (RowParallelLinear, )): diff --git a/vllm/model_executor/model_loader/tensorizer.py b/vllm/model_executor/model_loader/tensorizer.py index c48b287ed181a..87f3fcb5cae00 100644 --- a/vllm/model_executor/model_loader/tensorizer.py +++ b/vllm/model_executor/model_loader/tensorizer.py @@ -13,7 +13,7 @@ from transformers import PretrainedConfig import vllm.envs as envs -from vllm.config import ModelConfig, ParallelConfig +from vllm.config import ModelConfig, ParallelConfig, set_current_vllm_config from vllm.engine.arg_utils import EngineArgs from vllm.engine.llm_engine import LLMEngine from vllm.logger import init_logger @@ -284,7 +284,8 @@ def _init_model(self): model_args = self.tensorizer_config.hf_config model_args.torch_dtype = self.tensorizer_config.dtype assert self.tensorizer_config.model_class is not None - with no_init_or_tensor(): + # TODO: Do we need to consider old-style model class? + with no_init_or_tensor(), set_current_vllm_config(self.vllm_config): return self.tensorizer_config.model_class( vllm_config=self.vllm_config, ) diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index e58ad19cab54c..fd6b5659df5d1 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -33,7 +33,7 @@ from vllm.transformers_utils.configs.arctic import ArcticConfig from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -44,15 +44,14 @@ class ArcticMLP(nn.Module): def __init__(self, config: ArcticConfig, - layer_id: int, expert_id: int = -1, is_residual_mlp: bool = False, quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True): + reduce_results: bool = True, + prefix: str = ""): super().__init__() self.hidden_size = config.hidden_size self.expert_id = expert_id - self.layer_id = layer_id self.ffn_dim = config.intermediate_size if not is_residual_mlp \ else self.hidden_size @@ -85,13 +84,14 @@ class ArcticMoE(nn.Module): def __init__(self, config: ArcticConfig, - layer_id: int, tp_size: Optional[int] = None, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - reduce_results: bool = True): + reduce_results: bool = True, + prefix: str = ""): super().__init__() + layer_id = extract_layer_index(prefix) self.tp_size = tp_size or get_tensor_model_parallel_world_size() self.hidden_size = config.hidden_size self.num_experts = config.num_local_experts @@ -109,15 +109,16 @@ def __init__(self, if not self.is_moe_layer: self.mlp = ArcticMLP(config, - layer_id=layer_id, quant_config=quant_config, - reduce_results=reduce_results) + reduce_results=reduce_results, + prefix=f"{prefix}.mlp") else: self.gate = ReplicatedLinear(self.hidden_size, self.num_experts, bias=False, params_dtype=self.params_dtype, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.gate") if self.is_quant: self.ws = DeepSpeedFPParameter( torch.Size((self.num_experts, 2 * self.intermediate_size, @@ -220,13 +221,12 @@ class ArcticAttention(nn.Module): def __init__( self, config: ArcticConfig, - layer_idx: Optional[int] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config - self.layer_idx = layer_idx self.hidden_size = config.hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -274,7 +274,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -296,24 +297,25 @@ class ArcticDecoderLayer(nn.Module): def __init__( self, config: ArcticConfig, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() - self.layer_idx = layer_idx self.hidden_size = config.hidden_size + layer_idx = extract_layer_index(prefix) is_moe_layer = (layer_idx + 1) % config.moe_layer_frequency == 0 self.use_residual = config.use_residual and is_moe_layer self.self_attn = ArcticAttention(config, - layer_idx, cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn") self.block_sparse_moe = ArcticMoE( config, - layer_id=layer_idx, quant_config=quant_config, - reduce_results=(not self.use_residual)) + reduce_results=(not self.use_residual), + prefix=f"{prefix}.block_sparse_moe", + ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -324,9 +326,9 @@ def __init__( self.residual_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.residual_mlp = ArcticMLP(config, - layer_id=layer_idx, is_residual_mlp=True, - reduce_results=False) + reduce_results=False, + prefix=f"{prefix}.residual_mlp") def forward( self, @@ -380,8 +382,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): org_num_embeddings=self.vocab_size) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: ArcticDecoderLayer(config, int( - prefix.split(".")[-1]), cache_config, quant_config), + lambda prefix: ArcticDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self._attn_implementation = config._attn_implementation self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py new file mode 100644 index 0000000000000..0356435e9c257 --- /dev/null +++ b/vllm/model_executor/models/aria.py @@ -0,0 +1,695 @@ +import math +from typing import Iterable, List, Optional, Set, Tuple, TypedDict, Union + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from transformers import LlamaConfig + +from vllm.attention import AttentionMetadata +from vllm.config import CacheConfig, QuantizationConfig, VllmConfig +from vllm.distributed import get_tensor_model_parallel_rank +from vllm.inputs import INPUT_REGISTRY, token_inputs +from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( + get_compressed_tensors_cache_scale) +from vllm.model_executor.layers.sampler import (Sampler, SamplerOutput, + SamplingMetadata) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, maybe_remap_kv_scale_name) +from vllm.model_executor.models.idefics2_vision_model import ( + Idefics2VisionTransformer) +from vllm.model_executor.models.interfaces import SupportsMultiModal +from vllm.model_executor.models.llama import (LlamaDecoderLayer, LlamaMLP, + LlamaModel) +from vllm.model_executor.models.utils import (AutoWeightsLoader, WeightsMapper, + is_pp_missing_parameter, + make_layers, maybe_prefix, + merge_multimodal_embeddings) +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.multimodal.base import MultiModalInputs +from vllm.multimodal.image import cached_get_image_processor +from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.utils import (cached_get_tokenizer, + repeat_and_pad_placeholder_tokens) +from vllm.sequence import IntermediateTensors +from vllm.transformers_utils.configs.aria import (AriaMoELMConfig, + AriaVisionConfig) + +from .utils import flatten_bn + + +class AriaImagePixelInputs(TypedDict): + pixel_values: torch.Tensor + pixel_mask: Optional[torch.Tensor] + """ + Shape: + pixel_values: `(batch_size * num_images, num_channels, height, width)` + pixel_mask: `(batch_size * num_images, height, width)` + """ + + +class AriaVisionTransformer(Idefics2VisionTransformer): + """ + AriaVisionTransformer is a modified version of Idefics2VisionTransformer + that replaces the post-layernorm with an identity layer. + """ + + def __init__( + self, + config: AriaVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, quant_config, prefix) + self.post_layernorm = nn.Identity() + + +class AriaVisionModel(nn.Module): + config_class = AriaVisionConfig + + def __init__( + self, + config: AriaVisionConfig, + quant_config: Optional[QuantizationConfig] = None, + *, + prefix: str = "", + ) -> None: + super().__init__() + + self.vision_model = AriaVisionTransformer( + config, + quant_config, + prefix=f"{prefix}.vision_model", + ) + + def forward( + self, + pixel_values: torch.Tensor, + pixel_mask: Optional[torch.BoolTensor] = None, + ) -> Tuple[torch.Tensor, Optional[torch.BoolTensor]]: + patch_attention_mask = self._create_patch_attention_mask(pixel_mask) + + vit_oup = self.vision_model( + pixel_values=pixel_values, + patch_attention_mask=patch_attention_mask, + ) + + image_atts = self._create_image_attention_mask(patch_attention_mask) + + return vit_oup, image_atts + + def _create_patch_attention_mask(self, pixel_mask): + if pixel_mask is None: + return None + + patches_subgrid = pixel_mask.unfold( + dimension=1, + size=self.vision_model.config.patch_size, + step=self.vision_model.config.patch_size, + ).unfold( + dimension=2, + size=self.vision_model.config.patch_size, + step=self.vision_model.config.patch_size, + ) + return (patches_subgrid.sum(dim=(-1, -2)) > 0).bool() + + def _create_image_attention_mask(self, patch_attention_mask): + if patch_attention_mask is None: + return None + + flattened_mask = patch_attention_mask.flatten(1) + return torch.logical_not(flattened_mask) + + +class FFN(nn.Module): + + def __init__(self, embed_dim, ff_dim, output_dim): + super().__init__() + self.linear_in = ColumnParallelLinear(embed_dim, ff_dim, bias=False) + self.linear_out = RowParallelLinear(ff_dim, output_dim, bias=False) + self.act = get_act_fn("gelu_new") + + def forward(self, hidden_states): + hidden_states, _ = self.linear_in(hidden_states) + hidden_states = self.act(hidden_states) + hidden_states, _ = self.linear_out(hidden_states) + return hidden_states + + +class CrossAttention(nn.Module): + + def __init__(self, kv_dim, embed_dim, num_heads, drop_out_rate=0): + super().__init__() + self.num_heads = num_heads + self.q_proj = nn.Linear(embed_dim, embed_dim, bias=False) + self.k_proj = nn.Linear(kv_dim, embed_dim, bias=False) + self.v_proj = nn.Linear(kv_dim, embed_dim, bias=False) + + self.multihead_attn = nn.MultiheadAttention(embed_dim, num_heads) + self.linear = nn.Linear(embed_dim, embed_dim) + self.dropout = nn.Dropout(drop_out_rate) + + self.layer_norm = nn.LayerNorm(embed_dim) + self.ln_kv = nn.LayerNorm(kv_dim) + + def forward(self, x, hidden_states, attn_mask=None, add_residual=False): + normed_hidden_states = self.layer_norm(hidden_states) + query = self.q_proj(normed_hidden_states).permute(1, 0, 2) + + x = self.ln_kv(x) + key = self.k_proj(x).permute(1, 0, 2) + value = self.v_proj(x).permute(1, 0, 2) + + attn_output, _ = self.multihead_attn(query, + key, + value, + attn_mask=attn_mask) + + attn_output = attn_output.permute(1, 0, 2) + + if add_residual: + attn_output = hidden_states + self.dropout( + self.linear(attn_output)) + else: + attn_output = self.dropout(self.linear(attn_output)) + + return attn_output + + +class AriaProjector(nn.Module): + """ + A projection module with one cross attention layer and one FFN layer, which + projects ViT's outputs into MoE's inputs. + + Args: + patch_to_query_dict (dict): Maps patch numbers to their corresponding + query numbers, + e.g., {1225: 128, 4900: 256}. This allows for different query sizes + based on image resolution. + embed_dim (int): Embedding dimension. + num_heads (int): Number of attention heads. + kv_dim (int): Dimension of key and value. + ff_dim (int): Hidden dimension of the feed-forward network. + output_dim (int): Output dimension. + norm_layer (nn.Module): Normalization layer. Default is nn.LayerNorm. + + Outputs: + A tensor with the shape of (batch_size, query_number, output_dim) + """ + + def __init__( + self, + patch_to_query_dict, + embed_dim, + num_heads, + kv_dim, + ff_dim, + output_dim, + norm_layer=nn.LayerNorm, + ): + super().__init__() + self.patch_to_query_dict = patch_to_query_dict + self.embed_dim = embed_dim + self.num_heads = num_heads + + self.query = nn.Parameter( + torch.zeros(max(patch_to_query_dict.values()), self.embed_dim)) + + trunc_normal_(self.query, std=0.02) + + self.cross_attn = CrossAttention(kv_dim, embed_dim, num_heads) + + self.ln_ffn = norm_layer(embed_dim) + self.ffn = FFN(embed_dim, ff_dim, output_dim) + + def forward(self, x, attn_mask=None): + bs = x.shape[0] + queries = self.query.unsqueeze(0).repeat(bs, 1, 1) + + query_num = self.patch_to_query_dict.get(x.shape[1], None) + assert (query_num is not None + ), f"Query number for {x.shape[1]} patches is not provided" + + queries = queries[:, :query_num, :] + + if attn_mask is not None: + attn_mask = attn_mask.repeat_interleave(self.num_heads, 0) + attn_mask = attn_mask.unsqueeze(1).expand(-1, queries.size(1), -1) + + attention_out = self.cross_attn(x, queries, attn_mask=attn_mask) + + out = self.ffn(self.ln_ffn(attention_out)) + + return out + + +class AriaFusedMoE(FusedMoE): + + def weight_loader(self, param: nn.Parameter, loaded_weight: torch.Tensor, + shard_id: str) -> Set[str]: + # Override the weight_loader to handle the expert weights in the Aria + # model, which are already packed with experts, and merge the gate and + # up weights for each expert. + # Note: Loading expert weights with quantization is not supported + tp_rank = get_tensor_model_parallel_rank() + if shard_id == 'w13': + # the shape of loaded_weight is + # (num_experts, hidden_size, 2 * moe_intermediate_size) + if self.tp_size > 1: + up, gate = loaded_weight.chunk(2, dim=-1) + up_current_rank = up.chunk(self.tp_size, dim=-1)[tp_rank] + gate_current_rank = gate.chunk(self.tp_size, dim=-1)[tp_rank] + up_and_gate = torch.cat([up_current_rank, gate_current_rank], + dim=-1).transpose(1, 2) + param.data.copy_(up_and_gate) + else: + param.data.copy_(loaded_weight.transpose(1, 2)) + elif shard_id == 'w2': + # the shape of loaded_weight is + # (num_experts, moe_intermediate_size, hidden_size) + if self.tp_size > 1: + down_current_rank = loaded_weight.chunk(self.tp_size, + dim=1)[tp_rank] + param.data.copy_(down_current_rank.transpose(1, 2)) + else: + param.data.copy_(loaded_weight.transpose(1, 2)) + + +class MoELayer(nn.Module): + """ + Mixture of Experts (MoE) Layer for the AriaMoE model. + + This layer implements the MoE mechanism, which routes input tokens to + different experts based on a routing algorithm, processes them through the + experts, and then combines the outputs. + """ + + def __init__( + self, + config: AriaMoELMConfig, + quant_config: Optional[QuantizationConfig], + ) -> None: + super().__init__() + self.config = config + + self.router_weight = nn.Parameter( + torch.empty( + (self.config.moe_num_experts, self.config.hidden_size))) + + self.experts = AriaFusedMoE( + num_experts=config.moe_num_experts, + top_k=config.moe_topk, + hidden_size=config.hidden_size, + intermediate_size=config.moe_intermediate_size, + quant_config=quant_config, + reduce_results=True, + ) + self.shared_experts = LlamaMLP( + config.hidden_size, + config.moe_intermediate_size * config.moe_num_shared_experts, + "silu", + quant_config=quant_config, + ) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + """ + Forward pass of the MoE Layer. + + Args: + hidden_states (torch.Tensor): Input tensor of shape (batch_size, + sequence_length, hidden_size). + + Returns: + torch.Tensor: Output tensor after passing through the MoE layer. + """ + + router_output = torch.nn.functional.linear(hidden_states, + self.router_weight) + + shared_expert_output = self.shared_experts(hidden_states) + sparse_expert_output = self.experts(hidden_states, router_output) + + return sparse_expert_output + shared_expert_output + + +class MoEDecoderLayer(LlamaDecoderLayer): + """ + Custom Decoder Layer for the AriaMoE model which modifies the standard + `LlamaDecoderLayer` by replacing the traditional MLP with a Mixture of + Experts (MoE) Layer. + """ + + def __init__( + self, + config: LlamaConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__(config, cache_config, quant_config, prefix) + self.mlp = MoELayer(config, quant_config=quant_config) + + +class AriaMoELMModel(LlamaModel): + """ + Custom LlamaModel for the AriaMoE model which modifies the standard + LlamaModel by replacing the `LlamaDecoderLayer` with `MoEDecoderLayer`. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__(vllm_config=vllm_config, prefix=prefix) + + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config + + # FIXME: this is a hack to disable the compilation of the model + self.do_not_compile = True + + self.layers = None + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MoEDecoderLayer( + config=config, + cache_config=cache_config, + quant_config=quant_config, + prefix=prefix, + ), + prefix=f"{prefix}.layers", + ) + + # Adapted from LlamaModel.load_weights with the modification of adding + # the expert weights mapping to `stacked_params_mapping` + def load_weights(self, weights: Iterable[Tuple[str, + torch.Tensor]]) -> Set[str]: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ("experts.w13_weight", "experts.fc1.weight", 'w13'), + ("experts.w2_weight", "experts.fc2.weight", 'w2'), + ] + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + if ("rotary_emb.cos_cached" in name + or "rotary_emb.sin_cached" in name): + # Models trained using ColossalAI may include these tensors in + # the checkpoint. Skip them. + continue + if scale_name := get_compressed_tensors_cache_scale(name): + # Loading kv cache scales for compressed-tensors quantization + param = params_dict[scale_name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + loaded_weight = loaded_weight[0] + weight_loader(param, loaded_weight) + loaded_params.add(scale_name) + continue + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + # Remapping the name of FP8 kv-scale. + name = maybe_remap_kv_scale_name(name, params_dict) + if name is None: + continue + + if is_pp_missing_parameter(name, self): + continue + + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + loaded_params.add(name) + return loaded_params + + +def build_mm_projector(config): + return AriaProjector( + patch_to_query_dict=config.projector_patch_to_query_dict, + embed_dim=config.vision_config.hidden_size, + num_heads=config.vision_config.num_attention_heads, + kv_dim=config.vision_config.hidden_size, + ff_dim=config.text_config.hidden_size, + output_dim=config.text_config.hidden_size, + ) + + +def get_max_multimodal_tokens(ctx): + return max(ctx.model_config.hf_config.image_size2tokens.values()) + + +def input_mapper_for_aria(ctx, data): + return MultiModalInputs(data) + + +def input_processor(ctx, llm_inputs): + multi_modal_data = llm_inputs.get("multi_modal_data") + # if it is pure text input, use it as is + if multi_modal_data is None or "image" not in multi_modal_data: + return llm_inputs + + model_config = ctx.model_config + + tokenizer = cached_get_tokenizer(model_config.tokenizer) + image_processor = cached_get_image_processor( + model_config.model, trust_remote_code=model_config.trust_remote_code) + hf_config = model_config.hf_config + + # prepare image tokens, the max_image_size is used to determine the number + # of patch_size for every image + max_image_size = multi_modal_data.pop("max_image_size", 980) + _split_image = multi_modal_data.pop("split_image", False) + + assert isinstance(max_image_size, + (int, float)), "max_image_size should be float or int" + images = (multi_modal_data["image"] if isinstance( + multi_modal_data["image"], list) else [multi_modal_data["image"]]) + + image_inputs = image_processor.preprocess(images, + max_image_size=max_image_size, + split_image=_split_image, + return_tensors="pt").data + image_inputs['pixel_values'] = image_inputs['pixel_values'].to( + ctx.model_config.dtype) + num_crops = image_inputs.pop("num_crops") + + prompt_token_ids = llm_inputs["prompt_token_ids"] + if num_crops.sum().item() > 0: + _, prompt_token_ids, _ = repeat_and_pad_placeholder_tokens( + tokenizer, + None, + prompt_token_ids, + placeholder_token_id=hf_config.image_token_index, + repeat_count=num_crops, + ) + + repeat_count = [hf_config.image_size2tokens[max_image_size] + ] * sum(num_crops).item() + new_prompt, new_token_ids, _ = repeat_and_pad_placeholder_tokens( + tokenizer, + None, + prompt_token_ids, + placeholder_token_id=hf_config.image_token_index, + repeat_count=repeat_count, + ) + + return token_inputs( + prompt_token_ids=new_token_ids, + prompt=new_prompt, + multi_modal_data={"image": image_inputs}, + ) + + +@MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_multimodal_tokens) +@MULTIMODAL_REGISTRY.register_image_input_mapper(input_mapper_for_aria) +@INPUT_REGISTRY.register_input_processor(input_processor) +class AriaForConditionalGeneration(nn.Module, SupportsMultiModal): + """ + Aria model for conditional generation tasks. + + This model combines a vision tower, a multi-modal projector, and a language + model to perform tasks that involve both image and text inputs. + """ + + def __init__( + self, + vllm_config: VllmConfig, + prefix: str = "", + ): + super().__init__() + config = vllm_config.model_config.hf_config + quant_config = vllm_config.quant_config + + # prepare the image_size to tokens mapping for the image preprocess, see + # input_processor + config.image_size2tokens = { + int(math.sqrt(k) * config.vision_config.patch_size): v + for k, v in config.projector_patch_to_query_dict.items() + } + self.config = config + self.vision_tower = AriaVisionModel(config.vision_config) + self.multi_modal_projector = build_mm_projector(config) + self.vocab_size = config.text_config.vocab_size + self.language_model = AriaMoELMModel( + vllm_config=vllm_config.with_hf_config(config.text_config), + prefix=maybe_prefix(prefix, "language_model.model"), + ) + self.pad_token_id = (self.config.pad_token_id + if self.config.pad_token_id is not None else -1) + self.unpadded_vocab_size = config.text_config.vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.text_config.hidden_size, + org_num_embeddings=self.language_model.org_vocab_size, + quant_config=quant_config, + ) + logit_scale = getattr(config, "logit_scale", 1.0) + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + self.vocab_size, logit_scale) + self.sampler = Sampler() + + def _validate_image_sizes( + self, images: List[torch.Tensor]) -> List[torch.Tensor]: + if not all(img.shape == images[0].shape for img in images): + raise ValueError("All images must be the same size") + return images + + def _parse_and_validate_image_input( + self, **kwargs: object) -> Optional[AriaImagePixelInputs]: + pixel_values = kwargs.pop("pixel_values", None) + pixel_mask = kwargs.pop("pixel_mask", None) + + if pixel_values is None: + return None + + if not isinstance(pixel_values, (torch.Tensor, list)): + raise ValueError("Incorrect type of pixel values. " + f"Got type: {type(pixel_values)}") + + pixel_values = self._validate_image_sizes(pixel_values) + pixel_values = flatten_bn(pixel_values, concat=True) + if pixel_mask is not None: + pixel_mask = flatten_bn(pixel_mask, concat=True) + + return AriaImagePixelInputs( + pixel_values=pixel_values, + pixel_mask=pixel_mask, + ) + + def _process_image_input( + self, image_input: AriaImagePixelInputs + ) -> Tuple[torch.Tensor, torch.Tensor]: + assert self.vision_tower is not None + + pixel_values = image_input['pixel_values'] + pixel_mask = image_input['pixel_mask'] + + image_feature, image_attn_mask = self.vision_tower( + pixel_values, pixel_mask=pixel_mask) + return self.multi_modal_projector(image_feature, image_attn_mask) + + def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is None: + return None + multimodal_embeddings = self._process_image_input(image_input) + return multimodal_embeddings + + def get_input_embeddings( + self, + input_ids: torch.Tensor, + multimodal_embeddings: Optional[NestedTensors] = None, + ) -> torch.Tensor: + inputs_embeds = self.language_model.get_input_embeddings(input_ids) + if multimodal_embeddings is not None: + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, multimodal_embeddings, + self.config.image_token_index) + return inputs_embeds + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + **kwargs: object, + ) -> Union[torch.Tensor, IntermediateTensors]: + if inputs_embeds is None: + multimodal_embeddings = self.get_multimodal_embeddings(**kwargs) + # always pass the input via `inputs_embeds` + # to make sure the computation graph is consistent + inputs_embeds = self.get_input_embeddings(input_ids, + multimodal_embeddings) + input_ids = None + + hidden_states = self.language_model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds, + ) + + return hidden_states + + def compute_logits(self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> torch.Tensor: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def sample( + self, + logits: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "language_model.model": "language_model", + "language_model.lm_head": "lm_head", + }, + orig_to_new_suffix={ + "router.weight": "router_weight", + }, + ) + + loader = AutoWeightsLoader(self) + loader.load_weights(weights, mapper=hf_to_vllm_mapper) diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index 3749a16a38994..39cb5a8b2cbbe 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -116,6 +116,7 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = hidden_size @@ -158,7 +159,8 @@ def __init__( self.head_dim, scaling, alibi_slopes=alibi_slopes, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") else: self.rotary_emb = get_rope( self.head_dim, @@ -171,7 +173,8 @@ def __init__( self.head_dim, self.scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -195,7 +198,8 @@ def __init__(self, config: PretrainedConfig, position_embedding: str, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) @@ -209,6 +213,7 @@ def __init__(self, max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) self.mlp = BaiChuanMLP( hidden_size=self.hidden_size, @@ -275,8 +280,11 @@ def __init__( ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: BaiChuanDecoderLayer(config, position_embedding, - cache_config, quant_config), + lambda prefix: BaiChuanDecoderLayer(config, + position_embedding, + cache_config, + quant_config, + prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -342,6 +350,21 @@ class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): embedding_modules = {} embedding_padding_modules = [] + # BitandBytes specific attributes + default_bitsandbytes_target_modules = [ + ".W_pack.", + ".o_proj.", + ".down_proj.", + ".up_proj.", + ".gate_proj.", + ".up_proj.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + def __init__( self, *, diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index a50a5a5b018e1..3776490cb3465 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -126,6 +126,7 @@ def __init__( config: Optional[BartConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model @@ -178,7 +179,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: @@ -208,6 +210,7 @@ def __init__( config: Optional[BartConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model @@ -260,7 +263,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward(self, hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata) -> torch.Tensor: @@ -290,6 +294,7 @@ def __init__( config: Optional[BartConfig] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model @@ -342,7 +347,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -384,6 +390,7 @@ def __init__( config: BartConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.embed_dim = config.d_model @@ -393,7 +400,9 @@ def __init__( num_heads=config.encoder_attention_heads, config=config, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) self.activation_fn = get_act_fn(config.activation_function) @@ -464,6 +473,7 @@ def __init__( config: BartConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.embed_dim = config.d_model @@ -473,7 +483,9 @@ def __init__( num_heads=config.decoder_attention_heads, config=config, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.activation_fn = get_act_fn(config.activation_function) self.self_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -486,6 +498,7 @@ def __init__( self.embed_dim, config.decoder_attention_heads, config=config, + prefix=f"{prefix}.encoder_attn", ) self.encoder_attn_layer_norm = nn.LayerNorm(self.embed_dim) @@ -578,7 +591,8 @@ def __init__(self, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, - embed_tokens: Optional[nn.Embedding] = None): + embed_tokens: Optional[nn.Embedding] = None, + prefix: str = ""): super().__init__() self.cache_config = cache_config @@ -599,9 +613,13 @@ def __init__(self, config.max_position_embeddings, embed_dim, ) - self.layers = nn.ModuleList( - [BartEncoderLayer(config,cache_config,quant_config) \ - for _ in range(config.encoder_layers)]) + self.layers = nn.ModuleList([ + BartEncoderLayer(config, + cache_config, + quant_config, + prefix=f"{prefix}.layers.{layer_idx}") + for layer_idx in range(config.encoder_layers) + ]) self.layernorm_embedding = nn.LayerNorm(embed_dim) @@ -661,6 +679,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, embed_tokens: Optional[nn.Embedding] = None, + prefix: str = "", ): super().__init__() self.cache_config = cache_config @@ -683,8 +702,9 @@ def __init__( ) self.layers = nn.ModuleList( - [BartDecoderLayer(config,cache_config,quant_config) \ - for _ in range(config.decoder_layers)]) + [BartDecoderLayer(config,cache_config,quant_config, + prefix=f"{prefix}.layers.{layer_idx}") \ + for layer_idx in range(config.decoder_layers)]) self.layernorm_embedding = nn.LayerNorm(config.d_model) @@ -759,10 +779,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.encoder = BartEncoder(config, cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.encoder") self.decoder = BartDecoder(config, cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.decoder") def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, encoder_input_ids: torch.Tensor, diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index d8301a36acb01..1fff72b3490e9 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -5,22 +5,26 @@ from transformers import BertConfig from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, PoolerConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, + PoolingType) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.transformers_utils.config import ( + get_cross_encoder_activation_function) -from .utils import maybe_prefix +from .interfaces import SupportsCrossEncoding +from .utils import WeightsMapper, maybe_prefix class BertEmbedding(nn.Module): @@ -48,7 +52,9 @@ def __init__(self, config: BertConfig): def forward( self, input_ids: torch.Tensor, - position_ids: Optional[torch.Tensor] = None, + seq_lens: torch.Tensor, + position_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_shape = input_ids.size() @@ -58,25 +64,42 @@ def forward( # Position embeddings. position_embeddings = self.position_embeddings(position_ids) - # Token type embeddings. (TODO: move off hotpath?) - token_type_embeddings = self.token_type_embeddings( - torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device)) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) + + token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings + position_embeddings embeddings = self.LayerNorm(embeddings) return embeddings +class BertPooler(nn.Module): + + def __init__(self, config: BertConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.activation = nn.Tanh() + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # We "pool" the model by simply taking the hidden state corresponding + # to the first token. + first_token_tensor = hidden_states[0, :] + pooled_output = self.dense(first_token_tensor) + pooled_output = self.activation(pooled_output) + return pooled_output + + +@support_torch_compile class BertEncoder(nn.Module): - def __init__(self, - config: BertConfig, - cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + def __init__(self, vllm_config: VllmConfig, prefix: str = ""): super().__init__() + config = vllm_config.model_config.hf_config + cache_config = vllm_config.cache_config + quant_config = vllm_config.quant_config self.layer = nn.ModuleList([ BertLayer(config=config, cache_config=cache_config, @@ -309,16 +332,14 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "", - embedding_class: type = BertEmbedding): + embedding_class: type = BertEmbedding, + add_pooling_layer: bool = False): super().__init__() config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config self.embeddings = embedding_class(config) - self.encoder = BertEncoder(config, - cache_config, - quant_config, + self.encoder = BertEncoder(vllm_config=vllm_config, prefix=f"{prefix}.encoder") + self.pooler = BertPooler(config) if add_pooling_layer else None def forward( self, @@ -328,13 +349,17 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = inputs_embeds else: - hidden_states = self.embeddings(input_ids=input_ids, - position_ids=position_ids) - + assert hasattr(attn_metadata, "seq_lens_tensor") + hidden_states = self.embeddings( + input_ids=input_ids, + seq_lens=attn_metadata.seq_lens_tensor, + position_ids=position_ids, + token_type_ids=token_type_ids) return self.encoder(hidden_states, kv_caches, attn_metadata) def load_weights(self, weights: Iterable[Tuple[str, @@ -349,7 +374,7 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: - if "pooler" in name: + if self.pooler is None and "pooler" in name: continue for (param_name, weight_name, shard_id) in stacked_params_mapping: if weight_name not in name: @@ -416,6 +441,8 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = hf_to_vllm_mapper.apply(weights) self.model.load_weights(weights) def _build_model(self, @@ -430,3 +457,78 @@ def _build_pooler(self, pooler_config: PoolerConfig) -> Pooler: pooling_type=PoolingType.CLS, normalize=True, softmax=False) + + +class BertForSequenceClassification(nn.Module, SupportsCrossEncoding): + """A model that uses Bert to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + model: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + + self.default_activation_function = \ + get_cross_encoder_activation_function(config) + + self.num_labels = config.num_labels + self.bert = BertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=BertEmbedding, + add_pooling_layer=True) + self.classifier = nn.Linear(config.hidden_size, config.num_labels) + self._pooler = CrossEncodingPooler(config, self.classifier, + self.bert.pooler) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + self_weights = [] + + def weight_filter(): + for name, weight in weights: + if name.startswith("bert."): + yield (name[len("bert."):], weight) + else: + self_weights.append((name, weight)) + + self.bert.load_weights(weight_filter()) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in self_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + + def forward( + self, + input_ids: Optional[torch.Tensor], + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + return self.bert(input_ids=input_ids, + position_ids=positions, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + attn_metadata=attn_metadata, + token_type_ids=token_type_ids) diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 1060d418474ef..fee74f491acc1 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -78,6 +78,7 @@ def __init__( config: BloomConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -116,7 +117,8 @@ def __init__( scaling, alibi_slopes=alibi_slopes, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -168,14 +170,17 @@ def __init__( config: BloomConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size self.input_layernorm = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.self_attention = BloomAttention(config, cache_config, - quant_config) + self.self_attention = BloomAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attention") self.post_attention_layernorm = nn.LayerNorm( hidden_size, eps=config.layer_norm_epsilon) self.mlp = BloomMLP(config, quant_config) @@ -242,7 +247,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Transformer blocks self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: BloomBlock(config, cache_config, quant_config), + lambda prefix: BloomBlock( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h") # Final Layer Norm diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 8f91abffaea90..5a6d6432112f0 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -223,6 +223,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, cache_config: Optional[CacheConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -276,7 +277,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: @@ -313,6 +315,7 @@ def __init__( config: ChameleonConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -336,6 +339,7 @@ def __init__( quant_config=quant_config, bias=False, cache_config=cache_config, + prefix=f"{prefix}.self_attn", ) self.mlp = ChameleonMLP( hidden_size=self.hidden_size, @@ -386,6 +390,7 @@ def __init__( config: ChameleonConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -409,6 +414,7 @@ def __init__( quant_config=quant_config, bias=False, cache_config=cache_config, + prefix=f"{prefix}.self_attn", ) self.mlp = ChameleonMLP( hidden_size=self.hidden_size, @@ -855,7 +861,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.num_hidden_layers, lambda prefix: decoder_layer(config=config, cache_config=cache_config, - quant_config=quant_config), + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.layers", ) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 2ea592aaba9f9..5bcbce7180ca4 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -230,6 +230,7 @@ def __init__( config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -285,7 +286,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -364,6 +366,7 @@ def __init__( config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.apply_residual_connection_post_layernorm = ( @@ -377,7 +380,10 @@ def __init__( eps=config.layernorm_epsilon) # Self attention. - self.self_attention = GLMAttention(config, cache_config, quant_config) + self.self_attention = GLMAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attention") self.hidden_dropout = config.hidden_dropout # Layernorm on the attention output @@ -446,7 +452,8 @@ def __init__( # Transformer layers. self.start_layer, self.end_layer, self.layers = make_layers( self.num_layers, - lambda prefix: GLMBlock(config, cache_config, quant_config), + lambda prefix: GLMBlock( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) @@ -500,16 +507,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.num_layers = config.num_layers self.multi_query_group_num = config.multi_query_group_num self.kv_channels = config.kv_channels - self.encoder = GLMTransformer(config, cache_config, quant_config) + self.encoder = GLMTransformer(config, + cache_config, + quant_config, + prefix=f"{prefix}.encoder") self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.output_layer") vision_config_flag = getattr(config, 'vision_config', None) if vision_config_flag is not None: self.vision_config = Namespace(**config.vision_config) - self.vision = EVA2CLIPModel(self.config, quant_config) + self.vision = EVA2CLIPModel(self.config, + quant_config, + prefix=f"{prefix}.vision") else: self.vision = None @@ -747,7 +760,7 @@ def __new__( config = vllm_config.model_config.hf_config # Initialize VL if hasattr(config, "visual"): - return ChatGLM(vllm_config=vllm_config, prefix=prefix) + return ChatGLMV(vllm_config=vllm_config, prefix=prefix) # Initialize LLM else: - return ChatGLMV(vllm_config=vllm_config, prefix=prefix) + return ChatGLM(vllm_config=vllm_config, prefix=prefix) diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py index 7f638506f9fb2..cd89519e95986 100644 --- a/vllm/model_executor/models/clip.py +++ b/vllm/model_executor/models/clip.py @@ -21,7 +21,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges, - repeat_and_pad_placeholder_tokens) + repeat_and_pad_placeholder_tokens, + resolve_visual_encoder_outputs) from vllm.sequence import SequenceData from .utils import get_vit_attn_backend @@ -389,12 +390,20 @@ def __init__( for layer_idx in range(num_hidden_layers) ]) - def forward(self, inputs_embeds: torch.Tensor): - + def forward( + self, inputs_embeds: torch.Tensor, return_all_hidden_states: bool + ) -> Union[torch.Tensor, list[torch.Tensor]]: + hidden_states_pool = [] hidden_states = inputs_embeds + for encoder_layer in self.layers: hidden_states = encoder_layer(hidden_states) - + if return_all_hidden_states: + hidden_states_pool.append(hidden_states) + # If we have multiple feature sample layers, we return all hidden + # states in order and grab the ones we need by index. + if return_all_hidden_states: + return hidden_states_pool return hidden_states @@ -419,6 +428,7 @@ def __init__( # NOTE: This typo of "layrnorm" is not fixed on purpose to match # the original transformers code and name of the model weights. self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps) + self.encoder = CLIPEncoder( config=config, quant_config=quant_config, @@ -446,16 +456,26 @@ def __init__( def forward( self, pixel_values: torch.Tensor, + feature_sample_layers: Optional[list[int]] = None, ) -> torch.Tensor: hidden_states = self.embeddings(pixel_values) hidden_states = self.pre_layrnorm(hidden_states) - hidden_states = self.encoder(inputs_embeds=hidden_states) - if self.post_layernorm is None: - return hidden_states + return_all_hidden_states = feature_sample_layers is not None + + # Produces either the last layer output or all of the hidden states, + # depending on if we have feature_sample_layers or not + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=return_all_hidden_states) + + # Handle post-norm (if applicable) and stacks feature layers if needed + encoder_outputs = resolve_visual_encoder_outputs( + encoder_outputs, feature_sample_layers, self.post_layernorm, + self.config.num_hidden_layers) - return self.post_layernorm(hidden_states) + return encoder_outputs class CLIPVisionModel(nn.Module): @@ -478,11 +498,14 @@ def __init__( quant_config=quant_config, num_hidden_layers_override=num_hidden_layers_override, require_post_norm=require_post_norm, - prefix=f"{prefix}.vision_model", - ) + prefix=f"{prefix}.vision_model") - def forward(self, pixel_values: torch.Tensor) -> torch.Tensor: - return self.vision_model(pixel_values) + def forward( + self, + pixel_values: torch.Tensor, + feature_sample_layers: Optional[list[int]] = None, + ) -> torch.Tensor: + return self.vision_model(pixel_values, feature_sample_layers) @property def device(self): diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 9fd083e5a02a9..85e24ca660686 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -120,6 +120,7 @@ def __init__( config: CohereConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() tp_size = get_tensor_model_parallel_world_size() @@ -175,7 +176,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") if self.use_qk_norm: self.q_norm = LayerNorm(param_shape=(self.num_heads, self.head_dim), @@ -215,13 +217,15 @@ class CohereDecoderLayer(nn.Module): def __init__(self, config: CohereConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.hidden_size = config.hidden_size self.self_attn = CohereAttention(config, cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn") self.mlp = CohereMLP(config, quant_config=quant_config) self.input_layernorm = LayerNorm(param_shape=(config.hidden_size), @@ -271,8 +275,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: CohereDecoderLayer(config, cache_config, - quant_config), + lambda prefix: CohereDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = LayerNorm(param_shape=(config.hidden_size), eps=config.layer_norm_eps) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index eab338800249e..3932d8b52a9d1 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -154,6 +154,7 @@ def __init__( config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model @@ -208,7 +209,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -234,10 +236,14 @@ def __init__( config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model - self.attn = DbrxAttention(config, cache_config, quant_config) + self.attn = DbrxAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") self.norm_1 = nn.LayerNorm(self.d_model) self.norm_2 = nn.LayerNorm(self.d_model) @@ -269,10 +275,14 @@ def __init__( config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() - self.norm_attn_norm = DbrxFusedNormAttention(config, cache_config, - quant_config) + self.norm_attn_norm = DbrxFusedNormAttention( + config, + cache_config, + quant_config, + prefix=f"{prefix}.norm_attn_norm") self.ffn = DbrxMoE(config, quant_config) def forward( @@ -308,7 +318,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.blocks = make_layers( config.n_layers, - lambda prefix: DbrxBlock(config, cache_config, quant_config), + lambda prefix: DbrxBlock( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.blocks", ) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 8c5ad9904e925..74b6bfdf21909 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -49,7 +49,7 @@ from vllm.sequence import IntermediateTensors from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -63,6 +63,7 @@ def __init__( hidden_act: str, quant_config: Optional[QuantizationConfig] = None, reduce_results: bool = True, + prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( @@ -92,6 +93,7 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -184,6 +186,7 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -236,7 +239,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -258,11 +262,12 @@ class DeepseekDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() + layer_idx = extract_layer_index(prefix) self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) @@ -277,17 +282,21 @@ def __init__( max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) if (config.n_routed_experts is not None and layer_idx >= config.first_k_dense_replace and layer_idx % config.moe_layer_freq == 0): - self.mlp = DeepseekMoE(config=config, quant_config=quant_config) + self.mlp = DeepseekMoE(config=config, + quant_config=quant_config, + prefix=f"{prefix}.mlp") else: self.mlp = DeepseekMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, + prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -343,10 +352,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: DeepseekDecoderLayer(config, - int(prefix.split(".")[-1]), - cache_config, - quant_config=quant_config), + lambda prefix: DeepseekDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index d2c4ca0bf85e9..4cf4e6c358bf2 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -268,7 +268,8 @@ def __init__( self.scaling, num_kv_heads=self.num_local_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 9d739d0479548..5ca26d53a17e7 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -174,6 +174,7 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.attn", ) def forward( @@ -219,7 +220,7 @@ def __init__( quant_config=quant_config, bias=bias, cache_config=cache_config, - prefix=prefix, + prefix=f"{prefix}.attention", ) def forward( diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 2aa4b67d99894..096ad32b38e86 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -84,6 +84,7 @@ def __init__( config: FalconConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() @@ -158,7 +159,8 @@ def __init__( self.head_dim, self.inv_norm_factor, num_kv_heads=self.num_kv_heads, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") elif self.use_alibi: tp_rank = get_tensor_model_parallel_rank() head_start = tp_rank * self.num_heads @@ -171,14 +173,16 @@ def __init__( self.inv_norm_factor, num_kv_heads=self.num_kv_heads, alibi_slopes=alibi_slopes, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") else: self.attn = Attention(self.num_heads, self.head_dim, scale=self.inv_norm_factor, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -241,12 +245,16 @@ def __init__( config: FalconConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size self.num_heads = config.num_attention_heads - self.self_attention = FalconAttention(config, cache_config, - quant_config) + self.self_attention = FalconAttention( + config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attention") self.mlp = FalconMLP(config, quant_config) self.config = config @@ -357,8 +365,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # Transformer blocks self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: FalconDecoderLayer(config, cache_config, - quant_config), + lambda prefix: FalconDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h") # Final Layer Norm diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py index d3a9ff6915b84..3a5fe8e1f4144 100644 --- a/vllm/model_executor/models/florence2.py +++ b/vllm/model_executor/models/florence2.py @@ -35,10 +35,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.shared = BartScaledWordEmbedding(self.vocab_size, config.d_model) self.encoder = BartEncoder(config, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.encoder") self.decoder = BartDecoder(config, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.decoder") if self.config.tie_word_embeddings: self.encoder.embed_tokens.weight = self.shared.weight @@ -99,7 +101,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.config = config self.model = Florence2LanguageModel(vllm_config=vllm_config, - prefix=prefix) + prefix=f"{prefix}.model") embed_scale = math.sqrt( config.d_model) if config.scale_embedding else 1.0 @@ -198,7 +200,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): # TODO(Isotr0py): Add vision backbone self.language_model = Florence2LanguageForConditionalGeneration( vllm_config=vllm_config.with_hf_config(config.text_config), - prefix=prefix, + prefix=f"{prefix}.language_model", ) @property diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 64e03b30bf2f1..131e9af139c2a 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -174,7 +174,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 4ba39223cc07f..d229eb74669ee 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -42,7 +42,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -85,7 +86,6 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class Gemma2Attention(nn.Module): def __init__(self, - layer_idx: int, config: Gemma2Config, hidden_size: int, num_heads: int, @@ -95,9 +95,9 @@ def __init__(self, rope_theta: float, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, - attn_logits_soft_cap: Optional[float] = None) -> None: + attn_logits_soft_cap: Optional[float] = None, + prefix: str = "") -> None: super().__init__() - self.layer_idx = layer_idx self.config = config self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() @@ -142,19 +142,22 @@ def __init__(self, is_neox_style=True, ) - # FIXME(woosuk): While Gemma 2 uses sliding window attention for every - # odd layer, vLLM currently ignores it and uses global attention for - # all layers. - use_sliding_window = (layer_idx % 2 == 1 - and config.sliding_window is not None) - del use_sliding_window # Unused. + # reference: + # https://github.com/huggingface/transformers/blob/54be2d7ae87e873482b984cc956e165ca4dc0ba3/src/transformers/models/gemma2/modeling_gemma2.py#L312 # noqa + layer_idx = extract_layer_index(prefix) + use_sliding_window = (layer_idx % 2 == 0 and + config.interleaved_sliding_window is not None) + sliding_window = config.interleaved_sliding_window if \ + use_sliding_window else None self.attn = Attention(self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, - logits_soft_cap=attn_logits_soft_cap) + logits_soft_cap=attn_logits_soft_cap, + per_layer_sliding_window=sliding_window, + prefix=f"{prefix}.attn") def forward( self, @@ -175,15 +178,14 @@ class Gemma2DecoderLayer(nn.Module): def __init__( self, - layer_idx: int, config: Gemma2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size self.self_attn = Gemma2Attention( - layer_idx=layer_idx, config=config, hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -194,6 +196,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config, attn_logits_soft_cap=config.attn_logit_softcapping, + prefix=f"{prefix}.self_attn", ) self.hidden_size = config.hidden_size self.mlp = Gemma2MLP( @@ -257,8 +260,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[ - -1]), config, cache_config, quant_config), + lambda prefix: Gemma2DecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -508,4 +511,6 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = hf_to_vllm_mapper.apply(weights) self.model.load_weights(weights) diff --git a/vllm/model_executor/models/glm4_vision_encoder.py b/vllm/model_executor/models/glm4_vision_encoder.py index 025615b0920fd..f37ab0f82d52a 100644 --- a/vllm/model_executor/models/glm4_vision_encoder.py +++ b/vllm/model_executor/models/glm4_vision_encoder.py @@ -56,6 +56,7 @@ def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', ): super().__init__() self.hidden_size = config.hidden_size @@ -135,11 +136,14 @@ def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', ): super().__init__() self.input_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = Attention(config, quant_config=quant_config) + self.attention = Attention(config, + quant_config=quant_config, + prefix=f"{prefix}.attention") self.mlp = MLP(config, quant_config=quant_config) self.post_attention_layernorm = LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -161,11 +165,14 @@ def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', ): super().__init__() self.layers = nn.ModuleList([ - TransformerLayer(config, quant_config=quant_config) - for _ in range(config.num_hidden_layers) + TransformerLayer(config, + quant_config=quant_config, + prefix=f"{prefix}.layer.{layer_idx}") + for layer_idx in range(config.num_hidden_layers) ]) def forward(self, hidden_states): @@ -252,12 +259,14 @@ def __init__( self, config, quant_config: Optional[QuantizationConfig] = None, + prefix: str = '', ): super().__init__() vision_config = Namespace(**config.vision_config) self.patch_embedding = PatchEmbedding(vision_config) self.transformer = Transformer(vision_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.transformer") self.linear_proj = GLU(config, in_features=config.hidden_size, quant_config=quant_config) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 1c61408ae1dd9..fd926ff0254d4 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -84,7 +84,8 @@ def __init__( self.head_dim, scale=self.scale, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index 50a143cb1b600..c64bc70688806 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -52,6 +52,7 @@ def __init__( config: GPTBigCodeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -92,7 +93,8 @@ def __init__( scale=self.scale, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -151,6 +153,7 @@ def __init__( config: GPTBigCodeConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size @@ -158,7 +161,10 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = GPTBigCodeAttention(config, cache_config, quant_config) + self.attn = GPTBigCodeAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = GPTBigMLP(inner_dim, config, quant_config) @@ -210,7 +216,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config), + lambda prefix: GPTBigCodeBlock( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h", ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index d5defc60764e6..4829578a56959 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -53,6 +53,7 @@ def __init__( config: GPTJConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.total_num_heads = config.num_attention_heads @@ -94,7 +95,8 @@ def __init__( self.head_size, scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -147,12 +149,16 @@ def __init__( config: GPTJConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() inner_dim = (4 * config.n_embd if config.n_inner is None else config.n_inner) self.ln_1 = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon) - self.attn = GPTJAttention(config, cache_config, quant_config) + self.attn = GPTJAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") self.mlp = GPTJMLP(inner_dim, config, quant_config) def forward( @@ -193,7 +199,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.h = make_layers( config.n_layer, - lambda prefix: GPTJBlock(config, cache_config, quant_config), + lambda prefix: GPTJBlock( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h", ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 0bb5e2f9b95f9..731642772011c 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -52,6 +52,7 @@ def __init__( config: GPTNeoXConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.total_num_heads = config.num_attention_heads @@ -94,7 +95,8 @@ def __init__( self.head_size, scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -145,6 +147,7 @@ def __init__( config: GPTNeoXConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.use_parallel_residual = config.use_parallel_residual @@ -152,7 +155,10 @@ def __init__( eps=config.layer_norm_eps) self.post_attention_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.attention = GPTNeoXAttention(config, cache_config, quant_config) + self.attention = GPTNeoXAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attention") self.mlp = GPTNeoXMLP(config, quant_config) def forward( @@ -205,7 +211,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: GPTNeoXLayer(config, cache_config, quant_config), + lambda prefix: GPTNeoXLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) self.final_layer_norm = nn.LayerNorm(config.hidden_size, diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index c1e2e87f08ec3..bd2394e71c973 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -161,7 +161,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py index a91a18816995f..51296ef0cc08e 100644 --- a/vllm/model_executor/models/granitemoe.py +++ b/vllm/model_executor/models/granitemoe.py @@ -164,7 +164,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index dcead65115132..4f0c75b2c6a57 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -7,6 +7,8 @@ from vllm.logger import init_logger from vllm.utils import supports_kw +from .interfaces_base import is_embedding_model + if TYPE_CHECKING: from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.sequence import IntermediateTensors @@ -350,3 +352,37 @@ def is_attention_free( return isinstance(model, _IsAttentionFreeType) return isinstance(model, IsAttentionFree) + + +@runtime_checkable +class SupportsCrossEncoding(Protocol): + """The interface required for all models that support cross encoding.""" + + supports_cross_encoding: ClassVar[Literal[True]] = True + + +@overload +def supports_cross_encoding( + model: Type[object]) -> TypeIs[Type[SupportsCrossEncoding]]: + ... + + +@overload +def supports_cross_encoding(model: object) -> TypeIs[SupportsCrossEncoding]: + ... + + +def _supports_cross_encoding( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + + if isinstance(model, type): + return isinstance(model, SupportsCrossEncoding) + + return isinstance(model, SupportsCrossEncoding) + + +def supports_cross_encoding( + model: Union[Type[object], object], +) -> Union[TypeIs[Type[SupportsCrossEncoding]], TypeIs[SupportsCrossEncoding]]: + return is_embedding_model(model) and _supports_cross_encoding(model) diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 94b819b5d9366..906128940ff76 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -1,5 +1,5 @@ from functools import partial -from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union +from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Type, Union import torch from torch import nn @@ -250,7 +250,12 @@ def forward( @support_torch_compile class InternLM2Model(nn.Module): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__( + self, + *, + vllm_config: VllmConfig, + prefix: str = "", + layer_type: Type[InternLMDecoderLayer] = InternLMDecoderLayer): super().__init__() config = vllm_config.model_config.hf_config @@ -266,7 +271,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: InternLMDecoderLayer( + lambda prefix: layer_type( config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -316,14 +321,18 @@ def forward( class InternLM2ForCausalLM(nn.Module, SupportsPP): - def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + def __init__(self, + *, + vllm_config: VllmConfig, + prefix: str = "", + model_type: Type[InternLM2Model] = InternLM2Model): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config - self.model = InternLM2Model(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + self.model = model_type(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "model")) self.output = ParallelLMHead(config.vocab_size, config.hidden_size, quant_config=quant_config, diff --git a/vllm/model_executor/models/internlm2_ve.py b/vllm/model_executor/models/internlm2_ve.py index f1b7c896cadfe..93ac2dcf8d587 100644 --- a/vllm/model_executor/models/internlm2_ve.py +++ b/vllm/model_executor/models/internlm2_ve.py @@ -14,8 +14,6 @@ InternLM2MLP, InternLM2Model) from vllm.sequence import IntermediateTensors -from .utils import make_layers, maybe_prefix - class InternLM2VEDecoderLayer(nn.Module): @@ -105,17 +103,9 @@ def forward( class InternLM2VEModel(InternLM2Model): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - config = vllm_config.model_config.hf_config - cache_config = vllm_config.cache_config - quant_config = vllm_config.quant_config - - self.start_layer, self.end_layer, self.layers = make_layers( - config.num_hidden_layers, - lambda prefix: InternLM2VEDecoderLayer( - config, cache_config, quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + super().__init__(vllm_config=vllm_config, + prefix=prefix, + layer_type=InternLM2VEDecoderLayer) def forward( self, @@ -159,7 +149,6 @@ def forward( class InternLM2VEForCausalLM(InternLM2ForCausalLM): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): - super().__init__(vllm_config=vllm_config, prefix=prefix) - - self.model = InternLM2VEModel(vllm_config=vllm_config, - prefix=maybe_prefix(prefix, "model")) + super().__init__(vllm_config=vllm_config, + prefix=prefix, + model_type=InternLM2VEModel) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 7ea2f9be2191d..47ac00b6afe9b 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -19,8 +19,8 @@ from vllm.config import VllmConfig from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext, token_inputs) -from vllm.model_executor.layers.quantization import (AWQConfig, - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.awq import AWQConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.models.intern_vit import (InternVisionModel, InternVisionPatchModel) @@ -123,8 +123,15 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, return blocks, target_width, target_height -def calculate_num_blocks_wrapper(hf_config: PretrainedConfig, - max_dynamic_patch: Optional[int] = None): +def calculate_num_blocks_wrapper( + hf_config: PretrainedConfig, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, +): + if dynamic_image_size is None: + dynamic_image_size = hf_config.dynamic_image_size + + max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch min_num = hf_config.min_dynamic_patch @@ -183,10 +190,17 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int, return pixel_values -def image_to_pixel_values_wrapper(hf_config: PretrainedConfig, - max_dynamic_patch: Optional[int] = None): +def image_to_pixel_values_wrapper( + hf_config: PretrainedConfig, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, +): image_size = hf_config.vision_config.image_size min_num = hf_config.min_dynamic_patch + if dynamic_image_size is None: + dynamic_image_size = hf_config.dynamic_image_size + + max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch use_thumbnail = hf_config.use_thumbnail @@ -207,11 +221,17 @@ def get_internvl_num_patches(hf_config: PretrainedConfig): (downsample_ratio**2)) -def get_max_internvl_image_tokens(ctx: InputContext, - *, - max_dynamic_patch: Optional[int] = None): +def get_max_internvl_image_tokens( + ctx: InputContext, + *, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, +): hf_config = ctx.get_hf_config() + if dynamic_image_size is None: + dynamic_image_size = hf_config.dynamic_image_size + max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch use_thumbnail = hf_config.use_thumbnail @@ -222,12 +242,18 @@ def get_max_internvl_image_tokens(ctx: InputContext, return num_patches * max_dynamic_patch -def get_max_internvl_image_size(ctx: InputContext, - *, - max_dynamic_patch: Optional[int] = None): +def get_max_internvl_image_size( + ctx: InputContext, + *, + max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, +): hf_config = ctx.get_hf_config() image_size = hf_config.vision_config.image_size + if dynamic_image_size is None: + dynamic_image_size = hf_config.dynamic_image_size + max_dynamic_patch = max_dynamic_patch if dynamic_image_size else 1 if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch use_thumbnail = hf_config.use_thumbnail @@ -281,6 +307,7 @@ def input_processor( inputs: DecoderOnlyInputs, *, max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, ) -> DecoderOnlyInputs: multi_modal_data = inputs.get("multi_modal_data") if multi_modal_data is None or "image" not in multi_modal_data: @@ -292,7 +319,7 @@ def input_processor( image_data = multi_modal_data["image"] num_patches = get_internvl_num_patches(hf_config) num_blocks_calculator = calculate_num_blocks_wrapper( - hf_config, max_dynamic_patch) + hf_config, max_dynamic_patch, dynamic_image_size) if isinstance(image_data, Image.Image): width, height = image_data.size num_blocks, _, _ = num_blocks_calculator(width, height) @@ -332,11 +359,12 @@ def input_mapper( data: object, *, max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, ): hf_config = ctx.get_hf_config() image_pixel_values_mapper = image_to_pixel_values_wrapper( - hf_config, max_dynamic_patch) + hf_config, max_dynamic_patch, dynamic_image_size) if isinstance(data, Image.Image): data = image_pixel_values_mapper(data) # Add an N dimension for number of images per prompt (currently 1). @@ -366,13 +394,17 @@ def dummy_data( mm_counts: Mapping[str, int], *, max_dynamic_patch: Optional[int] = None, + dynamic_image_size: Optional[bool] = None, ): num_images = mm_counts["image"] hf_config = ctx.get_hf_config() image_feature_size = get_max_internvl_image_tokens( - ctx, max_dynamic_patch=max_dynamic_patch) + ctx, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) model_config = ctx.model_config tokenizer = cached_get_tokenizer( model_config.tokenizer, @@ -388,7 +420,10 @@ def dummy_data( ) max_image_width, max_image_height = get_max_internvl_image_size( - ctx, max_dynamic_patch=max_dynamic_patch) + ctx, + max_dynamic_patch=max_dynamic_patch, + dynamic_image_size=dynamic_image_size, + ) mm_data = dummy_image_for_clip( hf_config.vision_config, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 41db85b678456..8c81dff6b5768 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -76,6 +76,7 @@ def __init__( config: JAISConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size @@ -114,7 +115,8 @@ def __init__( scale=self.scale, alibi_slopes=alibi_slopes, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -178,6 +180,7 @@ def __init__( config: JAISConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.hidden_size @@ -185,7 +188,10 @@ def __init__( hidden_size) self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) - self.attn = JAISAttention(config, cache_config, quant_config) + self.attn = JAISAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon) self.mlp = JAISMLP(inner_dim, config, quant_config) @@ -241,7 +247,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.num_hidden_layers, lambda prefix: JAISBlock(config=config, cache_config=cache_config, - quant_config=quant_config), + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.h", ) diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index f83f0fce7275f..099ca7e12b288 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -102,7 +102,8 @@ def __init__(self, config: JambaConfig, layer_idx: int, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__() self.config = config self.mamba = MambaMixer(hidden_size= config.hidden_size, @@ -157,6 +158,7 @@ def __init__( layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -198,6 +200,7 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, + prefix=f"{prefix}.attn", ) num_experts = config.layers_num_experts[layer_idx] @@ -287,7 +290,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): layer_class(config, layer_idx=i, cache_config=cache_config, - quant_config=quant_config)) + quant_config=quant_config, + prefix=f"{prefix}.layers.{i}")) self.layers = nn.ModuleList(decoder_layers) self.final_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 2b40e9ec73fad..33d78d74129c8 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -53,7 +53,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -174,6 +175,7 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.attn", ) def forward( @@ -688,6 +690,8 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = hf_to_vllm_mapper.apply(weights) self.model.load_weights(weights) def load_kv_cache_scales(self, quantization_param_path: str) -> None: diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index e7d3161a7cb2d..05c6cc62efcd7 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -204,7 +204,41 @@ def input_processor_for_llava(ctx: InputContext, inputs: DecoderOnlyInputs): class LlavaLikeConfig(Protocol): vision_config: PretrainedConfig - vision_feature_layer: int + vision_feature_layer: Union[int, List[int]] + + +def _get_num_hidden_layers(hf_config: LlavaLikeConfig) -> int: + """Determine the number of hidden layers to initialize up to in the + visual encoder. + + Args: + hf_config: Model config with vision feature layer(s). + """ + feature_layers = hf_config.vision_feature_layer + num_hidden_layers = hf_config.vision_config.num_hidden_layers + # If we have one feature layer, initialize up to that layer + if isinstance(feature_layers, int): + return _get_layer_index(feature_layers, num_hidden_layers) + # If we have multiple feature layers, initialize up to the deepest one + elif isinstance(feature_layers, (list, tuple)): + return max( + _get_layer_index(idx, num_hidden_layers) for idx in feature_layers) + raise TypeError(f"vision_layer_feature type: {type(feature_layers)}" + " is not supported") + + +def _get_layer_index(feature_layer_index: int, num_hidden_layers: int) -> int: + """Given an signed vision feature layer, get the number of hidden layers + needed to leverage it. + + Args: + feature_layer_index: Index of a required layer in the visual encoder. + num_hidden_layers: The total number of hidden layers in the visual + encoder. + """ + if feature_layer_index < 0: + return num_hidden_layers + feature_layer_index + 1 + return feature_layer_index + 1 def init_vision_tower_for_llava( @@ -216,13 +250,8 @@ def init_vision_tower_for_llava( ): vision_config = hf_config.vision_config - # Initialize the vision tower only up to the required feature layer - vision_feature_layer = hf_config.vision_feature_layer - if vision_feature_layer < 0: - num_hidden_layers = hf_config.vision_config.num_hidden_layers \ - + vision_feature_layer + 1 - else: - num_hidden_layers = vision_feature_layer + 1 + # Initialize the vision tower only up to the deepest required feature layer + num_hidden_layers = _get_num_hidden_layers(hf_config) if isinstance(vision_config, CLIPVisionConfig): return CLIPVisionModel( diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 37e2227a52dcd..abeebb45fc4a7 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -288,6 +288,21 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: pooler_config = vllm_config.model_config.pooler_config multimodal_config = vllm_config.model_config.multimodal_config + vision_feature_layer = config.vision_feature_layer + # Determine the layer up to which we will initialize the vision tower + if isinstance(vision_feature_layer, int): + vision_hidden_size = config.vision_config.hidden_size + self.feature_sample_layers = None + # Used for multimodal granite models to control encoder outputs + elif isinstance(vision_feature_layer, (list, tuple)): + vision_hidden_size = config.vision_config.hidden_size * len( + vision_feature_layer) + self.feature_sample_layers = vision_feature_layer + else: + raise TypeError( + f"vision_layer_feature type: {type(vision_feature_layer)}" + " is not supported") + self.config = config self.multimodal_config = multimodal_config @@ -300,7 +315,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) self.multi_modal_projector = LlavaMultiModalProjector( - vision_hidden_size=config.vision_config.hidden_size, + vision_hidden_size=vision_hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) @@ -419,7 +434,8 @@ def _image_pixels_to_features( # NOTE: we skip the step to select the vision feature layer since # this is already done inside the vision tower - image_features = vision_tower(pixel_values) + image_features = vision_tower( + pixel_values, feature_sample_layers=self.feature_sample_layers) return self._select_image_features( image_features, diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index 405b8f7787ba8..ac0d265a961f0 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -1,5 +1,5 @@ """PyTorch MAMBA model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -243,10 +243,8 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters()) - loaded_params: Set[str] = set() for name, loaded_weight in weights: if "A_log" in name: name = name.replace("A_log", "A") @@ -258,5 +256,3 @@ def load_weights(self, weights: Iterable[Tuple[str, weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) - loaded_params.add(name) - return loaded_params diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index b4ed6538bddac..66bdcb89a0213 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -61,14 +61,25 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.truncated_vocab_size = config.truncated_vocab_size self.unpadded_vocab_size = self.truncated_vocab_size - self.lm_heads = nn.ModuleList([ - ParallelLMHead( + if getattr(config, "original_lm_head", False): + self.lm_head = ParallelLMHead( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=self.truncated_vocab_size, padding_size=DEFAULT_VOCAB_PADDING_SIZE, - ) for _ in range(self.config.num_heads) - ]) + ) + self.lm_heads = [ + self.lm_head for _ in range(self.config.num_heads) + ] + else: + self.lm_heads = nn.ModuleList([ + ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=self.truncated_vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE, + ) for _ in range(self.config.num_heads) + ]) logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, @@ -172,6 +183,9 @@ def load_weights(self, weights: Iterable[Tuple[str, requires_grad=False) elif name in params_dict: weights_map[name] = loaded_weight + elif (getattr(self.config, "original_lm_head", False) + and name == "lm_heads.0.weight"): + weights_map["lm_head.weight"] = loaded_weight for name, loaded_weight in weights_map.items(): if "lm_head" in name and self.token_map is not None and\ diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index b92bff4d7c28c..c9a573278a136 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -192,6 +192,7 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -246,7 +247,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -273,6 +275,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -283,6 +286,7 @@ def __init__( self.rope_scaling = getattr(config, "rope_scaling", None) self.max_position_embeddings = getattr(config, "max_position_embeddings", 8192) + self.prefix = prefix self._init_attn_block() self._init_ffn_block() @@ -298,6 +302,7 @@ def _init_attn_block(self): max_position_embeddings=self.max_position_embeddings, cache_config=self.cache_config, quant_config=self.quant_config, + prefix=f"{self.prefix}.self_attn", ) def _init_ffn_block(self): @@ -388,8 +393,8 @@ def _init_layers( ): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: MiniCPMDecoderLayer(config, cache_config, - quant_config), + lambda prefix: MiniCPMDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index 278c4bbe6e563..c38c31a0d4953 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -60,6 +60,7 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -119,7 +120,8 @@ def __init__( self.scaling, num_kv_heads=self.num_local_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -195,6 +197,7 @@ def _init_attn_block(self): max_position_embeddings=self.max_position_embeddings, cache_config=self.cache_config, quant_config=self.quant_config, + prefix=f"{self.prefix}.self_attn", ) @@ -209,8 +212,8 @@ def _init_layers( ): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: MiniCPM3DecoderLayer(config, cache_config, - quant_config), + lambda prefix: MiniCPM3DecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 0faffb4f1b00c..a5b364fe5ec85 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -166,7 +166,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index ddd6afcf6a1b6..7a9b8cd88cfd0 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -170,6 +170,7 @@ def __init__( rope_theta: float = 10000, quant_config: Optional[QuantizationConfig] = None, cache_config: Optional[CacheConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -219,7 +220,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -243,6 +245,7 @@ def __init__( config: MixtralConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -255,7 +258,9 @@ def __init__( num_kv_heads=config.num_key_value_heads, rope_theta=rope_theta, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + ) self.block_sparse_moe = MixtralMoE(config=config, quant_config=quant_config) self.input_layernorm = RMSNorm(config.hidden_size, @@ -311,7 +316,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: MixtralDecoderLayer( - config, cache_config, quant_config=quant_config), + config, cache_config, quant_config=quant_config, prefix=prefix + ), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 41f62b37f3bd9..9e6634a9a7579 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -32,9 +32,8 @@ import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType -from vllm.attention.backends.flash_attn import FlashAttentionMetadata -from vllm.attention.backends.xformers import XFormersMetadata from vllm.attention.ops.paged_attn import PagedAttention +from vllm.attention.selector import _Backend from vllm.config import VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import (INPUT_REGISTRY, DummyData, EncoderDecoderInputs, @@ -828,7 +827,8 @@ def _attention_with_mask( ) -> torch.Tensor: # Skip writing kv-cache for the initial profiling run. if len(kv_cache.shape) > 1: - if isinstance(attn_metadata, FlashAttentionMetadata): + if self.attn.backend in (_Backend.FLASH_ATTN, + _Backend.FLASH_ATTN_VLLM_V1): cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) torch.ops._C_cache_ops.reshape_and_cache_flash( @@ -842,7 +842,7 @@ def _attention_with_mask( 1.0, 1.0, ) - elif isinstance(attn_metadata, XFormersMetadata): + elif self.attn.backend in (_Backend.XFORMERS, _Backend.TORCH_SDPA): key_cache, value_cache = PagedAttention.split_kv_cache( kv_cache, self.num_local_key_value_heads, self.head_dim) cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) @@ -852,9 +852,9 @@ def _attention_with_mask( attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) else: raise ValueError( - f"Unsupported AttentionMetadata {type(attn_metadata)} " - f"class found. Expected the AttentionMetadata to " - f"be either XFormersMetadata or FlashAttentionMetadata.") + f"Unsupported Attention backend {self.attn.backend} " + "enum found. Expected the Attention backend to be " + "FLASH_ATTN, FLASH_ATTN_VLLM_V1, XFORMERS or TORCH_SDPA.") # We have to call torch.sdpa for prefill when using a # custom cross-attention mask. Because the mask is not a diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py index f2aa2653c4f5c..d49da5f29aa14 100644 --- a/vllm/model_executor/models/mlp_speculator.py +++ b/vllm/model_executor/models/mlp_speculator.py @@ -193,7 +193,8 @@ def load_weights(self, weights: Iterable[Tuple[str, params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: - param = params_dict.get(name.replace("speculator.", "")) + name = name.replace("speculator.", "") + param = params_dict.get(name) if param is not None: weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/molmo.py b/vllm/model_executor/models/molmo.py index 2528f741864b3..ee7b560fe1ee4 100644 --- a/vllm/model_executor/models/molmo.py +++ b/vllm/model_executor/models/molmo.py @@ -370,6 +370,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -427,7 +428,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") # Attention output projection. self.o_proj = RowParallelLinear( @@ -517,10 +519,14 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() # Attention block. - self.self_attn = MolmoAttention(config, cache_config, quant_config) + self.self_attn = MolmoAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attn") # MLP block. self.mlp = MolmoMLP(config, quant_config=quant_config) @@ -738,7 +744,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): else MolmoDecoderLayer self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: decoder_layer(config, cache_config, quant_config), + lambda prefix: decoder_layer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 8716e92b0f1c2..1235816413a44 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -50,6 +50,7 @@ def __init__( config: MPTConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.d_model = config.d_model @@ -115,7 +116,8 @@ def __init__( alibi_slopes=alibi_slopes, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -176,11 +178,15 @@ def __init__( config: MPTConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() hidden_size = config.d_model self.norm_1 = nn.LayerNorm(hidden_size) - self.attn = MPTAttention(config, cache_config, quant_config) + self.attn = MPTAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.attn") self.norm_2 = nn.LayerNorm(hidden_size) self.ffn = MPTMLP(config, quant_config) @@ -224,7 +230,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.blocks = make_layers( config.n_layers, - lambda prefix: MPTBlock(config, cache_config, quant_config), + lambda prefix: MPTBlock( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.blocks") self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index ceab299a7950a..c7b4c22b6896b 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -195,7 +195,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index dc138e2e636ad..538e31ec91699 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -62,6 +62,7 @@ def __init__( config: OlmoConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -101,7 +102,8 @@ def __init__( self.head_dim, scale=self.scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") # Attention output projection. self.o_proj = RowParallelLinear( @@ -184,10 +186,14 @@ class OlmoDecoderLayer(nn.Module): def __init__(self, config: OlmoConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() # Attention block. - self.self_attn = OlmoAttention(config, cache_config, quant_config) + self.self_attn = OlmoAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attn") # MLP block. self.mlp = OlmoMLP(config, quant_config) @@ -238,8 +244,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config - ), + lambda prefix: OlmoDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index ab87695d8e650..5d9091cfb9311 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -102,6 +102,7 @@ def __init__( max_position_embeddings: int = 4096, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -156,7 +157,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -179,9 +181,9 @@ class OlmoeDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -199,6 +201,7 @@ def __init__( max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) self.mlp = OlmoeMoE( @@ -260,8 +263,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: OlmoeDecoderLayer(config, int( - prefix.split(".")[-1]), cache_config, quant_config), + lambda prefix: OlmoeDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=1e-5) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index b01734af8ddd8..a3757b5c8808e 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -75,6 +75,7 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -126,7 +127,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -150,6 +152,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -166,6 +169,7 @@ def __init__( max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) self.mlp = OrionMLP( hidden_size=self.hidden_size, @@ -226,10 +230,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: OrionDecoderLayer( - config, - cache_config, - quant_config, - ), + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = ( diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index 3b8199f4f1661..14dd4b5b1b4da 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -75,7 +75,8 @@ class PersimmonAttention(nn.Module): def __init__(self, config: PersimmonConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.config = config tensor_parallel_world_size = get_tensor_model_parallel_world_size() @@ -122,7 +123,8 @@ def __init__(self, self.head_dim, scale=self.scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def _split_heads(self, x: torch.Tensor) -> torch.Tensor: # [seq_length, hidden_size] -> [seq_length, num_heads, head_dim] @@ -167,12 +169,14 @@ class PersimmonDecoderLayer(nn.Module): def __init__(self, config: PersimmonConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.hidden_size = config.hidden_size self.self_attn = PersimmonAttention(config=config, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn") self.mlp = PersimmonMLP(config, quant_config=quant_config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) @@ -226,8 +230,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: PersimmonDecoderLayer(config, cache_config, - quant_config), + lambda prefix: PersimmonDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 0a117bf16c9b3..998d3723a0d7d 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -69,7 +69,8 @@ class PhiAttention(nn.Module): def __init__(self, config: PhiConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.total_num_heads = config.num_attention_heads self.hidden_size = config.hidden_size @@ -116,7 +117,8 @@ def __init__(self, self.head_size, scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -167,11 +169,15 @@ class PhiLayer(nn.Module): def __init__(self, config: PhiConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) - self.self_attn = PhiAttention(config, cache_config, quant_config) + self.self_attn = PhiAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attn") self.mlp = PhiMLP(config, quant_config) def forward( @@ -210,7 +216,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.hidden_size) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: PhiLayer(config, cache_config, quant_config), + lambda prefix: PhiLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index f71cbd1264c45..da7e4cdbc6940 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -117,6 +117,7 @@ def __init__( layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.layer_idx = layer_idx @@ -214,15 +215,14 @@ def __init__( "homo_head": self.homo_heads } - self.attn = Attention( - self.num_heads_per_partition, - self.head_dim, - self.scale, - num_kv_heads=self.num_kv_heads_per_partion, - cache_config=cache_config, - quant_config=quant_config, - blocksparse_params=bs_params, - ) + self.attn = Attention(self.num_heads_per_partition, + self.head_dim, + self.scale, + num_kv_heads=self.num_kv_heads_per_partion, + cache_config=cache_config, + quant_config=quant_config, + blocksparse_params=bs_params, + prefix=f"{prefix}.attn") def forward( self, @@ -259,13 +259,15 @@ def __init__( layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Phi3SmallSelfAttention(config, layer_idx, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn") self.mlp = Phi3SmallMLP(config, quant_config) self.input_layernorm = nn.LayerNorm(config.hidden_size, @@ -315,7 +317,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): config.num_hidden_layers, lambda prefix: Phi3SmallDecoderLayer(config, int(prefix.split('.')[-1]), - cache_config, quant_config), + cache_config, + quant_config, + prefix=prefix), prefix=f"{prefix}.layers") self.final_layernorm = nn.LayerNorm(config.hidden_size, diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index e475d286bd7ea..1febd62f2f705 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -294,6 +294,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, rope_scaling: Optional[dict] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -347,6 +348,7 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.attn", ) def forward( @@ -371,6 +373,7 @@ def __init__( config: PhiMoEConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -385,6 +388,7 @@ def __init__( cache_config=cache_config, quant_config=quant_config, rope_scaling=config.rope_scaling, + prefix=f"{prefix}.self_attn", ) self.block_sparse_moe = PhiMoE( num_experts=config.num_local_experts, @@ -454,8 +458,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: PhiMoEDecoderLayer(config, cache_config, - quant_config), + lambda prefix: PhiMoEDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers") self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index d14b89d6b3f85..6711cbf5694b9 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -33,7 +33,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargs from vllm.multimodal.inputs import PlaceholderRange from vllm.multimodal.utils import (cached_get_tokenizer, - consecutive_placeholder_ranges) + consecutive_placeholder_ranges, + resolve_visual_encoder_outputs) from vllm.sequence import IntermediateTensors, SequenceData from vllm.transformers_utils.processor import cached_get_processor from vllm.utils import is_list_of @@ -970,9 +971,18 @@ def forward( x: torch.Tensor, attention_mask: torch.Tensor, position_embeddings: torch.Tensor, + return_all_hidden_states: bool, ) -> torch.Tensor: + hidden_states_pool = [] + for layer in self.layers: x = layer(x, attention_mask, position_embeddings) + if return_all_hidden_states: + hidden_states_pool.append(x) + # If we have multiple feature sample layers, we return all hidden + # states in order and grab the ones we need by index. + if return_all_hidden_states: + return hidden_states_pool return x @@ -990,6 +1000,7 @@ def __init__( super().__init__() self.config = config + self.patch_conv = nn.Conv2d( in_channels=config.num_channels, out_channels=config.hidden_size, @@ -1024,6 +1035,7 @@ def __init__( def forward( self, pixel_values: List[torch.Tensor], + feature_sample_layers: Optional[list[int]] = None, ) -> torch.Tensor: """ Args: @@ -1031,6 +1043,9 @@ def forward( in pixel_values. This means it will be a list of tensors because multiple requests batched can have multiple images, each with their own shape potentially + feature_sample_layers: Layer indices whose features should be + concatenated and used as the visual encoder output. If none + are provided, the last layer is used. Returns: image_features: tensor of token features for @@ -1065,8 +1080,15 @@ def forward( [p.shape[-2] * p.shape[-1] for p in patch_embeds_list], patch_embeds) - out = self.transformer(patch_embeds, attention_mask, - position_embedding) + return_all_hidden_states = feature_sample_layers is not None + out = self.transformer( + patch_embeds, + attention_mask, + position_embedding, + return_all_hidden_states=return_all_hidden_states) + + out = resolve_visual_encoder_outputs(out, feature_sample_layers, None, + self.config.num_hidden_layers) return out diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 44ce6eda42943..8f001200308fe 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -442,6 +442,7 @@ def __init__( rope_scaling: Optional[Dict[str, Any]] = None, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.hidden_size = hidden_size @@ -478,7 +479,8 @@ def __init__( self.head_dim, self.scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -502,6 +504,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.ln_1 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -514,7 +517,8 @@ def __init__( rope_theta=rope_theta, rope_scaling=rope_scaling, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") self.ln_2 = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) @@ -568,7 +572,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.h = make_layers( config.num_hidden_layers, - lambda prefix: QWenBlock(config, cache_config, quant_config), + lambda prefix: QWenBlock( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h") self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) self.make_empty_intermediate_tensors = ( @@ -1023,6 +1028,18 @@ class QWenLLM(QWenBaseModel): embedding_modules = {} embedding_padding_modules = [] + default_bitsandbytes_target_modules = [ + ".c_attn.", + ".c_proj.", + ".w1.", + ".w2.", + ] + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "w2": ("gate_up_proj", 0), + "w1": ("gate_up_proj", 1), + } + class QWenVL(QWenBaseModel, SupportsMultiModal): packed_modules_mapping = { diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 370cff5fa153f..46640226d4cf8 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -27,7 +27,7 @@ from torch import nn from transformers import Qwen2Config -from vllm.attention import Attention, AttentionMetadata +from vllm.attention import Attention, AttentionMetadata, AttentionType from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size @@ -50,7 +50,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -164,11 +165,17 @@ def forward( hidden_states: torch.Tensor, kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, + attn_type: str = AttentionType.DECODER, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) q, k = self.rotary_emb(positions, q, k) - attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + attn_output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=attn_type) output, _ = self.o_proj(attn_output) return output @@ -210,6 +217,15 @@ def __init__( self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + # By default, Qwen2 uses causal attention as it is a decoder-only model. + # You can override the HF config with `is_causal=False` to enable + # bidirectional attention, which is used in some embedding models + # (e.g. Alibaba-NLP/gte-Qwen2-7B-instruct) + if getattr(config, "is_causal", True): + self._attn_type = AttentionType.DECODER + else: + self._attn_type = AttentionType.ENCODER_ONLY + def forward( self, positions: torch.Tensor, @@ -230,6 +246,7 @@ def forward( hidden_states=hidden_states, kv_cache=kv_cache, attn_metadata=attn_metadata, + attn_type=self._attn_type, ) # Fully Connected @@ -569,8 +586,7 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader(self, - ignore_unexpected_prefixes=["lm_head."]) - return loader.load_weights(weights) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = hf_to_vllm_mapper.apply(weights) + self.model.load_weights(weights) diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py index a4965f34b1ca8..0c2374c3c3fc9 100644 --- a/vllm/model_executor/models/qwen2_audio.py +++ b/vllm/model_executor/models/qwen2_audio.py @@ -212,7 +212,7 @@ def input_processor_for_qwen2_audio( return token_inputs( prompt_token_ids=new_input_ids, - prompt=inputs['prompt'], + prompt=inputs.get("prompt"), multi_modal_data=multi_modal_data, ) diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index 96a9bc451f4df..ba70243c6533d 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -53,7 +53,7 @@ from vllm.utils import print_warning_once from .interfaces import SupportsPP -from .utils import (is_pp_missing_parameter, +from .utils import (extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -168,6 +168,7 @@ def __init__( max_position_embeddings: int = 8192, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -220,7 +221,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -242,9 +244,9 @@ class Qwen2MoeDecoderLayer(nn.Module): def __init__( self, config: PretrainedConfig, - layer_idx: int, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -261,10 +263,12 @@ def __init__( max_position_embeddings=max_position_embeddings, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.self_attn", ) # Note: Qwen/Qwen2-57B-A14B-Instruct does not have # `mlp_only_layers` in the config. + layer_idx = extract_layer_index(prefix) mlp_only_layers = ([] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers) if (layer_idx not in mlp_only_layers) and ( @@ -333,10 +337,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Qwen2MoeDecoderLayer(config=config, - layer_idx=int( - prefix.split(".")[-1]), cache_config=cache_config, - quant_config=quant_config), + quant_config=quant_config, + prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index 0ac81387b1bd8..531608a877f2f 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -51,9 +51,10 @@ RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.pooler import Pooler, PoolingType -from vllm.model_executor.layers.quantization import (GPTQConfig, - GPTQMarlinConfig, - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.quantization.gptq import GPTQConfig +from vllm.model_executor.layers.quantization.gptq_marlin import ( + GPTQMarlinConfig) from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py index 22c2e328bfb65..184f4b2bc1526 100644 --- a/vllm/model_executor/models/registry.py +++ b/vllm/model_executor/models/registry.py @@ -21,7 +21,8 @@ from vllm.platforms import current_platform from .interfaces import (has_inner_state, is_attention_free, - supports_multimodal, supports_pp) + supports_cross_encoding, supports_multimodal, + supports_pp) from .interfaces_base import is_embedding_model, is_text_generation_model logger = init_logger(__name__) @@ -100,6 +101,7 @@ # [Text-only] "BertModel": ("bert", "BertEmbeddingModel"), "RobertaModel": ("roberta", "RobertaEmbeddingModel"), + "RobertaForMaskedLM": ("roberta", "RobertaEmbeddingModel"), "XLMRobertaModel": ("roberta", "RobertaEmbeddingModel"), "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), "Gemma2Model": ("gemma2", "Gemma2EmbeddingModel"), @@ -121,8 +123,17 @@ "Qwen2VLForConditionalGeneration": ("qwen2_vl", "Qwen2VLForConditionalGeneration") # noqa: E501, } +_CROSS_ENCODER_MODELS = { + "BertForSequenceClassification": ("bert", "BertForSequenceClassification"), + "RobertaForSequenceClassification": ("roberta", + "RobertaForSequenceClassification"), + "XLMRobertaForSequenceClassification": ("roberta", + "RobertaForSequenceClassification"), +} + _MULTIMODAL_MODELS = { # [Decoder-only] + "AriaForConditionalGeneration": ("aria", "AriaForConditionalGeneration"), "Blip2ForConditionalGeneration": ("blip2", "Blip2ForConditionalGeneration"), "ChameleonForConditionalGeneration": ("chameleon", "ChameleonForConditionalGeneration"), # noqa: E501 "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), @@ -159,6 +170,7 @@ _VLLM_MODELS = { **_TEXT_GENERATION_MODELS, **_EMBEDDING_MODELS, + **_CROSS_ENCODER_MODELS, **_MULTIMODAL_MODELS, **_SPECULATIVE_DECODING_MODELS, } @@ -193,6 +205,7 @@ class _ModelInfo: is_text_generation_model: bool is_embedding_model: bool + supports_cross_encoding: bool supports_multimodal: bool supports_pp: bool has_inner_state: bool @@ -203,6 +216,7 @@ def from_model_cls(model: Type[nn.Module]) -> "_ModelInfo": return _ModelInfo( is_text_generation_model=is_text_generation_model(model), is_embedding_model=is_embedding_model(model), + supports_cross_encoding=supports_cross_encoding(model), supports_multimodal=supports_multimodal(model), supports_pp=supports_pp(model), has_inner_state=has_inner_state(model), @@ -415,6 +429,12 @@ def is_embedding_model( ) -> bool: return self.inspect_model_cls(architectures).is_embedding_model + def is_cross_encoder_model( + self, + architectures: Union[str, List[str]], + ) -> bool: + return self.inspect_model_cls(architectures).supports_cross_encoding + def is_multimodal_model( self, architectures: Union[str, List[str]], @@ -489,4 +509,4 @@ def _run() -> None: if __name__ == "__main__": - _run() \ No newline at end of file + _run() diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index c1dcdd36ec3de..ba1a78ac640fd 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import Iterable, List, Optional, Tuple import torch from torch import nn @@ -6,10 +6,18 @@ from vllm.attention import AttentionMetadata from vllm.config import VllmConfig +from vllm.model_executor.layers.pooler import CrossEncodingPooler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) +from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel -from vllm.sequence import IntermediateTensors +from vllm.model_executor.models.utils import maybe_prefix +from vllm.model_executor.pooling_metadata import PoolingMetadata +from vllm.sequence import IntermediateTensors, PoolerOutput +from vllm.transformers_utils.config import ( + get_cross_encoder_activation_function) + +from .interfaces import SupportsCrossEncoding class RobertaEmbedding(nn.Module): @@ -39,34 +47,93 @@ def __init__(self, config: RobertaConfig): def forward( self, input_ids: torch.Tensor, - position_ids: Optional[torch.Tensor] = None, + seq_lens: torch.Tensor, + position_ids: torch.Tensor, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: input_shape = input_ids.size() - - # Input embeddings. inputs_embeds = self.word_embeddings(input_ids) - # TODO: figure out if there is a better way - # to make to make position ids start at padding_idx + 1 + # Replace position ids because in RoBERTa models + # they have to start at padding_idx + 1 and ignore + # existing padding tokens # References: # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L133 # - https://github.com/huggingface/transformers/blob/a3d69a8994d673899608a7c17fbf4f953f50474e/src/transformers/models/roberta/modeling_roberta.py#L1669 - position_ids += self.padding_idx + 1 + pos_list = [] + token_list = [] + offset = 0 + for seq_len in seq_lens: + pos_list.append(position_ids[offset:offset + seq_len]) + token_list.append(input_ids[offset:offset + seq_len]) + offset += seq_len + + new_pos_list = [] + for positions, tokens in zip(pos_list, token_list): + # Verify assumption that incoming position are + # always a sequence from 0 to N. + expected_pos = torch.arange(positions.size()[0], + dtype=torch.long, + device=inputs_embeds.device) + assert torch.equal(positions, expected_pos) + new_pos_list.append( + create_position_ids_from_input_ids(tokens, self.padding_idx)) + position_ids = torch.cat(new_pos_list) # Position embeddings. position_embeddings = self.position_embeddings(position_ids) + if token_type_ids is None: + token_type_ids = torch.zeros(input_shape, + dtype=torch.long, + device=inputs_embeds.device) - # Token type embeddings. (TODO: move off hotpath?) - token_type_embeddings = self.token_type_embeddings( - torch.zeros(input_shape, - dtype=torch.long, - device=inputs_embeds.device)) - + token_type_embeddings = self.token_type_embeddings(token_type_ids) embeddings = inputs_embeds + token_type_embeddings + position_embeddings embeddings = self.LayerNorm(embeddings) return embeddings +# Adapted from transformers +def create_position_ids_from_input_ids(input_ids, + padding_idx, + past_key_values_length=0): + """ + Replace non-padding symbols with their position numbers. + Position numbers begin at padding_idx+1. Padding symbols + are ignored. This is modified from fairseq's `utils.make_positions`. + + Args: + x: torch.Tensor x: + + Returns: torch.Tensor + """ + # The series of casts and type-conversions here are carefully + # balanced to both work with ONNX export and XLA. + mask = input_ids.ne(padding_idx).int() + + incremental_indices = (torch.cumsum(mask, dim=0).type_as(mask) + + past_key_values_length) * mask + + return incremental_indices.long() + padding_idx + + +# Adapted from transformers +class RobertaClassificationHead(nn.Module): + """Head for sentence-level classification tasks.""" + + def __init__(self, config: RobertaConfig): + super().__init__() + self.dense = nn.Linear(config.hidden_size, config.hidden_size) + self.out_proj = nn.Linear(config.hidden_size, config.num_labels) + + def forward(self, features, **kwargs): + x = features[0, :] # take token (equiv. to [CLS]) + x = self.dense(x) + x = torch.tanh(x) + x = self.out_proj(x) + return x + + class RobertaEmbeddingModel(BertEmbeddingModel): """A model that uses Roberta to provide embedding functionalities. @@ -85,6 +152,62 @@ def _build_model(self, prefix=prefix, embedding_class=RobertaEmbedding) + +class RobertaForSequenceClassification(nn.Module, SupportsCrossEncoding): + """A model that uses Roberta to provide embedding functionalities. + + This class encapsulates the BertModel and provides an interface for + embedding operations and customized pooling functions. + + Attributes: + roberta: An instance of BertModel used for forward operations. + _pooler: An instance of Pooler used for pooling operations. + """ + + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): + super().__init__() + config = vllm_config.model_config.hf_config + + self.default_activation_function = \ + get_cross_encoder_activation_function(config) + + self.num_labels = config.num_labels + self.roberta = BertModel(vllm_config=vllm_config, + prefix=maybe_prefix(prefix, "bert"), + embedding_class=RobertaEmbedding, + add_pooling_layer=False) + self.classifier = RobertaClassificationHead(config) + self._pooler = CrossEncodingPooler(config, self.classifier) + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + + self_weights = [] + + def weight_filter(): + for name, weight in weights: + if name.startswith("roberta."): + yield (name[len("roberta."):], weight) + else: + self_weights.append((name, weight)) + + self.roberta.load_weights(weight_filter()) + + params_dict = dict(self.named_parameters()) + + for name, loaded_weight in self_weights: + if name.startswith("classifier"): + param = params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + def pooler( + self, + hidden_states: torch.Tensor, + pooling_metadata: PoolingMetadata, + ) -> Optional[PoolerOutput]: + return self._pooler(hidden_states, pooling_metadata) + def forward( self, input_ids: Optional[torch.Tensor], @@ -93,25 +216,12 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, + token_type_ids: Optional[torch.Tensor] = None, ) -> torch.Tensor: - - # Verify assumption that position are always a sequence from - # 0 to N. (Actually here we just check 0 and N to simplify). - # This is important to fix the position which are assumed to - # start from padding_idx + 1 instead of 0 in the Roberta models. - assert hasattr(attn_metadata, "seq_lens_tensor") - cumulative = attn_metadata.seq_lens_tensor.cumsum(dim=0) - start_pos = torch.cat( - (torch.tensor([0], device=attn_metadata.seq_lens_tensor.device), - cumulative[:-1])) - assert len(torch.nonzero(positions[start_pos])) == 0 - end_pos = cumulative - 1 - last_tokens = attn_metadata.seq_lens_tensor - 1 - assert len(torch.nonzero(positions[end_pos] - last_tokens)) == 0 - - return super().forward(input_ids=input_ids, - positions=positions, - kv_caches=kv_caches, - attn_metadata=attn_metadata, - intermediate_tensors=intermediate_tensors, - inputs_embeds=inputs_embeds) + return self.roberta(input_ids=input_ids, + position_ids=positions, + kv_caches=kv_caches, + inputs_embeds=inputs_embeds, + intermediate_tensors=intermediate_tensors, + attn_metadata=attn_metadata, + token_type_ids=token_type_ids) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index c58ad99692900..deaed0ba7e4ce 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -25,7 +25,8 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.multimodal.utils import (cached_get_tokenizer, consecutive_placeholder_ranges, - repeat_and_pad_placeholder_tokens) + repeat_and_pad_placeholder_tokens, + resolve_visual_encoder_outputs) from vllm.sequence import SequenceData from .utils import get_vit_attn_backend @@ -450,11 +451,19 @@ def __init__( def forward( self, inputs_embeds: torch.Tensor, - ) -> torch.Tensor: + return_all_hidden_states: bool, + ) -> Union[torch.Tensor, list[torch.Tensor]]: + hidden_states_pool = [] hidden_states = inputs_embeds + for encoder_layer in self.layers: hidden_states, _ = encoder_layer(hidden_states) - + if return_all_hidden_states: + hidden_states_pool.append(hidden_states) + # If we have multiple feature sample layers, we return all hidden + # states in order and grab the ones we need by index. + if return_all_hidden_states: + return hidden_states_pool return hidden_states @@ -509,6 +518,7 @@ def __init__( embed_dim = config.hidden_size self.embeddings = SiglipVisionEmbeddings(config) + self.encoder = SiglipEncoder( config, quant_config=quant_config, @@ -546,23 +556,33 @@ def forward( self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = True, + feature_sample_layers: Optional[list[int]] = None, ) -> torch.Tensor: + hidden_states = self.embeddings( pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, ) - encoder_outputs = self.encoder(inputs_embeds=hidden_states) + return_all_hidden_states = feature_sample_layers is not None + + # Produces either the last layer output or all of the hidden states, + # depending on if we have feature_sample_layers or not + encoder_outputs = self.encoder( + inputs_embeds=hidden_states, + return_all_hidden_states=return_all_hidden_states, + ) - if self.post_layernorm is None: - return encoder_outputs + # Handle post-norm (if applicable) and stacks feature layers if needed + encoder_outputs = resolve_visual_encoder_outputs( + encoder_outputs, feature_sample_layers, self.post_layernorm, + self.config.num_hidden_layers) - last_hidden_state = self.post_layernorm(encoder_outputs) - # TODO: add this back when pooled_output is used in inference + # TODO: add this back when pooled_output is used in inference. # if self.use_head: - # pooled_output = self.head(last_hidden_state) + # pooled_output = self.head(encoder_outputs) - return last_hidden_state + return encoder_outputs class SiglipVisionModel(nn.Module): @@ -595,10 +615,12 @@ def forward( self, pixel_values: torch.Tensor, interpolate_pos_encoding: bool = False, + feature_sample_layers: Optional[list[int]] = None, ) -> torch.Tensor: return self.vision_model( pixel_values=pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, + feature_sample_layers=feature_sample_layers, ) def load_weights(self, weights: Iterable[Tuple[str, diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 6d6fafc5ab0eb..f58710d215056 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -167,6 +167,7 @@ def __init__( num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, + prefix=f"{prefix}.attn", ) def forward( diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index e11d2e916730a..6b2107bef0a66 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -77,7 +77,8 @@ class StablelmAttention(nn.Module): def __init__(self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "") -> None: super().__init__() self.config = config self.hidden_size = config.hidden_size @@ -131,7 +132,8 @@ def __init__(self, self.scaling, num_kv_heads=self.num_key_value_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -155,9 +157,13 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() - self.self_attn = StablelmAttention(config, cache_config, quant_config) + self.self_attn = StablelmAttention(config, + cache_config, + quant_config, + prefix=f"{prefix}.self_attn") self.mlp = StablelmMLP(config, quant_config) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) @@ -207,8 +213,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: StablelmDecoderLayer(config, cache_config, - quant_config), + lambda prefix: StablelmDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) norm_eps = getattr(config, "norm_eps", diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 74c66042226de..15e8f2af52cda 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -52,7 +52,8 @@ class Starcoder2Attention(nn.Module): def __init__(self, config: Starcoder2Config, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.config = config @@ -105,7 +106,8 @@ def __init__(self, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -154,12 +156,14 @@ class Starcoder2DecoderLayer(nn.Module): def __init__(self, config: Starcoder2Config, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.hidden_size = config.hidden_size self.self_attn = Starcoder2Attention(config, cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.self_attn") self.mlp = Starcoder2MLP(config, quant_config=quant_config) self.input_layernorm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) @@ -213,7 +217,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Starcoder2DecoderLayer( - config, cache_config, quant_config=quant_config), + config, cache_config, quant_config=quant_config, prefix=prefix + ), prefix=f"{prefix}.layers", ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index 2ab9b19e22068..dcfd2cb7d2622 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -629,3 +629,24 @@ def maybe_prefix(prefix: str, name: str) -> str: The string "prefix.name" if prefix was non-empty, otherwise just "name". """ return name if not prefix else f"{prefix}.{name}" + + +def extract_layer_index(layer_name: str) -> int: + """ + Extract the layer index from the module name. + Examples: + - "encoder.layers.0" -> 0 + - "encoder.layers.1.self_attn" -> 1 + - "2.self_attn" -> 2 + - "model.encoder.layers.0.sub.1" -> ValueError + """ + subnames = layer_name.split(".") + int_vals: List[int] = [] + for subname in subnames: + try: + int_vals.append(int(subname)) + except ValueError: + continue + assert len(int_vals) == 1, (f"layer name {layer_name} should" + " only contain one integer") + return int_vals[0] diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index bc37a997eabb5..25a0d474e2863 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -93,6 +93,7 @@ def __init__( quant_config: Optional[QuantizationConfig] = None, bias: bool = False, cache_config: Optional[CacheConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = hidden_size @@ -138,7 +139,8 @@ def __init__( self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, - quant_config=quant_config) + quant_config=quant_config, + prefix=f"{prefix}.attn") def forward( self, @@ -162,6 +164,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.hidden_size = config.hidden_size @@ -180,6 +183,7 @@ def __init__( quant_config=quant_config, bias=getattr(config, "bias", False), cache_config=cache_config, + prefix=f"{prefix}.self_attn", ) self.mlp = XverseMLP( hidden_size=self.hidden_size, @@ -243,8 +247,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, - lambda prefix: XverseDecoderLayer(config, cache_config, - quant_config), + lambda prefix: XverseDecoderLayer( + config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) diff --git a/vllm/multimodal/inputs.py b/vllm/multimodal/inputs.py index 64a4c58d5509c..640c7c04b8817 100644 --- a/vllm/multimodal/inputs.py +++ b/vllm/multimodal/inputs.py @@ -6,7 +6,7 @@ import torch import torch.types from PIL.Image import Image -from typing_extensions import TypeAlias +from typing_extensions import NotRequired, TypeAlias from vllm.utils import JSONTree, is_list_of, json_map_leaves @@ -203,18 +203,14 @@ class MultiModalInputsV2(TypedDict): """The type of inputs.""" prompt: str - """ - The original, unprocessed prompt text. - - Note: - Since prompt text is not required by vLLM internals, we leave this - unprocessed to save CPU computation. You can still call - :code:`tokenizer.decode(prompt_token_ids)` to get the processed text. - """ + """The processed prompt text.""" prompt_token_ids: List[int] """The processed token IDs which includes placeholder tokens.""" + token_type_ids: NotRequired[List[int]] + """The token type IDs of the prompt.""" + mm_kwargs: MultiModalKwargs """Keyword arguments to be directly passed to the model after batching.""" diff --git a/vllm/multimodal/processing.py b/vllm/multimodal/processing.py index 88a924da174a6..28c8dda581982 100644 --- a/vllm/multimodal/processing.py +++ b/vllm/multimodal/processing.py @@ -1,34 +1,91 @@ +import re +from abc import ABC, abstractmethod +from collections.abc import Callable, ItemsView, Iterable, Mapping, Sequence from dataclasses import dataclass -from functools import lru_cache, partial -from typing import (Any, Callable, Collection, Generic, List, Mapping, - Optional, TypedDict, TypeVar, final) +from functools import lru_cache +from itertools import groupby +from typing import Any, Generic, NamedTuple, Optional, Protocol, TypeVar, Union +import numpy as np from transformers import BatchFeature -from typing_extensions import TypeAlias +from typing_extensions import TypeAlias, TypedDict from vllm.inputs import InputProcessingContext from vllm.transformers_utils.tokenizer import AnyTokenizer, MistralTokenizer -from vllm.utils import is_list_of +from vllm.utils import flatten_2d_lists, full_groupby, is_list_of from .inputs import (AudioItem, ImageItem, MultiModalDataDict, MultiModalInputsV2, MultiModalKwargs, PlaceholderRange, VideoItem) + +def bind_prompt_sequence( + seq: Union[str, list[int]], + tokenizer: AnyTokenizer, +) -> "_BoundPromptSequence": + """ + Bind a text or token sequence to a tokenizer so that it can be + lazily converted into the other format on demand. + """ + return _BoundPromptSequence( + tokenizer=tokenizer, + _text=seq if isinstance(seq, str) else None, + _token_ids=seq if isinstance(seq, list) else None, + ) + + _T = TypeVar("_T") +_S = TypeVar("_S", str, list[int]) -ReplacementFunc: TypeAlias = Callable[[_T, BatchFeature, int], List[int]] -""" -Given the original data item, HF-processed data, and index of the processed -item, output the replacement token IDs to be allocated in vLLM. -""" + +@dataclass +class PromptReplacement(Generic[_S, _T]): + target: _S + """The text or token sequence to find and replace.""" + + repl_unit: _S + """ + The unit making up the replacement text or token sequence. + + See :code:`repl_count` for more details. + """ + + repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int] + """ + Given the original multi-modal items for this modality, HF-processed data, + and index of the processed item, output the number of repetitions of + :code:`repl_unit` to build up the replacement text or token sequence. + + For convenience, you can pass in an integer if the number of repetitions is + a constant. + """ + + def __repr__(self) -> str: + return (f"{type(self).__name__}(target={self.target!r}, " + f"repl_unit={self.repl_unit!r})") + + def bind( + self, + modality: str, + tokenizer: AnyTokenizer, + ) -> "_BoundPromptReplacement[_T]": + return _BoundPromptReplacement( + modality=modality, + target=bind_prompt_sequence(self.target, tokenizer), + repl_unit=bind_prompt_sequence(self.repl_unit, tokenizer), + repl_count=self.repl_count, + ) @dataclass class ModalityProcessingMetadata(Generic[_T]): - placeholder_replacements: Mapping[str, ReplacementFunc] + prompt_repls: Sequence[Union[PromptReplacement[str, _T], + PromptReplacement[list[int], _T]]] """ - A dictionary where each item represents the original placeholder in the - prompt text and the corresponding replacement. + Defines each text or token sequence to replace in the HF-processed prompt. + + This is skipped if the HF-processed prompt is found to already contain + the replacement prompts. """ @@ -52,46 +109,138 @@ class MultiModalProcessingMetadataBuiltins(TypedDict, total=False): Read more on that :ref:`here `. """ -MultiModalMultiData: TypeAlias = List[_T] -""" -A list of data items, where the number of data items allowed -per modality is restricted by :code:`--limit-mm-per-prompt`. -""" +def _encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: bool = False, +) -> list[int]: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.encode(text, add_special_tokens=...)`. + """ + if isinstance(tokenizer, MistralTokenizer): + return tokenizer.tokenizer.encode(text, + bos=add_special_tokens, + eos=add_special_tokens) -@final -class MultiModalMultiDataBuiltins(TypedDict, total=False): - """Type annotations for modality types predefined by vLLM.""" + return tokenizer.encode(text, add_special_tokens=add_special_tokens) - image: MultiModalMultiData[ImageItem] - """The input images.""" - video: MultiModalMultiData[VideoItem] - """The input videos.""" +@lru_cache(maxsize=2048) +def _cached_encode( + tokenizer: AnyTokenizer, + text: str, + *, + add_special_tokens: bool = False, +) -> list[int]: + return _encode(tokenizer, text, add_special_tokens=add_special_tokens) - audio: MultiModalMultiData[AudioItem] - """The input audios.""" +def _decode( + tokenizer: AnyTokenizer, + token_ids: list[int], + *, + skip_special_tokens: bool = False, +) -> str: + """ + Backend-agnostic equivalent of HF's + :code:`tokenizer.decode(token_ids, skip_special_tokens=...)`. + """ + return tokenizer.decode(token_ids, skip_special_tokens=skip_special_tokens) -MultiModalMultiDataDict: TypeAlias = Mapping[str, MultiModalMultiData[Any]] -""" -A dictionary containing an entry for each modality type to input. -Note: - This dictionary also accepts modality keys defined outside - :class:`MultiModalMultiDataBuiltins` as long as a customized plugin - is registered through the :class:`~vllm.multimodal.MULTIMODAL_REGISTRY`. - Read more on that :ref:`here `. -""" +@lru_cache(maxsize=2048) +def _cached_decode( + tokenizer: AnyTokenizer, + token_ids: tuple[int, ...], + *, + skip_special_tokens: bool = False, +) -> str: + return _decode(tokenizer, + list(token_ids), + skip_special_tokens=skip_special_tokens) + + +class _HasModalityAttr(Protocol): + modality: str + +class _HasModalityProp(Protocol): -def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict: + @property + def modality(self) -> str: + ... + + +_M = TypeVar("_M", bound=Union[_HasModalityAttr, _HasModalityProp]) + + +def full_groupby_modality(values: Iterable[_M]) -> ItemsView[str, list[_M]]: + """Convenience function to apply :func:`full_groupby` based on modality.""" + return full_groupby(values, key=lambda x: x.modality) + + +@dataclass +class _BoundPromptSequence: + tokenizer: AnyTokenizer + _text: Optional[str] + _token_ids: Optional[list[int]] + + def __post_init__(self) -> None: + if self._text is None and self._token_ids is None: + raise ValueError("At least one of 'text' and 'token_ids' must be " + "specified") + + @property + def text(self) -> str: + if self._text is None: + assert self._token_ids is not None + self._text = _cached_decode(self.tokenizer, tuple(self._token_ids)) + + return self._text + + @property + def token_ids(self) -> list[int]: + if self._token_ids is None: + assert self._text is not None + self._token_ids = _cached_encode(self.tokenizer, self._text) + + return self._token_ids + + def __repr__(self) -> str: + return (f"{type(self).__name__}(_text={self._text!r}, " + f"_token_ids={self._token_ids!r})") + + +@dataclass +class _BoundPromptReplacement(Generic[_T]): + modality: str + target: _BoundPromptSequence + repl_unit: _BoundPromptSequence + repl_count: Union[Callable[[list[_T], BatchFeature, int], int], int] + + def get_count( + self, + mm_items: list[_T], + hf_inputs: BatchFeature, + item_idx: int, + ) -> int: + repl_count = self.repl_count + if isinstance(repl_count, int): + return repl_count + + return repl_count(mm_items, hf_inputs, item_idx) + + +def to_multi_format(data: MultiModalDataDict) -> dict[str, list[Any]]: """ Convert a :class:`MultiModalDataDict` containing single data items to a :class:`MultiModalMultiDataDict` containing multiple data items per entry. """ - multi_data: Mapping[str, MultiModalMultiData[Any]] = {} + multi_data = dict[str, list[Any]]() for k, v in data.items(): # yapf: disable @@ -107,86 +256,279 @@ def to_multi_format(data: MultiModalDataDict) -> MultiModalMultiDataDict: return multi_data -def encode_no_special_tokens( - tokenizer: AnyTokenizer, - text: str, -) -> List[int]: +class _TokenRun(NamedTuple): + token_id: int + + start_idx: int + length: int + + +def iter_token_runs(token_ids: list[int]) -> Iterable[_TokenRun]: """ - Backend-agnostic equivalent of HF's - :code:`tokenizer.encode(text, add_special_tokens=False)`. + Yield the starting index and length of each run of tokens that are the same. """ - if isinstance(tokenizer, MistralTokenizer): - return tokenizer.tokenizer.encode(text, bos=False, eos=False) + start_idx = 0 + + for token_id, it in groupby(token_ids): + length = sum(1 for _ in it) + yield _TokenRun(token_id=token_id, start_idx=start_idx, length=length) + + start_idx += length + + +class _PlaceholderInfo(NamedTuple): + modality: str + offset: int + length: int + + def to_range(self) -> PlaceholderRange: + return PlaceholderRange(offset=self.offset, length=self.length) + + +def iter_placeholders( + prompt_repls: Sequence[_BoundPromptReplacement[Any]], + token_ids: list[int], + *, + min_placeholder_count: int, +) -> Iterable[_PlaceholderInfo]: + """Yield each set of placeholder tokens found in :code:`token_ids`.""" + placeholder_ids_by_modality = { + modality: { + token_id + for prompt_repl in repls + for token_id in prompt_repl.repl_unit.token_ids + } + for modality, repls in full_groupby_modality(prompt_repls) + } - return tokenizer.encode(text, add_special_tokens=False) + for run_info in iter_token_runs(token_ids): + if run_info.length > min_placeholder_count: + for (modality, + placeholder_ids) in placeholder_ids_by_modality.items(): + if run_info.token_id in placeholder_ids: + yield _PlaceholderInfo( + modality=modality, + offset=run_info.start_idx, + length=run_info.length, + ) -@lru_cache -def candidate_placeholders( - tokenizer: AnyTokenizer, - placeholder_text: str, -) -> Collection[List[int]]: - """Generate token ID sequences that may represent a placeholder text.""" - # When the placeholder text is not mapped to a special token ID, - # it may be tokenized differently based on whether it is at the start/end - # of the string. So, we go through each combination of whether the text - # is at the start and end boundaries of the string - - # Matches the placeholder when it is in the middle of the string - start_id, = encode_no_special_tokens(tokenizer, "a") - end_id, = encode_no_special_tokens(tokenizer, "b") - - candidate_basic = encode_no_special_tokens(tokenizer, placeholder_text) - - start_id_, *candidate_a = encode_no_special_tokens( - tokenizer, - f"a{placeholder_text}", - ) - assert start_id == start_id_ +class _TokenMatch(NamedTuple): + start_idx: int + end_idx: int - start_id_, *candidate_ab, end_id_ = encode_no_special_tokens( - tokenizer, - f"a{placeholder_text}b", - ) - assert start_id == start_id_ and end_id == end_id_ - *candidate_b, end_id_ = encode_no_special_tokens( - tokenizer, - f"{placeholder_text}b", - ) - assert end_id == end_id_ +def iter_token_matches( + token_ids: list[int], + match_ids: list[int], +) -> Iterable[_TokenMatch]: + """Yield each occurrence of :code:`match_ids` in :code:`token_ids`.""" + match_len = len(match_ids) - # Remove duplicates (need to convert to tuple to be hashable) - unique_candidates = { - tuple(c) - for c in [candidate_basic, candidate_a, candidate_ab, candidate_b] - } + last_end_idx = 0 + for start_idx in range(len(token_ids) - match_len + 1): + if start_idx < last_end_idx: + continue # Exclude overlapping matches - # Convert back to list - return [list(c) for c in unique_candidates] + end_idx = start_idx + match_len + if token_ids[start_idx:end_idx] == match_ids: + yield _TokenMatch(start_idx=start_idx, end_idx=end_idx) + last_end_idx = end_idx -def apply_placeholders( - token_ids: List[int], - placeholder_ids: List[int], - get_replacement_ids: Callable[[], List[int]], -) -> Optional[PlaceholderRange]: - """ - Find the first occurrence of :code:`placeholder_ids`, - and replace it with the output of :code:`get_replacement_ids`. +class _PromptReplacementMatch(ABC, Generic[_T, _S]): + prompt_repl: _BoundPromptReplacement[_T] + + @property + def modality(self) -> str: + return self.prompt_repl.modality + + @property + @abstractmethod + def start_idx(self) -> int: + raise NotImplementedError + + @property + @abstractmethod + def end_idx(self) -> int: + raise NotImplementedError + + @abstractmethod + def get_repl( + self, + mm_items: list[_T], + hf_inputs: BatchFeature, + item_idx: int, + ) -> _S: + raise NotImplementedError + + def __repr__(self) -> str: + return (f"{type(self).__name__}(modality={self.modality!r}, " + f"start_idx={self.start_idx!r}, end_idx={self.end_idx!r})") + + +@dataclass(repr=False) +class _PromptReplacementTokenMatch(_PromptReplacementMatch[_T, list[int]]): + prompt_repl: _BoundPromptReplacement[_T] + match: _TokenMatch + + @property + def start_idx(self) -> int: + return self.match.start_idx + + @property + def end_idx(self) -> int: + return self.match.end_idx + + def get_repl( + self, + mm_items: list[_T], + hf_inputs: BatchFeature, + item_idx: int, + ) -> list[int]: + prompt_repl = self.prompt_repl + count = prompt_repl.get_count(mm_items, hf_inputs, item_idx) + return prompt_repl.repl_unit.token_ids * count - This function updates :code:`token_ids` in place. + +@dataclass(repr=False) +class _PromptReplacementTextMatch(_PromptReplacementMatch[_T, str]): + prompt_repl: _BoundPromptReplacement[_T] + match: re.Match[str] + + @property + def start_idx(self) -> int: + return self.match.start() + + @property + def end_idx(self) -> int: + return self.match.end() + + def get_repl( + self, + mm_items: list[_T], + hf_inputs: BatchFeature, + item_idx: int, + ) -> str: + prompt_repl = self.prompt_repl + count = prompt_repl.get_count(mm_items, hf_inputs, item_idx) + return prompt_repl.repl_unit.text * count + + +def find_token_matches( + prompt: list[int], + prompt_repls: Sequence[_BoundPromptReplacement[_T]], +) -> list[_PromptReplacementTokenMatch[_T]]: + """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" + return [ + _PromptReplacementTokenMatch(prompt_repl, match) + for prompt_repl in prompt_repls + for match in iter_token_matches(prompt, prompt_repl.target.token_ids) + ] + + +def find_text_matches( + prompt: str, + prompt_repls: Sequence[_BoundPromptReplacement[_T]], +) -> list[_PromptReplacementTextMatch[_T]]: + """Return each target of :code:`prompt_repls` found in :code:`prompt`.""" + return [ + _PromptReplacementTextMatch(prompt_repl, match) + for prompt_repl in prompt_repls + for match in re.finditer(re.escape(prompt_repl.target.text), prompt) + ] + + +def _resolve_matches( + prompt: _S, + matches: Sequence[_PromptReplacementMatch[_T, _S]], +) -> list[_PromptReplacementMatch[_T, _S]]: + """ + Resolve :code:`matches` to ensure that there are no overlapping matches, + and sort them such that earlier matches take priority over later ones. """ - placeholder_length = len(placeholder_ids) + num_matches_by_idx = np.zeros(len(prompt), dtype=int) + for match in matches: + num_matches_by_idx[match.start_idx:match.end_idx] += 1 + + duplicate_matches_idxs, = np.nonzero(num_matches_by_idx > 1) + if len(duplicate_matches_idxs) > 0: + raise ValueError("Unable to find a unique replacement " + f"at indices={duplicate_matches_idxs} " + f"of prompt={prompt}") + + return sorted(matches, key=lambda x: x.start_idx) + + +def _replace_matches( + prompt: _S, + matches: Sequence[_PromptReplacementMatch[_T, _S]], + mm_items_by_modality: Mapping[str, list[_T]], + hf_inputs: BatchFeature, +) -> list[_S]: + out_seqs = list[_S]() + prev_end_idx = 0 + next_idx_by_modality = {modality: 0 for modality in mm_items_by_modality} + + for match in _resolve_matches(prompt, matches): + modality = match.modality + mm_items = mm_items_by_modality[modality] + + item_idx = next_idx_by_modality[modality] + if item_idx >= len(mm_items): + continue + + start_idx = match.start_idx + end_idx = match.end_idx + repl_ids = match.get_repl(mm_items, hf_inputs, item_idx) + + out_seqs.append(prompt[prev_end_idx:start_idx] + repl_ids) + prev_end_idx = end_idx + next_idx_by_modality[modality] += 1 + + out_seqs.append(prompt[prev_end_idx:]) + + return out_seqs + + +def replace_token_matches( + prompt: list[int], + matches: Sequence[_PromptReplacementMatch[_T, list[int]]], + mm_items_by_modality: Mapping[str, list[_T]], + hf_inputs: BatchFeature, +) -> list[int]: + """Apply :code:`prompt_repls` to :code:`prompt`.""" + if not matches: + return prompt + + token_id_seqs = _replace_matches( + prompt, + matches, + mm_items_by_modality, + hf_inputs, + ) + + return flatten_2d_lists(token_id_seqs) - for start_idx in range(len(token_ids) - placeholder_length + 1): - if token_ids[start_idx:placeholder_length] == placeholder_ids: - token_ids[start_idx:placeholder_length] = get_replacement_ids() - return PlaceholderRange(offset=start_idx, - length=placeholder_length) +def replace_text_matches( + prompt: str, + matches: Sequence[_PromptReplacementMatch[_T, str]], + mm_items_by_modality: Mapping[str, list[_T]], + hf_inputs: BatchFeature, +) -> str: + """Apply :code:`prompt_repls` to :code:`prompt`.""" + if not matches: + return prompt - return None + texts = _replace_matches( + prompt, + matches, + mm_items_by_modality, + hf_inputs, + ) + + return "".join(texts) class MultiModalProcessor: @@ -212,62 +554,166 @@ def __call__( ) -> MultiModalInputsV2: return self.apply(prompt, mm_data, mm_processor_kwargs) - def apply( + def _find_placeholders( + self, + all_prompt_repls: Sequence[_BoundPromptReplacement[Any]], + new_token_ids: list[int], + *, + # To avoid false positives from multi-input when detecting + # whether placeholder tokens have been inserted, in case + # the target sequence is a subset of the replacement tokens + min_placeholder_count: int = 16, + ) -> list[_PlaceholderInfo]: + return list( + iter_placeholders( + all_prompt_repls, + new_token_ids, + min_placeholder_count=min_placeholder_count, + )) + + def _apply_hf_processor( self, prompt: str, mm_data: MultiModalDataDict, mm_processor_kwargs: Mapping[str, object], - ) -> MultiModalInputsV2: - tokenizer = self.ctx.tokenizer + ) -> BatchFeature: hf_processor = self.ctx.get_hf_processor() - processed_inputs = hf_processor( + return hf_processor( text=prompt, # type: ignore **mm_data, **mm_processor_kwargs, ) - new_token_ids, = processed_inputs.pop("input_ids").tolist() - mm_kwargs = MultiModalKwargs(processed_inputs) - mm_placeholders: Mapping[str, List[PlaceholderRange]] = {} + def _bind_prompt_replacements( + self, + mm_data: MultiModalDataDict, + ) -> list[_BoundPromptReplacement[Any]]: + tokenizer = self.ctx.tokenizer - for modality, orig_inputs in to_multi_format(mm_data).items(): - assert isinstance(orig_inputs, list) + return [ + prompt_repl.bind(modality, tokenizer) + for modality, metadata in self.metadata.items() + if modality in mm_data for prompt_repl in metadata.prompt_repls + ] - metadata = self.metadata[modality] - placeholder_replacements = metadata.placeholder_replacements + def _apply_prompt_replacements( + self, + mm_data: MultiModalDataDict, + hf_inputs: BatchFeature, + token_ids: list[int], + prompt_repls: Sequence[_BoundPromptReplacement[Any]], + ) -> tuple[list[int], str, list[_PlaceholderInfo]]: + tokenizer = self.ctx.tokenizer - modality_placeholders: List[PlaceholderRange] = [] + mm_items = to_multi_format(mm_data) + token_matches = find_token_matches(token_ids, prompt_repls) + + # If the search text does not represent a special token, + # it may have different token IDs in the prompt, because + # the tokens may go across the boundaries of the search text. + # ---- + # e.g. when searching for "foo" in "food", if "food" itself makes + # up a token, then the token ID of "foo" will not appear at all + # ---- + # Since it is inefficient to search for all possible tokenizations + # of the search text in the prompt, we instead perform string + # replacement on the decoded token IDs, then encode them back. + if all( + len(matches) >= len(mm_data[modality]) + for modality, matches in full_groupby_modality(token_matches) + ): # yapf: disable + token_ids = replace_token_matches( + token_ids, + token_matches, + mm_items, + hf_inputs, + ) + + text = _decode(tokenizer, token_ids) + matched_repls = [match.prompt_repl for match in token_matches] + else: + text = _decode(tokenizer, token_ids) + + text_matches = find_text_matches(text, prompt_repls) + text = replace_text_matches( + text, + text_matches, + mm_items, + hf_inputs, + ) + + token_ids = _encode(tokenizer, text) + matched_repls = [match.prompt_repl for match in text_matches] + + placeholders = self._find_placeholders(matched_repls, token_ids) + + # Sanity check + assert len(placeholders) == len(matched_repls), dict( + # Log this information for easier debugging + text=text, + token_ids=token_ids, + placeholders=placeholders, + matched_repls=matched_repls, + ) - for item_idx, orig_item in enumerate(orig_inputs): - for match_text, replace_fn in placeholder_replacements.items(): - candidates = candidate_placeholders(tokenizer, match_text) - get_replacement_ids = partial( - replace_fn, - orig_item, - processed_inputs, - item_idx, - ) + return token_ids, text, placeholders - for match_ids in candidates: - # TODO(youkaichao): Don't update new_token_ids - placeholders = apply_placeholders( - new_token_ids, - match_ids, - get_replacement_ids, - ) + def apply( + self, + prompt_text: str, + mm_data: MultiModalDataDict, + mm_processor_kwargs: Mapping[str, object], + ) -> MultiModalInputsV2: + """ + Process multi-modal inputs to be used in vLLM. + + The main steps are: + + 1. Apply HF Processor on prompt text and multi-modal data together, + outputting token IDs and processed tensors. + 2. Find and replace sequences in the token IDs with placeholder tokens. + The number of placeholder tokens equals the feature size of the + multi-modal data outputted by the multi-modal encoder. + 3. Extract information about the placeholder tokens from the + processed token IDs. + """ + tokenizer = self.ctx.tokenizer + + hf_inputs = self._apply_hf_processor(prompt_text, mm_data, + mm_processor_kwargs) + prompt_ids, = hf_inputs.pop("input_ids").tolist() + mm_kwargs = MultiModalKwargs(hf_inputs) - if placeholders is not None: - modality_placeholders.append(placeholders) + all_prompt_repls = self._bind_prompt_replacements(mm_data) - # yapf: disable - mm_placeholders[modality] = modality_placeholders # type: ignore[index] - # yapf: enable + # If HF processor already inserts placeholder tokens, + # there is no need for us to insert them + all_placeholders = self._find_placeholders(all_prompt_repls, + prompt_ids) + if all_placeholders: + prompt_text = _decode(tokenizer, prompt_ids) + else: + ( + prompt_ids, + prompt_text, + all_placeholders, + ) = self._apply_prompt_replacements( + mm_data, + hf_inputs, + prompt_ids, + all_prompt_repls, + ) + + mm_placeholders = { + modality: [item.to_range() for item in items] + for modality, items in full_groupby_modality(all_placeholders) + } return MultiModalInputsV2( type="multimodal", - prompt=prompt, - prompt_token_ids=new_token_ids, + prompt=prompt_text, + prompt_token_ids=prompt_ids, mm_kwargs=mm_kwargs, mm_placeholders=mm_placeholders, ) diff --git a/vllm/multimodal/utils.py b/vllm/multimodal/utils.py index 40194716bbf94..d4333b7519b47 100644 --- a/vllm/multimodal/utils.py +++ b/vllm/multimodal/utils.py @@ -6,6 +6,7 @@ import numpy as np import numpy.typing as npt +import torch from PIL import Image import vllm.envs as envs @@ -392,6 +393,49 @@ def encode_video_base64(frames: npt.NDArray): return ",".join(base64_frames) +def resolve_visual_encoder_outputs( + encoder_outputs: Union[torch.Tensor, list[torch.Tensor]], + feature_sample_layers: Optional[list[int]], + post_layer_norm: Optional[torch.nn.LayerNorm], + max_possible_layers: int, +) -> torch.Tensor: + """Given the outputs a visual encoder module that may correspond to the + output of the last layer, or a list of hidden states to be stacked, + handle post normalization and resolve it into a single output tensor. + + Args: + encoder_outputs: Output of encoder's last layer or all hidden states. + feature_sample_layers: Optional layer indices to grab from the encoder + outputs; if provided, encoder outputs must be a list. + post_layer_norm: Post norm to apply to the output of the encoder. + max_possible_layers: Total layers in the fully loaded visual encoder. + + """ + if feature_sample_layers is None: + if post_layer_norm is not None: + return post_layer_norm(encoder_outputs) + return encoder_outputs + + # Get the hidden states corresponding to the layer indices. + # Negative values are relative to the full visual encoder, + # so offset them depending on how many layers were loaded. + # NOTE: this assumes that encoder_outputs contains a list + # of hidden states in the same order as the encoder layers + # that produced them. + offset = max_possible_layers - len(encoder_outputs) + hs_pool = [ + encoder_outputs[layer_idx] + if layer_idx >= 0 else encoder_outputs[layer_idx + offset] + for layer_idx in feature_sample_layers + ] + + # Apply post-norm on the final hidden state if we are using it + uses_last_layer = feature_sample_layers[-1] in (len(hs_pool) - 1, -1) + if post_layer_norm is not None and uses_last_layer: + hs_pool[-1] = post_layer_norm(encoder_outputs) + return torch.cat(hs_pool, dim=-1) + + # Utilities for input processors _T = TypeVar("_T", str, int) diff --git a/vllm/outputs.py b/vllm/outputs.py index 4ae9b377ae693..2d256803edfe8 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -60,7 +60,6 @@ class EmbeddingOutput: embedding: The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the embedding guide. """ - embedding: List[float] def __repr__(self) -> str: @@ -363,6 +362,50 @@ def __repr__(self): f"finished={self.finished})") +@dataclass +class ScoreOutput: + """The output data of one completion output of a request. + + Args: + score: The score, which is a list of floats. + index: The correspondent text index of the score. + """ + index: int + score: List[float] + + def __repr__(self) -> str: + return (f"ScoreOutput(" + f"score={self.score}), " + f"index={self.index})") + + +class ScoreRequestOutput: + """ + The output data of an score request to the LLM. + + Args: + request_id (str): A unique identifier for the score request. + outputs (score): The embedding results for the given input. + """ + + def __init__(self, request_id: str, outputs: "ScoreOutput"): + self.request_id = request_id + self.outputs = outputs + + def __repr__(self): + """ + Returns a string representation of an ScoreRequestOutput instance. + + The representation includes the request_id and the number of outputs, + providing a quick overview of the embedding request's results. + + Returns: + str: A string representation of the ScoreRequestOutput instance. + """ + return (f"ScoreRequestOutput(request_id='{self.request_id}', " + f"outputs={repr(self.outputs)}") + + class RequestOutputFactory: @staticmethod diff --git a/vllm/platforms/cpu.py b/vllm/platforms/cpu.py index f9a34a47959ec..cbc982752c6b4 100644 --- a/vllm/platforms/cpu.py +++ b/vllm/platforms/cpu.py @@ -19,6 +19,8 @@ class CpuPlatform(Platform): _enum = PlatformEnum.CPU + device_type: str = "cpu" + dispatch_key: str = "CPU" @classmethod def get_device_name(cls, device_id: int = 0) -> str: @@ -53,11 +55,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: cache_config = vllm_config.cache_config - if cache_config.enable_prefix_caching: - logger.warning( - "Prefix caching is not supported on CPU, disable it.") - cache_config.enable_prefix_caching = False - kv_cache_space = envs.VLLM_CPU_KVCACHE_SPACE if kv_cache_space >= 0: @@ -74,10 +71,12 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: f" {kv_cache_space}, expect a positive integer value.") scheduler_config = vllm_config.scheduler_config - if scheduler_config.chunked_prefill_enabled: - logger.warning( - "Chunked prefill is not supported on CPU, disable it.") - scheduler_config.chunked_prefill_enabled = False + if ((scheduler_config.chunked_prefill_enabled + or cache_config.enable_prefix_caching) + and model_config.dtype == torch.half): + logger.warning("Chunked-prefill on the CPU backend only does not" + " support fp16 for now, cast to bf16.") + model_config.dtype = torch.bfloat16 parallel_config = vllm_config.parallel_config if (parallel_config.distributed_executor_backend is not None @@ -86,3 +85,5 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: "distributed executor backend."), parallel_config.distributed_executor_backend) parallel_config.distributed_executor_backend = "mp" + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = "vllm.worker.cpu_worker.CPUWorker" diff --git a/vllm/platforms/cuda.py b/vllm/platforms/cuda.py index 9c5212ace1346..70724b8be4c45 100644 --- a/vllm/platforms/cuda.py +++ b/vllm/platforms/cuda.py @@ -4,16 +4,23 @@ import os from functools import lru_cache, wraps -from typing import Callable, List, Tuple, TypeVar +from typing import TYPE_CHECKING, Callable, List, Tuple, TypeVar import pynvml import torch from typing_extensions import ParamSpec +# import custom ops, trigger op registration +import vllm._C # noqa from vllm.logger import init_logger from .interface import DeviceCapability, Platform, PlatformEnum +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + logger = init_logger(__name__) _P = ParamSpec("_P") @@ -97,8 +104,14 @@ def device_id_to_physical_device_id(device_id: int) -> int: if "CUDA_VISIBLE_DEVICES" in os.environ: device_ids = os.environ["CUDA_VISIBLE_DEVICES"].split(",") if device_ids == [""]: - raise RuntimeError("CUDA_VISIBLE_DEVICES is set to empty string," - " which means GPU support is disabled.") + msg = ( + "CUDA_VISIBLE_DEVICES is set to empty string, which means" + " GPU support is disabled. If you are using ray, please unset" + " the environment variable `CUDA_VISIBLE_DEVICES` inside the" + " worker/actor. " + "Check https://github.com/vllm-project/vllm/issues/8402 for" + " more information.") + raise RuntimeError(msg) physical_device_id = device_ids[device_id] return int(physical_device_id) else: @@ -107,6 +120,8 @@ def device_id_to_physical_device_id(device_id: int) -> int: class CudaPlatform(Platform): _enum = PlatformEnum.CUDA + device_type: str = "cuda" + dispatch_key: str = "CUDA" @classmethod def get_device_capability(cls, device_id: int = 0) -> DeviceCapability: @@ -148,3 +163,17 @@ def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool: " machine has no NVLink equipped.") return False return True + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + parallel_config = vllm_config.parallel_config + scheduler_config = vllm_config.scheduler_config + if parallel_config.worker_cls == "auto": + if scheduler_config.is_multi_step: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_worker.MultiStepWorker" + elif vllm_config.speculative_config: + parallel_config.worker_cls = \ + "vllm.spec_decode.spec_decode_worker.create_spec_worker" + else: + parallel_config.worker_cls = "vllm.worker.worker.Worker" diff --git a/vllm/platforms/hpu.py b/vllm/platforms/hpu.py index 1e0888a30ba96..3071136e43b85 100644 --- a/vllm/platforms/hpu.py +++ b/vllm/platforms/hpu.py @@ -1,10 +1,19 @@ +from typing import TYPE_CHECKING + import torch from .interface import Platform, PlatformEnum, _Backend +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + class HpuPlatform(Platform): _enum = PlatformEnum.HPU + device_type: str = "hpu" + dispatch_key: str = "HPU" @classmethod def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: @@ -13,3 +22,19 @@ def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: @staticmethod def inference_mode(): return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + + scheduler_config = vllm_config.scheduler_config + if scheduler_config.is_multi_step: + raise NotImplementedError( + "Multi-step execution is not implemented for HPU") + + if vllm_config.speculative_config is not None: + raise NotImplementedError( + "Speculative decoding is not implemented for HPU") + + parallel_config = vllm_config.parallel_config + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = "vllm.worker.hpu_worker.HPUWorker" diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index f4849fa2ccfb0..3328665029039 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -56,6 +56,11 @@ def to_int(self) -> int: class Platform: _enum: PlatformEnum + device_type: str + # available dispatch keys: + # check https://github.com/pytorch/pytorch/blob/313dac6c1ca0fa0cde32477509cce32089f8532a/torchgen/model.py#L134 # noqa + # use "CPU" as a fallback for platforms not registered in PyTorch + dispatch_key: str = "CPU" def is_cuda(self) -> bool: return self._enum == PlatformEnum.CUDA @@ -169,3 +174,4 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: class UnspecifiedPlatform(Platform): _enum = PlatformEnum.UNSPECIFIED + device_type = "" diff --git a/vllm/platforms/neuron.py b/vllm/platforms/neuron.py index 07d8398eda525..4c4d778ed3dd4 100644 --- a/vllm/platforms/neuron.py +++ b/vllm/platforms/neuron.py @@ -1,9 +1,24 @@ +from typing import TYPE_CHECKING + from .interface import Platform, PlatformEnum +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + class NeuronPlatform(Platform): _enum = PlatformEnum.NEURON + device_type: str = "neuron" @classmethod def get_device_name(cls, device_id: int = 0) -> str: return "neuron" + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + parallel_config = vllm_config.parallel_config + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = \ + "vllm.worker.neuron_worker.NeuronWorker" diff --git a/vllm/platforms/openvino.py b/vllm/platforms/openvino.py index ad69ced5417b3..ea5ec7b40b95c 100644 --- a/vllm/platforms/openvino.py +++ b/vllm/platforms/openvino.py @@ -1,3 +1,5 @@ +from typing import TYPE_CHECKING + import torch import vllm.envs as envs @@ -5,11 +7,24 @@ from .interface import Platform, PlatformEnum, _Backend +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + logger = init_logger(__name__) +try: + import openvino as ov + import openvino.properties.hint as hints +except ImportError as e: + logger.warning("Failed to import OpenVINO with %r", e) + class OpenVinoPlatform(Platform): _enum = PlatformEnum.OPENVINO + device_type: str = "openvino" + dispatch_key: str = "CPU" @classmethod def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: @@ -37,3 +52,81 @@ def is_openvino_gpu(self) -> bool: def is_pin_memory_available(self) -> bool: logger.warning("Pin memory is not supported on OpenViNO.") return False + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + from vllm.utils import GiB_bytes + + parallel_config = vllm_config.parallel_config + assert ( + parallel_config.world_size == 1 + ), "OpenVINOExecutor only supports single CPU socket currently." + + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = \ + "vllm.worker.openvino_worker.OpenVINOWorker" + + # check and update model config + model_config = vllm_config.model_config + if model_config.dtype != torch.float32: + logger.warning( + f"Only float32 dtype is supported on OpenVINO, casting from {model_config.dtype}." # noqa: G004, E501 + ) + model_config.dtype = torch.float32 + if not model_config.enforce_eager: + logger.warning( + "CUDA graph is not supported on OpenVINO backend, fallback to " + "the eager mode.") + model_config.enforce_eager = True + + # check and update cache config + ov_core = ov.Core() + cache_config = vllm_config.cache_config + if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8": + if not OpenVinoPlatform.is_openvino_cpu(): + logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is" + "ignored for GPU, f16 data type will be used.") + cache_config.cache_dtype = ov.Type.f16 + else: + logger.info("KV cache type is overridden to u8 via " + "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.") + cache_config.cache_dtype = ov.Type.u8 + else: + if OpenVinoPlatform.is_openvino_cpu(): + ov_device = envs.VLLM_OPENVINO_DEVICE + inference_precision = ov_core.get_property( + ov_device, hints.inference_precision) + if inference_precision == ov.Type.bf16: + cache_config.cache_dtype = ov.Type.bf16 + else: + cache_config.cache_dtype = ov.Type.f16 + else: + cache_config.cache_dtype = ov.Type.f16 + + if OpenVinoPlatform.is_openvino_cpu(): + if cache_config.block_size != 32: + logger.info( + f"OpenVINO CPU optimal block size is 32, overriding currently set {cache_config.block_size}" # noqa: G004, E501 + ) + cache_config.block_size = 32 + else: + if cache_config.block_size != 16: + logger.info( + f"OpenVINO GPU optimal block size is 16, overriding currently set {cache_config.block_size}" # noqa: G004, E501 + ) + cache_config.block_size = 16 + + kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE + if kv_cache_space >= 0: + if kv_cache_space == 0 and OpenVinoPlatform.is_openvino_cpu(): + cache_config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore + logger.warning( + "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) " + "for OpenVINO backend is not set, using 4 by default.") + else: + cache_config.openvino_kvcache_space_bytes = ( # type: ignore + kv_cache_space * GiB_bytes) + else: + raise RuntimeError( + "Invalid environment variable VLLM_OPENVINO_KVCACHE_SPACE" + f" {kv_cache_space}, expect a positive integer value.") diff --git a/vllm/platforms/rocm.py b/vllm/platforms/rocm.py index 022256996f97b..d2f44c3e423e3 100644 --- a/vllm/platforms/rocm.py +++ b/vllm/platforms/rocm.py @@ -1,5 +1,6 @@ import os from functools import lru_cache +from typing import TYPE_CHECKING import torch @@ -7,8 +8,24 @@ from .interface import DeviceCapability, Platform, PlatformEnum, _Backend +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + logger = init_logger(__name__) +try: + import vllm._C # noqa: F401 +except ImportError as e: + logger.warning("Failed to import from vllm._C with %r", e) + +# import custom ops, trigger op registration +try: + import vllm._rocm_C # noqa: F401 +except ImportError as e: + logger.warning("Failed to import from vllm._rocm_C with %r", e) + if os.environ.get("VLLM_WORKER_MULTIPROC_METHOD", None) in ["fork", None]: logger.warning("`fork` method is not supported by ROCm. " "VLLM_WORKER_MULTIPROC_METHOD is overridden to" @@ -18,6 +35,8 @@ class RocmPlatform(Platform): _enum = PlatformEnum.ROCM + device_type: str = "cuda" + dispatch_key: str = "CUDA" @classmethod def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: @@ -46,3 +65,17 @@ def get_device_name(cls, device_id: int = 0) -> str: def get_device_total_memory(cls, device_id: int = 0) -> int: device_props = torch.cuda.get_device_properties(device_id) return device_props.total_memory + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + parallel_config = vllm_config.parallel_config + scheduler_config = vllm_config.scheduler_config + if parallel_config.worker_cls == "auto": + if scheduler_config.is_multi_step: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_worker.MultiStepWorker" + elif vllm_config.speculative_config: + parallel_config.worker_cls = \ + "vllm.spec_decode.spec_decode_worker.create_spec_worker" + else: + parallel_config.worker_cls = "vllm.worker.worker.Worker" diff --git a/vllm/platforms/tpu.py b/vllm/platforms/tpu.py index 2a7ca9fb8c576..137af57023ea9 100644 --- a/vllm/platforms/tpu.py +++ b/vllm/platforms/tpu.py @@ -16,6 +16,8 @@ class TpuPlatform(Platform): _enum = PlatformEnum.TPU + device_type: str = "tpu" + dispatch_key: str = "XLA" @classmethod def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: @@ -47,3 +49,15 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: if compilation_config.backend == "": compilation_config.backend = "openxla" + + assert vllm_config.speculative_config is None, \ + "TPU does not support speculative decoding" + + parallel_config = vllm_config.parallel_config + scheduler_config = vllm_config.scheduler_config + if parallel_config.worker_cls == "auto": + if scheduler_config.is_multi_step: + parallel_config.worker_cls = \ + "vllm.worker.multi_step_tpu_worker.MultiStepTPUWorker" + else: + parallel_config.worker_cls = "vllm.worker.tpu_worker.TPUWorker" diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index d0b3dca9a4195..69388a8e0f27c 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -1,14 +1,23 @@ +from typing import TYPE_CHECKING + import torch from vllm.logger import init_logger from .interface import DeviceCapability, Platform, PlatformEnum, _Backend +if TYPE_CHECKING: + from vllm.config import VllmConfig +else: + VllmConfig = None + logger = init_logger(__name__) class XPUPlatform(Platform): _enum = PlatformEnum.XPU + device_type: str = "xpu" + dispatch_key: str = "XPU" @classmethod def get_default_attn_backend(cls, selected_backend: _Backend) -> _Backend: @@ -34,3 +43,33 @@ def get_device_total_memory(cls, device_id: int = 0) -> int: @staticmethod def inference_mode(): return torch.no_grad() + + @classmethod + def check_and_update_config(cls, vllm_config: VllmConfig) -> None: + # check and update model config + model_config = vllm_config.model_config + if model_config.dtype == torch.bfloat16: + logger.warning( + "bfloat16 is not fully supported on XPU, casting to float16.") + model_config.dtype = torch.float16 + if not model_config.enforce_eager: + logger.warning( + "CUDA graph is not supported on XPU, fallback to the eager " + "mode.") + model_config.enforce_eager = True + + if vllm_config.speculative_config is not None: + raise NotImplementedError( + "XPU does not support speculative decoding") + + # check and update parallel config + parallel_config = vllm_config.parallel_config + if (parallel_config.distributed_executor_backend is not None + and parallel_config.distributed_executor_backend != "ray"): + logger.warning( + "%s is not supported on XPU, fallback to ray distributed" + " executor backend.", + parallel_config.distributed_executor_backend) + parallel_config.distributed_executor_backend = "ray" + if parallel_config.worker_cls == "auto": + parallel_config.worker_cls = "vllm.worker.xpu_worker.XPUWorker" diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index dc183dbfc9b96..3c64726ca3344 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -1,11 +1,9 @@ import logging -from contextlib import contextmanager -from typing import TYPE_CHECKING, Optional +import os -import vllm.envs as envs +import torch -if TYPE_CHECKING: - from vllm.config import VllmConfig +import vllm.envs as envs logger = logging.getLogger(__name__) @@ -18,6 +16,15 @@ def load_general_plugins(): processes. They should be designed in a way that they can be loaded multiple times without causing issues. """ + + # all processes created by vllm will load plugins, + # and here we can inject some common environment variables + # for all processes. + + # see https://github.com/vllm-project/vllm/issues/10480 + os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' + # see https://github.com/vllm-project/vllm/issues/10619 + torch._inductor.config.compile_threads = 1 global plugins_loaded if plugins_loaded: return @@ -52,39 +59,3 @@ def load_general_plugins(): logger.info("plugin %s loaded.", plugin.name) except Exception: logger.exception("Failed to load plugin %s", plugin.name) - - -_current_vllm_config: Optional["VllmConfig"] = None - - -@contextmanager -def set_current_vllm_config(vllm_config: "VllmConfig"): - """ - Temporarily set the current VLLM config. - Used during model initialization. - We save the current VLLM config in a global variable, - so that all modules can access it, e.g. custom ops - can access the VLLM config to determine how to dispatch. - """ - global _current_vllm_config - old_vllm_config = _current_vllm_config - try: - _current_vllm_config = vllm_config - yield - finally: - logger.debug("enabled custom ops: %s", - vllm_config.compilation_config.enabled_custom_ops) - logger.debug("disabled custom ops: %s", - vllm_config.compilation_config.disabled_custom_ops) - _current_vllm_config = old_vllm_config - - -def get_current_vllm_config() -> "VllmConfig": - if _current_vllm_config is None: - # in ci, usually when we test custom ops/modules directly, - # we don't set the vllm config. In that case, we set a default - # config. - logger.warning("Current VLLM config is not set.") - from vllm.config import VllmConfig - return VllmConfig() - return _current_vllm_config diff --git a/vllm/sequence.py b/vllm/sequence.py index 3b41d25a2fe42..669124319c4f4 100644 --- a/vllm/sequence.py +++ b/vllm/sequence.py @@ -449,6 +449,10 @@ def prompt_token_ids(self) -> List[int]: def prompt_embeds(self) -> Optional[torch.Tensor]: return self.inputs.prompt_embeds + @property + def token_type_ids(self) -> List[int]: + return self.inputs.token_type_ids + @property def multi_modal_data(self) -> "MultiModalDataDict": return self.inputs.multi_modal_data @@ -579,6 +583,9 @@ def get_num_new_tokens(self) -> int: return 1 return self.data.get_num_uncomputed_tokens() + def get_num_computed_tokens(self) -> int: + return self.data.get_num_computed_tokens() + def is_prefill(self) -> bool: return self.data.stage == SequenceStage.PREFILL @@ -684,6 +691,10 @@ def encoder_prompt_token_ids(self) -> Optional[List[int]]: return (self.encoder_seq.prompt_token_ids if self.encoder_seq is not None else None) + @property + def token_type_ids(self) -> Optional[List[int]]: + return self.first_seq.token_type_ids + @property def multi_modal_data(self) -> MultiModalDataDict: return self.first_seq.multi_modal_data @@ -906,6 +917,7 @@ class SequenceGroupMetadata( default_factory=lambda: SequenceGroupState()) # "MultiModalDataDict" types. We have to use Any due to msgspec # doesn't allow to have union of 2 different dicts. + token_type_ids: Optional[List[int]] = None multi_modal_data: Optional[Any] = None multi_modal_placeholders: Optional[MultiModalPlaceholderDict] = None mm_processor_kwargs: Optional[Dict[str, Any]] = None diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index cd4d7eb0e6e4e..cf166e3eb5bad 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -273,7 +273,8 @@ def execute_model( if previous_hidden_states is not None else {} # Run model - with set_forward_context(model_input.attn_metadata): + with set_forward_context(model_input.attn_metadata, + self.vllm_config): hidden_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/transformers_utils/config.py b/vllm/transformers_utils/config.py index 054845584c2ef..70d18d40b7aa7 100644 --- a/vllm/transformers_utils/config.py +++ b/vllm/transformers_utils/config.py @@ -9,6 +9,7 @@ from huggingface_hub.utils import (EntryNotFoundError, LocalEntryNotFoundError, RepositoryNotFoundError, RevisionNotFoundError) +from torch import nn from transformers import GenerationConfig, PretrainedConfig from transformers.models.auto.image_processing_auto import ( get_image_processor_config) @@ -31,6 +32,7 @@ UltravoxConfig) # yapf: enable from vllm.transformers_utils.utils import check_gguf_file +from vllm.utils import resolve_obj_by_qualname if VLLM_USE_MODELSCOPE: from modelscope import AutoConfig @@ -107,6 +109,15 @@ def patch_rope_scaling(config: PretrainedConfig) -> None: def patch_rope_scaling_dict(rope_scaling: Dict[str, Any]) -> None: + if "rope_type" in rope_scaling and "type" in rope_scaling: + rope_type = rope_scaling["rope_type"] + rope_type_legacy = rope_scaling["type"] + if rope_type != rope_type_legacy: + raise ValueError( + f"Found conflicts between 'rope_type={rope_type}' (modern " + f"field) and 'type={rope_type_legacy}' (legacy field). " + "You should only specify one of them.") + if "rope_type" not in rope_scaling and "type" in rope_scaling: rope_scaling["rope_type"] = rope_scaling["type"] logger.info("Replacing legacy 'type' key with 'rope_type'") @@ -568,3 +579,16 @@ def try_get_generation_config( return GenerationConfig.from_model_config(config) except OSError: # Not found return None + + +def get_cross_encoder_activation_function(config: PretrainedConfig): + if (hasattr(config, "sbert_ce_default_activation_function") + and config.sbert_ce_default_activation_function is not None): + + function_name = config.sbert_ce_default_activation_function + assert function_name.startswith("torch.nn.modules."), \ + "Loading of activation functions is restricted to " \ + "torch.nn.modules for security reasons" + return resolve_obj_by_qualname(function_name)() + else: + return nn.Sigmoid() if config.num_labels == 1 else nn.Identity() diff --git a/vllm/transformers_utils/configs/aria.py b/vllm/transformers_utils/configs/aria.py new file mode 100644 index 0000000000000..d253da0d96a34 --- /dev/null +++ b/vllm/transformers_utils/configs/aria.py @@ -0,0 +1,47 @@ +from transformers.models.idefics2.configuration_idefics2 import ( + Idefics2VisionConfig) +from transformers.models.llama.configuration_llama import LlamaConfig + + +class AriaVisionConfig(Idefics2VisionConfig): + model_type = "aria_vision_model" + + +class AriaMoELMConfig(LlamaConfig): + """ + Configuration class for AriaMoE language model. + + This class extends the LlamaConfig to include additional parameters specific + to the Mixture of Experts (MoE) architecture. + """ + + model_type = "aria_moe_lm" + + def __init__( + self, + moe_intermediate_size: int = 4096, + moe_num_experts: int = 8, + moe_topk: int = 2, + moe_num_shared_experts: int = 2, + **kwargs, + ): + """ + Initialize the AriaMoELMConfig. + + Args: + moe_intermediate_size (int): The intermediate size for MoE layers. + Default is 4096. + moe_num_experts (int): The number of experts in the MoE layer. + Default is 8. + moe_topk (int): The number of top experts to route to for each + token. Default is 2. + moe_num_shared_experts (int): The number of shared experts. Default + is 2. + **kwargs: Additional keyword arguments to be passed to the parent + LlamaConfig. + """ + super().__init__(**kwargs) + self.moe_intermediate_size = moe_intermediate_size + self.moe_num_experts = moe_num_experts + self.moe_topk = moe_topk + self.moe_num_shared_experts = moe_num_shared_experts diff --git a/vllm/utils.py b/vllm/utils.py index 5d0514cd9d168..dd4283e3ac381 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -19,7 +19,8 @@ import warnings import weakref from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task -from collections.abc import Mapping +from collections import defaultdict +from collections.abc import Iterable, Mapping from functools import lru_cache, partial, wraps from platform import uname from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, @@ -705,6 +706,12 @@ def create_kv_caches_with_random( return key_caches, value_caches +@lru_cache +def print_info_once(msg: str) -> None: + # Set the stacklevel to 2 to print the caller's line info + logger.info(msg, stacklevel=2) + + @lru_cache def print_warning_once(msg: str) -> None: # Set the stacklevel to 2 to print the caller's line info @@ -899,6 +906,23 @@ def flatten_2d_lists(lists: List[List[T]]) -> List[T]: return [item for sublist in lists for item in sublist] +_K = TypeVar("_K", bound=Hashable) +_V = TypeVar("_V") + + +def full_groupby(values: Iterable[_V], *, key: Callable[[_V], _K]): + """ + Unlike :class:`itertools.groupby`, groups are not broken by + non-contiguous data. + """ + groups = defaultdict[_K, list[_V]](list) + + for value in values: + groups[key(value)].append(value) + + return groups.items() + + # TODO: This function can be removed if transformer_modules classes are # serialized by value when communicating between processes def init_cached_hf_modules() -> None: @@ -1186,6 +1210,10 @@ def parse_args(self, args=None, namespace=None): else: processed_args.append('--' + arg[len('--'):].replace('_', '-')) + elif arg.startswith('-O') and arg != '-O' and len(arg) == 2: + # allow -O flag to be used without space, e.g. -O3 + processed_args.append('-O') + processed_args.append(arg[2:]) else: processed_args.append(arg) @@ -1491,6 +1519,9 @@ def __getitem__(self, key) -> T: self._dict[key] = self._factory[key]() return self._dict[key] + def __setitem__(self, key: str, value: Callable[[], T]): + self._factory[key] = value + def __iter__(self): return iter(self._factory) @@ -1498,15 +1529,6 @@ def __len__(self): return len(self._factory) -def combine_fx_passes(passes: List[Callable]) -> Callable: - - def combined_fx(graph) -> None: - for fx in passes: - fx(graph) - - return combined_fx - - def weak_ref_tensor(tensor: torch.Tensor) -> torch.Tensor: """ Create a weak reference to a tensor. @@ -1569,6 +1591,7 @@ def direct_register_custom_op( mutates_args: List[str], fake_impl: Optional[Callable] = None, target_lib: Optional[Library] = None, + dispatch_key: str = "CUDA", ): """ `torch.library.custom_op` can have significant overhead because it @@ -1597,7 +1620,7 @@ def direct_register_custom_op( schema_str = torch._custom_op.impl.infer_schema(op_func, mutates_args) my_lib = target_lib or vllm_lib my_lib.define(op_name + schema_str) - my_lib.impl(op_name, op_func, "CUDA") + my_lib.impl(op_name, op_func, dispatch_key=dispatch_key) if fake_impl is not None: my_lib._register_fake(op_name, fake_impl) diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index e73a1e60b2730..5f8535eaa303f 100644 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -19,7 +19,7 @@ def get_supported_head_sizes() -> List[int]: @staticmethod def get_name() -> str: - return "flash-attn-vllm-v1" + return "FLASH_ATTN_VLLM_V1" @staticmethod def get_impl_cls() -> Type["FlashAttentionImpl"]: @@ -173,7 +173,8 @@ def unified_v1_flash_attention( alibi_slopes: Optional[torch.Tensor] = None, logits_soft_cap: Optional[float] = None, ) -> None: - current_metadata = get_forward_context() + context = get_forward_context() + current_metadata = context.dynamic_forward_context if current_metadata is None: # Profiling run. return diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 38f1c03a4d3ac..8eb3fb976eb87 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -79,6 +79,9 @@ def get_computed_blocks(self, request: Request) -> List[KVCacheBlock]: return [] computed_blocks = [] + + # TODO(rickyx): potentially we could cache this so we don't have to + # recompute it every time. block_hashes = hash_request_tokens(self.block_size, request.all_token_ids) @@ -120,47 +123,45 @@ def append_slots( # slots, but we cannot allocate new blocks due to the limit. return None - # When caching is enabled, assign token IDs to already allocated blocks. - new_token_ids = None - parent_block = None - if self.enable_caching: - # Figure out the token IDs to add to the blocks. - new_token_ids = request.all_token_ids[ - request.num_computed_tokens:request.num_computed_tokens + - num_tokens] - - # Find the last full block index. - # TODO: This may be optimized by calculating the computed tokens. - last_full_block_idx = len(req_blocks) - 1 - while (last_full_block_idx >= 0 - and req_blocks[last_full_block_idx].block_hash is None): - last_full_block_idx -= 1 - - parent_block = (req_blocks[last_full_block_idx] - if last_full_block_idx >= 0 else None) - token_id_idx = self._add_token_ids_to_blocks( - blocks=req_blocks[last_full_block_idx + 1:], - token_ids=new_token_ids, - parent_block=parent_block) - - new_token_ids = new_token_ids[token_id_idx:] - parent_block = req_blocks[-1] - - # No new block is needed. When caching is enabled, we make sure - # token_id_idx is equal to len(new_token_ids), meaning that all tokens - # are added to allocated blocks. - if num_required_blocks <= len(req_blocks): - assert not self.enable_caching or token_id_idx == num_tokens, \ - f"{token_id_idx=} != {num_tokens=}" - return [] + if num_new_blocks <= 0: + # No new block is needed. + new_blocks = [] + else: + # Get new blocks from the free block pool considering + # preallocated blocks. + num_new_blocks = min( + num_new_blocks + self.num_preallocate_blocks, + self.free_block_queue.num_free_blocks, + ) + + new_blocks = self._get_new_blocks(num_new_blocks) + req_blocks.extend(new_blocks) + + if not self.enable_caching: + return new_blocks + + num_computed_full_blocks = (request.num_computed_tokens // + self.block_size) + + # NOTE(rickyx): We are assuming the `num_tokens` are actual + # tokens rather than lookahead slots (e.g. for speculative decoding). + # TODO(rickyx): When supporting speculative decoding, we will need to + # differentiate between them so that we can know how many blocks are + # full after appending the actual tokens. + num_full_blocks_after_append = (request.num_computed_tokens + + num_tokens) // self.block_size + assert num_full_blocks_after_append <= len(req_blocks) + + new_full_blocks = req_blocks[ + num_computed_full_blocks:num_full_blocks_after_append] + self._cache_full_blocks( + request=request, + blk_start_idx=num_computed_full_blocks, + full_blocks=new_full_blocks, + prev_block=req_blocks[num_computed_full_blocks - 1] + if num_computed_full_blocks >= 1 else None, + ) - # Allocate new blocks considering preallocated blocks, and - # add token IDs to them if caching is enabled. - num_new_blocks = min(num_new_blocks + self.num_preallocate_blocks, - self.free_block_queue.num_free_blocks) - new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids, - parent_block) - req_blocks.extend(new_blocks) return new_blocks def allocate_slots( @@ -184,11 +185,20 @@ def allocate_slots( raise ValueError( f"num_tokens must be greater than 0, got {num_tokens}") - # If a computed block of a request is an eviction candidate (in the - # free queue and ref_cnt == 0), it cannot be counted as a free block - # when allocating this request. - num_evictable_computed_blocks = len( - [blk for blk in computed_blocks if blk.ref_cnt == 0]) + # Touch the computed blocks to make sure they won't be evicted. + num_evictable_computed_blocks = 0 + if self.enable_caching: + self._touch(computed_blocks) + + # If a computed block of a request is an eviction candidate (in the + # free queue and ref_cnt == 0), it cannot be counted as a free block + # when allocating this request. + num_evictable_computed_blocks = len( + [blk for blk in computed_blocks if blk.ref_cnt == 0]) + else: + assert not computed_blocks, ( + "Computed blocks should be empty when " + "prefix caching is disabled") num_required_blocks = cdiv(num_tokens, self.block_size) if (num_required_blocks > self.free_block_queue.num_free_blocks - @@ -201,35 +211,28 @@ def allocate_slots( num_new_blocks = min( num_required_blocks + self.num_preallocate_blocks, self.free_block_queue.num_free_blocks - - num_evictable_computed_blocks) - - num_computed_tokens = len(computed_blocks) * self.block_size + num_evictable_computed_blocks, + ) - # When caching is enabled, get the new token IDs and the parent block - # ID to generate cache keys. - new_token_ids = None - parent_block = None - if self.enable_caching: - # Touch the computed blocks to make sure they won't be evicted. - self._touch(computed_blocks) + # Concatenate the computed block IDs and the new block IDs. + new_blocks = self._get_new_blocks(num_new_blocks) + self.req_to_blocks[request.request_id] = computed_blocks + new_blocks - # Get the token IDs for the blocks being allocated for hashing. - new_token_ids = request.all_token_ids[ - num_computed_tokens:num_computed_tokens + num_tokens] - if not new_token_ids: - raise RuntimeError( - "Failed to infer the token IDs for allocation. " - f"#all_tokens={len(request.all_token_ids)} < " - f"#computed_tokens={num_computed_tokens}") + if not self.enable_caching: + return new_blocks - # Get the parent block ID to construct the block chain. - parent_block = computed_blocks[-1] if computed_blocks else None + num_computed_tokens = len(computed_blocks) * self.block_size + num_full_blocks = (num_computed_tokens + num_tokens) // self.block_size - new_blocks = self._get_new_blocks(num_new_blocks, new_token_ids, - parent_block) + self._cache_full_blocks( + request=request, + blk_start_idx=len(computed_blocks), + # The new full blocks are the full blocks that are not computed. + full_blocks=self.req_to_blocks[request.request_id] + [len(computed_blocks):num_full_blocks], + prev_block=computed_blocks[-1] if computed_blocks else None, + ) - # Concatenate the computed block IDs and the new block IDs. - self.req_to_blocks[request.request_id] = computed_blocks + new_blocks return new_blocks def free(self, request: Request) -> None: @@ -248,24 +251,17 @@ def free(self, request: Request) -> None: blocks = reversed(blocks) for block in blocks: - block.ref_cnt -= 1 + block.decr_ref() if block.ref_cnt == 0: self.free_block_queue.append(block) - def _get_new_blocks( - self, - num_blocks: int, - token_ids: Optional[List[int]] = None, - parent_block: Optional[int] = None) -> List[KVCacheBlock]: - """Get new blocks from the free block pool, and add token IDs to - allocated blocks if caching is enabled. + def _get_new_blocks(self, num_blocks: int) -> List[KVCacheBlock]: + """Get new blocks from the free block pool. + Note that we do not check block cache in this function. Args: num_blocks: The number of blocks to allocate. - token_ids: The token IDs in the blocks. None if caching is disabled. - parent_block: The parent block. Used to include block chain - in the block hash. Returns: A list of new block. @@ -274,56 +270,38 @@ def _get_new_blocks( raise ValueError( f"Cannot get {num_blocks} free blocks from the pool") - # First allocate blocks. ret: List[KVCacheBlock] = [] idx = 0 while idx < num_blocks: + # First allocate blocks. curr_block = self.free_block_queue.popleft() assert curr_block.ref_cnt == 0 - # Evict blocks from the cache. + # If the block is cached, evict it. if self.enable_caching: - block_hash = curr_block.block_hash - if (block_hash is not None - and block_hash in self.cached_block_hash_to_block): - if len(self.cached_block_hash_to_block[block_hash]) == 1: - del self.cached_block_hash_to_block[block_hash] - else: - del self.cached_block_hash_to_block[block_hash][ - curr_block.block_id] - curr_block.reset() - - curr_block.ref_cnt = 1 + self._evict_cached_block(curr_block) + + curr_block.incr_ref() ret.append(curr_block) idx += 1 - # Then assign token IDs to the allocated blocks. - if self.enable_caching: - assert token_ids is not None - token_id_idx = self._add_token_ids_to_blocks( - blocks=ret, token_ids=token_ids, parent_block=parent_block) - assert token_id_idx == len(token_ids) - return ret - def _cache_full_block(self, - block: KVCacheBlock, - parent_block: Optional[KVCacheBlock] = None) -> None: - """Cache a full block for prefix caching. + def _evict_cached_block(self, block: KVCacheBlock) -> None: + """ + If a block is cached in `cached_block_hash_to_block`, we reset its hash + metadata and evict it from the cache. Args: - block: The block to cache. - parent_block: The parent block. None if this is the first block. + block: The block to evict. """ - parent_block_hash = (parent_block.block_hash - if parent_block is not None else None) - assert len(block.token_ids) == self.block_size - block.token_ids = tuple(block.token_ids) - block_hash = hash_block_tokens(parent_block_hash, block.token_ids) - block.block_hash = block_hash - block.num_hashed_tokens = self.block_size + ( - parent_block.num_hashed_tokens if parent_block is not None else 0) - self.cached_block_hash_to_block[block_hash][block.block_id] = block + block_hash = block.block_hash + if block_hash and block_hash in self.cached_block_hash_to_block: + block.reset_hash() + del self.cached_block_hash_to_block[block_hash][block.block_id] + + if len(self.cached_block_hash_to_block[block_hash]) == 0: + del self.cached_block_hash_to_block[block_hash] def _get_cached_block(self, block_hash: BlockHashType) -> Optional[KVCacheBlock]: @@ -355,43 +333,50 @@ def _touch(self, blocks: List[KVCacheBlock]) -> None: # candidate), so remove it. if block.ref_cnt == 0: self.free_block_queue.remove(block) - block.ref_cnt += 1 - - def _add_token_ids_to_blocks( - self, - blocks: List[KVCacheBlock], - token_ids: List[int], - parent_block: Optional[KVCacheBlock] = None) -> int: - """Add token IDs to a list of allocated blocks. - If a block becomes full after adding token IDs, cache it. - Return the token ID index that has not been added to the blocks - if the blocks are not enough to hold all the token IDs. + block.incr_ref() - Args: - blocks: A list of blocks to add token IDs. - token_ids: A list of token IDs to add. - parent_block: The parent block. None if this is the - first block. + def _cache_full_blocks( + self, + request: Request, + blk_start_idx: int, + full_blocks: List[KVCacheBlock], + prev_block: Optional[KVCacheBlock], + ) -> None: + """Cache a list of full blocks for prefix caching. - Returns: - The starting token ID index that has not been added to the blocks - due to insufficient given blocks. + This function takes a list of blocks that will have their block hash + metadata to be updated and cached. Given a request, it computes the + block hashes for the blocks starting from `blk_start_idx` to the end + of the request's full blocks, updating the metadata for each block + and caching them in the `cached_block_hash_to_block`. + + Args: + request: The request to cache the blocks. + blk_start_idx: The index of the first block in the request's blocks + to cache. + full_blocks: The list of blocks to update hash metadata. + prev_block: The previous block in the chain. """ - token_id_start = 0 - for curr_block in blocks: - # If all token IDs are added, then the rest of the blocks are - # preallocated blocks, so we only need to update the - # parent_block_id. FIXME - if token_id_start == len(token_ids): - continue - - # Add token IDs to the empty slots in the block. - empty_slots = self.block_size - len(curr_block.token_ids) - token_id_end = min(token_id_start + empty_slots, len(token_ids)) - curr_block.token_ids.extend(token_ids[token_id_start:token_id_end]) - # Cache the block if it becomes full. - if len(curr_block.token_ids) == self.block_size: - self._cache_full_block(curr_block, parent_block) - parent_block = curr_block - token_id_start = token_id_end - return token_id_start + # Update the new blocks with the block hashes through the chain. + prev_block_hash = (prev_block.block_hash + if prev_block is not None else None) + for i, blk in enumerate(full_blocks): + blk_idx = blk_start_idx + i + + block_tokens = request.all_token_ids[blk_idx * + self.block_size:(blk_idx + + 1) * + self.block_size] + assert len(block_tokens) == self.block_size, ( + f"Expected {self.block_size} tokens, got {len(block_tokens)} " + f"at {blk_idx}th block for request " + f"{request.request_id}({request})") + + # Compute the hash of the current block. + block_hash = hash_block_tokens(prev_block_hash, + tuple(block_tokens)) + + # Update and added the full block to the cache. + blk.block_hash = block_hash + self.cached_block_hash_to_block[block_hash][blk.block_id] = blk + prev_block_hash = block_hash diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 33dbfb7377bfd..fb666c364bfb2 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -1,6 +1,6 @@ """KV-Cache Utilities.""" -from dataclasses import dataclass, field -from typing import List, Optional, Tuple, Union +from dataclasses import dataclass +from typing import List, Optional, Tuple from vllm.logger import init_logger @@ -16,27 +16,34 @@ class KVCacheBlock: block_id: int # Reference count. ref_cnt: int = 0 - # Token IDs in the block. When the block is full, the type of token_ids - # should be Tuple[int] for fast matching. - token_ids: Union[List[int], Tuple[int]] = field(default_factory=list) # The hash of the block composed of (block hash, tuple of token IDs). # It is only available when the block is full. - block_hash: Optional[BlockHashType] = None - # The number of hashed tokens. More hashed tokens means the block - # is closer to the end of a prompt and more likely to be evicted. - num_hashed_tokens: int = 0 + _block_hash: Optional[BlockHashType] = None # Used to construct a doubly linked list for free blocks. # These two attributes should only be manipulated by FreeKVCacheBlockQueue. prev_free_block: Optional["KVCacheBlock"] = None next_free_block: Optional["KVCacheBlock"] = None - def reset(self): - """Reset the block metadata.""" - self.ref_cnt = 0 - self.token_ids = [] - self.block_hash = None - self.num_hashed_tokens = 0 + def incr_ref(self): + self.ref_cnt += 1 + + def decr_ref(self): + self.ref_cnt -= 1 + + @property + def block_hash(self) -> Optional[BlockHashType]: + return self._block_hash + + @block_hash.setter + def block_hash(self, block_hash: BlockHashType): + assert self.block_hash is None, ( + "The block already has a hash. This should not happen.") + self._block_hash = block_hash + + def reset_hash(self): + """Reset the block hash when the block is evicted.""" + self._block_hash = None class FreeKVCacheBlockQueue: diff --git a/vllm/v1/engine/__init__.py b/vllm/v1/engine/__init__.py index edfb8bd7c2fc1..967124fd850ea 100644 --- a/vllm/v1/engine/__init__.py +++ b/vllm/v1/engine/__init__.py @@ -68,6 +68,11 @@ class EngineCoreOutputs(msgspec.Struct, outputs: List[EngineCoreOutput] +@dataclass +class EngineCoreProfile: + is_start: bool + + class EngineCoreRequestType(enum.Enum): """ Request types defined as hex byte strings, so it can be sent over sockets @@ -75,3 +80,4 @@ class EngineCoreRequestType(enum.Enum): """ ADD = b'\x00' ABORT = b'\x01' + PROFILE = b'\x02' diff --git a/vllm/v1/engine/async_llm.py b/vllm/v1/engine/async_llm.py index b5428bc82f742..a17c8eac4b77c 100644 --- a/vllm/v1/engine/async_llm.py +++ b/vllm/v1/engine/async_llm.py @@ -346,10 +346,10 @@ async def check_health(self) -> None: logger.debug("Called check_health.") async def start_profile(self) -> None: - raise ValueError("Not supported on V1 yet.") + await self.engine_core.profile(True) async def stop_profile(self) -> None: - raise ValueError("Not supported on V1 yet.") + await self.engine_core.profile(False) @property def is_running(self) -> bool: diff --git a/vllm/v1/engine/core.py b/vllm/v1/engine/core.py index 495c4e3222649..34f99dd30ef2e 100644 --- a/vllm/v1/engine/core.py +++ b/vllm/v1/engine/core.py @@ -1,4 +1,5 @@ import multiprocessing +import pickle import queue import threading import time @@ -16,7 +17,8 @@ from vllm.usage.usage_lib import UsageContext from vllm.v1.core.scheduler import Scheduler from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, - EngineCoreRequest, EngineCoreRequestType) + EngineCoreProfile, EngineCoreRequest, + EngineCoreRequestType) from vllm.v1.engine.mm_input_mapper import MMInputMapper from vllm.v1.executor.gpu_executor import GPUExecutor from vllm.v1.request import Request, RequestStatus @@ -113,6 +115,9 @@ def step(self) -> List[EngineCoreOutput]: scheduler_output, output) return engine_core_outputs + def profile(self, is_start=True): + self.model_executor.worker.profile(is_start) + class EngineCoreProc(EngineCore): """ZMQ-wrapper for running EngineCore in background process.""" @@ -299,11 +304,14 @@ def _log_stats(self): self._last_logging_time = now def _handle_client_request( - self, request: Union[EngineCoreRequest, List[str]]) -> None: + self, request: Union[EngineCoreRequest, EngineCoreProfile, + List[str]]) -> None: """Handle EngineCoreRequest or EngineCoreABORT from Client.""" if isinstance(request, EngineCoreRequest): self.add_request(request) + elif isinstance(request, EngineCoreProfile): + self.model_executor.worker.profile(request.is_start) else: # TODO: make an EngineCoreAbort wrapper assert isinstance(request, list) @@ -328,6 +336,8 @@ def process_input_socket(self, input_path: str): request = decoder_add_req.decode(request_data) elif request_type == EngineCoreRequestType.ABORT.value: request = decoder_abort_req.decode(request_data) + elif request_type == EngineCoreRequestType.PROFILE.value: + request = pickle.loads(request_data) else: raise ValueError(f"Unknown RequestType: {request_type}") diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index 09801e20e16ca..835963f7ee86c 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -9,7 +9,8 @@ from vllm.logger import init_logger from vllm.utils import get_open_zmq_ipc_path from vllm.v1.engine import (EngineCoreOutput, EngineCoreOutputs, - EngineCoreRequest, EngineCoreRequestType) + EngineCoreProfile, EngineCoreRequest, + EngineCoreRequestType) from vllm.v1.engine.core import EngineCore, EngineCoreProc from vllm.v1.serial_utils import PickleEncoder @@ -58,6 +59,9 @@ def get_output(self) -> List[EngineCoreOutput]: def add_request(self, request: EngineCoreRequest) -> None: raise NotImplementedError + async def profile(self, is_start=True) -> None: + raise NotImplementedError + def abort_requests(self, request_ids: List[str]) -> None: raise NotImplementedError @@ -95,6 +99,9 @@ def add_request(self, request: EngineCoreRequest) -> None: def abort_requests(self, request_ids: List[str]) -> None: self.engine_core.abort_requests(request_ids) + async def profile(self, is_start=True) -> None: + self.engine_core.profile(is_start) + class MPClient(EngineCoreClient): """ @@ -177,8 +184,10 @@ def get_output(self) -> List[EngineCoreOutput]: engine_core_outputs = self.decoder.decode(frame.buffer).outputs return engine_core_outputs - def _send_input(self, request_type: EngineCoreRequestType, - request: Union[EngineCoreRequest, List[str]]) -> None: + def _send_input( + self, request_type: EngineCoreRequestType, + request: Union[EngineCoreRequest, EngineCoreProfile, + List[str]]) -> None: # (RequestType, SerializedRequest) msg = (request_type.value, self.encoder.encode(request)) @@ -190,6 +199,10 @@ def add_request(self, request: EngineCoreRequest) -> None: def abort_requests(self, request_ids: List[str]) -> None: self._send_input(EngineCoreRequestType.ABORT, request_ids) + async def profile(self, is_start=True) -> None: + self._send_input(EngineCoreRequestType.PROFILE, + EngineCoreProfile(is_start)) + class AsyncMPClient(MPClient): """Asyncio-compatible client for multi-proc EngineCore.""" @@ -205,8 +218,9 @@ async def get_output_async(self) -> List[EngineCoreOutput]: return engine_core_outputs async def _send_input( - self, request_type: EngineCoreRequestType, - request: Union[EngineCoreRequest, List[str]]) -> None: + self, request_type: EngineCoreRequestType, + request: Union[EngineCoreRequest, EngineCoreProfile, + List[str]]) -> None: msg = (request_type.value, self.encoder.encode(request)) await self.input_socket.send_multipart(msg, copy=False) @@ -217,3 +231,7 @@ async def add_request_async(self, request: EngineCoreRequest) -> None: async def abort_requests_async(self, request_ids: List[str]) -> None: if len(request_ids) > 0: await self._send_input(EngineCoreRequestType.ABORT, request_ids) + + async def profile(self, is_start=True) -> None: + await self._send_input(EngineCoreRequestType.PROFILE, + EngineCoreProfile(is_start)) diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1f9b544637bf7..02f9498142bb7 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1,3 +1,4 @@ +import gc import time from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Set, Tuple @@ -446,7 +447,7 @@ def execute_model( # Run the decoder. # Use persistent buffers for CUDA graphs. - with set_forward_context(attn_metadata): + with set_forward_context(attn_metadata, self.vllm_config): hidden_states = self.model( input_ids=None, positions=self.positions[:num_input_tokens], @@ -515,7 +516,25 @@ def load_model(self) -> None: logger.info("Loading model weights took %.4f GB", self.model_memory_usage / float(2**30)) - def _dummy_run(self, model: nn.Module, num_tokens: int) -> None: + @torch.inference_mode() + def _dummy_run( + self, + model: nn.Module, + num_tokens: int, + kv_caches: List[torch.Tensor], + ) -> torch.Tensor: + with set_forward_context(None, self.vllm_config): + hidden_states = model( + input_ids=None, + positions=self.positions[:num_tokens], + kv_caches=kv_caches, + attn_metadata=None, + inputs_embeds=self.inputs_embeds[:num_tokens]) + return hidden_states + + def profile_run(self) -> None: + # TODO(woosuk): Profile the max memory usage of the encoder and + # the encoder cache. # use an empty tensor instead of `None`` to force Dynamo to pass # it by reference, rather by specializing on the value `None`. # the `dtype` argument does not matter, and we use `float32` as @@ -527,45 +546,32 @@ def _dummy_run(self, model: nn.Module, num_tokens: int) -> None: torch.tensor([], dtype=torch.float32, device=self.device) for _ in range(self.num_attn_layers) ] - with set_forward_context(None): # noqa: SIM117 - with set_compile_context(self.cudagraph_batch_sizes): - # Trigger compilation for general shape. - model(input_ids=None, - positions=self.positions, - kv_caches=dummy_kv_caches, - attn_metadata=None, - inputs_embeds=self.inputs_embeds) - - @torch.inference_mode() - def profile_run(self) -> None: - # TODO(woosuk): Profile the max memory usage of the encoder and - # the encoder cache. - self._dummy_run(self.model, self.max_num_tokens) + with set_compile_context(self.cudagraph_batch_sizes): + # Trigger compilation for general shape. + hidden_states = self._dummy_run(self.model, self.max_num_tokens, + dummy_kv_caches) + logits = self.model.compute_logits(hidden_states, None) + logits = logits[:self.max_num_tokens] + # TODO(woosuk): Consider the memory usage of the sampler. torch.cuda.synchronize() + del hidden_states, logits + gc.collect() - @torch.inference_mode() def capture_model(self) -> None: if not self.use_cuda_graph: logger.warning( "Skipping CUDA graph capture. Please add " - "-O 3 to use CUDA graphs.", CompilationLevel.PIECEWISE) + "-O %s to use CUDA graphs.", CompilationLevel.PIECEWISE) return start_time = time.perf_counter() start_free_gpu_memory = torch.cuda.mem_get_info()[0] - with set_forward_context(None): - # Trigger CUDA graph capture for specific shapes. - # Capture the large shapes first so that the smaller shapes - # can reuse the memory pool allocated for the large shapes. - for num_tokens in reversed(self.cudagraph_batch_sizes): - self.model( - input_ids=None, - positions=self.positions[:num_tokens], - kv_caches=self.kv_caches, - attn_metadata=None, - inputs_embeds=self.inputs_embeds[:num_tokens], - ) + # Trigger CUDA graph capture for specific shapes. + # Capture the large shapes first so that the smaller shapes + # can reuse the memory pool allocated for the large shapes. + for num_tokens in reversed(self.cudagraph_batch_sizes): + self._dummy_run(self.model, num_tokens, self.kv_caches) end_time = time.perf_counter() end_free_gpu_memory = torch.cuda.mem_get_info()[0] diff --git a/vllm/v1/worker/gpu_worker.py b/vllm/v1/worker/gpu_worker.py index c8192b7f86eb0..d33b55a8a9f9a 100644 --- a/vllm/v1/worker/gpu_worker.py +++ b/vllm/v1/worker/gpu_worker.py @@ -6,6 +6,7 @@ import torch import torch.distributed +import vllm.envs as envs from vllm.config import CacheConfig, ModelConfig, ParallelConfig, VllmConfig from vllm.distributed import (ensure_model_parallel_initialized, init_distributed_environment, @@ -56,6 +57,22 @@ def __init__( init_cached_hf_modules() self.model_runner = GPUModelRunner(vllm_config) + # Torch profiler. Enabled and configured through env vars: + # VLLM_TORCH_PROFILER_DIR=/path/to/save/trace + if envs.VLLM_TORCH_PROFILER_DIR: + torch_profiler_trace_dir = envs.VLLM_TORCH_PROFILER_DIR + logger.info("Profiling enabled. Traces will be saved to: %s", + torch_profiler_trace_dir) + self.profiler = torch.profiler.profile( + activities=[ + torch.profiler.ProfilerActivity.CPU, + torch.profiler.ProfilerActivity.CUDA, + ], + with_stack=True, + on_trace_ready=torch.profiler.tensorboard_trace_handler( + torch_profiler_trace_dir, use_gzip=True)) + else: + self.profiler = None def initialize(self): if self.device_config.device.type == "cuda": @@ -105,35 +122,48 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: # Profile the memory usage of the model and get the maximum number of # cache blocks that can be allocated with the remaining free memory. torch.cuda.empty_cache() + torch.cuda.reset_peak_memory_stats() + _, total_gpu_memory = torch.cuda.mem_get_info() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. self.model_runner.profile_run() - - # Calculate the number of blocks that can be allocated with the - # profiled peak memory. torch.cuda.synchronize() - free_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() + + free_gpu_memory, _ = torch.cuda.mem_get_info() # NOTE(woosuk): Here we assume that the other processes using the same # GPU did not change their memory usage during the profiling. - peak_memory = self.init_gpu_memory - free_gpu_memory - assert peak_memory > 0, ( + assert self.init_gpu_memory > free_gpu_memory, ( "Error in memory profiling. " f"Initial free memory {self.init_gpu_memory}, current free memory" f" {free_gpu_memory}. This happens when the GPU memory was " "not properly cleaned up before initializing the vLLM instance.") + # Get the peak memory allocation recorded by torch + peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"] + + # Check for any memory left around that may have been allocated on the + # gpu outside of `torch`. NCCL operations, for example, can use a few + # GB during a forward pass + torch.cuda.empty_cache() + torch_allocated_bytes = torch.cuda.memory_stats( + )["allocated_bytes.all.current"] + total_allocated_bytes = torch.cuda.mem_get_info( + )[1] - torch.cuda.mem_get_info()[0] + non_torch_allocations = total_allocated_bytes - torch_allocated_bytes + if non_torch_allocations > 0: + peak_memory += non_torch_allocations + available_kv_cache_memory = ( + total_gpu_memory * self.cache_config.gpu_memory_utilization - + peak_memory) + + # Calculate the number of blocks that can be allocated with the + # profiled peak memory. cache_block_size = _get_cache_block_size(self.cache_config, self.model_config, self.parallel_config) - num_gpu_blocks = int( - (total_gpu_memory * self.cache_config.gpu_memory_utilization - - peak_memory) // cache_block_size) + num_gpu_blocks = int(available_kv_cache_memory // cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) - # if self.model_runner.lora_manager: - # self.model_runner.remove_all_loras() - gc.collect() - torch.cuda.empty_cache() return num_gpu_blocks, 0 def initialize_cache(self, num_gpu_blocks: int) -> None: @@ -171,6 +201,14 @@ def execute_model( # TODO(woosuk): Send the output to the engine process. return output + def profile(self, is_start=True): + if self.profiler is None: + raise RuntimeError("Profiler is not enabled.") + if is_start: + self.profiler.start() + else: + self.profiler.stop() + def init_worker_distributed_environment( parallel_config: ParallelConfig, diff --git a/vllm/worker/cpu_embedding_model_runner.py b/vllm/worker/cpu_embedding_model_runner.py index d0b8fec48d74f..3954e4c4c8a5b 100644 --- a/vllm/worker/cpu_embedding_model_runner.py +++ b/vllm/worker/cpu_embedding_model_runner.py @@ -3,6 +3,7 @@ import torch +from vllm.forward_context import set_forward_context from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MultiModalKwargs from vllm.pooling_params import PoolingParams @@ -49,6 +50,9 @@ def execute_model( ] model_executable = self.model + cross_enc_kwargs = {} + if model_input.token_type_ids is not None: + cross_enc_kwargs["token_type_ids"] = model_input.token_type_ids execute_model_kwargs = { "input_ids": model_input.input_tokens, @@ -60,11 +64,13 @@ def execute_model( model_input.attn_metadata, **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, device=self.device), + **cross_enc_kwargs, "intermediate_tensors": intermediate_tensors, } - hidden_states = model_executable(**execute_model_kwargs) + with set_forward_context(model_input.attn_metadata, self.vllm_config): + hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. if not self.is_driver_worker: diff --git a/vllm/worker/cpu_enc_dec_model_runner.py b/vllm/worker/cpu_enc_dec_model_runner.py index d040831870bd8..cc24cfe04d2ba 100644 --- a/vllm/worker/cpu_enc_dec_model_runner.py +++ b/vllm/worker/cpu_enc_dec_model_runner.py @@ -4,6 +4,7 @@ import torch from vllm.attention import AttentionMetadata +from vllm.forward_context import set_forward_context from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.sampler import SamplerOutput from vllm.multimodal import MultiModalKwargs @@ -34,6 +35,7 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: "input_positions": self.input_positions, "encoder_input_tokens": self.encoder_input_tokens, "encoder_input_positions": self.encoder_input_positions, + "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -303,7 +305,8 @@ def execute_model( intermediate_tensors, } - hidden_states = model_executable(**execute_model_kwargs) + with set_forward_context(model_input.attn_metadata, self.vllm_config): + hidden_states = model_executable(**execute_model_kwargs) # Compute the logits. logits = self.model.compute_logits(hidden_states, diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index d3e1202c15e61..b08171d79f002 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -2,14 +2,15 @@ import weakref from collections import defaultdict from dataclasses import dataclass -from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type, - TypeVar, Union) +from typing import (TYPE_CHECKING, Any, Dict, List, Optional, Type, TypeVar, + Union) import torch from torch import nn from vllm.attention import AttentionMetadata, get_attn_backend from vllm.config import VllmConfig +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding @@ -19,7 +20,6 @@ MultiModalKwargs, MultiModalPlaceholderMap) from vllm.sequence import (IntermediateTensors, SequenceData, SequenceGroupMetadata) -from vllm.utils import make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, ModelRunnerInputBuilderBase, _add_attn_metadata_broadcastable_dict, @@ -43,6 +43,7 @@ class ModelInputForCPU(ModelRunnerInputBase): """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None + token_type_ids: Optional[torch.Tensor] = None attn_metadata: Optional["AttentionMetadata"] = None multi_modal_kwargs: Optional[BatchedTensorInputs] = None virtual_engine: Optional[int] = None @@ -54,6 +55,7 @@ def as_broadcastable_tensor_dict( tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, + "token_type_ids": self.token_type_ids, "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) @@ -83,6 +85,8 @@ def as_broadcastable_tensor_dict(self) -> Dict[str, Any]: tensor_dict = { "input_tokens": self.input_tokens, "input_positions": self.input_positions, + "token_type_ids": self.token_type_ids, + "multi_modal_kwargs": self.multi_modal_kwargs, } _add_attn_metadata_broadcastable_dict(tensor_dict, self.attn_metadata) _add_sampling_metadata_broadcastable_dict(tensor_dict, @@ -104,65 +108,233 @@ def from_broadcasted_tensor_dict( class ModelInputForCPUBuilder(ModelRunnerInputBuilderBase[ModelInputForCPU]): + class ModelInputData: + + def __init__(self, use_mrope: bool): + self.use_mrope = use_mrope + self.input_tokens: List[int] = [] + self.input_positions: Optional[ + List[int]] = [] if not self.use_mrope else None + self.token_type_ids: Optional[List[int]] = [] + self.seq_lens: List[int] = [] + self.query_lens: List[int] = [] + self.prefill_block_tables: List[List[int]] = [] + self.decode_block_tables: List[List[int]] = [] + self.max_decode_seq_len: int = 0 + self.num_prefills: int = 0 + self.num_prefill_tokens: int = 0 + self.num_decode_tokens: int = 0 + self.slot_mapping: List[int] = [] + self.multi_modal_inputs_list: List[MultiModalKwargs] = [] + self.multi_modal_placeholder_maps: Dict[ + str, MultiModalPlaceholderMap] = defaultdict( + MultiModalPlaceholderMap) + self.input_mrope_positions: Optional[List[List[int]]] = [ + [] for _ in range(3) + ] if self.use_mrope else None + def __init__(self, runner: "CPUModelRunner", finished_requests_ids: Optional[List[str]] = None) -> None: super().__init__() self.seq_group_metadata_list: List[SequenceGroupMetadata] = [] self.runner = runner + + self.chunked_prefill = (runner.scheduler_config.chunked_prefill_enabled + or runner.cache_config.enable_prefix_caching) self.model_input_cls = self.runner._model_input_cls self.attn_backend = self.runner.attn_backend - self.sliding_window = self.runner.sliding_window - self.block_size = self.runner.block_size - self.device = self.runner.device self.multi_modal_input_mapper = self.runner.multi_modal_input_mapper + self.input_data = ModelInputForCPUBuilder.ModelInputData( + self.runner.model_config.uses_mrope) + self.att_metadata_builder = self.runner.attn_backend.get_builder_cls()( + self) def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): self.seq_group_metadata_list.append(seq_group_metadata) + def set_seq_group_list( + self, seq_group_metadata_list: List[SequenceGroupMetadata]): + self.seq_group_metadata_list = seq_group_metadata_list + def build(self) -> ModelInputForCPU: + self._build_input_data() + + input_data = self.input_data + input_tokens = torch.tensor(input_data.input_tokens, + dtype=torch.long, + device="cpu") + input_positions = torch.tensor( + input_data.input_positions + if not input_data.use_mrope else input_data.input_mrope_positions, + dtype=torch.long, + device="cpu") + token_type_ids = torch.tensor(input_data.token_type_ids, + dtype=torch.long, + device="cpu") \ + if input_data.token_type_ids else None + + # For multi-modal models multi_modal_kwargs = None - # NOTE: We assume that all sequences in the group are all prompts or - # all decodes. - is_prompt = self.seq_group_metadata_list[0].is_prompt - # Prepare input tensors. - if is_prompt: - (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs) = self._prepare_prompt( - self.seq_group_metadata_list) - else: - (input_tokens, input_positions, - attn_metadata) = self._prepare_decode( - self.seq_group_metadata_list) - seq_lens = None + if len(input_data.multi_modal_inputs_list) != 0: + multi_modal_kwargs = MultiModalKwargs.batch( + input_data.multi_modal_inputs_list) + + attn_metadata = self.att_metadata_builder.build( + input_data.seq_lens, input_data.query_lens, -1, -1) return self.model_input_cls( input_tokens=input_tokens, input_positions=input_positions, + token_type_ids=token_type_ids, + seq_lens=input_data.seq_lens, + query_lens=input_data.query_lens, attn_metadata=attn_metadata, multi_modal_kwargs=multi_modal_kwargs, - # query_lens is not needed if chunked prefill is not - # supported. Since CPU worker doesn't support chunked prefill - # just use seq_lens instead. - seq_lens=seq_lens, - query_lens=seq_lens, ) - def _compute_multi_modal_input( - self, - seq_data: SequenceData, - computed_len: int, - seq_group_metadata: SequenceGroupMetadata, - ): + def _build_input_data(self): + for seq_group_metadata in self.seq_group_metadata_list: + for seq_id, seq_data in seq_group_metadata.seq_data.items(): + if seq_group_metadata.is_prompt: + self._compute_prompt_input_tokens(self.input_data, + seq_group_metadata, + seq_data, seq_id) + if seq_group_metadata.multi_modal_data: + self._compute_multi_modal_input( + seq_group_metadata, seq_data) + else: + self._compute_decode_input_tokens(self.input_data, + seq_group_metadata, + seq_data, seq_id) + + def _compute_decode_input_tokens(self, data: ModelInputData, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData, seq_id: int): + """ + Compute decode input tokens, positions, block table and slot mapping. + """ + block_size = self.runner.block_size + + block_table = seq_group_metadata.block_tables[seq_id] + seq_len = seq_data.get_len() + context_len = seq_data.get_num_computed_tokens() + + tokens = seq_data.get_last_token_id() + token_positions = seq_len - 1 + block_number = block_table[token_positions // block_size] + block_offset = token_positions % block_size + slot = block_number * block_size + block_offset + + # For paged_attention kernel + if self.runner.sliding_window: + start_idx = max(0, seq_len - self.runner.sliding_window) + start_block = start_idx // block_size + start_idx = start_block * block_size + seq_len = seq_len - start_idx + block_table = block_table[start_block:] + + # For MRotaryEmbedding + if data.input_positions is None: + next_pos = MRotaryEmbedding.get_next_input_positions( + seq_data.mrope_position_delta, + context_len, + seq_len, + ) + for idx in range(3): + data.input_mrope_positions[idx].extend( # type: ignore + next_pos[idx]) + else: + data.input_positions.append(token_positions) # type: ignore + + # Update fields + data.input_tokens.append(tokens) + data.max_decode_seq_len = max(data.max_decode_seq_len, seq_len) + data.num_decode_tokens += 1 + data.slot_mapping.append(slot) + data.decode_block_tables.append(block_table) + data.query_lens.append(1) + data.seq_lens.append(seq_len) + + def _compute_prompt_input_tokens(self, data: ModelInputData, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData, seq_id: int): + """ + Compute prompt input tokens, positions, block table and slot mapping. + """ + token_chunk_size = seq_group_metadata.token_chunk_size + block_size = self.runner.block_size + + block_table = seq_group_metadata.block_tables[seq_id] + seq_len = seq_data.get_len() + context_len = seq_data.get_num_computed_tokens() + seq_len = min(seq_len, context_len + token_chunk_size) + + # For prefix caching + prefix_cache_block_num = len(seq_group_metadata.computed_block_nums) + if prefix_cache_block_num > 0: + prefix_cache_len = (prefix_cache_block_num * + self.runner.block_size) + if prefix_cache_len <= context_len: + # We already passed the cache hit region, + # so do normal computation. + pass + elif context_len < prefix_cache_len < seq_len: + # Partial hit. Compute the missing part. + context_len = prefix_cache_len + token_chunk_size = seq_len - context_len + elif seq_len <= prefix_cache_len: + # Full hit. Only compute the last token to avoid + # erroneous behavior. FIXME: Ideally we should directly + # mark all tokens as computed in the scheduler and do not + # schedule this sequence, so this case should not happen. + context_len = seq_len - 1 + token_chunk_size = 1 + + tokens = seq_data.get_token_ids() + tokens = tokens[context_len:seq_len] + token_positions = range(context_len, seq_len) + token_types = seq_group_metadata.token_type_ids + + # For encoder-only models, the block_table is None, + # and there is no need to initialize the slot_mapping. + if block_table is not None: + slot_mapping = [_PAD_SLOT_ID] * len(token_positions) + for i, pos in enumerate(token_positions): + block_number = block_table[pos // block_size] + block_offset = pos % block_size + slot = block_number * block_size + block_offset + slot_mapping[i] = slot + data.slot_mapping.extend(slot_mapping) + + # The MROPE positions are prepared in _compute_multi_modal_input + if data.input_positions is not None: + data.input_positions.extend(token_positions) + + if data.token_type_ids is not None: + data.token_type_ids.extend(token_types if token_types else []) + + # Update fields + data.input_tokens.extend(tokens) + data.num_prefills += 1 + data.num_prefill_tokens += len(tokens) + data.query_lens.append(len(tokens)) + data.prefill_block_tables.append(block_table) + data.seq_lens.append(seq_len) + + def _compute_multi_modal_input(self, + seq_group_metadata: SequenceGroupMetadata, + seq_data: SequenceData): + computed_len = seq_data.get_num_computed_tokens() + seq_len = self.input_data.seq_lens[-1] + # NOTE: mm_data only includes the subset of multi-modal items that # intersect with the current prefill positions. mm_data, placeholder_maps = MultiModalPlaceholderMap.from_seq_group( - seq_group_metadata, - range(computed_len, len(seq_data.get_token_ids())), - ) + seq_group_metadata, range(computed_len, seq_len)) if not mm_data: - return None, None, None + return if self.runner.mm_registry.has_processor(self.runner.model_config): mm_kwargs = mm_data @@ -173,8 +345,10 @@ def _compute_multi_modal_input( ) # special processing for mrope position deltas. - mrope_positions = None if self.runner.model_config.uses_mrope: + assert not self.chunked_prefill, \ + "MROPE on CPU does not support chunked-prefill." + image_grid_thw = mm_kwargs.get("image_grid_thw", None) video_grid_thw = mm_kwargs.get("video_grid_thw", None) assert image_grid_thw is not None or video_grid_thw is not None, ( @@ -198,226 +372,15 @@ def _compute_multi_modal_input( context_len=computed_len, ) seq_data.mrope_position_delta = mrope_position_delta - return mm_kwargs, placeholder_maps, mrope_positions - - def _prepare_prompt( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata, List[int], - BatchedTensorInputs]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] - input_positions: List[int] = [] - input_mrope_positions: List[List[int]] = [[] for _ in range(3)] - - slot_mapping: List[int] = [] - seq_lens: List[int] = [] - multi_modal_kwargs_list: List[MultiModalKwargs] = [] - multi_modal_placeholder_maps: Dict[ - str, - MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap) - - for seq_group_metadata in seq_group_metadata_list: - assert seq_group_metadata.is_prompt - seq_ids = list(seq_group_metadata.seq_data.keys()) - assert len(seq_ids) == 1 - seq_id = seq_ids[0] - - seq_data = seq_group_metadata.seq_data[seq_id] - prompt_tokens = seq_data.get_token_ids() - computed_len = seq_data.get_num_computed_tokens() - seq_len = len(prompt_tokens) - - seq_lens.append(seq_len) # Prompt token num - input_tokens.extend(prompt_tokens) # Token ids - - mrope_positions = None - if seq_group_metadata.multi_modal_data: - ( - mm_kwargs, - placeholder_maps, - mrope_positions, - ) = self._compute_multi_modal_input(seq_data, computed_len, - seq_group_metadata) - - multi_modal_kwargs_list.append(mm_kwargs) - for modality, placeholder_map in placeholder_maps.items(): - multi_modal_placeholder_maps[modality].extend( - placeholder_map) - - # Token position ids - # NOTE(woosuk): Here we assume that the first token in the prompt - # is always the first token in the sequence. - if mrope_positions: - for idx in range(3): - input_mrope_positions[idx].extend(mrope_positions[idx]) - else: - input_positions.extend(list(range(computed_len, seq_len))) - - # Compute the slot mapping. - block_table = seq_group_metadata.block_tables[seq_id] - # Mask the [0, start_idx) tokens of the prompt with _PAD_SLOT_ID, - # where start_idx is max(0, seq_len - sliding_window). - # For example, if the prompt len is 10, sliding window is 8, and - # block size is 4, the first two tokens are masked and the slot - # mapping will be [-1, -1, 2, 3, 4, 5, 6, 7, 0, 1]. - start_idx = 0 - if self.sliding_window is not None: - start_idx = max(0, seq_len - self.sliding_window) - - for i in range(computed_len, seq_len): - if i < start_idx: - slot_mapping.append(_PAD_SLOT_ID) - continue - - # For encoder-only models, the block_table is None, - # and there is no need to initialize the slot_mapping. - if block_table is not None: - block_number = block_table[i // - self.block_size] # type: ignore - block_offset = i % self.block_size # type: ignore - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - if any(input_mrope_positions): - input_positions = None # type: ignore - else: - input_mrope_positions = None # type: ignore - num_prompt_tokens = len(input_tokens) + for i in range(3): + self.input_data.input_mrope_positions[ # type: ignore + i].extend(mrope_positions[i]) - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) # type: ignore - input_positions = torch.tensor(input_positions - or input_mrope_positions, - dtype=torch.long, - device=self.device) # type: ignore - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) # type: ignore - placeholder_index_maps = { - modality: placeholder_map.index_map() - for modality, placeholder_map in - multi_modal_placeholder_maps.items() - } - - attn_metadata = self.attn_backend.make_metadata( - is_prompt=True, - seq_lens=seq_lens, - seq_lens_tensor=torch.tensor([]), - max_decode_seq_len=0, - num_prefills=len(seq_lens), - num_prefill_tokens=num_prompt_tokens, - num_decode_tokens=0, - block_tables=torch.tensor([]), - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=placeholder_index_maps, - ) - - multi_modal_kwargs = MultiModalKwargs.batch(multi_modal_kwargs_list) - - return (input_tokens, input_positions, attn_metadata, seq_lens, - multi_modal_kwargs) - - def _prepare_decode( - self, - seq_group_metadata_list: List[SequenceGroupMetadata], - ) -> Tuple[torch.Tensor, torch.Tensor, AttentionMetadata]: - assert len(seq_group_metadata_list) > 0 - input_tokens: List[int] = [] - input_positions: List[int] = [] - input_mrope_positions: List[List[int]] = [[] for _ in range(3)] - slot_mapping: List[int] = [] - seq_lens: List[int] = [] - block_tables: List[List[int]] = [] - - for seq_group_metadata in seq_group_metadata_list: - assert not seq_group_metadata.is_prompt - assert seq_group_metadata.token_chunk_size == 1 - - seq_ids = list(seq_group_metadata.seq_data.keys()) - - for seq_id in seq_ids: - seq_data = seq_group_metadata.seq_data[seq_id] - generation_token = seq_data.get_last_token_id() - input_tokens.append(generation_token) - - seq_len = seq_data.get_len() - position = seq_len - 1 - if seq_data.mrope_position_delta is not None: - context_len = seq_data.get_num_computed_tokens() - next_pos = MRotaryEmbedding.get_next_input_positions( - seq_data.mrope_position_delta, - context_len, - seq_len, - ) - for idx in range(3): - input_mrope_positions[idx].extend(next_pos[idx]) - else: - input_positions.append(position) - - seq_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) - seq_lens.append(seq_len) - - block_table = seq_group_metadata.block_tables[seq_id] - block_number = block_table[position // self.block_size] - block_offset = position % self.block_size - slot = block_number * self.block_size + block_offset - slot_mapping.append(slot) - - if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) - block_table = block_table[-sliding_window_blocks:] - block_tables.append(block_table) - - if any(input_mrope_positions): - input_positions = None # type: ignore - else: - input_mrope_positions = None # type: ignore - - max_decode_seq_len = max(seq_lens) - - input_tokens = torch.tensor(input_tokens, - dtype=torch.long, - device=self.device) - input_positions = torch.tensor(input_positions - or input_mrope_positions, - dtype=torch.long, - device=self.device) - slot_mapping = torch.tensor(slot_mapping, - dtype=torch.long, - device=self.device) - seq_lens_tensor = torch.tensor(seq_lens, - dtype=torch.int, - device=self.device) - - block_tables = make_tensor_with_pad( - block_tables, - pad=0, - dtype=torch.int, - device=self.device, - ) - - attn_metadata = self.attn_backend.make_metadata( - is_prompt=False, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - seq_lens=seq_lens, - seq_lens_tensor=seq_lens_tensor, - max_decode_seq_len=max_decode_seq_len, - num_prefill_tokens=0, - num_decode_tokens=len(input_tokens), - num_prefills=0, - block_tables=block_tables, - ) - return ( - input_tokens, - input_positions, - attn_metadata, - ) + self.input_data.multi_modal_inputs_list.append(mm_kwargs) + for modality, placeholder_map in placeholder_maps.items(): + self.input_data.multi_modal_placeholder_maps[modality].extend( + placeholder_map) class CPUModelRunnerBase(ModelRunnerBase[TModelInputForCPU]): @@ -436,8 +399,6 @@ def __init__( **kwargs, ): ModelRunnerBase.__init__(self, vllm_config) - # Currently, CPU worker doesn't support chunked prefill. - assert self.scheduler_config.chunked_prefill_enabled is False model_config = self.model_config cache_config = self.cache_config @@ -479,8 +440,7 @@ def _prepare_model_input_tensors( """ builder = self._builder_cls(weakref.proxy(self), finished_requests_ids) - for seq_group_metadata in seq_group_metadata_list: - builder.add_seq_group(seq_group_metadata) + builder.set_seq_group_list(seq_group_metadata_list) return builder.build() # type: ignore @@ -537,22 +497,20 @@ def execute_model( "CPU worker does not support multi-step execution.") model_executable = self.model - execute_model_kwargs = { - "input_ids": - model_input.input_tokens, - "positions": - model_input.input_positions, - "kv_caches": - kv_caches, - "attn_metadata": - model_input.attn_metadata, - **MultiModalKwargs.as_kwargs(model_input.multi_modal_kwargs or {}, - device=self.device), - "intermediate_tensors": - intermediate_tensors, - } - - hidden_states = model_executable(**execute_model_kwargs) + multimodal_kwargs = {} + if model_input.multi_modal_kwargs is not None: + multimodal_kwargs = MultiModalKwargs.as_kwargs( + model_input.multi_modal_kwargs, device=self.device) + + with set_forward_context(model_input.attn_metadata, self.vllm_config): + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **multimodal_kwargs, + ) # Compute the logits. logits = self.model.compute_logits(hidden_states, diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 37cfcbf13d7a3..f56805918fd15 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -97,7 +97,11 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - with set_forward_context(model_input.attn_metadata): + cross_enc_kwargs = {} + if model_input.token_types is not None: + cross_enc_kwargs["token_type_ids"] = model_input.token_types + + with set_forward_context(model_input.attn_metadata, self.vllm_config): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, @@ -105,7 +109,8 @@ def execute_model( attn_metadata=model_input.attn_metadata, intermediate_tensors=intermediate_tensors, **MultiModalKwargs.as_kwargs(multi_modal_kwargs, - device=self.device)) + device=self.device), + **cross_enc_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 687d2cc79360f..ae18c79c980c8 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -176,7 +176,7 @@ def execute_model( } if self.has_inner_state else {} multi_modal_kwargs = model_input.multi_modal_kwargs or {} - with set_forward_context(model_input.attn_metadata): + with set_forward_context(model_input.attn_metadata, self.vllm_config): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index ed0360fb7f727..1f654a9cce465 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -92,6 +92,7 @@ class ModelInputForGPU(ModelRunnerInputBase): """ input_tokens: Optional[torch.Tensor] = None input_positions: Optional[torch.Tensor] = None + token_types: Optional[torch.Tensor] = None seq_lens: Optional[List[int]] = None query_lens: Optional[List[int]] = None lora_mapping: Optional["LoRAMapping"] = None @@ -200,6 +201,7 @@ class InterDataForSeqGroup: def simple_reinit(self): self.input_tokens[0].clear() # type: ignore self.input_positions[0].clear() # type: ignore + self.token_types[0].clear() # type: ignore self.mrope_input_positions = None # type: ignore self.seq_lens[0] = 0 # type: ignore self.orig_seq_lens[0] = 0 # type: ignore @@ -226,6 +228,7 @@ def __init__( # Input tokens and positions. input_tokens: Optional[List[List[int]]] = None, input_positions: Optional[List[List[int]]] = None, + token_types: Optional[List[List[int]]] = None, mrope_input_positions: Optional[List[List[List[int]]]] = None, # The sequence length (may be capped to the sliding window). @@ -291,6 +294,12 @@ def __init__( for seq_id in range(len(self.seq_ids)): self.input_positions[seq_id].clear() + if token_types: + self.token_types = token_types + else: + for seq_id in range(len(self.seq_ids)): + self.token_types[seq_id].clear() + self.mrope_input_positions = None if seq_lens: @@ -354,6 +363,7 @@ def __init__( else: self.input_tokens = input_tokens or [] self.input_positions = input_positions or [] + self.token_types = token_types or [] self.mrope_input_positions = mrope_input_positions or None self.seq_lens = seq_lens or [] self.orig_seq_lens = orig_seq_lens or [] @@ -386,6 +396,7 @@ def __post_init__(self): self.input_tokens = [[] for _ in range(self.n_seqs)] self.input_positions = [[] for _ in range(self.n_seqs)] + self.token_types = [[] for _ in range(self.n_seqs)] self.mrope_input_positions = None self.seq_lens = [0] * self.n_seqs self.orig_seq_lens = [0] * self.n_seqs @@ -498,12 +509,15 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute tokens. tokens = seq_data.get_token_ids()[context_len:seq_len] + token_types = seq_group_metadata.token_type_ids inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len inter_data.input_tokens[seq_idx].extend(tokens) inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) + inter_data.token_types[seq_idx].extend( + token_types if token_types else []) inter_data.query_lens[seq_idx] = seq_len - context_len if seq_data.mrope_position_delta is not None: @@ -561,6 +575,8 @@ def _compute_for_prefix_cache_hit( seq_idx][uncomputed_start:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ seq_idx][uncomputed_start:] + inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ + uncomputed_start:] context_len = prefix_cache_len inter_data.context_lens[seq_idx] = context_len @@ -575,6 +591,8 @@ def _compute_for_prefix_cache_hit( seq_idx][-1:] inter_data.input_positions[seq_idx] = inter_data.input_positions[ seq_idx][-1:] + inter_data.token_types[seq_idx] = inter_data.token_types[seq_idx][ + -1:] inter_data.query_lens[seq_idx] = 1 inter_data.context_lens[seq_idx] = inter_data.seq_lens[seq_idx] - 1 @@ -803,9 +821,12 @@ def build(self) -> ModelInputForGPU: """ # Combine and flatten intermediate data. input_tokens = [] + token_types = [] for inter_data in self.inter_data_list: for cur_input_tokens in inter_data.input_tokens: input_tokens.extend(cur_input_tokens) + for cur_token_types in inter_data.token_types: + token_types.extend(cur_token_types) if not input_tokens: # This may happen when all prefill requests hit @@ -874,6 +895,12 @@ def build(self) -> ModelInputForGPU: input_tokens_tensor = async_tensor_h2d(input_tokens, torch.long, self.runner.device, self.runner.pin_memory) + + token_types_tensor = async_tensor_h2d(token_types, torch.long, + self.runner.device, + self.runner.pin_memory) \ + if token_types else None + if mrope_input_positions is not None: for idx in range(3): mrope_input_positions[idx].extend( @@ -952,6 +979,7 @@ def build(self) -> ModelInputForGPU: return self.model_input_cls( input_tokens=input_tokens_tensor, input_positions=input_positions_tensor, + token_types=token_types_tensor, attn_metadata=attn_metadata, seq_lens=seq_lens, query_lens=query_lens, @@ -1503,7 +1531,7 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self._update_inputs_to_capture_for_enc_dec_model( capture_inputs) - with set_forward_context(attn_metadata): + with set_forward_context(attn_metadata, self.vllm_config): graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( @@ -1649,7 +1677,7 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - with set_forward_context(model_input.attn_metadata): + with set_forward_context(model_input.attn_metadata, self.vllm_config): hidden_or_intermediate_states = model_executable( input_ids=model_input.input_tokens, positions=model_input.input_positions, diff --git a/vllm/worker/tpu_model_runner.py b/vllm/worker/tpu_model_runner.py index d7a641857a613..9a054eb8a4cf7 100644 --- a/vllm/worker/tpu_model_runner.py +++ b/vllm/worker/tpu_model_runner.py @@ -1,3 +1,4 @@ +import enum import time from dataclasses import dataclass from typing import (TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, @@ -11,7 +12,6 @@ import torch_xla.runtime as xr from vllm.attention import AttentionMetadata, get_attn_backend -from vllm.compilation.wrapper import TorchCompileWrapperWithCustomDispatcher from vllm.config import VllmConfig from vllm.logger import init_logger from vllm.model_executor.layers.sampler import SamplerOutput @@ -39,6 +39,15 @@ _MAX_NUM_SAMPLES = 128 +class ExecutionMode(enum.Enum): + PREFILL = enum.auto() + DECODE = enum.auto() + PREFIX_PREFILL = enum.auto() + + def is_prefill(self) -> bool: + return self in (ExecutionMode.PREFILL, ExecutionMode.PREFIX_PREFILL) + + @dataclass(frozen=True) class ModelInputForTPU(ModelRunnerInputBase): token_ids: torch.Tensor @@ -140,16 +149,21 @@ def load_model(self) -> None: model = get_model(vllm_config=self.vllm_config) model = model.eval() xm.wait_device_ops() - self.model = ModelWrapper(model, self.vllm_config) + model = ModelWrapper(model) + self.model = torch.compile(model, + backend="openxla", + fullgraph=True, + dynamic=False) def _dummy_run( self, batch_size: int, seq_len: int, kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], - is_prompt: bool, + exec_mode: ExecutionMode, ) -> None: - if is_prompt: + exec_mode = ExecutionMode(exec_mode) + if exec_mode.is_prefill(): seq_len = (seq_len + 15) // 16 * 16 token_ids = torch.zeros((batch_size, seq_len), dtype=torch.int32, @@ -160,18 +174,38 @@ def _dummy_run( slot_mapping = torch.zeros((batch_size, seq_len), dtype=torch.int64, device=self.device) - attn_metadata = self.attn_backend.make_metadata( - num_prefills=batch_size, - num_prefill_tokens=batch_size * seq_len, - num_decode_tokens=0, - slot_mapping=slot_mapping, - multi_modal_placeholder_index_maps=None, - block_tables=None, - context_lens=None, - ) input_lens = torch.ones((batch_size, ), dtype=torch.int32, device=self.device) + if exec_mode == ExecutionMode.PREFILL: + attn_metadata = self.attn_backend.make_metadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=None, + context_lens=None, + effective_query_lens=None, + ) + else: + context_lens = torch.ones((batch_size, ), + dtype=torch.int32, + device=self.device) + block_tables = torch.tensor(self.block_tables[:batch_size], + dtype=torch.int32, + device=self.device) + effective_query_lens = torch.ones_like(context_lens) + attn_metadata = self.attn_backend.make_metadata( + num_prefills=batch_size, + num_prefill_tokens=batch_size * seq_len, + num_decode_tokens=0, + slot_mapping=slot_mapping, + multi_modal_placeholder_index_maps=None, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=effective_query_lens, + ) else: assert seq_len == 1 token_ids = torch.zeros((batch_size, seq_len), @@ -204,7 +238,7 @@ def _dummy_run( ) t = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) p = torch.ones((batch_size, ), dtype=torch.float32, device=self.device) - num_samples = _MAX_NUM_SAMPLES if is_prompt else 1 + num_samples = _MAX_NUM_SAMPLES if exec_mode.is_prefill() else 1 # NOTE(woosuk): There are two stages of compilation: torch.compile and # XLA compilation. Using `mark_dynamic` can reduce the torch.compile @@ -213,7 +247,7 @@ def _dummy_run( # be re-compiled for every different shapes. This overhead is inevitable # in the first run, but can be skipped afterwards as we cache the XLA # graphs in the disk (VLLM_XLA_CACHE_PATH). - if is_prompt: + if exec_mode.is_prefill(): # Prefll torch._dynamo.mark_dynamic(token_ids, 1) torch._dynamo.mark_dynamic(position_ids, 1) @@ -229,15 +263,8 @@ def _dummy_run( torch._dynamo.mark_dynamic(t, 0) torch._dynamo.mark_dynamic(p, 0) # Dummy run. - self.model(token_ids, - position_ids, - attn_metadata, - input_lens, - t, - p, - num_samples, - kv_caches, - is_prompt=is_prompt) + self.model(token_ids, position_ids, attn_metadata, input_lens, t, p, + num_samples, kv_caches) def warmup_model( self, @@ -248,13 +275,13 @@ def warmup_model( start = time.time() for batch_size in [1]: seq_len = 16 - while True: - self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=True) + while seq_len <= self.model_config.max_model_len: + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.PREFILL) xm.wait_device_ops() logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) - - if seq_len >= self.model_config.max_model_len: - break num_tokens = batch_size * seq_len if num_tokens >= self.scheduler_config.max_num_batched_tokens: break @@ -263,12 +290,39 @@ def warmup_model( end = time.time() logger.info("Compilation for prefill done in %.2f s.", end - start) + # Prefix prefill + if self.cache_config.enable_prefix_caching: + logger.info("Compiling the model with different input shapes for " + "prefix prefill...") + start = time.time() + for batch_size in [1]: + seq_len = 16 + while seq_len <= self.model_config.max_model_len: + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.PREFIX_PREFILL) + xm.wait_device_ops() + logger.info("batch_size: %d, seq_len: %d", batch_size, + seq_len) + num_tokens = batch_size * seq_len + if (num_tokens >= + self.scheduler_config.max_num_batched_tokens): + break + seq_len = seq_len * 2 + end = time.time() + logger.info("Compilation for prefix prefill done in %.2f s.", + end - start) + # Decode start = time.time() seq_len = 1 batch_size = 8 # Must be in sync with _get_padded_batch_size() while True: - self._dummy_run(batch_size, seq_len, kv_caches, is_prompt=False) + self._dummy_run(batch_size, + seq_len, + kv_caches, + exec_mode=ExecutionMode.DECODE) xm.wait_device_ops() logger.info("batch_size: %d, seq_len: %d", batch_size, seq_len) @@ -287,9 +341,11 @@ def _prepare_prompt( input_tokens: List[int] = [] input_positions: List[int] = [] prompt_lens: List[int] = [] + context_lens: List[int] = [] slot_mapping: List[int] = [] - for seq_group_metadata in seq_group_metadata_list: + for batch_idx, seq_group_metadata in enumerate( + seq_group_metadata_list): assert seq_group_metadata.is_prompt seq_ids = list(seq_group_metadata.seq_data.keys()) assert len(seq_ids) == 1 @@ -298,19 +354,31 @@ def _prepare_prompt( seq_data = seq_group_metadata.seq_data[seq_id] # Could include output tokens when a request is preempted. prompt_tokens = seq_data.get_token_ids() + seq_len = len(prompt_tokens) + + num_computed_blocks = len(seq_group_metadata.computed_block_nums) + num_computed_tokens = num_computed_blocks * self.block_size + if num_computed_tokens > 0: + prompt_tokens = prompt_tokens[num_computed_tokens:] + context_lens.append(seq_len) + else: + context_lens.append(0) + prompt_len = len(prompt_tokens) prompt_lens.append(prompt_len) input_tokens.extend(prompt_tokens) - input_positions.extend(list(range(prompt_len))) + input_positions.extend(range(num_computed_tokens, seq_len)) assert seq_group_metadata.block_tables is not None block_table = seq_group_metadata.block_tables[seq_id] - for i in range(prompt_len): + for i in range(num_computed_tokens, seq_len): block_number = block_table[i // self.block_size] block_offset = i % self.block_size slot = block_number * self.block_size + block_offset slot_mapping.append(slot) + if num_computed_tokens > 0: + self.block_tables[batch_idx, :len(block_table)] = block_table # Add paddings to EACH prompt to the smallest power of 2 that is # greater than or equal to the prompt length. @@ -338,14 +406,21 @@ def _prepare_prompt( prompt_lens = torch.tensor(prompt_lens, dtype=torch.int32, device="cpu") + context_lens = torch.tensor(context_lens, + dtype=torch.int32, + device="cpu") + block_tables = torch.tensor(self.block_tables[:num_prefills], + dtype=torch.int32, + device="cpu") attn_metadata = self.attn_backend.make_metadata( num_prefills=num_prefills, num_prefill_tokens=0, # NOTE: This is not used. num_decode_tokens=0, slot_mapping=slot_mapping, multi_modal_placeholder_index_maps=None, - block_tables=None, - context_lens=None, + block_tables=block_tables, + context_lens=context_lens, + effective_query_lens=prompt_lens, ) return input_tokens, input_positions, attn_metadata, prompt_lens @@ -550,6 +625,10 @@ def execute_model( # process them separately. This is a temporary hack that should be # optimized by using SplashAttention. orig_slot_mapping = model_input.attn_metadata.slot_mapping + orig_block_tables = model_input.attn_metadata.block_tables + orig_context_lens = model_input.attn_metadata.context_lens + orig_effective_query_lens = \ + model_input.attn_metadata.effective_query_lens batch_size = model_input.input_lens.shape[0] start_idx = 0 next_token_ids = [] @@ -568,18 +647,24 @@ def execute_model( attn_metadata.num_prefills = 1 attn_metadata.slot_mapping = orig_slot_mapping[ None, start_idx:end_idx].to(self.device) + if orig_context_lens[i].item() > 0: + attn_metadata.context_lens = orig_context_lens[i:i + 1].to( + self.device) + attn_metadata.block_tables = orig_block_tables[ + i].unsqueeze(0).to(self.device) + attn_metadata.effective_query_lens = \ + orig_effective_query_lens[i:i + 1].to(self.device) + else: + attn_metadata.context_lens = None + attn_metadata.block_tables = None + attn_metadata.effective_query_lens = None input_lens = model_input.input_lens[i:i + 1].to(self.device) t = model_input.t[i:i + 1].to(self.device) p = model_input.p[i:i + 1].to(self.device) - output_token_ids = self.model(token_ids, - position_ids, - attn_metadata, - input_lens, - t, - p, + output_token_ids = self.model(token_ids, position_ids, + attn_metadata, input_lens, t, p, model_input.num_samples, - kv_caches, - is_prompt=True) + kv_caches) next_token_ids.append(output_token_ids[0]) start_idx = end_idx @@ -624,15 +709,10 @@ def execute_model( input_lens = model_input.input_lens.to(self.device) for i in range(num_steps): slot_mapping = attn_metadata.slot_mapping - output_token_ids = self.model(token_ids, - position_ids, - attn_metadata, - input_lens, - t, - p, + output_token_ids = self.model(token_ids, position_ids, + attn_metadata, input_lens, t, p, model_input.num_samples, - kv_caches, - is_prompt=False) + kv_caches) self.cached_step_outputs.append(output_token_ids) if i < num_steps - 1: @@ -667,34 +747,11 @@ def execute_model( return [sampler_output] -class ModelWrapper(TorchCompileWrapperWithCustomDispatcher): +class ModelWrapper(nn.Module): - def __init__(self, model: nn.Module, vllm_config: VllmConfig): + def __init__(self, model: nn.Module): + super().__init__() self.model = model - compiled_callable = torch.compile(self.forward, - backend="openxla", - fullgraph=True, - dynamic=False) - super().__init__( - compiled_callable, - compilation_level=vllm_config.compilation_config.level) - - def __call__(self, *args, is_prompt: bool, **kwargs): - if len(self.compiled_codes) < 3 or not self.use_custom_dispatcher: - # not fully compiled yet, or not using the custom dispatcher, - # let PyTorch handle it - return self.compiled_callable(*args, **kwargs) - # the 3 compiled codes are: - # 0: for profiling - # 1: for prompt - # 2: for decode - # dispatch to the compiled code directly, skip PyTorch - if is_prompt: - with self.dispatch_to_code(1): - return self.forward(*args, **kwargs) - else: - with self.dispatch_to_code(2): - return self.forward(*args, **kwargs) def forward( self, diff --git a/vllm/worker/tpu_worker.py b/vllm/worker/tpu_worker.py index 096cb23416909..8754f7538f251 100644 --- a/vllm/worker/tpu_worker.py +++ b/vllm/worker/tpu_worker.py @@ -13,7 +13,7 @@ from vllm.model_executor import set_random_seed from vllm.sequence import ExecuteModelRequest from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, get_dtype_size -from vllm.worker.tpu_model_runner import TPUModelRunner +from vllm.worker.tpu_model_runner import ExecutionMode, TPUModelRunner from vllm.worker.worker_base import (LocalOrDistributedWorkerBase, LoraNotSupportedWorkerBase, WorkerBase, WorkerInput) @@ -112,7 +112,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: batch_size=1, seq_len=self.scheduler_config.max_num_batched_tokens, kv_caches=kv_caches, - is_prompt=True, + exec_mode=ExecutionMode.PREFILL, ) # Synchronize before measuring the memory usage. xm.wait_device_ops() diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index d3ca6d9d0b17e..80fd7bc3b67cc 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -1,6 +1,7 @@ """A GPU worker class.""" import gc import os +import time from typing import Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -189,6 +190,7 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: torch.cuda.reset_peak_memory_stats() free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info() + start_time = time.time() # Execute a forward pass with dummy inputs to profile the memory usage # of the model. @@ -229,12 +231,18 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) + end_time = time.time() logger.info( - "Memory profiling results: total_gpu_memory=%.2fGiB" - " initial_memory_usage=%.2fGiB peak_torch_memory=%.2fGiB" - " memory_usage_post_profile=%.2fGiB" - " non_torch_memory=%.2fGiB kv_cache_size=%.2fGiB" - " gpu_memory_utilization=%.2f", total_gpu_memory / (1024**3), + "Memory profiling results: " + "duration=%.2f seconds, " + "total_gpu_memory=%.2fGiB, " + "initial_memory_usage=%.2fGiB, " + "peak_torch_memory=%.2fGiB, " + "memory_usage_post_profile=%.2fGiB, " + "non_torch_memory=%.2fGiB, " + "kv_cache_size=%.2fGiB, " + "gpu_memory_utilization=%.2f.", end_time - start_time, + total_gpu_memory / (1024**3), (total_gpu_memory - free_memory_pre_profile) / (1024**3), (peak_memory - non_torch_allocations) / (1024**3), total_allocated_bytes / (1024**3), diff --git a/vllm/worker/worker_base.py b/vllm/worker/worker_base.py index cf8a4946a71c4..e7fec6d17eecd 100644 --- a/vllm/worker/worker_base.py +++ b/vllm/worker/worker_base.py @@ -1,9 +1,8 @@ import dataclasses -import importlib import os import time from abc import ABC, abstractmethod -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Type, Union +from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union import torch @@ -15,7 +14,7 @@ from vllm.platforms import current_platform from vllm.sequence import ExecuteModelRequest, IntermediateTensors from vllm.utils import (enable_trace_function_call_for_thread, - update_environment_variables) + resolve_obj_by_qualname, update_environment_variables) from vllm.worker.model_runner_base import (BroadcastableModelInput, ModelRunnerBase, ModelRunnerInputBase) @@ -411,23 +410,14 @@ class WorkerWrapperBase: We first instantiate the WorkerWrapper, which remembers the worker module and class name. Then, when we call `update_environment_variables`, and the real initialization happens in `init_worker`. - - If worker_class_fn is specified, it will be executed to get the worker - class. - Otherwise, the worker class will be obtained by dynamically importing it - using worker_module_name and worker_class_name. """ def __init__( self, - worker_module_name: str, - worker_class_name: str, - trust_remote_code: bool = False, - worker_class_fn: Optional[Callable[[], - Type[WorkerBase]]] = None) -> None: - self.worker_module_name = worker_module_name - self.worker_class_name = worker_class_name - self.worker_class_fn = worker_class_fn + vllm_config: VllmConfig, + ) -> None: + self.vllm_config = vllm_config + trust_remote_code = vllm_config.model_config.trust_remote_code self.worker: Optional[WorkerBase] = None if trust_remote_code: # note: lazy import to avoid importing torch before initializing @@ -456,12 +446,8 @@ def init_worker(self, *args, **kwargs): from vllm.plugins import load_general_plugins load_general_plugins() - if self.worker_class_fn: - worker_class = self.worker_class_fn() - else: - mod = importlib.import_module(self.worker_module_name) - worker_class = getattr(mod, self.worker_class_name) - + worker_class = resolve_obj_by_qualname( + self.vllm_config.parallel_config.worker_cls) self.worker = worker_class(*args, **kwargs) assert self.worker is not None