diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index b12bf7b382d0f..427dc14513d45 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -146,7 +146,9 @@ steps: source_file_dependencies: - vllm/ - tests/test_regression - command: pytest -v -s test_regression.py + commands: + - pip install modelscope + - pytest -v -s test_regression.py working_dir: "/vllm-workspace/tests" # optional - label: Engine Test # 10min @@ -208,7 +210,7 @@ steps: - tests/spec_decode commands: - pytest -v -s spec_decode/e2e/test_multistep_correctness.py - - pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py + - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py - label: LoRA Test %N # 15min each mirror_hardwares: [amd] diff --git a/CMakeLists.txt b/CMakeLists.txt index e531a410ec8c8..7b24c4abc650e 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -143,6 +143,19 @@ else() message(FATAL_ERROR "Can't find CUDA or HIP installation.") endif() + +# +# For cuda we want to be able to control which architectures we compile for on +# a per-file basis in order to cut down on compile time. So here we extract +# the set of architectures we want to compile for and remove the from the +# CMAKE_CUDA_FLAGS so that they are not applied globally. +# +if(VLLM_GPU_LANG STREQUAL "CUDA") + clear_cuda_arches(CUDA_ARCH_FLAGS) + extract_unique_cuda_archs_ascending(CUDA_ARCHS "${CUDA_ARCH_FLAGS}") + message(STATUS "CUDA target architectures: ${CUDA_ARCHS}") +endif() + # # Override the GPU architectures detected by cmake/torch and filter them by # the supported versions for the current language. @@ -223,30 +236,89 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "csrc/mamba/causal_conv1d/causal_conv1d.cu" "csrc/quantization/aqlm/gemm_kernels.cu" "csrc/quantization/awq/gemm_kernels.cu" - "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" - "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" - "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" - "csrc/quantization/gptq_marlin/gptq_marlin.cu" - "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" - "csrc/quantization/gptq_marlin/awq_marlin_repack.cu" "csrc/quantization/gguf/gguf_kernel.cu" - "csrc/quantization/fp8/fp8_marlin.cu" "csrc/custom_all_reduce.cu" "csrc/permute_cols.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu" - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + "csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu") + + set_gencode_flags_for_srcs( + SRCS "${VLLM_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + + # Only build Marlin kernels if we are building for at least some compatible archs. + # Keep building Marlin for 9.0 as there are some group sizes and shapes that + # are not supported by Machete yet. + cuda_archs_loose_intersection(MARLIN_ARCHS "8.0;8.6;8.9;9.0" ${CUDA_ARCHS}) + if (MARLIN_ARCHS) + set(MARLIN_SRCS + "csrc/quantization/fp8/fp8_marlin.cu" + "csrc/quantization/marlin/dense/marlin_cuda_kernel.cu" + "csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu" + "csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu" + "csrc/quantization/gptq_marlin/gptq_marlin.cu" + "csrc/quantization/gptq_marlin/gptq_marlin_repack.cu" + "csrc/quantization/gptq_marlin/awq_marlin_repack.cu") + set_gencode_flags_for_srcs( + SRCS "${MARLIN_SRCS}" + CUDA_ARCHS "${MARLIN_ARCHS}") + list(APPEND VLLM_EXT_SRC "${MARLIN_SRCS}") + message(STATUS "Building Marlin kernels for archs: ${MARLIN_ARCHS}") + else() + message(STATUS "Not building Marlin kernels as no compatible archs found" + "in CUDA target architectures") + endif() # - # The CUTLASS kernels for Hopper require sm90a to be enabled. - # This is done via the below gencode option, BUT that creates kernels for both sm90 and sm90a. - # That adds an extra 17MB to compiled binary, so instead we selectively enable it. - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) - set_source_files_properties( - "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu" - PROPERTIES - COMPILE_FLAGS - "-gencode arch=compute_90a,code=sm_90a") + # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require + # CUDA 12.0 or later (and only work on Hopper, 9.0/9.0a for now). + cuda_archs_loose_intersection(SCALED_MM_3X_ARCHS "9.0;9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c3x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_3X_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C3X=1") + message(STATUS "Building scaled_mm_c3x for archs: ${SCALED_MM_3X_ARCHS}") + else() + # clear SCALED_MM_3X_ARCHS so the scaled_mm_c2x kernels know we didn't + # build any 3x kernels + set(SCALED_MM_3X_ARCHS) + + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND SCALED_MM_3X_ARCHS) + message(STATUS "Not building scaled_mm_c3x as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running FP8 quantized models on " + "Hopper.") + else() + message(STATUS "Not building scaled_mm_c3x as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + # + # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) + # kernels for the remaining archs that are not already built for 3x. + cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS + "7.5;8.0;8.6;8.9;9.0;9.0a" "${CUDA_ARCHS}") + # subtract out the archs that are already built for 3x + list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) + if (SCALED_MM_2X_ARCHS) + set(SRCS "csrc/quantization/cutlass_w8a8/scaled_mm_c2x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1") + message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}") + else() + if (SCALED_MM_3X_ARCHS) + message(STATUS "Not building scaled_mm_c2x as all archs are already built" + " for and covered by scaled_mm_c3x") + else() + message(STATUS "Not building scaled_mm_c2x as no compatible archs found " + "in CUDA target architectures") + endif() endif() @@ -254,47 +326,72 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Machete kernels # The machete kernels only work on hopper and require CUDA 12.0 or later. - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) + # Only build Machete kernels if we are building for something compatible with sm90a + cuda_archs_loose_intersection(MACHETE_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 AND MACHETE_ARCHS) # # For the Machete kernels we automatically generate sources for various # preselected input type pairs and schedules. # Generate sources: - execute_process( - COMMAND ${CMAKE_COMMAND} -E env - PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH - ${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py - RESULT_VARIABLE machete_generation_result - OUTPUT_VARIABLE machete_generation_output - OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log - ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log - ) - - if (NOT machete_generation_result EQUAL 0) - message(FATAL_ERROR "Machete generation failed." - " Result: \"${machete_generation_result}\"" - "\nCheck the log for details: " - "${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log") + set(MACHETE_GEN_SCRIPT + ${CMAKE_CURRENT_SOURCE_DIR}/csrc/quantization/machete/generate.py) + file(MD5 ${MACHETE_GEN_SCRIPT} MACHETE_GEN_SCRIPT_HASH) + + message(STATUS "Machete generation script hash: ${MACHETE_GEN_SCRIPT_HASH}") + message(STATUS "Last run machete generate script hash: $CACHE{MACHETE_GEN_SCRIPT_HASH}") + + if (NOT DEFINED CACHE{MACHETE_GEN_SCRIPT_HASH} + OR NOT $CACHE{MACHETE_GEN_SCRIPT_HASH} STREQUAL ${MACHETE_GEN_SCRIPT_HASH}) + execute_process( + COMMAND ${CMAKE_COMMAND} -E env + PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/csrc/cutlass_extensions/:${CUTLASS_DIR}/python/:${VLLM_PYTHON_PATH}:$PYTHONPATH + ${Python_EXECUTABLE} ${MACHETE_GEN_SCRIPT} + RESULT_VARIABLE machete_generation_result + OUTPUT_VARIABLE machete_generation_output + OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log + ) + + if (NOT machete_generation_result EQUAL 0) + message(FATAL_ERROR "Machete generation failed." + " Result: \"${machete_generation_result}\"" + "\nCheck the log for details: " + "${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log") + else() + set(MACHETE_GEN_SCRIPT_HASH ${MACHETE_GEN_SCRIPT_HASH} + CACHE STRING "Last run machete generate script hash" FORCE) + message(STATUS "Machete generation completed successfully.") + endif() else() - message(STATUS "Machete generation completed successfully.") + message(STATUS "Machete generation script has not changed, skipping generation.") endif() # Add machete generated sources file(GLOB MACHETE_GEN_SOURCES "csrc/quantization/machete/generated/*.cu") list(APPEND VLLM_EXT_SRC ${MACHETE_GEN_SOURCES}) - message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}") - set_source_files_properties( - ${MACHETE_GEN_SOURCES} - PROPERTIES - COMPILE_FLAGS - "-gencode arch=compute_90a,code=sm_90a") + # forward compatible + set_gencode_flags_for_srcs( + SRCS "${MACHETE_GEN_SOURCES}" + CUDA_ARCHS "${MACHETE_ARCHS}") + + list(APPEND VLLM_EXT_SRC + csrc/quantization/machete/machete_pytorch.cu) + + message(STATUS "Building Machete kernels for archs: ${MACHETE_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0 + AND MACHETE_ARCHS) + message(STATUS "Not building Machete kernels as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running w4a16 quantized models on " + "Hopper.") + else() + message(STATUS "Not building Machete kernels as no compatible archs " + "found in CUDA target architectures") + endif() endif() - - # Add pytorch binding for machete (add on even CUDA < 12.0 so that we can - # raise an error if the user that this was built with an incompatible - # CUDA version) - list(APPEND VLLM_EXT_SRC - csrc/quantization/machete/machete_pytorch.cu) +# if CUDA endif endif() message(STATUS "Enabling C extension.") @@ -323,14 +420,31 @@ set(VLLM_MOE_EXT_SRC "csrc/moe/torch_bindings.cpp" "csrc/moe/topk_softmax_kernels.cu") +set_gencode_flags_for_srcs( + SRCS "${VLLM_MOE_EXT_SRC}" + CUDA_ARCHS "${CUDA_ARCHS}") + if(VLLM_GPU_LANG STREQUAL "CUDA") - list(APPEND VLLM_MOE_EXT_SRC - "csrc/moe/marlin_kernels/marlin_moe_kernel.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" - "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" - "csrc/moe/marlin_moe_ops.cu") + cuda_archs_loose_intersection(MARLIN_MOE_ARCHS "8.0;8.6;8.9;9.0" "${CUDA_ARCHS}") + if (MARLIN_MOE_ARCHS) + set(MARLIN_MOE_SRC + "csrc/moe/marlin_kernels/marlin_moe_kernel.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku4b8.cu" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.h" + "csrc/moe/marlin_kernels/marlin_moe_kernel_ku8b128.cu" + "csrc/moe/marlin_moe_ops.cu") + + set_gencode_flags_for_srcs( + SRCS "${MARLIN_MOE_SRC}" + CUDA_ARCHS "${MARLIN_MOE_ARCHS}") + + list(APPEND VLLM_MOE_EXT_SRC "${MARLIN_MOE_SRC}") + message(STATUS "Building Marlin MOE kernels for archs: ${MARLIN_MOE_ARCHS}") + else() + message(STATUS "Not building Marlin MOE kernels as no compatible archs found" + "in CUDA target architectures") + endif() endif() message(STATUS "Enabling moe extension.") @@ -368,6 +482,17 @@ if (NOT VLLM_TARGET_DEVICE STREQUAL "cuda") return() endif () +# vLLM flash attention requires VLLM_GPU_ARCHES to contain the set of target +# arches in the CMake syntax (75-real, 89-virtual, etc), since we clear the +# arches in the CUDA case (and instead set the gencodes on a per file basis) +# we need to manually set VLLM_GPU_ARCHES here. +if(VLLM_GPU_LANG STREQUAL "CUDA") + foreach(_ARCH ${CUDA_ARCHS}) + string(REPLACE "." "" _ARCH "${_ARCH}") + list(APPEND VLLM_GPU_ARCHES "${_ARCH}-real") + endforeach() +endif() + # # Build vLLM flash attention from source # diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 996a92d2a8b3d..56c37b241a359 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -90,7 +90,7 @@ def sample_sharegpt_requests( fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, int, int, None]]: # Load the dataset. - with open(dataset_path) as f: + with open(dataset_path, encoding='utf-8') as f: dataset = json.load(f) # Filter out the conversations with less than 2 turns. dataset = [data for data in dataset if len(data["conversations"]) >= 2] @@ -139,7 +139,7 @@ def sample_sonnet_requests( ), "'args.sonnet-input-len' must be greater than 'args.prefix-input-len'." # Load the dataset. - with open(dataset_path) as f: + with open(dataset_path, encoding='utf-8') as f: poem_lines = f.readlines() # Tokenize the poem lines. @@ -726,7 +726,7 @@ def main(args: argparse.Namespace): file_name = args.result_filename if args.result_dir: file_name = os.path.join(args.result_dir, file_name) - with open(file_name, "w") as outfile: + with open(file_name, "w", encoding='utf-8') as outfile: json.dump(result_json, outfile) diff --git a/cmake/utils.cmake b/cmake/utils.cmake index 10fa0a25bde15..24bb7299338ac 100644 --- a/cmake/utils.cmake +++ b/cmake/utils.cmake @@ -133,10 +133,181 @@ macro(string_to_ver OUT_VER IN_STR) string(REGEX REPLACE "\([0-9]+\)\([0-9]\)" "\\1.\\2" ${OUT_VER} ${IN_STR}) endmacro() +# +# Clear all `-gencode` flags from `CMAKE_CUDA_FLAGS` and store them in +# `CUDA_ARCH_FLAGS`. +# +# Example: +# CMAKE_CUDA_FLAGS="-Wall -gencode arch=compute_70,code=sm_70 -gencode arch=compute_75,code=sm_75" +# clear_cuda_arches(CUDA_ARCH_FLAGS) +# CUDA_ARCH_FLAGS="-gencode arch=compute_70,code=sm_70;-gencode arch=compute_75,code=sm_75" +# CMAKE_CUDA_FLAGS="-Wall" +# +macro(clear_cuda_arches CUDA_ARCH_FLAGS) + # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS` + string(REGEX MATCHALL "-gencode arch=[^ ]+" CUDA_ARCH_FLAGS + ${CMAKE_CUDA_FLAGS}) + + # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified + # and passed back via the `CUDA_ARCHITECTURES` property. + string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS + ${CMAKE_CUDA_FLAGS}) +endmacro() + +# +# Extract unique CUDA architectures from a list of compute capabilities codes in +# the form `[]`, convert them to the form sort +# `.`, dedupes them and then sorts them in ascending order and +# stores them in `OUT_ARCHES`. +# +# Example: +# CUDA_ARCH_FLAGS="-gencode arch=compute_75,code=sm_75;...;-gencode arch=compute_90a,code=sm_90a" +# extract_unique_cuda_archs_ascending(OUT_ARCHES CUDA_ARCH_FLAGS) +# OUT_ARCHES="7.5;...;9.0" +function(extract_unique_cuda_archs_ascending OUT_ARCHES CUDA_ARCH_FLAGS) + set(_CUDA_ARCHES) + foreach(_ARCH ${CUDA_ARCH_FLAGS}) + string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH}) + if (_COMPUTE) + set(_COMPUTE ${CMAKE_MATCH_1}) + endif() + + string_to_ver(_COMPUTE_VER ${_COMPUTE}) + list(APPEND _CUDA_ARCHES ${_COMPUTE_VER}) + endforeach() + + list(REMOVE_DUPLICATES _CUDA_ARCHES) + list(SORT _CUDA_ARCHES COMPARE NATURAL ORDER ASCENDING) + set(${OUT_ARCHES} ${_CUDA_ARCHES} PARENT_SCOPE) +endfunction() + +# +# For a specific file set the `-gencode` flag in compile options conditionally +# for the CUDA language. +# +# Example: +# set_gencode_flag_for_srcs( +# SRCS "foo.cu" +# ARCH "compute_75" +# CODE "sm_75") +# adds: "-gencode arch=compute_75,code=sm_75" to the compile options for +# `foo.cu` (only for the CUDA language). +# +macro(set_gencode_flag_for_srcs) + set(options) + set(oneValueArgs ARCH CODE) + set(multiValueArgs SRCS) + cmake_parse_arguments(arg "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + set(_FLAG -gencode arch=${arg_ARCH},code=${arg_CODE}) + set_property( + SOURCE ${arg_SRCS} + APPEND PROPERTY + COMPILE_OPTIONS "$<$:${_FLAG}>" + ) + + message(DEBUG "Setting gencode flag for ${arg_SRCS}: ${_FLAG}") +endmacro(set_gencode_flag_for_srcs) + +# +# For a list of source files set the `-gencode` flags in the files specific +# compile options (specifically for the CUDA language). +# +# arguments are: +# SRCS: list of source files +# CUDA_ARCHS: list of CUDA architectures in the form `.[letter]` +# BUILD_PTX_FOR_ARCH: if set to true, then the PTX code will be built +# for architecture `BUILD_PTX_FOR_ARCH` if there is a CUDA_ARCH in CUDA_ARCHS +# that is larger than BUILD_PTX_FOR_ARCH. +# +macro(set_gencode_flags_for_srcs) + set(options) + set(oneValueArgs BUILD_PTX_FOR_ARCH) + set(multiValueArgs SRCS CUDA_ARCHS) + cmake_parse_arguments(arg "${options}" "${oneValueArgs}" + "${multiValueArgs}" ${ARGN} ) + + foreach(_ARCH ${arg_CUDA_ARCHS}) + string(REPLACE "." "" _ARCH "${_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_ARCH}" + CODE "sm_${_ARCH}") + endforeach() + + if (${arg_BUILD_PTX_FOR_ARCH}) + list(SORT arg_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) + list(GET arg_CUDA_ARCHS -1 _HIGHEST_ARCH) + if (_HIGHEST_ARCH VERSION_GREATER_EQUAL ${arg_BUILD_PTX_FOR_ARCH}) + string(REPLACE "." "" _PTX_ARCH "${arg_BUILD_PTX_FOR_ARCH}") + set_gencode_flag_for_srcs( + SRCS ${arg_SRCS} + ARCH "compute_${_PTX_ARCH}" + CODE "compute_${_PTX_ARCH}") + endif() + endif() +endmacro() + +# +# For the given `SRC_CUDA_ARCHS` list of gencode versions in the form +# `.[letter]` compute the "loose intersection" with the +# `TGT_CUDA_ARCHS` list of gencodes. +# The loose intersection is defined as: +# { max{ x \in tgt | x <= y } | y \in src, { x \in tgt | x <= y } != {} } +# where `<=` is the version comparison operator. +# In other words, for each version in `TGT_CUDA_ARCHS` find the highest version +# in `SRC_CUDA_ARCHS` that is less or equal to the version in `TGT_CUDA_ARCHS`. +# We have special handling for 9.0a, if 9.0a is in `SRC_CUDA_ARCHS` and 9.0 is +# in `TGT_CUDA_ARCHS` then we should remove 9.0a from `SRC_CUDA_ARCHS` and add +# 9.0a to the result. +# The result is stored in `OUT_CUDA_ARCHS`. +# +# Example: +# SRC_CUDA_ARCHS="7.5;8.0;8.6;9.0;9.0a" +# TGT_CUDA_ARCHS="8.0;8.9;9.0" +# cuda_archs_loose_intersection(OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) +# OUT_CUDA_ARCHS="8.0;8.6;9.0;9.0a" +# +function(cuda_archs_loose_intersection OUT_CUDA_ARCHS SRC_CUDA_ARCHS TGT_CUDA_ARCHS) + list(REMOVE_DUPLICATES SRC_CUDA_ARCHS) + + # if 9.0a is in SRC_CUDA_ARCHS and 9.0 is in CUDA_ARCHS then we should + # remove 9.0a from SRC_CUDA_ARCHS and add 9.0a to _CUDA_ARCHS + set(_CUDA_ARCHS) + if ("9.0a" IN_LIST SRC_CUDA_ARCHS) + list(REMOVE_ITEM SRC_CUDA_ARCHS "9.0a") + if ("9.0" IN_LIST TGT_CUDA_ARCHS) + set(_CUDA_ARCHS "9.0a") + endif() + endif() + + list(SORT SRC_CUDA_ARCHS COMPARE NATURAL ORDER ASCENDING) + + # for each ARCH in CUDA_ARCHS find the highest arch in SRC_CUDA_ARCHS that is + # less or eqault to ARCH + foreach(_ARCH ${CUDA_ARCHS}) + set(_TMP_ARCH) + foreach(_SRC_ARCH ${SRC_CUDA_ARCHS}) + if (_SRC_ARCH VERSION_LESS_EQUAL _ARCH) + set(_TMP_ARCH ${_SRC_ARCH}) + else() + break() + endif() + endforeach() + if (_TMP_ARCH) + list(APPEND _CUDA_ARCHS ${_TMP_ARCH}) + endif() + endforeach() + + list(REMOVE_DUPLICATES _CUDA_ARCHS) + set(${OUT_CUDA_ARCHS} ${_CUDA_ARCHS} PARENT_SCOPE) +endfunction() + # # Override the GPU architectures detected by cmake/torch and filter them by # `GPU_SUPPORTED_ARCHES`. Sets the final set of architectures in -# `GPU_ARCHES`. +# `GPU_ARCHES`. This only applies to the HIP language since for CUDA we set +# the architectures on a per file basis. # # Note: this is defined as a macro since it updates `CMAKE_CUDA_FLAGS`. # @@ -174,109 +345,7 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES) "None of the detected ROCm architectures: ${HIP_ARCHITECTURES} is" " supported. Supported ROCm architectures are: ${_GPU_SUPPORTED_ARCHES_LIST}.") endif() - - elseif(${GPU_LANG} STREQUAL "CUDA") - # - # Setup/process CUDA arch flags. - # - # The torch cmake setup hardcodes the detected architecture flags in - # `CMAKE_CUDA_FLAGS`. Since `CMAKE_CUDA_FLAGS` is a "global" variable, it - # can't modified on a per-target basis. - # So, all the `-gencode` flags need to be extracted and removed from - # `CMAKE_CUDA_FLAGS` for processing so they can be passed by another method. - # Since it's not possible to use `target_compiler_options` for adding target - # specific `-gencode` arguments, the target's `CUDA_ARCHITECTURES` property - # must be used instead. This requires repackaging the architecture flags - # into a format that cmake expects for `CUDA_ARCHITECTURES`. - # - # This is a bit fragile in that it depends on torch using `-gencode` as opposed - # to one of the other nvcc options to specify architectures. - # - # Note: torch uses the `TORCH_CUDA_ARCH_LIST` environment variable to override - # detected architectures. - # - message(DEBUG "initial CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") - - # Extract all `-gencode` flags from `CMAKE_CUDA_FLAGS` - string(REGEX MATCHALL "-gencode arch=[^ ]+" _CUDA_ARCH_FLAGS - ${CMAKE_CUDA_FLAGS}) - - # Remove all `-gencode` flags from `CMAKE_CUDA_FLAGS` since they will be modified - # and passed back via the `CUDA_ARCHITECTURES` property. - string(REGEX REPLACE "-gencode arch=[^ ]+ *" "" CMAKE_CUDA_FLAGS - ${CMAKE_CUDA_FLAGS}) - - # If this error is triggered, it might mean that torch has changed how it sets - # up nvcc architecture code generation flags. - if (NOT _CUDA_ARCH_FLAGS) - message(FATAL_ERROR - "Could not find any architecture related code generation flags in " - "CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})") - endif() - - message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}") - message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}") - - # Initialize the architecture lists to empty. - set(${GPU_ARCHES}) - - # Process each `gencode` flag. - foreach(_ARCH ${_CUDA_ARCH_FLAGS}) - # For each flag, extract the version number and whether it refers to PTX - # or native code. - # Note: if a regex matches then `CMAKE_MATCH_1` holds the binding - # for that match. - - string(REGEX MATCH "arch=compute_\([0-9]+a?\)" _COMPUTE ${_ARCH}) - if (_COMPUTE) - set(_COMPUTE ${CMAKE_MATCH_1}) - endif() - - string(REGEX MATCH "code=sm_\([0-9]+a?\)" _SM ${_ARCH}) - if (_SM) - set(_SM ${CMAKE_MATCH_1}) - endif() - - string(REGEX MATCH "code=compute_\([0-9]+a?\)" _CODE ${_ARCH}) - if (_CODE) - set(_CODE ${CMAKE_MATCH_1}) - endif() - - # Make sure the virtual architecture can be matched. - if (NOT _COMPUTE) - message(FATAL_ERROR - "Could not determine virtual architecture from: ${_ARCH}.") - endif() - - # One of sm_ or compute_ must exist. - if ((NOT _SM) AND (NOT _CODE)) - message(FATAL_ERROR - "Could not determine a codegen architecture from: ${_ARCH}.") - endif() - - if (_SM) - # -real suffix let CMake to only generate elf code for the kernels. - # we want this, otherwise the added ptx (default) will increase binary size. - set(_VIRT "-real") - set(_CODE_ARCH ${_SM}) - else() - # -virtual suffix let CMake to generate ptx code for the kernels. - set(_VIRT "-virtual") - set(_CODE_ARCH ${_CODE}) - endif() - - # Check if the current version is in the supported arch list. - string_to_ver(_CODE_VER ${_CODE_ARCH}) - if (NOT _CODE_VER IN_LIST _GPU_SUPPORTED_ARCHES_LIST) - message(STATUS "discarding unsupported CUDA arch ${_VER}.") - continue() - endif() - - # Add it to the arch list. - list(APPEND ${GPU_ARCHES} "${_CODE_ARCH}${_VIRT}") - endforeach() endif() - message(STATUS "${GPU_LANG} target arches: ${${GPU_ARCHES}}") endmacro() # diff --git a/csrc/core/registration.h b/csrc/core/registration.h index e5396e9a8b137..4d0ce1c572c1c 100644 --- a/csrc/core/registration.h +++ b/csrc/core/registration.h @@ -12,6 +12,11 @@ // could be a macro instead of a literal token. #define TORCH_LIBRARY_EXPAND(NAME, MODULE) TORCH_LIBRARY(NAME, MODULE) +// A version of the TORCH_LIBRARY_IMPL macro that expands the NAME, i.e. so NAME +// could be a macro instead of a literal token. +#define TORCH_LIBRARY_IMPL_EXPAND(NAME, DEVICE, MODULE) \ + TORCH_LIBRARY_IMPL(NAME, DEVICE, MODULE) + // REGISTER_EXTENSION allows the shared library to be loaded and initialized // via python's import statement. #define REGISTER_EXTENSION(NAME) \ diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index c97b5dbd2a54e..661490d95e791 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -27,6 +27,7 @@ #include "core/exception.hpp" #include "core/scalar_type.hpp" +#include "core/registration.h" #include "marlin_kernels/marlin_moe_kernel_ku4b8.h" #include "marlin_kernels/marlin_moe_kernel_ku8b128.h" @@ -552,3 +553,7 @@ torch::Tensor marlin_gemm_moe( thread_n, sms, max_par, replicate_input, apply_weights); return c; } + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("marlin_gemm_moe", &marlin_gemm_moe); +} diff --git a/csrc/moe/marlin_moe_ops.h b/csrc/moe/marlin_moe_ops.h deleted file mode 100644 index adee8399a4d6f..0000000000000 --- a/csrc/moe/marlin_moe_ops.h +++ /dev/null @@ -1,15 +0,0 @@ -#pragma once - -#include - -#include "core/scalar_type.hpp" - -torch::Tensor marlin_gemm_moe( - const torch::Tensor& a, const torch::Tensor& b_q_weights, - const torch::Tensor& sorted_ids, const torch::Tensor& topk_weights, - const torch::Tensor& topk_ids, const torch::Tensor& b_scales, - const torch::Tensor& g_idx, const torch::Tensor& perm, - torch::Tensor& workspace, vllm::ScalarTypeTorchPtr const& b_q_type, - int64_t size_m, int64_t size_n, int64_t size_k, bool is_k_full, - int64_t num_experts, int64_t topk, int64_t moe_block_size, - bool replicate_input, bool apply_weights); diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index cd65a8ee92b94..cbc8754f7a5b2 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -1,6 +1,5 @@ #include "core/registration.h" #include "moe_ops.h" -#include "marlin_moe_ops.h" TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { // Apply topk softmax to the gating outputs. @@ -18,7 +17,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "int size_n, int size_k, bool is_k_full, int num_experts, int topk, " "int moe_block_size, bool replicate_input, bool apply_weights)" " -> Tensor"); - m.impl("marlin_gemm_moe", torch::kCUDA, &marlin_gemm_moe); + // conditionally compiled so impl registration is in source file #endif } diff --git a/csrc/ops.h b/csrc/ops.h index 3e31ddb286e80..fce545f95a7cc 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -90,63 +90,8 @@ torch::Tensor awq_dequantize(torch::Tensor _kernel, torch::Tensor _zeros, int64_t split_k_iters, int64_t thx, int64_t thy); -torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t size_m, int64_t size_n, int64_t size_k); - -namespace machete { - -std::vector supported_schedules( - vllm::ScalarTypeTorchPtr const& btype); - -torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, - vllm::ScalarTypeTorchPtr const& btype, - c10::optional const& scales, - c10::optional const& zeros, - c10::optional group_size, - c10::optional const& C, - c10::optional alpha, c10::optional beta, - c10::optional schedule); - -torch::Tensor prepack_B(torch::Tensor const& B, - vllm::ScalarTypeTorchPtr const& btype); - -}; // namespace machete - torch::Tensor permute_cols(torch::Tensor const& A, torch::Tensor const& perm); -torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_meta, - torch::Tensor& b_scales, - torch::Tensor& workspace, - vllm::ScalarTypeTorchPtr const& b_q_type, - int64_t size_m, int64_t size_n, - int64_t size_k); - -torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& b_zeros, - torch::Tensor& g_idx, torch::Tensor& perm, - torch::Tensor& workspace, - vllm::ScalarTypeTorchPtr const& b_q_type, - int64_t size_m, int64_t size_n, int64_t size_k, - bool is_k_full, bool has_zp, - bool use_fp32_reduce); - -torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, - int64_t size_k, int64_t size_n, - int64_t num_bits); - -torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, - torch::Tensor& perm, c10::SymInt size_k, - c10::SymInt size_n, int64_t num_bits); - -torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, - int64_t size_n, int64_t num_bits); - -torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight, - c10::SymInt size_k, c10::SymInt size_n, - int64_t num_bits); - torch::Tensor ggml_dequantize(torch::Tensor W, int64_t type, int64_t m, int64_t n); @@ -156,11 +101,6 @@ torch::Tensor ggml_mul_mat_vec_a8(torch::Tensor W, torch::Tensor X, torch::Tensor ggml_mul_mat_a8(torch::Tensor W, torch::Tensor X, int64_t type, int64_t row); -torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, - torch::Tensor& b_scales, torch::Tensor& workspace, - int64_t num_bits, int64_t size_m, int64_t size_n, - int64_t size_k); - bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); void cutlass_scaled_mm(torch::Tensor& out, torch::Tensor const& a, @@ -175,14 +115,6 @@ void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& azp_adj, c10::optional const& azp, c10::optional const& bias); - -torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, - torch::Tensor const& b_q_weight, - torch::Tensor const& s_tok, - torch::Tensor const& s_ch, - torch::Tensor const& s_group, - torch::Tensor& workspace, int64_t size_m, - int64_t size_n, int64_t size_k); #endif void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input, diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index 195eb27dee749..46fef79f439fb 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -17,6 +17,17 @@ __global__ void advance_step_flashattn_kernel( long const* sampled_token_ids_ptr, long* input_positions_ptr, int* seq_lens_ptr, long* slot_mapping_ptr, int const* block_tables_ptr, int64_t const block_tables_stride) { + int const n_pad = num_seqs - num_queries; + if (n_pad && blockIdx.x == 0) { + // Handle cuda graph padding + int const offset = num_queries; + for (int i = threadIdx.x; i < n_pad; i += blockDim.x) { + input_tokens_ptr[offset + i] = 0; + input_positions_ptr[offset + i] = 0; + slot_mapping_ptr[offset + i] = -1; + } + } + int num_query_blocks = div_ceil(num_queries, num_threads); if (blockIdx.x >= num_query_blocks) { diff --git a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu index 0b1d5cfe1b338..1657f7d0b16e8 100644 --- a/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu +++ b/csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu @@ -21,7 +21,7 @@ void cutlass_scaled_mm_sm89(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b_scales, c10::optional const& bias); -#if defined CUDA_VERSION && CUDA_VERSION >= 12000 +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a, torch::Tensor const& b, torch::Tensor const& a_scales, @@ -114,26 +114,39 @@ void cutlass_scaled_mm(torch::Tensor& c, torch::Tensor const& a, at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); int32_t version_num = get_sm_version_num(); - if (version_num >= 90) { - // Hopper + // Hopper - // Guard against compilation issues for sm90 kernels -#if defined CUDA_VERSION && CUDA_VERSION >= 12000 + // Guard against compilation issues for sm90 kernels +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X + if (version_num >= 90) { cutlass_scaled_mm_sm90(c, a, b, a_scales, b_scales, bias); -#else - cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias); + return; + } #endif - } else if (version_num == 89) { + +#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X + if (version_num == 89) { // Ada Lovelace cutlass_scaled_mm_sm89(c, a, b, a_scales, b_scales, bias); - } else if (version_num >= 80) { + return; + } + + if (version_num >= 80) { // Ampere cutlass_scaled_mm_sm80(c, a, b, a_scales, b_scales, bias); - } else { - // Turing - TORCH_CHECK(version_num >= 75); - cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias); + return; } + + // Turing + TORCH_CHECK(version_num >= 75); + cutlass_scaled_mm_sm75(c, a, b, a_scales, b_scales, bias); +#endif + + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_scaled_mm for a compute capability less than " + "CUDA device capability: ", + version_num); } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, @@ -174,25 +187,38 @@ void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, "currently bias dtype must match output dtype ", c.dtype()); at::cuda::OptionalCUDAGuard const device_guard(device_of(a)); + int32_t version_num = get_sm_version_num(); - if (version_num >= 90) { - // Hopper - // Guard against compilation issues for sm90 kernels -#if defined CUDA_VERSION && CUDA_VERSION >= 12000 +#if defined ENABLE_SCALED_MM_C3X && ENABLE_SCALED_MM_C3X + if (version_num >= 90) { cutlass_scaled_mm_azp_sm90(c, a, b, a_scales, b_scales, azp_adj, azp, bias); -#else - cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias); + return; + } #endif - } else if (version_num == 89) { + +#if defined ENABLE_SCALED_MM_C2X && ENABLE_SCALED_MM_C2X + if (version_num == 89) { // Ada Lovelace cutlass_scaled_mm_azp_sm89(c, a, b, a_scales, b_scales, azp_adj, azp, bias); - } else if (version_num >= 80) { + return; + } + + if (version_num >= 80) { // Ampere cutlass_scaled_mm_azp_sm80(c, a, b, a_scales, b_scales, azp_adj, azp, bias); - } else { - // Turing - TORCH_CHECK(version_num >= 75); - cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias); + return; } + + // Turing + TORCH_CHECK(version_num >= 75); + cutlass_scaled_mm_azp_sm75(c, a, b, a_scales, b_scales, azp_adj, azp, bias); + return; +#endif + + TORCH_CHECK_NOT_IMPLEMENTED( + false, + "No compiled cutlass_scaled_mm_azp for a compute capability less than " + "CUDA device capability: ", + version_num); } \ No newline at end of file diff --git a/csrc/quantization/fp8/fp8_marlin.cu b/csrc/quantization/fp8/fp8_marlin.cu index eef6dc6ebdf4a..376bbd498ca52 100644 --- a/csrc/quantization/fp8/fp8_marlin.cu +++ b/csrc/quantization/fp8/fp8_marlin.cu @@ -22,6 +22,8 @@ #include "../gptq_marlin/marlin.cuh" #include "../gptq_marlin/marlin_dtypes.cuh" +#include "core/registration.h" + using namespace marlin; #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ @@ -1303,3 +1305,7 @@ torch::Tensor fp8_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, } #endif + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("fp8_marlin_gemm", &fp8_marlin_gemm); +} \ No newline at end of file diff --git a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu index de8d9ef2ee63e..3e2f87dbc4553 100644 --- a/csrc/quantization/gptq_marlin/awq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/awq_marlin_repack.cu @@ -1,25 +1,6 @@ #include "marlin.cuh" -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -namespace marlin { - -template -__global__ void awq_marlin_repack_kernel( - uint32_t const* __restrict__ b_q_weight_ptr, uint32_t* __restrict__ out_ptr, - int size_k, int size_n) {} - -} // namespace marlin - -torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, - int64_t size_k, int64_t size_n, - int64_t num_bits) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else +#include "core/registration.h" namespace marlin { @@ -122,7 +103,7 @@ __global__ void awq_marlin_repack_kernel( } uint32_t vals[8]; - #pragma unroll +#pragma unroll for (int i = 0; i < 4; i++) { int cur_elem = tc_row + tc_offsets[i]; @@ -143,7 +124,7 @@ __global__ void awq_marlin_repack_kernel( constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; uint32_t res = 0; - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { res |= vals[pack_idx[i]] << (i * 4); } @@ -155,7 +136,7 @@ __global__ void awq_marlin_repack_kernel( uint32_t res1 = 0; uint32_t res2 = 0; - #pragma unroll +#pragma unroll for (int i = 0; i < 4; i++) { res1 |= vals[pack_idx[i]] << (i * 8); res2 |= vals[4 + pack_idx[i]] << (i * 8); @@ -167,21 +148,21 @@ __global__ void awq_marlin_repack_kernel( }; auto start_pipes = [&](int k_tile_id, int n_tile_id) { - #pragma unroll +#pragma unroll for (int pipe = 0; pipe < repack_stages - 1; pipe++) { fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); } wait_for_stage(); }; - #pragma unroll +#pragma unroll for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { int n_tile_id = 0; start_pipes(k_tile_id, n_tile_id); while (n_tile_id < n_tiles) { - #pragma unroll +#pragma unroll for (int pipe = 0; pipe < repack_stages; pipe++) { fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); @@ -195,15 +176,15 @@ __global__ void awq_marlin_repack_kernel( } // namespace marlin - #define CALL_IF(NUM_BITS) \ - else if (num_bits == NUM_BITS) { \ - cudaFuncSetAttribute( \ - marlin::awq_marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - marlin::awq_marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, out_ptr, size_k, size_n); \ - } +#define CALL_IF(NUM_BITS) \ + else if (num_bits == NUM_BITS) { \ + cudaFuncSetAttribute( \ + marlin::awq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + marlin::awq_marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, out_ptr, size_k, size_n); \ + } torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, int64_t size_n, int64_t num_bits) { @@ -266,8 +247,6 @@ torch::Tensor awq_marlin_repack(torch::Tensor& b_q_weight, int64_t size_k, return out; } -#endif - torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) { @@ -279,3 +258,11 @@ torch::Tensor awq_marlin_repack_meta(torch::Tensor& b_q_weight, {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, options); } + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("awq_marlin_repack", &awq_marlin_repack); +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) { + m.impl("awq_marlin_repack", &awq_marlin_repack_meta); +} \ No newline at end of file diff --git a/csrc/quantization/gptq_marlin/gptq_marlin.cu b/csrc/quantization/gptq_marlin/gptq_marlin.cu index 9b4a6a515107d..227bc19b914a0 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin.cu @@ -23,6 +23,8 @@ #include "marlin_dtypes.cuh" #include "core/scalar_type.hpp" +#include "core/registration.h" + #define STATIC_ASSERT_SCALAR_TYPE_VALID(scalar_t) \ static_assert(std::is_same::value || \ std::is_same::value, \ @@ -2297,3 +2299,7 @@ torch::Tensor gptq_marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, } #endif + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("gptq_marlin_gemm", &gptq_marlin_gemm); +} \ No newline at end of file diff --git a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu index 70d48de12ab05..5cd078555046d 100644 --- a/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu +++ b/csrc/quantization/gptq_marlin/gptq_marlin_repack.cu @@ -1,26 +1,6 @@ #include "marlin.cuh" -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 - -namespace marlin { - -template -__global__ void gptq_marlin_repack_kernel( - uint32_t const* __restrict__ b_q_weight_ptr, - uint32_t const* __restrict__ perm_ptr, uint32_t* __restrict__ out_ptr, - int size_k, int size_n) {} - -} // namespace marlin - -torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, - int64_t size_k, int64_t size_n, - int64_t num_bits) { - TORCH_CHECK_NOT_IMPLEMENTED( - false, "marlin_repack_from_gptq(..) requires CUDA_ARCH >= 8.0"); - return torch::empty({1, 1}); -} - -#else +#include "core/registration.h" namespace marlin { @@ -174,13 +154,13 @@ __global__ void gptq_marlin_repack_kernel( uint32_t b1_vals[tile_ints]; uint32_t b2_vals[tile_ints]; - #pragma unroll +#pragma unroll for (int i = 0; i < tile_ints; i++) { b1_vals[i] = sh_stage_int_ptr[cur_n + sh_stride * i]; b2_vals[i] = sh_stage_int_ptr[cur_n + 8 + sh_stride * i]; } - #pragma unroll +#pragma unroll for (int i = 0; i < 4; i++) { int cur_elem = tc_row + tc_offsets[i]; int cur_int = cur_elem / pack_factor; @@ -200,7 +180,7 @@ __global__ void gptq_marlin_repack_kernel( constexpr int pack_idx[8] = {0, 2, 4, 6, 1, 3, 5, 7}; uint32_t res = 0; - #pragma unroll +#pragma unroll for (int i = 0; i < 8; i++) { res |= vals[pack_idx[i]] << (i * 4); } @@ -212,7 +192,7 @@ __global__ void gptq_marlin_repack_kernel( uint32_t res1 = 0; uint32_t res2 = 0; - #pragma unroll +#pragma unroll for (int i = 0; i < 4; i++) { res1 |= vals[pack_idx[i]] << (i * 8); res2 |= vals[4 + pack_idx[i]] << (i * 8); @@ -224,14 +204,14 @@ __global__ void gptq_marlin_repack_kernel( }; auto start_pipes = [&](int k_tile_id, int n_tile_id) { - #pragma unroll +#pragma unroll for (int pipe = 0; pipe < repack_stages - 1; pipe++) { fetch_to_shared(pipe, k_tile_id, n_tile_id + pipe); } wait_for_stage(); }; - #pragma unroll +#pragma unroll for (int k_tile_id = start_k_tile; k_tile_id < finish_k_tile; k_tile_id++) { int n_tile_id = 0; @@ -242,7 +222,7 @@ __global__ void gptq_marlin_repack_kernel( start_pipes(k_tile_id, n_tile_id); while (n_tile_id < n_tiles) { - #pragma unroll +#pragma unroll for (int pipe = 0; pipe < repack_stages; pipe++) { fetch_to_shared((pipe + repack_stages - 1) % repack_stages, k_tile_id, n_tile_id + pipe + repack_stages - 1); @@ -256,17 +236,17 @@ __global__ void gptq_marlin_repack_kernel( } // namespace marlin - #define CALL_IF(NUM_BITS, HAS_PERM) \ - else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ - cudaFuncSetAttribute( \ - marlin::gptq_marlin_repack_kernel, \ - cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ - marlin::gptq_marlin_repack_kernel \ - <<>>( \ - b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ - } +#define CALL_IF(NUM_BITS, HAS_PERM) \ + else if (num_bits == NUM_BITS && has_perm == HAS_PERM) { \ + cudaFuncSetAttribute( \ + marlin::gptq_marlin_repack_kernel, \ + cudaFuncAttributeMaxDynamicSharedMemorySize, max_shared_mem); \ + marlin::gptq_marlin_repack_kernel \ + <<>>( \ + b_q_weight_ptr, perm_ptr, out_ptr, size_k, size_n); \ + } torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, int64_t size_k, int64_t size_n, @@ -341,8 +321,6 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm, return out; } -#endif - torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, torch::Tensor& perm, c10::SymInt size_k, c10::SymInt size_n, int64_t num_bits) { @@ -354,3 +332,11 @@ torch::Tensor gptq_marlin_repack_meta(torch::Tensor& b_q_weight, {size_k / marlin::tile_size, size_n * marlin::tile_size / pack_factor}, options); } + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("gptq_marlin_repack", &gptq_marlin_repack); +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, Meta, m) { + m.impl("gptq_marlin_repack", &gptq_marlin_repack_meta); +} \ No newline at end of file diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index c35dfe94c9c41..ebbe76cfb944a 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -284,7 +284,7 @@ def create_template(template_str): prepack_dispatch_template = create_template(PREPACK_TEMPLATE) -def create_sources(impl_config: ImplConfig, num_impl_files=2): +def create_sources(impl_config: ImplConfig, num_impl_files=1): sources = [] type_name = generate_type_signature(impl_config.type_config) diff --git a/csrc/quantization/machete/machete_prepack_kernel.cuh b/csrc/quantization/machete/machete_prepack_kernel.cuh index 8e02104587d17..f23483f928b47 100644 --- a/csrc/quantization/machete/machete_prepack_kernel.cuh +++ b/csrc/quantization/machete/machete_prepack_kernel.cuh @@ -34,10 +34,9 @@ static __global__ void prepack_B_kernel(BInTensor B_in, } template -static void prepack_B(cudaStream_t stream, - typename PrepackedLayoutB::ElementB const* B_in_ptr, - InLayout B_layout, - typename PrepackedLayoutB::ElementB* B_out_ptr) { +static void prepack_B_template( + cudaStream_t stream, typename PrepackedLayoutB::ElementB const* B_in_ptr, + InLayout B_layout, typename PrepackedLayoutB::ElementB* B_out_ptr) { using TileShapeNKL = decltype(append(typename PrepackedLayoutB::PPBlockShape_NK{}, _1{})); auto ilvd_NKbNbKL_to_offset = diff --git a/csrc/quantization/machete/machete_prepack_launcher.cuh b/csrc/quantization/machete/machete_prepack_launcher.cuh index df78312997fb0..a33d8f9484cfe 100644 --- a/csrc/quantization/machete/machete_prepack_launcher.cuh +++ b/csrc/quantization/machete/machete_prepack_launcher.cuh @@ -55,8 +55,8 @@ torch::Tensor prepack_impl(torch::Tensor const B) { // Allocate output torch::Tensor D = torch::empty_like(B, {}, at::MemoryFormat::Contiguous); - prepack_B(stream, B_ptr, layout_Bt, - static_cast(D.mutable_data_ptr())); + prepack_B_template( + stream, B_ptr, layout_Bt, static_cast(D.mutable_data_ptr())); return D; }; diff --git a/csrc/quantization/machete/machete_pytorch.cu b/csrc/quantization/machete/machete_pytorch.cu index a78cccb2358ee..a27f1e7c83df9 100644 --- a/csrc/quantization/machete/machete_pytorch.cu +++ b/csrc/quantization/machete/machete_pytorch.cu @@ -2,6 +2,8 @@ #include "machete_prepack_launcher.cuh" #include "core/scalar_type.hpp" +#include "core/registration.h" + namespace machete { using namespace vllm; @@ -78,14 +80,16 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B, } torch::Tensor prepack_B(torch::Tensor const& B, - ScalarTypeTorchPtr const& btype) { -#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12 + vllm::ScalarTypeTorchPtr const& btype) { return scalar_type_dispatch(*btype, [&](auto BType) { return PrepackBDispatcher::dispatch(B); }); -#else - TORCH_CHECK(false, "Machete requires CUDA 12.0 or later"); -#endif +} + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("machete_prepack_B", &prepack_B); + m.impl("machete_gemm", &gemm); + m.impl("machete_supported_schedules", &supported_schedules); } }; // namespace machete diff --git a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu index 1ce734c9d90de..c03fef886e4db 100644 --- a/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu +++ b/csrc/quantization/marlin/dense/marlin_cuda_kernel.cu @@ -26,6 +26,7 @@ #include #include "common/base.h" +#include "core/registration.h" #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include "common/mem.h" @@ -1066,3 +1067,7 @@ torch::Tensor marlin_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, return c; } + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("marlin_gemm", &marlin_gemm); +} diff --git a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu b/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu index 4162a38af1035..103a6444f3a21 100644 --- a/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu +++ b/csrc/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu @@ -30,6 +30,7 @@ #include #include "../dense/common/base.h" +#include "core/registration.h" #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 #include "../dense/common/mem.h" @@ -1241,3 +1242,7 @@ torch::Tensor marlin_qqq_gemm(torch::Tensor const& a, return d; } + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("marlin_qqq_gemm", &marlin_qqq_gemm); +} diff --git a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu index 93445a386593b..908e4f70ab1e6 100644 --- a/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu +++ b/csrc/quantization/marlin/sparse/marlin_24_cuda_kernel.cu @@ -28,6 +28,7 @@ #include "common/base.h" #include "core/scalar_type.hpp" +#include "core/registration.h" #if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800 @@ -1134,3 +1135,7 @@ torch::Tensor gptq_marlin_24_gemm(torch::Tensor& a, torch::Tensor& b_q_weight, return c; } + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("gptq_marlin_24_gemm", &gptq_marlin_24_gemm); +} diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 3538f2850f915..a0100b4a85edd 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -167,7 +167,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " "Tensor! workspace, int size_m, int size_n, int size_k) -> Tensor"); - ops.impl("marlin_gemm", torch::kCUDA, &marlin_gemm); + // conditionally compiled so impl in source file // Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ. ops.def( @@ -175,22 +175,24 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor b_scales, Tensor workspace, " "__torch__.torch.classes._core_C.ScalarType b_q_type, " "int size_m, int size_n, int size_k) -> Tensor"); - ops.impl("gptq_marlin_24_gemm", torch::kCUDA, &gptq_marlin_24_gemm); + // conditionally compiled so impl in source file // Machete (Dense) Optimized Mixed Precision GEMM for Hopper. - ops.def("machete_supported_schedules", &machete::supported_schedules); + ops.def( + "machete_supported_schedules(" + " __torch__.torch.classes._core_C.ScalarType btype" + ") -> str[]"); ops.def( "machete_gemm(Tensor A, Tensor B," " __torch__.torch.classes._core_C.ScalarType btype," " Tensor? scales, Tensor? zeros, int? group_size," " Tensor? C, float? alpha, float? beta, str? schedule)" "-> Tensor"); - ops.impl("machete_gemm", torch::kCUDA, &machete::gemm); ops.def( "machete_prepack_B(Tensor B," " __torch__.torch.classes._core_C.ScalarType btype)" "-> Tensor"); - ops.impl("machete_prepack_B", torch::kCUDA, &machete::prepack_B); + // conditionally compiled so impl registration is in source file ops.def("permute_cols(Tensor A, Tensor perm) -> Tensor"); ops.impl("permute_cols", torch::kCUDA, &permute_cols); @@ -202,21 +204,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "__torch__.torch.classes._core_C.ScalarType b_q_type, " "int size_m, int size_n, int size_k, bool is_k_full, " "bool has_zp, bool use_fp32_reduce) -> Tensor"); - ops.impl("gptq_marlin_gemm", torch::kCUDA, &gptq_marlin_gemm); + // conditionally compiled so impl registration is in source file // gptq_marlin repack from GPTQ. ops.def( "gptq_marlin_repack(Tensor b_q_weight, Tensor perm, " "SymInt size_k, SymInt size_n, int num_bits) -> Tensor"); - ops.impl("gptq_marlin_repack", torch::kCUDA, &gptq_marlin_repack); - ops.impl("gptq_marlin_repack", torch::kMeta, &gptq_marlin_repack_meta); + // conditionally compiled so impl registrations are in source file // awq_marlin repack from AWQ. ops.def( "awq_marlin_repack(Tensor b_q_weight, SymInt size_k, " "SymInt size_n, int num_bits) -> Tensor"); - ops.impl("awq_marlin_repack", torch::kCUDA, &awq_marlin_repack); - ops.impl("awq_marlin_repack", torch::kMeta, &awq_marlin_repack_meta); + // conditionally compiled so impl registrations are in source file // Dequantization for GGML. ops.def("ggml_dequantize(Tensor W, int type, int m, int n) -> Tensor"); @@ -237,7 +237,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "fp8_marlin_gemm(Tensor a, Tensor b_q_weight, Tensor b_scales, " "Tensor! workspace, int num_bits, int size_m, int size_n, " "int size_k) -> Tensor"); - ops.impl("fp8_marlin_gemm", torch::kCUDA, &fp8_marlin_gemm); + // conditionally compiled so impl registration is in source file // marlin_qqq_gemm for QQQ. ops.def( @@ -245,7 +245,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "Tensor s_tok, Tensor s_ch, Tensor s_group, " "Tensor! workspace, int size_m, int size_n, " "int size_k) -> Tensor"); - ops.impl("marlin_qqq_gemm", torch::kCUDA, &marlin_qqq_gemm); + // conditionally compiled so impl registration is in source file // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization, as well as bias diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 6687929c0bebe..80037dda20015 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -12,4 +12,5 @@ torch py-cpuinfo transformers mistral_common >= 1.3.4 -openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args \ No newline at end of file +openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args +partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args \ No newline at end of file diff --git a/docs/source/getting_started/openvino-installation.rst b/docs/source/getting_started/openvino-installation.rst index b67e0410f7441..5eeb7c78f7e51 100644 --- a/docs/source/getting_started/openvino-installation.rst +++ b/docs/source/getting_started/openvino-installation.rst @@ -3,7 +3,7 @@ Installation with OpenVINO ========================== -vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support. OpenVINO vLLM backend supports the following advanced vLLM features: +vLLM powered by OpenVINO supports all LLM models from :doc:`vLLM supported models list <../models/supported_models>` and can perform optimal model serving on all x86-64 CPUs with, at least, AVX2 support, as well as on both integrated and discrete IntelĀ® GPUs (`the list of supported GPUs `_). OpenVINO vLLM backend supports the following advanced vLLM features: - Prefix caching (``--enable-prefix-caching``) - Chunked prefill (``--enable-chunked-prefill``) @@ -53,34 +53,57 @@ Install from source $ pip install --upgrade pip $ pip install -r requirements-build.txt --extra-index-url https://download.pytorch.org/whl/cpu -- Finally, install vLLM with OpenVINO backend: +- Finally, install vLLM with OpenVINO backend: .. code-block:: console $ PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" VLLM_TARGET_DEVICE=openvino python -m pip install -v . +- [Optional] To use vLLM OpenVINO backend with a GPU device, ensure your system is properly set up. Follow the instructions provided here: `https://docs.openvino.ai/2024/get-started/configurations/configurations-intel-gpu.html `_. + .. _openvino_backend_performance_tips: Performance tips ---------------- -vLLM OpenVINO backend uses the following environment variables to control behavior: +vLLM OpenVINO backend environment variables +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + +- ``VLLM_OPENVINO_DEVICE`` to specify which device utilize for the inference. If there are multiple GPUs in the system, additional indexes can be used to choose the proper one (e.g, ``VLLM_OPENVINO_DEVICE=GPU.1``). If the value is not specified, CPU device is used by default. + +- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using `optimum-cli` and pass exported folder as `` + +CPU performance tips +~~~~~~~~~~~~~~~~~~~~ + +CPU uses the following environment variables to control behavior: - ``VLLM_OPENVINO_KVCACHE_SPACE`` to specify the KV Cache size (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=40`` means 40 GB space for KV cache), larger setting will allow vLLM running more requests in parallel. This parameter should be set based on the hardware configuration and memory management pattern of users. - ``VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8`` to control KV cache precision. By default, FP16 / BF16 is used depending on platform. -- ``VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON`` to enable U8 weights compression during model loading stage. By default, compression is turned off. You can also export model with different compression techniques using `optimum-cli` and pass exported folder as `` - To enable better TPOT / TTFT latency, you can use vLLM's chunked prefill feature (``--enable-chunked-prefill``). Based on the experiments, the recommended batch size is ``256`` (``--max-num-batched-tokens``) -OpenVINO best known configuration is: +OpenVINO best known configuration for CPU is: .. code-block:: console $ VLLM_OPENVINO_KVCACHE_SPACE=100 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION=u8 VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \ python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json --enable-chunked-prefill --max-num-batched-tokens 256 +GPU performance tips +~~~~~~~~~~~~~~~~~~~~ +GPU device implements the logic for automatic detection of available GPU memory and, by default, tries to reserve as much memory as possible for the KV cache (taking into account ``gpu_memory_utilization`` option). However, this behavior can be overridden by explicitly specifying the desired amount of memory for the KV cache using ``VLLM_OPENVINO_KVCACHE_SPACE`` environment variable (e.g, ``VLLM_OPENVINO_KVCACHE_SPACE=8`` means 8 GB space for KV cache). + +Currently, the best performance using GPU can be achieved with the default vLLM execution parameters for models with quantized weights (8 and 4-bit integer data types are supported) and `preemption-mode=swap`. + +OpenVINO best known configuration for GPU is: + +.. code-block:: console + + $ VLLM_OPENVINO_DEVICE=GPU VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS=ON \ + python3 vllm/benchmarks/benchmark_throughput.py --model meta-llama/Llama-2-7b-chat-hf --dataset vllm/benchmarks/ShareGPT_V3_unfiltered_cleaned_split.json + .. _openvino_backend_limitations: Limitations diff --git a/docs/source/models/adding_model.rst b/docs/source/models/adding_model.rst index 5cffb58cafd96..1f220b723cacd 100644 --- a/docs/source/models/adding_model.rst +++ b/docs/source/models/adding_model.rst @@ -99,7 +99,7 @@ This method should load the weights from the HuggingFace's checkpoint file and a 5. Register your model ---------------------- -Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/__init__.py `_. +Finally, register your :code:`*ForCausalLM` class to the :code:`_MODELS` in `vllm/model_executor/models/registry.py `_. 6. Out-of-Tree Model Integration -------------------------------------------- diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index b05cba3b5d423..23f08bfa9756e 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -12,185 +12,249 @@ Alongside each architecture, we include some popular models that use it. Decoder-only Language Models ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. list-table:: - :widths: 25 25 50 5 + :widths: 25 25 50 5 5 :header-rows: 1 * - Architecture - Models - Example HuggingFace Models - :ref:`LoRA ` + - :ref:`PP ` * - :code:`AquilaForCausalLM` - - Aquila & Aquila2 + - Aquila, Aquila2 - :code:`BAAI/Aquila-7B`, :code:`BAAI/AquilaChat-7B`, etc. - āœ…ļøŽ + - āœ…ļøŽ * - :code:`ArcticForCausalLM` - Arctic - :code:`Snowflake/snowflake-arctic-base`, :code:`Snowflake/snowflake-arctic-instruct`, etc. - + - āœ…ļøŽ * - :code:`BaiChuanForCausalLM` - - Baichuan & Baichuan2 + - Baichuan2, Baichuan - :code:`baichuan-inc/Baichuan2-13B-Chat`, :code:`baichuan-inc/Baichuan-7B`, etc. - āœ…ļøŽ + - āœ…ļøŽ * - :code:`BloomForCausalLM` - BLOOM, BLOOMZ, BLOOMChat - :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc. - + - āœ…ļøŽ * - :code:`ChatGLMModel` - ChatGLM - :code:`THUDM/chatglm2-6b`, :code:`THUDM/chatglm3-6b`, etc. - āœ…ļøŽ + - āœ…ļøŽ * - :code:`CohereForCausalLM` - Command-R - :code:`CohereForAI/c4ai-command-r-v01`, etc. - - + - āœ…ļøŽ + - āœ…ļøŽ * - :code:`DbrxForCausalLM` - DBRX - :code:`databricks/dbrx-base`, :code:`databricks/dbrx-instruct`, etc. - + - āœ…ļøŽ * - :code:`DeciLMForCausalLM` - DeciLM - :code:`Deci/DeciLM-7B`, :code:`Deci/DeciLM-7B-instruct`, etc. - + - āœ…ļøŽ + * - :code:`DeepseekForCausalLM` + - DeepSeek + - :code:`deepseek-ai/deepseek-llm-67b-base`, :code:`deepseek-ai/deepseek-llm-7b-chat` etc. + - + - āœ…ļøŽ + * - :code:`DeepseekV2ForCausalLM` + - DeepSeek-V2 + - :code:`deepseek-ai/DeepSeek-V2`, :code:`deepseek-ai/DeepSeek-V2-Chat` etc. + - + - āœ…ļøŽ * - :code:`ExaoneForCausalLM` - EXAONE-3 - :code:`LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct`, etc. - āœ…ļøŽ + - āœ…ļøŽ * - :code:`FalconForCausalLM` - Falcon - :code:`tiiuae/falcon-7b`, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc. - + - āœ…ļøŽ * - :code:`GemmaForCausalLM` - Gemma - :code:`google/gemma-2b`, :code:`google/gemma-7b`, etc. - āœ…ļøŽ + - āœ…ļøŽ * - :code:`Gemma2ForCausalLM` - Gemma2 - :code:`google/gemma-2-9b`, :code:`google/gemma-2-27b`, etc. - āœ…ļøŽ + - āœ…ļøŽ * - :code:`GPT2LMHeadModel` - GPT-2 - :code:`gpt2`, :code:`gpt2-xl`, etc. - + - āœ…ļøŽ * - :code:`GPTBigCodeForCausalLM` - StarCoder, SantaCoder, WizardCoder - :code:`bigcode/starcoder`, :code:`bigcode/gpt_bigcode-santacoder`, :code:`WizardLM/WizardCoder-15B-V1.0`, etc. - āœ…ļøŽ + - āœ…ļøŽ * - :code:`GPTJForCausalLM` - GPT-J - :code:`EleutherAI/gpt-j-6b`, :code:`nomic-ai/gpt4all-j`, etc. - + - āœ…ļøŽ * - :code:`GPTNeoXForCausalLM` - GPT-NeoX, Pythia, OpenAssistant, Dolly V2, StableLM - :code:`EleutherAI/gpt-neox-20b`, :code:`EleutherAI/pythia-12b`, :code:`OpenAssistant/oasst-sft-4-pythia-12b-epoch-3.5`, :code:`databricks/dolly-v2-12b`, :code:`stabilityai/stablelm-tuned-alpha-7b`, etc. - + - āœ…ļøŽ + * - :code:`GraniteForCausalLM` + - PowerLM + - :code:`ibm/PowerLM-3b` etc. + - āœ…ļøŽ + - āœ…ļøŽ + * - :code:`GraniteMoeForCausalLM` + - PowerMoE + - :code:`ibm/PowerMoE-3b` etc. + - āœ…ļøŽ + - āœ…ļøŽ * - :code:`InternLMForCausalLM` - InternLM - :code:`internlm/internlm-7b`, :code:`internlm/internlm-chat-7b`, etc. - āœ…ļøŽ + - āœ…ļøŽ * - :code:`InternLM2ForCausalLM` - InternLM2 - :code:`internlm/internlm2-7b`, :code:`internlm/internlm2-chat-7b`, etc. - + - āœ…ļøŽ * - :code:`JAISLMHeadModel` - Jais - :code:`core42/jais-13b`, :code:`core42/jais-13b-chat`, :code:`core42/jais-30b-v3`, :code:`core42/jais-30b-chat-v3`, etc. - + - āœ…ļøŽ * - :code:`JambaForCausalLM` - Jamba - - :code:`ai21labs/Jamba-v0.1`, etc. + - :code:`ai21labs/AI21-Jamba-1.5-Large`, :code:`ai21labs/AI21-Jamba-1.5-Mini`, :code:`ai21labs/Jamba-v0.1`, etc. - āœ…ļøŽ + - * - :code:`LlamaForCausalLM` - Llama 3.1, Llama 3, Llama 2, LLaMA, Yi - :code:`meta-llama/Meta-Llama-3.1-405B-Instruct`, :code:`meta-llama/Meta-Llama-3.1-70B`, :code:`meta-llama/Meta-Llama-3-70B-Instruct`, :code:`meta-llama/Llama-2-70b-hf`, :code:`01-ai/Yi-34B`, etc. - āœ…ļøŽ + - āœ…ļøŽ * - :code:`MiniCPMForCausalLM` - MiniCPM - :code:`openbmb/MiniCPM-2B-sft-bf16`, :code:`openbmb/MiniCPM-2B-dpo-bf16`, etc. - - + - āœ…ļøŽ + - āœ…ļøŽ * - :code:`MiniCPM3ForCausalLM` - MiniCPM3 - :code:`openbmb/MiniCPM3-4B`, etc. - - + - āœ…ļøŽ + - āœ…ļøŽ * - :code:`MistralForCausalLM` - Mistral, Mistral-Instruct - :code:`mistralai/Mistral-7B-v0.1`, :code:`mistralai/Mistral-7B-Instruct-v0.1`, etc. - āœ…ļøŽ + - āœ…ļøŽ * - :code:`MixtralForCausalLM` - Mixtral-8x7B, Mixtral-8x7B-Instruct - :code:`mistralai/Mixtral-8x7B-v0.1`, :code:`mistralai/Mixtral-8x7B-Instruct-v0.1`, :code:`mistral-community/Mixtral-8x22B-v0.1`, etc. - āœ…ļøŽ + - āœ…ļøŽ * - :code:`MPTForCausalLM` - MPT, MPT-Instruct, MPT-Chat, MPT-StoryWriter - :code:`mosaicml/mpt-7b`, :code:`mosaicml/mpt-7b-storywriter`, :code:`mosaicml/mpt-30b`, etc. - + - āœ…ļøŽ * - :code:`NemotronForCausalLM` - Nemotron-3, Nemotron-4, Minitron - :code:`nvidia/Minitron-8B-Base`, :code:`mgoin/Nemotron-4-340B-Base-hf-FP8`, etc. - āœ…ļøŽ - * - :code:`OLMoEForCausalLM` - - OLMoE - - :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc. - - + - āœ…ļøŽ * - :code:`OLMoForCausalLM` - OLMo - :code:`allenai/OLMo-1B-hf`, :code:`allenai/OLMo-7B-hf`, etc. - + - āœ…ļøŽ + * - :code:`OLMoEForCausalLM` + - OLMoE + - :code:`allenai/OLMoE-1B-7B-0924`, :code:`allenai/OLMoE-1B-7B-0924-Instruct`, etc. + - āœ…ļøŽ + - āœ…ļøŽ * - :code:`OPTForCausalLM` - OPT, OPT-IML - :code:`facebook/opt-66b`, :code:`facebook/opt-iml-max-30b`, etc. - + - āœ…ļøŽ * - :code:`OrionForCausalLM` - Orion - :code:`OrionStarAI/Orion-14B-Base`, :code:`OrionStarAI/Orion-14B-Chat`, etc. - + - āœ…ļøŽ * - :code:`PhiForCausalLM` - Phi - :code:`microsoft/phi-1_5`, :code:`microsoft/phi-2`, etc. - āœ…ļøŽ + - āœ…ļøŽ * - :code:`Phi3ForCausalLM` - Phi-3 - :code:`microsoft/Phi-3-mini-4k-instruct`, :code:`microsoft/Phi-3-mini-128k-instruct`, :code:`microsoft/Phi-3-medium-128k-instruct`, etc. - - + - āœ…ļøŽ + - āœ…ļøŽ * - :code:`Phi3SmallForCausalLM` - Phi-3-Small - :code:`microsoft/Phi-3-small-8k-instruct`, :code:`microsoft/Phi-3-small-128k-instruct`, etc. - + - āœ…ļøŽ * - :code:`PhiMoEForCausalLM` - Phi-3.5-MoE - :code:`microsoft/Phi-3.5-MoE-instruct`, etc. - - + - āœ…ļøŽ + - āœ…ļøŽ * - :code:`PersimmonForCausalLM` - Persimmon - :code:`adept/persimmon-8b-base`, :code:`adept/persimmon-8b-chat`, etc. - + - āœ…ļøŽ * - :code:`QWenLMHeadModel` - Qwen - :code:`Qwen/Qwen-7B`, :code:`Qwen/Qwen-7B-Chat`, etc. - + - āœ…ļøŽ * - :code:`Qwen2ForCausalLM` - Qwen2 - :code:`Qwen/Qwen2-beta-7B`, :code:`Qwen/Qwen2-beta-7B-Chat`, etc. - āœ…ļøŽ + - āœ…ļøŽ * - :code:`Qwen2MoeForCausalLM` - Qwen2MoE - :code:`Qwen/Qwen1.5-MoE-A2.7B`, :code:`Qwen/Qwen1.5-MoE-A2.7B-Chat`, etc. - + - āœ…ļøŽ * - :code:`StableLmForCausalLM` - StableLM - - :code:`stabilityai/stablelm-3b-4e1t/` , :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc. + - :code:`stabilityai/stablelm-3b-4e1t`, :code:`stabilityai/stablelm-base-alpha-7b-v2`, etc. - + - āœ…ļøŽ * - :code:`Starcoder2ForCausalLM` - Starcoder2 - :code:`bigcode/starcoder2-3b`, :code:`bigcode/starcoder2-7b`, :code:`bigcode/starcoder2-15b`, etc. - + - āœ…ļøŽ * - :code:`SolarForCausalLM` - - EXAONE-3 + - Solar Pro - :code:`upstage/solar-pro-preview-instruct`, etc. - - + - āœ…ļøŽ + - āœ…ļøŽ * - :code:`XverseForCausalLM` - - Xverse + - XVERSE - :code:`xverse/XVERSE-7B-Chat`, :code:`xverse/XVERSE-13B-Chat`, :code:`xverse/XVERSE-65B-Chat`, etc. - - + - āœ…ļøŽ + - āœ…ļøŽ .. note:: Currently, the ROCm version of vLLM supports Mistral and Mixtral only for context lengths up to 4096. @@ -201,7 +265,7 @@ Multimodal Language Models ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. list-table:: - :widths: 25 25 25 25 5 + :widths: 25 25 25 25 5 5 :header-rows: 1 * - Architecture @@ -209,86 +273,103 @@ Multimodal Language Models - Modalities - Example HuggingFace Models - :ref:`LoRA ` + - :ref:`PP ` * - :code:`Blip2ForConditionalGeneration` - BLIP-2 - Image\ :sup:`E` - :code:`Salesforce/blip2-opt-2.7b`, :code:`Salesforce/blip2-opt-6.7b`, etc. - + - āœ…ļøŽ * - :code:`ChameleonForConditionalGeneration` - Chameleon - Image - :code:`facebook/chameleon-7b` etc. - + - āœ…ļøŽ * - :code:`FuyuForCausalLM` - Fuyu - Image - :code:`adept/fuyu-8b` etc. - + - āœ…ļøŽ * - :code:`InternVLChatModel` - InternVL2 - Image\ :sup:`E+` - :code:`OpenGVLab/InternVL2-4B`, :code:`OpenGVLab/InternVL2-8B`, etc. - + - āœ…ļøŽ * - :code:`LlavaForConditionalGeneration` - LLaVA-1.5 - Image\ :sup:`E+` - :code:`llava-hf/llava-1.5-7b-hf`, :code:`llava-hf/llava-1.5-13b-hf`, etc. - + - āœ…ļøŽ * - :code:`LlavaNextForConditionalGeneration` - LLaVA-NeXT - Image\ :sup:`E+` - :code:`llava-hf/llava-v1.6-mistral-7b-hf`, :code:`llava-hf/llava-v1.6-vicuna-7b-hf`, etc. - + - āœ…ļøŽ * - :code:`LlavaNextVideoForConditionalGeneration` - LLaVA-NeXT-Video - Video - :code:`llava-hf/LLaVA-NeXT-Video-7B-hf`, etc. - + - āœ…ļøŽ * - :code:`LlavaOnevisionForConditionalGeneration` - LLaVA-Onevision - Image\ :sup:`+` / Video - :code:`llava-hf/llava-onevision-qwen2-7b-ov-hf`, :code:`llava-hf/llava-onevision-qwen2-0.5b-ov-hf`, etc. - + - āœ…ļøŽ * - :code:`MiniCPMV` - MiniCPM-V - Image\ :sup:`+` - :code:`openbmb/MiniCPM-V-2` (see note), :code:`openbmb/MiniCPM-Llama3-V-2_5`, :code:`openbmb/MiniCPM-V-2_6`, etc. - - + - āœ…ļøŽ + - āœ…ļøŽ * - :code:`MllamaForConditionalGeneration` - Llama 3.2 - Image - :code:`meta-llama/Llama-3.2-90B-Vision-Instruct`, :code:`meta-llama/Llama-3.2-11B-Vision`, etc. - + - * - :code:`PaliGemmaForConditionalGeneration` - PaliGemma - Image\ :sup:`E` - :code:`google/paligemma-3b-pt-224`, :code:`google/paligemma-3b-mix-224`, etc. - + - āœ…ļøŽ * - :code:`Phi3VForCausalLM` - Phi-3-Vision, Phi-3.5-Vision - Image\ :sup:`E+` - :code:`microsoft/Phi-3-vision-128k-instruct`, :code:`microsoft/Phi-3.5-vision-instruct` etc. - + - āœ…ļøŽ * - :code:`PixtralForConditionalGeneration` - Pixtral - Image\ :sup:`+` - :code:`mistralai/Pixtral-12B-2409` - + - āœ…ļøŽ * - :code:`QWenLMHeadModel` - Qwen-VL - Image\ :sup:`E+` - :code:`Qwen/Qwen-VL`, :code:`Qwen/Qwen-VL-Chat`, etc. - + - āœ…ļøŽ * - :code:`Qwen2VLForConditionalGeneration` - Qwen2-VL - Image\ :sup:`E+` / Video\ :sup:`+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - + - āœ…ļøŽ * - :code:`UltravoxModel` - Ultravox - Audio\ :sup:`E+` - :code:`fixie-ai/ultravox-v0_3` - + - āœ…ļøŽ | :sup:`E` Pre-computed embeddings can be inputted for this modality. | :sup:`+` Multiple items can be inputted per text prompt for this modality. diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index e0eba7f09bd65..8bb7067faa97c 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -157,8 +157,9 @@ vLLM will use guided decoding to ensure the response matches the tool parameter To enable this feature, you should set the following flags: * `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it deems appropriate. -* `--tool-call-parser` -- select the tool parser to use - currently either `hermes`, `mistral` or `llama3_json`. Additional tool parsers -will continue to be added in the future. +* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral` or `llama3_json` or `internlm`. Additional tool parsers +will continue to be added in the future, and also can register your own tool parsers in the `--tool-parser-plugin`. +* `--tool-parser-plugin` -- **optional** tool parser plugin used to register user defined tool parsers into vllm, the registered tool parser name can be specified in `--tool-call-parser`. * `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their `tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat @@ -218,4 +219,73 @@ it works better with vLLM. Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja` +#### Internlm Models +Supported models: +* `internlm/internlm2_5-7b-chat` (confirmed) +* Additional internlm2.5 function-calling models are compatible as well + +Known issues: +* Although this implementation also supports Internlm2, the tool call results are not stable when testing with the `internlm/internlm2-chat-7b` model. + +Recommended flags: `--tool-call-parser internlm --chat-template examples/tool_chat_template_internlm2_tool.jinja` + + +### How to write a tool parser plugin + +A tool parser plugin is a Python file containing one or more ToolParser implementations. You can write a ToolParser similar to the `Hermes2ProToolParser` in vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py. + +Here is a summary of a plugin file: + +```python + +# import the required packages + +# define a tool parser and register it to vllm +# the name list in register_module can be used +# in --tool-call-parser. you can define as many +# tool parsers as you want here. +@ToolParserManager.register_module(["example"]) +class ExampleToolParser(ToolParser): + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + + # adjust request. e.g.: set skip special tokens + # to False for tool call output. + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + return request + + # implement the tool call parse for stream call + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + return delta + + # implement the tool parse for non-stream call + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) + + +``` + +Then you can use this plugin in the command line like this. +``` + --enable-auto-tool-choice \ + --tool-parser-plugin + --tool-call-parser example \ + --chat-template \ +``` diff --git a/examples/tool_chat_template_internlm2_tool.jinja b/examples/tool_chat_template_internlm2_tool.jinja new file mode 100644 index 0000000000000..ac99666e93bc4 --- /dev/null +++ b/examples/tool_chat_template_internlm2_tool.jinja @@ -0,0 +1,60 @@ +{%- if messages[0]["role"] == "system" %} + {%- set system_message = messages[0]["content"] %} + {%- set loop_messages = messages[1:] %} +{%- else %} + {%- set loop_messages = messages %} +{%- endif %} + +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{{- bos_token }} +{%- if system_message is defined %} +{{- "<|im_start|>system\n" + system_message + "<|im_end|>\n" }} +{%- endif %} + +{%- if tools is not none %} + {{- "<|im_start|>system name=<|plugin|>\n[" }} + {%- for tool in tools %} + {{- tool.function|tojson }} + {%- if not loop.last %} + {{- ", " }} + {%- else %} + {{- "]" }} + {%- endif %} + {%- endfor %} + {{- "<|im_end|>\n" }} +{%- endif %} + +{%- for message in loop_messages %} + {%- if message["role"] == "user" %} + {{- "<|im_start|>user\n" + message["content"] + "<|im_end|>\n"}} + {%- elif message.tool_calls is defined and message.tool_calls is not none %} + {%- set content = message["content"] if message["content"] else "" %} + {{- "<|im_start|>assistant\n" + content }} + {%- for tool_call in message.tool_calls %} + {%- set function=tool_call.function %} + {{- "<|action_start|><|plugin|>\n" }} + {{- '{"name": "' + function.name + '", '}} + {{- '"arguments": ' + function.arguments|tojson + '}' }} + {{- "<|action_end|>" }} + {%- endfor %} + {{- "<|im_end|>\n" }} + {%- elif message["role"] == "assistant" %} + {{- "<|im_start|>assistant\n" + message["content"] + "<|im_end|>\n"}} + {%- elif message["role"] == "tool_results" or message["role"] == "tool" or message["role"] == "function" %} + {%- if message.content is defined and message.content.content is defined %} + {%- set content = message.content.content %} + {%- else %} + {%- set content = message.content %} + {%- endif %} + {{- "<|im_start|>environment name=<|plugin|>\n" + content|string + "<|im_end|>\n" }} + {%- else %} + {{- raise_exception("Only user and assistant and tool_results and tool and function roles are supported, with the exception of an initial optional system message!") }} + {%- endif %} +{%- endfor %} + +{%- if add_generation_prompt %} +{{- '<|im_start|>assistant\n' }} +{%- endif %} \ No newline at end of file diff --git a/examples/tool_chat_template_mistral_parallel.jinja b/examples/tool_chat_template_mistral_parallel.jinja index a294cbfd026be..2ef4bedf86211 100644 --- a/examples/tool_chat_template_mistral_parallel.jinja +++ b/examples/tool_chat_template_mistral_parallel.jinja @@ -6,8 +6,7 @@ {%- endif %} {%- if not tools is defined %} {%- set tools = none %} -{%- endif %} -{%- if tools is defined %} +{%- elif tools is not none %} {%- set parallel_tool_prompt = "You are a helpful assistant that can call tools. If you call one or more tools, format them in a single JSON array or objects, where each object is a tool call, not as separate objects outside of an array or multiple arrays. Use the format [{\"name\": tool call name, \"arguments\": tool call arguments}, additional tool calls] if you call more than one tool. If you call tools, do not attempt to interpret them or otherwise provide a response until you receive a tool call result that you can interpret for the user." %} {%- if system_message is defined %} {%- set system_message = parallel_tool_prompt + "\n\n" + system_message %} diff --git a/requirements-openvino.txt b/requirements-openvino.txt index 419294aa75626..800d59e2b9483 100644 --- a/requirements-openvino.txt +++ b/requirements-openvino.txt @@ -3,5 +3,6 @@ # OpenVINO dependencies torch >= 2.1.2 -openvino ~= 2024.3.0 -optimum-intel[openvino] >= 1.18.2 +openvino ~= 2024.4.0 +openvino-tokenizers[transformers] ~= 2024.4.0 +optimum-intel[openvino] >= 1.19.0 diff --git a/requirements-test.txt b/requirements-test.txt index 9c6fadb88865a..37c3bd8ba8794 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -10,8 +10,8 @@ pytest-shard awscli einops # required for MPT, qwen-vl and Mamba httpx -librosa # required for audio test -opencv-python # required for video test +librosa # required for audio tests +opencv-python # required for video tests peft requests ray[adag]==2.35 diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 2e8e83c3d271b..1f62cdc7e06a8 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -6,6 +6,8 @@ to fail. """ import os +from dataclasses import dataclass +from typing import List, NamedTuple, Optional import pytest @@ -18,49 +20,256 @@ VLLM_MULTI_NODE = os.getenv("VLLM_MULTI_NODE", "0") == "1" +class ParallelSetup(NamedTuple): + tp_size: int + pp_size: int + eager_mode: bool + chunked_prefill: bool + + +@dataclass +class PPTestSettings: + parallel_setups: List[ParallelSetup] + distributed_backends: List[str] + trust_remote_code: bool + tokenizer_mode: Optional[str] + + @staticmethod + def detailed( + *, + tp_base: int = 1, + pp_base: int = 2, + trust_remote_code: bool = False, + tokenizer_mode: Optional[str] = None, + ): + return PPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + pp_size=pp_base, + eager_mode=False, + chunked_prefill=False), + ParallelSetup(tp_size=tp_base, + pp_size=2 * pp_base, + eager_mode=False, + chunked_prefill=True), + ParallelSetup(tp_size=tp_base, + pp_size=2 * pp_base, + eager_mode=True, + chunked_prefill=False), + ParallelSetup(tp_size=2 * tp_base, + pp_size=pp_base, + eager_mode=False, + chunked_prefill=True), + ParallelSetup(tp_size=2 * tp_base, + pp_size=pp_base, + eager_mode=True, + chunked_prefill=False), + ], + distributed_backends=["mp", "ray"], + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + ) + + @staticmethod + def fast( + *, + tp_base: int = 1, + pp_base: int = 2, + trust_remote_code: bool = False, + tokenizer_mode: Optional[str] = None, + ): + return PPTestSettings( + parallel_setups=[ + ParallelSetup(tp_size=tp_base, + pp_size=pp_base, + eager_mode=True, + chunked_prefill=False), + ], + distributed_backends=["mp"], + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + ) + + def iter_params(self, model_name: str): + for parallel_setup in self.parallel_setups: + for distributed_backend in self.distributed_backends: + yield (model_name, parallel_setup, distributed_backend, + self.trust_remote_code, self.tokenizer_mode) + + +# yapf: disable +GENERATION_MODEL_SETTINGS = { + # [DETAILED TESTS] + "meta-llama/Meta-Llama-3-8B": PPTestSettings.detailed(), + # [FAST TESTS] + # Uses Llama + # "BAAI/AquilaChat-7B": PPTestSettings.fast(), + # TODO: Test on larger GPU + # "Snowflake/snowflake-arctic-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "baichuan-inc/Baichuan-7B": PPTestSettings.fast(trust_remote_code=True), + "baichuan-inc/Baichuan2-13B-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "bigscience/bloomz-1b1": PPTestSettings.fast(), + "THUDM/chatglm3-6b": PPTestSettings.fast(trust_remote_code=True), + "CohereForAI/c4ai-command-r-v01": PPTestSettings.fast(tp_base=2, trust_remote_code=True), # noqa: E501 + # TODO: Test on larger GPU + # "databricks/dbrx-instruct": PPTestSettings.fast(), + "Deci/DeciLM-7B-instruct": PPTestSettings.fast(trust_remote_code=True), + "deepseek-ai/deepseek-llm-7b-chat": PPTestSettings.fast(), + "deepseek-ai/DeepSeek-V2-Lite-Chat": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "LGAI-EXAONE/EXAONE-3.0-7.8B-Instruct": PPTestSettings.fast(), + "tiiuae/falcon-7b": PPTestSettings.fast(), + "google/gemma-2b": PPTestSettings.fast(), + "google/gemma-2-9b": PPTestSettings.fast(), + "gpt2": PPTestSettings.fast(), + "bigcode/starcoder": PPTestSettings.fast(), + "EleutherAI/gpt-j-6b": PPTestSettings.fast(), + "EleutherAI/pythia-12b": PPTestSettings.fast(), + "ibm/PowerLM-3b": PPTestSettings.fast(), + "ibm/PowerMoE-3b": PPTestSettings.fast(), + # Uses Llama + # "internlm/internlm-chat-7b": PPTestSettings.fast(), + "internlm/internlm2-chat-7b": PPTestSettings.fast(trust_remote_code=True), + "core42/jais-13b-chat": PPTestSettings.fast(), + # TODO: Implement PP + # "ai21labs/AI21-Jamba-1.5-Mini": PPTestSettings.fast(), + "openbmb/MiniCPM-2B-sft-bf16": PPTestSettings.fast(trust_remote_code=True), + "openbmb/MiniCPM3-4B": PPTestSettings.fast(trust_remote_code=True), + # Uses Llama + # "mistralai/Mistral-7B-Instruct-v0.1": PPTestSettings.fast(), + "mistralai/Mixtral-8x7B-Instruct-v0.1": PPTestSettings.fast(tp_base=4), + "mosaicml/mpt-7b": PPTestSettings.fast(), + "nvidia/Minitron-8B-Base": PPTestSettings.fast(), + "allenai/OLMoE-1B-7B-0924-Instruct": PPTestSettings.fast(), + "allenai/OLMo-1B-hf": PPTestSettings.fast(), + "facebook/opt-iml-max-1.3b": PPTestSettings.fast(), + "OrionStarAI/Orion-14B-Chat": PPTestSettings.fast(trust_remote_code=True), + "microsoft/phi-2": PPTestSettings.fast(), + "microsoft/Phi-3-mini-4k-instruct": PPTestSettings.fast(), + "microsoft/Phi-3-small-8k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + # FIXME: https://github.com/vllm-project/vllm/issues/8553 + # "microsoft/Phi-3.5-MoE-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "adept/persimmon-8b-chat": PPTestSettings.fast(), + "Qwen/Qwen-7B-Chat": PPTestSettings.fast(trust_remote_code=True), + "Qwen/Qwen2-beta-7B-Chat": PPTestSettings.fast(), + "Qwen/Qwen1.5-MoE-A2.7B-Chat": PPTestSettings.fast(), + "stabilityai/stablelm-3b-4e1t": PPTestSettings.fast(), + "bigcode/starcoder2-3b": PPTestSettings.fast(), + "upstage/solar-pro-preview-instruct": PPTestSettings.fast(tp_base=2), + # FIXME: Cannot load tokenizer in latest transformers version + # "xverse/XVERSE-7B-Chat": PPTestSettings.fast(trust_remote_code=True), +} + +EMBEDDING_MODEL_SETTINGS = { # type: ignore[var-annotated] + # [FAST TESTS] + # Uses Llama + # "intfloat/e5-mistral-7b-instruct": PPTestSettings.fast(), +} + +MULTIMODAL_MODEL_SETTINGS = { + # [FAST TESTS] + "Salesforce/blip2-opt-2.7b": PPTestSettings.fast(), + "facebook/chameleon-7b": PPTestSettings.fast(), + "adept/fuyu-8b": PPTestSettings.fast(), + "OpenGVLab/InternVL2-1B": PPTestSettings.fast(trust_remote_code=True), + "llava-hf/llava-1.5-7b-hf": PPTestSettings.fast(), + "llava-hf/llava-v1.6-mistral-7b-hf": PPTestSettings.fast(), + "llava-hf/LLaVA-NeXT-Video-7B-hf": PPTestSettings.fast(), + "llava-hf/llava-onevision-qwen2-0.5b-ov-hf": PPTestSettings.fast(), + "openbmb/MiniCPM-Llama3-V-2_5": PPTestSettings.fast(trust_remote_code=True), + # TODO: Implement PP + # "meta-llama/Llama-3.2-11B-Vision-Instruct": PPTestSettings.fast(), + "microsoft/Phi-3-vision-128k-instruct": PPTestSettings.fast(trust_remote_code=True), # noqa: E501 + "mistralai/Pixtral-12B-2409": PPTestSettings.fast(tp_base=2, tokenizer_mode="mistral"), # noqa: E501 + "Qwen/Qwen-VL-Chat": PPTestSettings.fast(trust_remote_code=True), + "Qwen/Qwen2-VL-2B-Instruct": PPTestSettings.fast(), + "fixie-ai/ultravox-v0_3": PPTestSettings.fast(), +} + +CONDITIONAL_GENERATION_MODEL_SETTINGS = { # type: ignore[var-annotated] + # [FAST TESTS] + # TODO: Implement PP + # "facebook/bart-base": PPTestSettings.fast(), +} +# yapf: enable + +MODEL_SETTINGS = { + **GENERATION_MODEL_SETTINGS, + **EMBEDDING_MODEL_SETTINGS, + **MULTIMODAL_MODEL_SETTINGS, +} + +# You can update this on your local machine to run specific tests +TEST_MODELS = [ + "meta-llama/Meta-Llama-3-8B", + "facebook/chameleon-7b", + "OpenGVLab/InternVL2-1B", + "microsoft/Phi-3-vision-128k-instruct", + "mistralai/Pixtral-12B-2409", + "fixie-ai/ultravox-v0_3", +] + + @pytest.mark.parametrize( - ("TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, TRUST_REMOTE_CODE, " - "MODEL_NAME, DIST_BACKEND"), + ("model_name", "parallel_setup", "distributed_backend", + "trust_remote_code", "tokenizer_mode"), [ - (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "mp"), - (1, 3, 0, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (1, 4, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 1, 0, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - (2, 2, 0, 1, 0, "meta-llama/Meta-Llama-3-8B", "ray"), - # NOTE: InternVL2 multi-node tests are flaky, - # use mp backend to skip the multi-node tests - (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"), - (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"), - (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"), - (1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp"), - # TP only models - (2, 1, 1, 0, 0, "adept/fuyu-8b", "mp"), + params for model_name, settings in MODEL_SETTINGS.items() + for params in settings.iter_params(model_name) + if model_name in TEST_MODELS ], ) @fork_new_process_for_each_test -def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, - TRUST_REMOTE_CODE, MODEL_NAME, DIST_BACKEND): - if VLLM_MULTI_NODE and DIST_BACKEND == "mp": +def test_compare_tp(model_name: str, parallel_setup: ParallelSetup, + distributed_backend: str, trust_remote_code: bool, + tokenizer_mode: Optional[str], num_gpus_available): + tp_size, pp_size, eager_mode, chunked_prefill = parallel_setup + + if num_gpus_available < tp_size: + pytest.skip(f"Need at least {tp_size} GPUs to run the test") + if VLLM_MULTI_NODE and distributed_backend == "mp": pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") - pp_args = [ + common_args = [ # use half precision for speed and memory savings in CI environment "--dtype", "float16", "--max-model-len", - "8192", + "2048", + "--max-num-seqs", + "8", + ] + if chunked_prefill: + common_args.append("--enable-chunked-prefill") + if eager_mode: + common_args.append("--enforce-eager") + if trust_remote_code: + common_args.append("--trust-remote-code") + if tokenizer_mode: + common_args.extend(["--tokenizer-mode", tokenizer_mode]) + + if (distributed_backend == "ray" and tp_size == 2 and pp_size == 2 + and chunked_prefill): + # Test Ray ADAG for a subset of the tests + pp_env = { + "VLLM_USE_RAY_COMPILED_DAG": "1", + "VLLM_USE_RAY_SPMD_WORKER": "1", + "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1", + } + # Temporary. Currently when zeromq + SPMD is used, it does not properly + # terminate because of aDAG issue. + common_args.append("--disable-frontend-multiprocessing") + else: + pp_env = None + + pp_args = [ + *common_args, "--pipeline-parallel-size", - str(PP_SIZE), + str(pp_size), "--tensor-parallel-size", - str(TP_SIZE), + str(tp_size), "--distributed-executor-backend", - DIST_BACKEND, + distributed_backend, ] # compare without pipeline parallelism @@ -69,41 +278,15 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, # schedule all workers in a node other than the head node, # which can cause the test to fail. tp_args = [ - # use half precision for speed and memory savings in CI environment - "--dtype", - "float16", - "--max-model-len", - "8192", + *common_args, "--tensor-parallel-size", - str(max(TP_SIZE, 2)), # We only use 2 GPUs in the CI. + str(tp_size), "--distributed-executor-backend", "mp", ] - if CHUNKED_PREFILL: - pp_args.append("--enable-chunked-prefill") - tp_args.append("--enable-chunked-prefill") - if EAGER_MODE: - pp_args.append("--enforce-eager") - tp_args.append("--enforce-eager") - if TRUST_REMOTE_CODE: - pp_args.append("--trust-remote-code") - tp_args.append("--trust-remote-code") - pp_env = None - if (DIST_BACKEND == "ray" and TP_SIZE == 2 and PP_SIZE == 2 - and CHUNKED_PREFILL): - # Test Ray ADAG for a subset of the tests - pp_env = { - "VLLM_USE_RAY_COMPILED_DAG": "1", - "VLLM_USE_RAY_SPMD_WORKER": "1", - "VLLM_USE_RAY_COMPILED_DAG_NCCL_CHANNEL": "1", - } - # Temporary. Currently when zeromq + SPMD is used, it does not properly - # terminate because of aDAG issue. - pp_args.append("--disable-frontend-multiprocessing") - tp_args.append("--disable-frontend-multiprocessing") try: - compare_two_settings(MODEL_NAME, pp_args, tp_args, pp_env) + compare_two_settings(model_name, pp_args, tp_args, pp_env) except Exception: if pp_env is None: raise diff --git a/tests/engine/test_arg_utils.py b/tests/engine/test_arg_utils.py index 360ac1bfbad93..f7dc167fea6e4 100644 --- a/tests/engine/test_arg_utils.py +++ b/tests/engine/test_arg_utils.py @@ -42,22 +42,42 @@ def test_bad_nullable_kvs(arg): nullable_kvs(arg) -@pytest.mark.parametrize(("arg", "expected"), [ - (None, None), - ("{}", {}), - ('{"num_crops": 4}', { - "num_crops": 4 - }), - ('{"foo": {"bar": "baz"}}', { - "foo": { - "bar": "baz" - } - }), +# yapf: disable +@pytest.mark.parametrize(("arg", "expected", "option"), [ + (None, None, "mm-processor-kwargs"), + ("{}", {}, "mm-processor-kwargs"), + ( + '{"num_crops": 4}', + { + "num_crops": 4 + }, + "mm-processor-kwargs" + ), + ( + '{"foo": {"bar": "baz"}}', + { + "foo": + { + "bar": "baz" + } + }, + "mm-processor-kwargs" + ), + ( + '{"cast_logits_dtype":"bfloat16","sequence_parallel_norm":true,"sequence_parallel_norm_threshold":2048}', + { + "cast_logits_dtype": "bfloat16", + "sequence_parallel_norm": True, + "sequence_parallel_norm_threshold": 2048, + }, + "override-neuron-config" + ), ]) -def test_mm_processor_kwargs_prompt_parser(arg, expected): +# yapf: enable +def test_composite_arg_parser(arg, expected, option): parser = EngineArgs.add_cli_args(FlexibleArgumentParser()) if arg is None: args = parser.parse_args([]) else: - args = parser.parse_args(["--mm-processor-kwargs", arg]) - assert args.mm_processor_kwargs == expected + args = parser.parse_args([f"--{option}", arg]) + assert getattr(args, option.replace("-", "_")) == expected diff --git a/tests/kernels/test_flash_attn.py b/tests/kernels/test_flash_attn.py index 71f61c19dd951..3e9b4d9a4f8a0 100644 --- a/tests/kernels/test_flash_attn.py +++ b/tests/kernels/test_flash_attn.py @@ -3,9 +3,9 @@ import pytest import torch -import vllm.attention.backends.flash_attn # noqa: F401 -from tests.kernels.utils import opcheck from vllm.utils import seed_everything +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache) NUM_HEADS = [(4, 4), (8, 2), (16, 2)] HEAD_SIZES = [128, 256] @@ -112,10 +112,10 @@ def test_flash_attn_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) - output = torch.ops.vllm.flash_attn_with_kvcache( - decode_query=query.unsqueeze(1), - key_cache=key_cache, - value_cache=value_cache, + output = flash_attn_with_kvcache( + q=query.unsqueeze(1), + k_cache=key_cache, + v_cache=value_cache, softmax_scale=scale, causal=True, block_table=block_tables, @@ -123,25 +123,6 @@ def test_flash_attn_with_paged_kv( softcap=soft_cap if soft_cap is not None else 0, ).squeeze(1) - if num_blocks <= 2048: - test_utils = ["test_faketensor", "test_schema"] - else: - test_utils = ["test_faketensor"] - - opcheck(torch.ops.vllm.flash_attn_with_kvcache, - args=tuple(), - kwargs=dict( - decode_query=query.unsqueeze(1), - key_cache=key_cache, - value_cache=value_cache, - softmax_scale=scale, - causal=True, - block_table=block_tables, - cache_seqlens=kv_lens_tensor, - softcap=soft_cap if soft_cap is not None else 0, - ), - test_utils=test_utils) - ref_output = ref_paged_attn( query=query, key_cache=key_cache, @@ -213,7 +194,7 @@ def test_varlen_with_paged_kv( (num_seqs, max_num_blocks_per_seq), dtype=torch.int32) - output = torch.ops.vllm.flash_attn_varlen_func( + output = flash_attn_varlen_func( q=query, k=key_cache, v=value_cache, @@ -228,29 +209,6 @@ def test_varlen_with_paged_kv( softcap=soft_cap if soft_cap is not None else 0, ) - if num_blocks <= 2048: - test_utils = ["test_faketensor", "test_schema"] - else: - test_utils = ["test_faketensor"] - - opcheck(torch.ops.vllm.flash_attn_varlen_func, - args=tuple(), - kwargs=dict( - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=cu_query_lens, - cu_seqlens_k=cu_kv_lens, - max_seqlen_q=max_query_len, - max_seqlen_k=max_kv_len, - softmax_scale=scale, - causal=True, - window_size=window_size, - block_table=block_tables, - softcap=soft_cap if soft_cap is not None else 0, - ), - test_utils=test_utils) - ref_output = ref_paged_attn( query=query, key_cache=key_cache, diff --git a/tests/models/decoder_only/language/test_granitemoe.py b/tests/models/decoder_only/language/test_granitemoe.py new file mode 100644 index 0000000000000..ba73375229eb3 --- /dev/null +++ b/tests/models/decoder_only/language/test_granitemoe.py @@ -0,0 +1,39 @@ +"""Compare the outputs of HF and vLLM for Granite models using greedy sampling. + +Run `pytest tests/models/test_granite.py`. +""" +import pytest + +from ...utils import check_logprobs_close + +MODELS = [ + "ibm/PowerMoE-3b", +] + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [64]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models( + hf_runner, + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, + num_logprobs: int, +) -> None: + with hf_runner(model, dtype=dtype) as hf_model: + hf_outputs = hf_model.generate_greedy_logprobs_limit( + example_prompts, max_tokens, num_logprobs) + + with vllm_runner(model, dtype=dtype) as vllm_model: + vllm_outputs = vllm_model.generate_greedy_logprobs( + example_prompts, max_tokens, num_logprobs) + check_logprobs_close( + outputs_0_lst=hf_outputs, + outputs_1_lst=vllm_outputs, + name_0="hf", + name_1="vllm", + ) diff --git a/tests/models/test_registry.py b/tests/models/test_registry.py index b058e2755c245..299aeacb9f337 100644 --- a/tests/models/test_registry.py +++ b/tests/models/test_registry.py @@ -1,9 +1,55 @@ +import warnings + import pytest +import torch.cuda + +from vllm.model_executor.models import ModelRegistry +from vllm.platforms import current_platform -from vllm.model_executor.models import _MODELS, ModelRegistry +from ..utils import fork_new_process_for_each_test -@pytest.mark.parametrize("model_cls", _MODELS) -def test_registry_imports(model_cls): +@pytest.mark.parametrize("model_arch", ModelRegistry.get_supported_archs()) +def test_registry_imports(model_arch): # Ensure all model classes can be imported successfully - ModelRegistry.resolve_model_cls([model_cls]) + ModelRegistry.resolve_model_cls(model_arch) + + +@fork_new_process_for_each_test +@pytest.mark.parametrize("model_arch,is_mm,init_cuda", [ + ("LlamaForCausalLM", False, False), + ("MllamaForConditionalGeneration", True, False), + ("LlavaForConditionalGeneration", True, True), +]) +def test_registry_is_multimodal(model_arch, is_mm, init_cuda): + assert ModelRegistry.is_multimodal_model(model_arch) is is_mm + + if init_cuda and current_platform.is_cuda_alike(): + assert not torch.cuda.is_initialized() + + ModelRegistry.resolve_model_cls(model_arch) + if not torch.cuda.is_initialized(): + warnings.warn( + "This model no longer initializes CUDA on import. " + "Please test using a different one.", + stacklevel=2) + + +@fork_new_process_for_each_test +@pytest.mark.parametrize("model_arch,is_pp,init_cuda", [ + ("MLPSpeculatorPreTrainedModel", False, False), + ("DeepseekV2ForCausalLM", True, False), + ("Qwen2VLForConditionalGeneration", True, True), +]) +def test_registry_is_pp(model_arch, is_pp, init_cuda): + assert ModelRegistry.is_pp_supported_model(model_arch) is is_pp + + if init_cuda and current_platform.is_cuda_alike(): + assert not torch.cuda.is_initialized() + + ModelRegistry.resolve_model_cls(model_arch) + if not torch.cuda.is_initialized(): + warnings.warn( + "This model no longer initializes CUDA on import. " + "Please test using a different one.", + stacklevel=2) diff --git a/tests/multi_step/test_correctness_llm.py b/tests/multi_step/test_correctness_llm.py index ff413e8e2da3f..f45428675bde8 100644 --- a/tests/multi_step/test_correctness_llm.py +++ b/tests/multi_step/test_correctness_llm.py @@ -1,5 +1,6 @@ # Test the LLMEngine with multi-step-decoding +import copy from typing import Optional import pytest @@ -196,3 +197,160 @@ def test_multi_step_llm_w_prompt_logprobs( name_0="hf", name_1="vllm", ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["half"]) +@pytest.mark.parametrize("tp_size", [1]) +@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("enforce_eager", [True]) +@pytest.mark.parametrize("num_scheduler_steps", NUM_SCHEDULER_STEPS) +@pytest.mark.parametrize("num_prompts", NUM_PROMPTS) +@pytest.mark.parametrize("num_logprobs", [None, 5]) +def test_multi_step_llm_chunked_prefill_prefix_cache( + vllm_runner, + example_prompts, + model: str, + dtype: str, + tp_size: int, + max_tokens: int, + enforce_eager: int, + num_scheduler_steps: int, + num_prompts: int, + num_logprobs: Optional[int], +) -> None: + """Test vLLM engine with multi-step+"single-step chunked prefill"+APC. + + Set up contrived scenario which tests for a possible failure mode of + scheduling with multi-step+"single-step chunked prefill"+APC + + "single-step chunked prefill" here refers to the current vLLM multi-step+ + chunked-prefill implementation, which requires that a prefill may only + be scheduled in the same step as decodes if the prefill prompt fits in a + single chunk (note that "complete" multi-step+chunked-prefill would allow + a prefill to span multiple chunks & multiple steps but that is not yet + the case.) + + "APC" is short for "automatic prefix caching". + + This test creates a scenario where the scheduler must decide whether/how + to schedule a prefill with a prompt that exceeds the available token budget. + The correct behavior for multi-step+"single-step chunked prefill"+APC is to + put off scheduling the prefill until a future step. + + Validate that: + * Multi-step kernels do not raise an exception due to incorrect scheduler + behavior + * Generated tokens match between + multi-step+"single-step chunked prefill"+APC and + single-step scheduling. + * (If logprobs are enabled) check logprobs are close enough + + Args: + vllm_runner: vLLM model runner fixture + example_prompts: test fixture providing example prompts + model: model under test (same for single- and multi-step engines) + dtype: tensor datatype for engine to utilize + tp_size: degree of tensor-parallelism + max_tokens: the maximum number of tokens to generate + enforce_eager + num_scheduler_steps: for multi-step scheduling, GPU-side steps per + GPU -> CPU output transfer + num_prompts: number of example prompts under test + num_logprobs: corresponds to the `logprobs` argument to the OpenAI + completions endpoint; `None` -> 1 logprob returned. + """ + + # Set up contrived test for correct scheduling behavior with + # multi-step+"single-step chunked prefill"+APC. + # + # Assume block_size=16 + # + # Assume max_num_batched_tokens=48 + # => Per-step token budget=48 + # + # 1. Scheduler schedules 0th prompt (24 tokens) + # => Remaining token budget=24 + # 2. Scheduler attempts to schedule 1st prompt (30 tokens) + # * 30 tokens exceeds 24 token remaining budget + # * Correct behavior: do not schedule this prompt in this step + # * Incorrect behavior: schedule prompt chunk + # * `do_sample=False` for this prompt in this step + # * Chunk size = (remaining tokens // block size) * block size + # + # The Incorrect scheduling behavior - if it occurs - will cause an exception + # in the model runner resulting from `do_sample=False`. + assert len(example_prompts) >= 2 + challenge_prompts = copy.deepcopy(example_prompts) + challenge_prompts[0] = ('vLLM is a high-throughput and memory-efficient ' + 'inference and serving engine for LLMs.\n' + ) # 24 tok + challenge_prompts[1] = ( + 'Briefly describe the major milestones in the ' + 'development of artificial intelligence from 1950 to 2020.\n' + ) # 30 tok + + # If necessary, adjust the length of `challenge_prompts` to match + # `num_prompts` + if len(challenge_prompts) < num_prompts: + challenge_prompts = (challenge_prompts * + ((num_prompts // len(challenge_prompts)) + 1)) + challenge_prompts = challenge_prompts[:num_prompts] + assert len(challenge_prompts) == num_prompts + + # Single-step scheduler baseline + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + num_scheduler_steps=num_scheduler_steps, + max_model_len=48, + max_num_batched_tokens=48, + max_num_seqs=4, + block_size=16, + ) as vllm_model: + outputs_baseline = (vllm_model.generate_greedy( + challenge_prompts, max_tokens) if num_logprobs is None else + vllm_model.generate_greedy_logprobs( + challenge_prompts, max_tokens, num_logprobs)) + + # multi-step+"single-step chunked prefill"+APC + with vllm_runner( + model, + dtype=dtype, + enforce_eager=enforce_eager, + gpu_memory_utilization=0.7, + tensor_parallel_size=tp_size, + use_v2_block_manager=True, + enable_chunked_prefill=True, + enable_prefix_caching=True, + num_scheduler_steps=num_scheduler_steps, + max_model_len=48, + max_num_batched_tokens=48, + max_num_seqs=4, + block_size=16, + ) as vllm_model: + outputs_w_features = (vllm_model.generate_greedy( + challenge_prompts, max_tokens) if num_logprobs is None else + vllm_model.generate_greedy_logprobs( + challenge_prompts, max_tokens, num_logprobs)) + + if num_logprobs is None: + # No-logprobs test + check_outputs_equal( + outputs_0_lst=outputs_baseline, + outputs_1_lst=outputs_w_features, + name_0="multi-step", + name_1="multi-step+features", + ) + else: + # Yes-logprobs test + check_logprobs_close( + outputs_0_lst=outputs_baseline, + outputs_1_lst=outputs_w_features, + name_0="multi-step", + name_1="multi-step+features", + ) diff --git a/tests/samplers/test_sampler.py b/tests/samplers/test_sampler.py index 467cea3659d49..b8c3cc3b18e2a 100644 --- a/tests/samplers/test_sampler.py +++ b/tests/samplers/test_sampler.py @@ -434,7 +434,7 @@ def run_test_case(*, expected_penalization: List[bool], sampling_metadata = SamplingMetadata.prepare( seq_group_metadata_list, seq_lens=seq_lens if seq_lens else None, - query_lens=seq_lens if seq_lens else None, + query_lens=seq_lens if seq_lens else [1] * batch_size, device=device, pin_memory=is_pin_memory_available()) # the logits tensor is modified in-place by the sampler diff --git a/tests/spec_decode/e2e/test_compatibility.py b/tests/spec_decode/e2e/test_compatibility.py index 81f91c5e10b0d..9f0af211e264a 100644 --- a/tests/spec_decode/e2e/test_compatibility.py +++ b/tests/spec_decode/e2e/test_compatibility.py @@ -100,6 +100,7 @@ def test_spec_decode_xfail_spec_max_model_len(test_llm_generator): "model": "JackFram/llama-68m", "speculative_model": "JackFram/llama-68m", "num_speculative_tokens": 5, + "use_v2_block_manager": False, }]) @pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) @pytest.mark.parametrize("test_llm_kwargs", [{}]) diff --git a/tests/spec_decode/e2e/test_integration.py b/tests/spec_decode/e2e/test_integration.py index 4a427d4c3e287..d04e312689bcc 100644 --- a/tests/spec_decode/e2e/test_integration.py +++ b/tests/spec_decode/e2e/test_integration.py @@ -102,3 +102,47 @@ def test_speculative_model_quantization_config(vllm_runner, common_llm_kwargs, max_output_len=32, seed=seed, temperature=0.0) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model_name": MAIN_MODEL, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "speculative_model": "JackFram/llama-68m", + "num_speculative_tokens": 3, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_disable_mqa_scorer": True, + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, + output_len: int, seed: int): + """Verify that ngram speculative decoding generates the same output + with batch expansion scorer and mqa scorer. + """ + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0) diff --git a/tests/spec_decode/e2e/test_medusa_correctness.py b/tests/spec_decode/e2e/test_medusa_correctness.py index 8c90e147df23a..0b36e712a11b2 100644 --- a/tests/spec_decode/e2e/test_medusa_correctness.py +++ b/tests/spec_decode/e2e/test_medusa_correctness.py @@ -350,6 +350,55 @@ def test_medusa_disable_queue(vllm_runner, common_llm_kwargs, temperature=0.0) +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + + # Precision + "dtype": PRECISION, + + # Main model + "model_name": MAIN_MODEL, + "speculative_model": SPEC_MODEL, + "num_speculative_tokens": MAX_SPEC_TOKENS, + "speculative_disable_by_batch_size": 4 + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_disable_mqa_scorer": True, + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, + output_len: int, seed: int): + """Verify that speculative decoding generates the same output + with batch expansion scorer and mqa scorer. + """ + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0) + + if __name__ == "__main__": import pytest pytest.main([__file__]) diff --git a/tests/spec_decode/e2e/test_mlp_correctness.py b/tests/spec_decode/e2e/test_mlp_correctness.py index 7f3180befaffc..52b48a33c3097 100644 --- a/tests/spec_decode/e2e/test_mlp_correctness.py +++ b/tests/spec_decode/e2e/test_mlp_correctness.py @@ -460,3 +460,46 @@ def test_mlp_disable_queue(vllm_runner, common_llm_kwargs, max_output_len=output_len, seed=seed, temperature=0.0) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model_name": MAIN_MODEL, + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "speculative_model": SPEC_MODEL, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_disable_mqa_scorer": True, + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_mqa_scorer(vllm_runner, common_llm_kwargs, per_test_common_llm_kwargs, + baseline_llm_kwargs, test_llm_kwargs, batch_size: int, + output_len: int, seed: int): + """Verify that speculative decoding generates the same output + with batch expansion scorer and mqa scorer. + """ + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0) diff --git a/tests/spec_decode/e2e/test_ngram_correctness.py b/tests/spec_decode/e2e/test_ngram_correctness.py index 850114eb7f5a8..5862459383167 100644 --- a/tests/spec_decode/e2e/test_ngram_correctness.py +++ b/tests/spec_decode/e2e/test_ngram_correctness.py @@ -292,3 +292,49 @@ def test_ngram_disable_queue(vllm_runner, common_llm_kwargs, max_output_len=output_len, seed=seed, temperature=0.0) + + +@pytest.mark.parametrize( + "common_llm_kwargs", + [{ + "model_name": "JackFram/llama-68m", + + # Skip cuda graph recording for fast test. + "enforce_eager": True, + + # Required for spec decode. + "use_v2_block_manager": True, + "speculative_model": "[ngram]", + "num_speculative_tokens": 5, + "ngram_prompt_lookup_max": 3, + }]) +@pytest.mark.parametrize("per_test_common_llm_kwargs", [{}]) +@pytest.mark.parametrize("baseline_llm_kwargs", [{}]) +@pytest.mark.parametrize("test_llm_kwargs", + [{ + "speculative_disable_mqa_scorer": True, + }]) +@pytest.mark.parametrize("batch_size", [1, 5]) +@pytest.mark.parametrize( + "output_len", + [ + # Use smaller output len for fast test. + 32, + ]) +@pytest.mark.parametrize("seed", [1]) +def test_ngram_scorer(vllm_runner, common_llm_kwargs, + per_test_common_llm_kwargs, baseline_llm_kwargs, + test_llm_kwargs, batch_size: int, output_len: int, + seed: int): + """Verify that ngram speculative decoding generates the same output + with batch expansion scorer and mqa scorer. + """ + run_equality_correctness_test(vllm_runner, + common_llm_kwargs, + per_test_common_llm_kwargs, + baseline_llm_kwargs, + test_llm_kwargs, + batch_size, + max_output_len=output_len, + seed=seed, + temperature=0.0) diff --git a/tests/spec_decode/test_multi_step_worker.py b/tests/spec_decode/test_multi_step_worker.py index 6fa386ffab12f..e6f7f480eebb2 100644 --- a/tests/spec_decode/test_multi_step_worker.py +++ b/tests/spec_decode/test_multi_step_worker.py @@ -173,7 +173,6 @@ def test_same_output_for_multi_step(): block_size, num_gpu_blocks, seed, - model_runner_cls=TP1DraftModelRunner, ) worker = create_worker( diff --git a/tests/spec_decode/test_scorer.py b/tests/spec_decode/test_scorer.py new file mode 100644 index 0000000000000..5f703b03ab7fe --- /dev/null +++ b/tests/spec_decode/test_scorer.py @@ -0,0 +1,65 @@ +import pytest +import torch + +from vllm.sequence import ExecuteModelRequest +from vllm.spec_decode.batch_expansion import BatchExpansionTop1Scorer +from vllm.spec_decode.interfaces import SpeculativeProposals, SpeculativeScores +from vllm.spec_decode.mqa_scorer import MQAScorer +from vllm.worker.worker import Worker + +from .utils import create_batch, create_worker + + +def create_proposal(batch_size: int, propose_len: int, vocab_size: int, + device: str) -> SpeculativeProposals: + proposal_probs = torch.rand((batch_size, propose_len, vocab_size), + device=device) + proposal_token_ids = torch.argmax(proposal_probs, dim=-1) + proposal_lens = torch.tensor([propose_len] * batch_size, device=device) + return SpeculativeProposals(proposal_token_ids, proposal_probs, + proposal_lens) + + +def assert_score_equal(score1: SpeculativeScores, + score2: SpeculativeScores) -> None: + assert torch.allclose(score1.probs, score2.probs) + assert torch.allclose(score1.logprobs, score2.logprobs) + assert torch.equal(score1.token_ids, score2.token_ids) + + +@pytest.mark.parametrize('model_name', ['facebook/opt-125m']) +@pytest.mark.parametrize('batch_size', [1, 2, 4, 8, 16]) +@pytest.mark.parametrize('propose_len', [1, 3, 5]) +@pytest.mark.parametrize('device', ['cuda']) +def test_scoroer(model_name: str, batch_size: int, propose_len: int, + device: str) -> None: + """ + Compare the batch expansion scorer and mqa scorer return the same score + """ + seed = 0 + block_size = 32 + num_gpu_blocks = 2048 // block_size + scorer_worker = create_worker(Worker, model_name, block_size, + num_gpu_blocks, seed) + scorer_worker.model_runner.model.sampler.include_gpu_probs_tensor = True + scorer_worker.model_runner.model.sampler.\ + should_modify_greedy_probs_inplace = True + + vocab_size = scorer_worker.vocab_size + proposals = create_proposal(batch_size, propose_len, vocab_size, device) + seq_group_metadatalist, _, _ = create_batch(batch_size, + propose_len, + block_size=block_size, + num_gpu_blocks=num_gpu_blocks) + requests = ExecuteModelRequest(seq_group_metadatalist, + num_lookahead_slots=propose_len) + + batch_expansion_scorer = BatchExpansionTop1Scorer(scorer_worker, device, + vocab_size) + batch_expansion_score = batch_expansion_scorer.score_proposals( + requests, proposals) + + mqa_scorer = MQAScorer(scorer_worker, device, vocab_size) + mqa_score = mqa_scorer.score_proposals(requests, proposals) + + assert_score_equal(batch_expansion_score, mqa_score) diff --git a/tests/spec_decode/test_spec_decode_worker.py b/tests/spec_decode/test_spec_decode_worker.py index 501d05756e01c..e0b7b7d47f1f1 100644 --- a/tests/spec_decode/test_spec_decode_worker.py +++ b/tests/spec_decode/test_spec_decode_worker.py @@ -63,10 +63,10 @@ def test_correctly_calls_draft_model(k: int, batch_size: int, @pytest.mark.parametrize("acceptance_sampler_method", ["rejection_sampler", "typical_acceptance_sampler"]) @torch.inference_mode() -def test_correctly_calls_target_model(k: int, batch_size: int, - acceptance_sampler_method: str): +def test_batch_expansion_correctly_calls_target_model( + k: int, batch_size: int, acceptance_sampler_method: str): """Verify SpecDecodeWorker calls the target model with correct - inputs. Everything else is mocked out. + inputs with batch expansion. Everything else is mocked out. """ draft_worker = mock_worker(cls=MultiStepWorker, use_spec=False) target_worker = mock_worker(use_spec=False) @@ -82,7 +82,8 @@ def test_correctly_calls_target_model(k: int, batch_size: int, target_worker, mock_spec_decode_sampler(acceptance_sampler_method), disable_logprobs=False, - metrics_collector=metrics_collector) + metrics_collector=metrics_collector, + disable_mqa_scorer=True) worker.init_device() vocab_size = 32_000 diff --git a/tests/spec_decode/utils.py b/tests/spec_decode/utils.py index f17e872881633..f683942a5854b 100644 --- a/tests/spec_decode/utils.py +++ b/tests/spec_decode/utils.py @@ -131,19 +131,22 @@ def create_seq_group_metadata_from_prompts( for i, final_len in enumerate(final_prompt_lens) } - return [ - SequenceGroupMetadata( - request_id=str(i), - is_prompt=len(cont_token_ids) == 0, - seq_data={ - i: SequenceData.from_seqs(prompt_token_ids[:], - cont_token_ids[:]), - }, - sampling_params=SamplingParams(temperature=0.0, ), - block_tables={i: block_allocations[i][:]}, - ) for i, (prompt_token_ids, - cont_token_ids) in enumerate(zip(prompts, continuations)) - ] + seq_grou_metadata_list = [] + for i, (prompt_token_ids, + cont_token_ids) in enumerate(zip(prompts, continuations)): + data = SequenceData.from_seqs(prompt_token_ids, cont_token_ids) + data.update_num_computed_tokens( + len(prompt_token_ids) + len(cont_token_ids) - 1) + seq_data = {i: data} + seq_grou_metadata_list.append( + SequenceGroupMetadata( + request_id=str(i), + is_prompt=len(cont_token_ids) == 0, + seq_data=seq_data, + sampling_params=SamplingParams(temperature=0.0), + block_tables={i: block_allocations[i][:]}, + )) + return seq_grou_metadata_list def assert_logprobs_dict_allclose( diff --git a/tests/tool_use/test_parallel_tool_calls.py b/tests/tool_use/test_parallel_tool_calls.py index ed7ac8afe1b4e..cff3c8a556ca4 100644 --- a/tests/tool_use/test_parallel_tool_calls.py +++ b/tests/tool_use/test_parallel_tool_calls.py @@ -45,7 +45,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, assert tool_call.type == "function" assert tool_call.function is not None assert isinstance(tool_call.id, str) - assert len(tool_call.id) > 16 + assert len(tool_call.id) >= 9 # make sure the weather tool was called correctly assert tool_call.function.name == WEATHER_TOOL["function"]["name"] @@ -108,7 +108,7 @@ async def test_parallel_tool_calls(client: openai.AsyncOpenAI, if tool_call.id: tool_call_id_count += 1 assert (isinstance(tool_call.id, str) - and (len(tool_call.id) > 16)) + and (len(tool_call.id) >= 9)) # if parts of the function start being streamed if tool_call.function: diff --git a/tests/tool_use/test_tool_calls.py b/tests/tool_use/test_tool_calls.py index c3abe9e1f5060..9e6d715f44fcf 100644 --- a/tests/tool_use/test_tool_calls.py +++ b/tests/tool_use/test_tool_calls.py @@ -33,7 +33,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): assert tool_calls[0].type == 'function' assert tool_calls[0].function is not None assert isinstance(tool_calls[0].id, str) - assert len(tool_calls[0].id) > 16 + assert len(tool_calls[0].id) >= 9 # make sure the weather tool was called (classic example) with arguments assert tool_calls[0].function.name == WEATHER_TOOL["function"]["name"] @@ -106,7 +106,7 @@ async def test_tool_call_and_choice(client: openai.AsyncOpenAI): assert finish_reason_count == 1 assert role_name == 'assistant' - assert isinstance(tool_call_id, str) and (len(tool_call_id) > 16) + assert isinstance(tool_call_id, str) and (len(tool_call_id) >= 9) # validate the name and arguments assert function_name == WEATHER_TOOL["function"]["name"] diff --git a/tests/tool_use/utils.py b/tests/tool_use/utils.py index 1a840f8a51c9f..ce36515a2381c 100644 --- a/tests/tool_use/utils.py +++ b/tests/tool_use/utils.py @@ -87,6 +87,18 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "call the tool. Otherwise, answer the user's query directly " "without calling a tool. DO NOT CALL A TOOL THAT IS IRRELEVANT " "to the user's question - just respond to it normally." + }, + "internlm": { + "model": + "internlm/internlm2_5-7b-chat", + "arguments": [ + "--tool-call-parser", "internlm", "--chat-template", + str(VLLM_PATH / + "examples/tool_chat_template_internlm2_tool.jinja"), + "--trust_remote_code" + ], + "supports_parallel": + False, } } @@ -109,7 +121,7 @@ def ensure_system_prompt(messages: List[Dict[str, Any]], "type": "string", "description": - "the two-letter abbreviation for the state " + "must the two-letter abbreviation for the state " "that the city is in, e.g. 'CA' which would " "mean 'California'" }, diff --git a/tests/utils.py b/tests/utils.py index 49bd4f236f658..8c8a7c4bf0c70 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -14,7 +14,6 @@ import pytest import requests from openai.types.completion import Completion -from transformers import AutoTokenizer from typing_extensions import ParamSpec from tests.models.utils import TextTextLogprobs @@ -24,6 +23,7 @@ from vllm.entrypoints.openai.cli_args import make_arg_parser from vllm.model_executor.model_loader.loader import get_model_loader from vllm.platforms import current_platform +from vllm.transformers_utils.tokenizer import get_tokenizer from vllm.utils import (FlexibleArgumentParser, GB_bytes, cuda_device_count_stateless, get_open_port, is_hip) @@ -181,15 +181,26 @@ def compare_two_settings(model: str, env2: The second set of environment variables to pass to the API server. """ - trust_remote_code = "--trust-remote-code" - if trust_remote_code in arg1 or trust_remote_code in arg2: - tokenizer = AutoTokenizer.from_pretrained(model, - trust_remote_code=True) - else: - tokenizer = AutoTokenizer.from_pretrained(model) + trust_remote_code = False + for args in (arg1, arg2): + if "--trust-remote-code" in args: + trust_remote_code = True + break + + tokenizer_mode = "auto" + for args in (arg1, arg2): + if "--tokenizer-mode" in args: + tokenizer_mode = args[args.index("--tokenizer-mode") + 1] + break + + tokenizer = get_tokenizer( + model, + trust_remote_code=trust_remote_code, + tokenizer_mode=tokenizer_mode, + ) prompt = "Hello, my name is" - token_ids = tokenizer(prompt)["input_ids"] + token_ids = tokenizer(prompt).input_ids results = [] for args, env in ((arg1, env1), (arg2, env2)): with RemoteOpenAIServer(model, diff --git a/tools/report_build_time_ninja.py b/tools/report_build_time_ninja.py new file mode 100644 index 0000000000000..3f9b68c2eccbe --- /dev/null +++ b/tools/report_build_time_ninja.py @@ -0,0 +1,311 @@ +#!/usr/bin/env python3 +# Copyright (c) 2018 The Chromium Authors. All rights reserved. +# Use of this source code is governed by a BSD-style license that can be +# found in the LICENSE file. + +# Modified version of: https://chromium.googlesource.com/chromium/tools/depot_tools.git/+/refs/heads/main/post_build_ninja_summary.py +"""Summarize the last ninja build, invoked with ninja's -C syntax. + +> python3 tools/report_build_time_ninja.py -C build/.. + +Typical output looks like this: +``` + Longest build steps for .cpp.o: + 1.0 weighted s to build ...torch_bindings.cpp.o (12.4 s elapsed time) + 2.0 weighted s to build ..._attn_c.dir/csrc... (23.5 s elapsed time) + 2.6 weighted s to build ...torch_bindings.cpp.o (31.5 s elapsed time) + 3.2 weighted s to build ...torch_bindings.cpp.o (38.5 s elapsed time) + Longest build steps for .so (linking): + 0.1 weighted s to build _core_C.abi3.so (0.7 s elapsed time) + 0.1 weighted s to build _moe_C.abi3.so (1.0 s elapsed time) + 0.5 weighted s to build ...flash_attn_c.abi3.so (1.1 s elapsed time) + 6.2 weighted s to build _C.abi3.so (6.2 s elapsed time) + Longest build steps for .cu.o: + 15.3 weighted s to build ...machete_mm_... (183.5 s elapsed time) + 15.3 weighted s to build ...machete_mm_... (183.5 s elapsed time) + 15.3 weighted s to build ...machete_mm_... (183.6 s elapsed time) + 15.3 weighted s to build ...machete_mm_... (183.7 s elapsed time) + 15.5 weighted s to build ...machete_mm_... (185.6 s elapsed time) + 15.5 weighted s to build ...machete_mm_... (185.9 s elapsed time) + 15.5 weighted s to build ...machete_mm_... (186.2 s elapsed time) + 37.4 weighted s to build ...scaled_mm_c3x.cu... (449.0 s elapsed time) + 43.9 weighted s to build ...scaled_mm_c2x.cu... (527.4 s elapsed time) + 344.8 weighted s to build ...attention_...cu.o (1087.2 s elapsed time) + 1110.0 s weighted time (10120.4 s elapsed time sum, 9.1x parallelism) + 134 build steps completed, average of 0.12/s +``` +""" + +import argparse +import errno +import fnmatch +import os +import sys +from collections import defaultdict + +# The number of long build times to report: +long_count = 10 +# The number of long times by extension to report +long_ext_count = 10 + + +class Target: + """Represents a single line read for a .ninja_log file.""" + + def __init__(self, start, end): + """Creates a target object by passing in the start/end times in seconds + as a float.""" + self.start = start + self.end = end + # A list of targets, appended to by the owner of this object. + self.targets = [] + self.weighted_duration = 0.0 + + def Duration(self): + """Returns the task duration in seconds as a float.""" + return self.end - self.start + + def SetWeightedDuration(self, weighted_duration): + """Sets the duration, in seconds, passed in as a float.""" + self.weighted_duration = weighted_duration + + def WeightedDuration(self): + """Returns the task's weighted duration in seconds as a float. + + Weighted_duration takes the elapsed time of the task and divides it + by how many other tasks were running at the same time. Thus, it + represents the approximate impact of this task on the total build time, + with serialized or serializing steps typically ending up with much + longer weighted durations. + weighted_duration should always be the same or shorter than duration. + """ + # Allow for modest floating-point errors + epsilon = 0.000002 + if (self.weighted_duration > self.Duration() + epsilon): + print('%s > %s?' % (self.weighted_duration, self.Duration())) + assert (self.weighted_duration <= self.Duration() + epsilon) + return self.weighted_duration + + def DescribeTargets(self): + """Returns a printable string that summarizes the targets.""" + # Some build steps generate dozens of outputs - handle them sanely. + # The max_length was chosen so that it can fit most of the long + # single-target names, while minimizing word wrapping. + result = ', '.join(self.targets) + max_length = 65 + if len(result) > max_length: + result = result[:max_length] + '...' + return result + + +# Copied with some modifications from ninjatracing +def ReadTargets(log, show_all): + """Reads all targets from .ninja_log file |log_file|, sorted by duration. + + The result is a list of Target objects.""" + header = log.readline() + assert header == '# ninja log v5\n', \ + 'unrecognized ninja log version %r' % header + targets_dict = {} + last_end_seen = 0.0 + for line in log: + parts = line.strip().split('\t') + if len(parts) != 5: + # If ninja.exe is rudely halted then the .ninja_log file may be + # corrupt. Silently continue. + continue + start, end, _, name, cmdhash = parts # Ignore restat. + # Convert from integral milliseconds to float seconds. + start = int(start) / 1000.0 + end = int(end) / 1000.0 + if not show_all and end < last_end_seen: + # An earlier time stamp means that this step is the first in a new + # build, possibly an incremental build. Throw away the previous + # data so that this new build will be displayed independently. + # This has to be done by comparing end times because records are + # written to the .ninja_log file when commands complete, so end + # times are guaranteed to be in order, but start times are not. + targets_dict = {} + target = None + if cmdhash in targets_dict: + target = targets_dict[cmdhash] + if not show_all and (target.start != start or target.end != end): + # If several builds in a row just run one or two build steps + # then the end times may not go backwards so the last build may + # not be detected as such. However in many cases there will be a + # build step repeated in the two builds and the changed + # start/stop points for that command, identified by the hash, + # can be used to detect and reset the target dictionary. + targets_dict = {} + target = None + if not target: + targets_dict[cmdhash] = target = Target(start, end) + last_end_seen = end + target.targets.append(name) + return list(targets_dict.values()) + + +def GetExtension(target, extra_patterns): + """Return the file extension that best represents a target. + + For targets that generate multiple outputs it is important to return a + consistent 'canonical' extension. Ultimately the goal is to group build steps + by type.""" + for output in target.targets: + if extra_patterns: + for fn_pattern in extra_patterns.split(';'): + if fnmatch.fnmatch(output, '*' + fn_pattern + '*'): + return fn_pattern + # Not a true extension, but a good grouping. + if output.endswith('type_mappings'): + extension = 'type_mappings' + break + + # Capture two extensions if present. For example: file.javac.jar should + # be distinguished from file.interface.jar. + root, ext1 = os.path.splitext(output) + _, ext2 = os.path.splitext(root) + extension = ext2 + ext1 # Preserve the order in the file name. + + if len(extension) == 0: + extension = '(no extension found)' + + if ext1 in ['.pdb', '.dll', '.exe']: + extension = 'PEFile (linking)' + # Make sure that .dll and .exe are grouped together and that the + # .dll.lib files don't cause these to be listed as libraries + break + if ext1 in ['.so', '.TOC']: + extension = '.so (linking)' + # Attempt to identify linking, avoid identifying as '.TOC' + break + # Make sure .obj files don't get categorized as mojo files + if ext1 in ['.obj', '.o']: + break + # Jars are the canonical output of java targets. + if ext1 == '.jar': + break + # Normalize all mojo related outputs to 'mojo'. + if output.count('.mojom') > 0: + extension = 'mojo' + break + return extension + + +def SummarizeEntries(entries, extra_step_types): + """Print a summary of the passed in list of Target objects.""" + + # Create a list that is in order by time stamp and has entries for the + # beginning and ending of each build step (one time stamp may have multiple + # entries due to multiple steps starting/stopping at exactly the same time). + # Iterate through this list, keeping track of which tasks are running at all + # times. At each time step calculate a running total for weighted time so + # that when each task ends its own weighted time can easily be calculated. + task_start_stop_times = [] + + earliest = -1 + latest = 0 + total_cpu_time = 0 + for target in entries: + if earliest < 0 or target.start < earliest: + earliest = target.start + if target.end > latest: + latest = target.end + total_cpu_time += target.Duration() + task_start_stop_times.append((target.start, 'start', target)) + task_start_stop_times.append((target.end, 'stop', target)) + length = latest - earliest + weighted_total = 0.0 + + # Sort by the time/type records and ignore |target| + task_start_stop_times.sort(key=lambda times: times[:2]) + # Now we have all task start/stop times sorted by when they happen. If a + # task starts and stops on the same time stamp then the start will come + # first because of the alphabet, which is important for making this work + # correctly. + # Track the tasks which are currently running. + running_tasks = {} + # Record the time we have processed up to so we know how to calculate time + # deltas. + last_time = task_start_stop_times[0][0] + # Track the accumulated weighted time so that it can efficiently be added + # to individual tasks. + last_weighted_time = 0.0 + # Scan all start/stop events. + for event in task_start_stop_times: + time, action_name, target = event + # Accumulate weighted time up to now. + num_running = len(running_tasks) + if num_running > 0: + # Update the total weighted time up to this moment. + last_weighted_time += (time - last_time) / float(num_running) + if action_name == 'start': + # Record the total weighted task time when this task starts. + running_tasks[target] = last_weighted_time + if action_name == 'stop': + # Record the change in the total weighted task time while this task + # ran. + weighted_duration = last_weighted_time - running_tasks[target] + target.SetWeightedDuration(weighted_duration) + weighted_total += weighted_duration + del running_tasks[target] + last_time = time + assert (len(running_tasks) == 0) + + # Warn if the sum of weighted times is off by more than half a second. + if abs(length - weighted_total) > 500: + print('Warning: Possible corrupt ninja log, results may be ' + 'untrustworthy. Length = %.3f, weighted total = %.3f' % + (length, weighted_total)) + + entries_by_ext = defaultdict(list) + for target in entries: + extension = GetExtension(target, extra_step_types) + entries_by_ext[extension].append(target) + + for key, values in entries_by_ext.items(): + print(' Longest build steps for %s:' % key) + values.sort(key=lambda x: x.WeightedDuration()) + for target in values[-long_count:]: + print(' %8.1f weighted s to build %s (%.1f s elapsed time)' % + (target.WeightedDuration(), target.DescribeTargets(), + target.Duration())) + + print(' %.1f s weighted time (%.1f s elapsed time sum, %1.1fx ' + 'parallelism)' % + (length, total_cpu_time, total_cpu_time * 1.0 / length)) + print(' %d build steps completed, average of %1.2f/s' % + (len(entries), len(entries) / (length))) + + +def main(): + log_file = '.ninja_log' + parser = argparse.ArgumentParser() + parser.add_argument('-C', dest='build_directory', help='Build directory.') + parser.add_argument( + '-s', + '--step-types', + help='semicolon separated fnmatch patterns for build-step grouping') + parser.add_argument('--log-file', + help="specific ninja log file to analyze.") + args, _extra_args = parser.parse_known_args() + if args.build_directory: + log_file = os.path.join(args.build_directory, log_file) + if args.log_file: + log_file = args.log_file + if args.step_types: + # Make room for the extra build types. + global long_ext_count + long_ext_count += len(args.step_types.split(';')) + + try: + with open(log_file, 'r') as log: + entries = ReadTargets(log, False) + SummarizeEntries(entries, args.step_types) + except IOError: + print('Log file %r not found, no build summary created.' % log_file) + return errno.ENOENT + + +if __name__ == '__main__': + sys.exit(main()) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 25bb9a44d2dd8..912cfe5df84f2 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -32,6 +32,15 @@ def hint_on_error(fn): def wrapper(*args, **kwargs): try: return fn(*args, **kwargs) + + except NotImplementedError as e: + msg = ( + "Error in calling custom op %s: %s\n" + "Not implemented or built, mostly likely because the current current device " + "does not support this kernel (less likely TORCH_CUDA_ARCH_LIST was set " + "incorrectly while building)") + logger.error(msg, fn.__name__, e) + raise NotImplementedError(msg % (fn.__name__, e)) from e except AttributeError as e: msg = ( "Error in calling custom op %s: %s\n" diff --git a/vllm/attention/backends/blocksparse_attn.py b/vllm/attention/backends/blocksparse_attn.py index 656cfd124ab44..57ac152d9edb6 100644 --- a/vllm/attention/backends/blocksparse_attn.py +++ b/vllm/attention/backends/blocksparse_attn.py @@ -186,6 +186,12 @@ class BlocksparseFlashAttentionMetadata(AttentionMetadata): # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = None + _cached_prefill_metadata: Optional[ "BlocksparseFlashAttentionMetadata"] = None _cached_decode_metadata: Optional[ diff --git a/vllm/attention/backends/flash_attn.py b/vllm/attention/backends/flash_attn.py index 43ca6c9ff160e..bba80262e52d3 100644 --- a/vllm/attention/backends/flash_attn.py +++ b/vllm/attention/backends/flash_attn.py @@ -13,152 +13,15 @@ compute_slot_mapping, compute_slot_mapping_start_idx, is_block_tables_empty) +from vllm.forward_context import get_forward_context from vllm.utils import async_tensor_h2d, make_tensor_with_pad if TYPE_CHECKING: from vllm.worker.model_runner import (ModelInputForGPUBuilder, ModelInputForGPUWithSamplingMetadata) -# yapf: disable -from vllm.vllm_flash_attn import ( - flash_attn_varlen_func as _flash_attn_varlen_func) -from vllm.vllm_flash_attn import ( - flash_attn_with_kvcache as _flash_attn_with_kvcache) - -# yapf: enable - - -@torch.library.custom_op("vllm::flash_attn_varlen_func", mutates_args=[]) -def flash_attn_varlen_func( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: Optional[float] = None, - causal: bool = False, - window_size: Optional[List[int]] = None, - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, -) -> torch.Tensor: - # custom op does not support tuple input - real_window_size: Tuple[int, int] - if window_size is None: - real_window_size = (-1, -1) - else: - assert len(window_size) == 2 - real_window_size = (window_size[0], window_size[1]) - return _flash_attn_varlen_func( - q=q, - k=k, - v=v, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=max_seqlen_q, - max_seqlen_k=max_seqlen_k, - softmax_scale=softmax_scale, - causal=causal, - window_size=real_window_size, - softcap=softcap, - alibi_slopes=alibi_slopes, - block_table=block_table, - ) - - -@flash_attn_varlen_func.register_fake # type: ignore -def _( - q: torch.Tensor, - k: torch.Tensor, - v: torch.Tensor, - cu_seqlens_q: torch.Tensor, - cu_seqlens_k: torch.Tensor, - max_seqlen_q: int, - max_seqlen_k: int, - softmax_scale: Optional[float] = None, - causal: bool = False, - window_size: Optional[List[int]] = None, - softcap: float = 0.0, - alibi_slopes: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, -) -> torch.Tensor: - return torch.empty_like(q) - - -@torch.library.custom_op("vllm::flash_attn_with_kvcache", mutates_args=[]) -def flash_attn_with_kvcache( - decode_query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - cache_seqlens: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - causal: bool = False, - alibi_slopes: Optional[torch.Tensor] = None, - softcap: float = 0.0, -) -> torch.Tensor: - return _flash_attn_with_kvcache( - decode_query, - key_cache, - value_cache, - cache_seqlens=cache_seqlens, - block_table=block_table, - softmax_scale=softmax_scale, - causal=causal, - alibi_slopes=alibi_slopes, - softcap=softcap, - ) - - -@flash_attn_with_kvcache.register_fake # type: ignore -def _( - decode_query: torch.Tensor, - key_cache: torch.Tensor, - value_cache: torch.Tensor, - cache_seqlens: Optional[torch.Tensor] = None, - block_table: Optional[torch.Tensor] = None, - softmax_scale: Optional[float] = None, - causal: bool = False, - alibi_slopes: Optional[torch.Tensor] = None, - softcap: float = 0.0, -) -> torch.Tensor: - return torch.empty_like(decode_query) - - -@torch.library.custom_op("vllm::reshape_and_cache_flash", - mutates_args=["kv_cache"]) -def reshape_and_cache_flash( - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, -) -> None: - """Inductor cannot deal with inplace operations on views. - See https://github.com/pytorch/pytorch/issues/131192 - and https://github.com/pytorch/pytorch/issues/130174 - This is a workaround to hide the view operation from the inductor. - """ - return torch.ops._C_cache_ops.reshape_and_cache_flash( - key, value, kv_cache[0], kv_cache[1], slot_mapping, kv_cache_dtype, - k_scale, v_scale) - - -@reshape_and_cache_flash.register_fake # type: ignore -def _( - key: torch.Tensor, - value: torch.Tensor, - kv_cache: torch.Tensor, - slot_mapping: torch.Tensor, - kv_cache_dtype: str, - k_scale: float, - v_scale: float, -) -> None: - pass +from vllm.vllm_flash_attn import (flash_attn_varlen_func, + flash_attn_with_kvcache) class FlashAttentionBackend(AttentionBackend): @@ -245,8 +108,15 @@ class FlashAttentionMetadata(AttentionMetadata): # |-------------------- seq_len ---------------------| # |-- query_len ---| - # Maximum query length in the batch. None for decoding. + # Maximum query length in the batch. max_query_len: Optional[int] + + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] + # Maximum sequence length among prefill batch. 0 if there are decoding # requests only. max_prefill_seq_len: int @@ -303,6 +173,7 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=self.slot_mapping[:self.num_prefill_tokens], seq_lens=self.seq_lens[:self.num_prefills], seq_lens_tensor=self.seq_lens_tensor[:self.num_prefills], + decode_query_len=0, max_query_len=self.max_query_len, max_prefill_seq_len=self.max_prefill_seq_len, max_decode_seq_len=0, @@ -331,7 +202,8 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]: slot_mapping=self.slot_mapping[self.num_prefill_tokens:], seq_lens=None, seq_lens_tensor=self.seq_lens_tensor[self.num_prefills:], - max_query_len=None, + decode_query_len=self.decode_query_len, + max_query_len=self.max_query_len, max_prefill_seq_len=0, max_decode_seq_len=self.max_decode_seq_len, query_start_loc=None, @@ -461,9 +333,6 @@ def _add_seq_group( self.num_prefill_tokens += token_len self.prefill_seq_lens.append(seq_len) else: - assert query_len == 1, ( - "seq_len: {}, context_len: {}, query_len: {}".format( - seq_len, context_len, query_len)) self.num_decode_tokens += query_len self.curr_seq_lens.append(curr_seq_len) @@ -494,6 +363,30 @@ def _add_seq_group( seq_len, context_len, start_idx, self.block_size, inter_data.block_tables) + def _get_graph_runner_block_tables( + self, num_seqs: int, + block_tables: List[List[int]]) -> torch.Tensor: + # The shape of graph_block_tables is + # [max batch size, max context len // block size]. + max_batch_size, max_blocks = self.runner.graph_block_tables.shape + assert max_batch_size >= num_seqs + + graph_block_tables = self.runner.graph_block_tables[:num_seqs] + for i, block_table in enumerate(block_tables): + if block_table: + num_blocks = len(block_table) + if num_blocks <= max_blocks: + graph_block_tables[i, :num_blocks] = block_table + else: + # It may be possible to have more blocks allocated due + # to lookahead slots of multi-step, however, they are + # not used anyway, so can be safely ignored. + graph_block_tables[ + i, :max_blocks] = block_table[:max_blocks] + + return torch.from_numpy(graph_block_tables).to( + device=self.runner.device, non_blocking=True) + def build(self, seq_lens: List[int], query_lens: List[int], cuda_graph_pad_size: int, batch_size: int): """Build attention metadata with on-device tensors. @@ -518,33 +411,22 @@ def build(self, seq_lens: List[int], query_lens: List[int], use_captured_graph = cuda_graph_pad_size != -1 max_query_len = max(query_lens) + decode_query_lens = query_lens[self.num_prefills:] + if len(decode_query_lens) > 0: + decode_query_len = max(decode_query_lens) + else: + decode_query_len = 1 max_prefill_seq_len = max(self.prefill_seq_lens, default=0) max_decode_seq_len = max(self.curr_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens + num_seqs = len(seq_lens) if use_captured_graph: self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size) self.block_tables.extend([] * cuda_graph_pad_size) - num_decode_tokens = batch_size - - # The shape of graph_block_tables is - # [max batch size, max context len // block size]. - input_block_tables = self.runner.graph_block_tables[:batch_size] - max_blocks = input_block_tables.shape[1] - for i, block_table in enumerate(self.block_tables): - if block_table: - num_blocks = len(block_table) - if num_blocks <= max_blocks: - input_block_tables[i, :num_blocks] = block_table - else: - # It may be possible to have more blocks allocated due - # to lookahead slots of multi-step, however, they are - # not used anyway, so can be safely ignored. - input_block_tables[ - i, :max_blocks] = block_table[:max_blocks] - - block_tables = torch.from_numpy(input_block_tables).to( - device=device, non_blocking=True) + num_decode_tokens = batch_size - self.num_prefill_tokens + block_tables = self._get_graph_runner_block_tables( + num_seqs, self.block_tables) else: block_tables = make_tensor_with_pad( self.block_tables, @@ -586,6 +468,7 @@ def build(self, seq_lens: List[int], query_lens: List[int], seq_lens=seq_lens, seq_lens_tensor=seq_lens_tensor, max_query_len=max_query_len, + decode_query_len=decode_query_len, max_prefill_seq_len=max_prefill_seq_len, max_decode_seq_len=max_decode_seq_len, query_start_loc=query_start_loc, @@ -701,108 +584,182 @@ def forward( assert k_scale == 1.0 and v_scale == 1.0, ( "key/v_scale is not supported in FlashAttention.") - num_tokens, hidden_size = query.shape - # Reshape the query, key, and value tensors. - query = query.view(-1, self.num_heads, self.head_size) - key = key.view(-1, self.num_kv_heads, self.head_size) - value = value.view(-1, self.num_kv_heads, self.head_size) - - if kv_cache.numel() > 0: - key_cache = kv_cache[0] - value_cache = kv_cache[1] - - # Reshape the input keys and values and store them in the cache. - # If kv_cache is not provided, the new key and value tensors are - # not cached. This happens during the initial memory profiling run. - torch.ops.vllm.reshape_and_cache_flash( - key, - value, - kv_cache, - attn_metadata.slot_mapping.flatten(), - self.kv_cache_dtype, - k_scale, - v_scale, - ) + output = torch.ops.vllm.unified_flash_attention( + query, + key, + value, + self.num_heads, + self.head_size, + self.num_kv_heads, + kv_cache, + self.kv_cache_dtype, + k_scale, + v_scale, + self.scale, + self.sliding_window, + self.alibi_slopes, + self.logits_soft_cap, + ) - num_prefill_tokens = attn_metadata.num_prefill_tokens - num_decode_tokens = attn_metadata.num_decode_tokens - assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa - assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ - f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa - - # Query for decode. KV is not needed because it is already cached. - decode_query = query[num_prefill_tokens:] - # QKV for prefill. - query = query[:num_prefill_tokens] - key = key[:num_prefill_tokens] - value = value[:num_prefill_tokens] - - assert query.shape[0] == num_prefill_tokens - assert decode_query.shape[0] == num_decode_tokens - - prefill_output: Optional[torch.Tensor] = None - decode_output: Optional[torch.Tensor] = None - - if prefill_meta := attn_metadata.prefill_metadata: - # Prompt run. - if (kv_cache.numel() == 0 or prefill_meta.block_tables is None - or prefill_meta.block_tables.numel() == 0): - # normal attention - # When block_tables are not filled, it means q and k are the - # prompt, and they have the same length. - prefill_output = torch.ops.vllm.flash_attn_varlen_func( - q=query, - k=key, - v=value, - cu_seqlens_q=prefill_meta.seq_start_loc, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_q=prefill_meta.max_prefill_seq_len, - max_seqlen_k=prefill_meta.max_prefill_seq_len, - softmax_scale=self.scale, - causal=True, - window_size=self.sliding_window, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ) - else: - # prefix-enabled attention - assert prefill_meta.seq_lens is not None - max_seq_len = max(prefill_meta.seq_lens) - prefill_output = torch.ops.vllm.flash_attn_varlen_func( # noqa - q=query, - k=key_cache, - v=value_cache, - cu_seqlens_q=prefill_meta.query_start_loc, - max_seqlen_q=prefill_meta.max_query_len, - cu_seqlens_k=prefill_meta.seq_start_loc, - max_seqlen_k=max_seq_len, - softmax_scale=self.scale, - causal=True, - alibi_slopes=self.alibi_slopes, - block_table=prefill_meta.block_tables, - softcap=self.logits_soft_cap, - ) - - if decode_meta := attn_metadata.decode_metadata: - # Decoding run. - decode_output = torch.ops.vllm.flash_attn_with_kvcache( - decode_query.unsqueeze(1), - key_cache, - value_cache, - block_table=decode_meta.block_tables, - cache_seqlens=decode_meta.seq_lens_tensor, - softmax_scale=self.scale, + return output + + +@torch.library.custom_op("vllm::unified_flash_attention", + mutates_args=["kv_cache"]) +def unified_flash_attention( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + + current_metadata = get_forward_context() + assert current_metadata is not None + assert isinstance(current_metadata, FlashAttentionMetadata) + attn_metadata: FlashAttentionMetadata = current_metadata + + num_tokens, hidden_size = query.shape + # Reshape the query, key, and value tensors. + query = query.view(-1, num_heads, head_size) + key = key.view(-1, num_kv_heads, head_size) + value = value.view(-1, num_kv_heads, head_size) + + if kv_cache.numel() > 0: + key_cache = kv_cache[0] + value_cache = kv_cache[1] + + # Reshape the input keys and values and store them in the cache. + # If kv_cache is not provided, the new key and value tensors are + # not cached. This happens during the initial memory profiling run. + torch.ops._C_cache_ops.reshape_and_cache_flash( + key, + value, + kv_cache[0], + kv_cache[1], + attn_metadata.slot_mapping.flatten(), + kv_cache_dtype, + k_scale, + v_scale, + ) + + num_prefill_tokens = attn_metadata.num_prefill_tokens + num_decode_tokens = attn_metadata.num_decode_tokens + assert key.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"key : {key.shape} : #prefill tokens {num_prefill_tokens} : #decode tokens {num_decode_tokens}" # noqa + assert value.shape[0] == num_prefill_tokens + num_decode_tokens, \ + f"value : {value.shape} : #prefill toks {num_prefill_tokens} : #decode toks {num_decode_tokens}" # noqa + + # Query for decode. KV is not needed because it is already cached. + decode_query = query[num_prefill_tokens:] + # QKV for prefill. + query = query[:num_prefill_tokens] + key = key[:num_prefill_tokens] + value = value[:num_prefill_tokens] + + assert query.shape[0] == num_prefill_tokens + assert decode_query.shape[0] == num_decode_tokens + + prefill_output: Optional[torch.Tensor] = None + decode_output: Optional[torch.Tensor] = None + + if prefill_meta := attn_metadata.prefill_metadata: + # Prompt run. + if (kv_cache.numel() == 0 or prefill_meta.block_tables is None + or prefill_meta.block_tables.numel() == 0): + # normal attention + # When block_tables are not filled, it means q and k are the + # prompt, and they have the same length. + prefill_output = flash_attn_varlen_func( + q=query, + k=key, + v=value, + cu_seqlens_q=prefill_meta.seq_start_loc, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_q=prefill_meta.max_prefill_seq_len, + max_seqlen_k=prefill_meta.max_prefill_seq_len, + softmax_scale=softmax_scale, + causal=True, + window_size=window_size, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + ) + else: + # prefix-enabled attention + assert prefill_meta.seq_lens is not None + max_seq_len = max(prefill_meta.seq_lens) + prefill_output = flash_attn_varlen_func( # noqa + q=query, + k=key_cache, + v=value_cache, + cu_seqlens_q=prefill_meta.query_start_loc, + max_seqlen_q=prefill_meta.max_query_len, + cu_seqlens_k=prefill_meta.seq_start_loc, + max_seqlen_k=max_seq_len, + softmax_scale=softmax_scale, causal=True, - alibi_slopes=self.alibi_slopes, - softcap=self.logits_soft_cap, - ).squeeze(1) - - if prefill_output is None: - assert decode_output is not None - return decode_output.view(num_decode_tokens, hidden_size) - if decode_output is None: - assert prefill_output is not None - return prefill_output.view(num_prefill_tokens, hidden_size) - output = torch.cat([prefill_output, decode_output], dim=0) - return output.view(num_tokens, hidden_size) + alibi_slopes=alibi_slopes, + block_table=prefill_meta.block_tables, + softcap=logits_soft_cap, + ) + + if decode_meta := attn_metadata.decode_metadata: + # Decoding run. + _, num_head, head_dim = decode_query.shape + decode_query = decode_query.reshape(-1, decode_meta.decode_query_len, + num_head, head_dim) + decode_output = flash_attn_with_kvcache( + q=decode_query, + k_cache=key_cache, + v_cache=value_cache, + block_table=decode_meta.block_tables, + cache_seqlens=decode_meta.seq_lens_tensor, + softmax_scale=softmax_scale, + causal=True, + alibi_slopes=alibi_slopes, + softcap=logits_soft_cap, + ).squeeze(1) + + if prefill_output is None: + assert decode_output is not None + return decode_output.view(num_decode_tokens, hidden_size) + if decode_output is None: + assert prefill_output is not None + return prefill_output.view(num_prefill_tokens, hidden_size) + + # Chunked prefill does not work with speculative decoding. + # Therefore, the query length for decode should be 1 in chunked prefill. + assert decode_meta is not None + assert decode_meta.decode_query_len == 1 + decode_output = decode_output.squeeze(1) + output = torch.cat([prefill_output, decode_output], dim=0) + return output.view(num_tokens, hidden_size) + + +@unified_flash_attention.register_fake +def _( + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + num_heads: int, + head_size: int, + num_kv_heads: int, + kv_cache: torch.Tensor, + kv_cache_dtype: str, + k_scale: float, + v_scale: float, + softmax_scale: float, + window_size: Optional[List[int]] = None, + alibi_slopes: Optional[torch.Tensor] = None, + logits_soft_cap: Optional[float] = None, +) -> torch.Tensor: + return torch.empty_like(query) diff --git a/vllm/attention/backends/flashinfer.py b/vllm/attention/backends/flashinfer.py index a64bf34596f99..40e804934cbdd 100644 --- a/vllm/attention/backends/flashinfer.py +++ b/vllm/attention/backends/flashinfer.py @@ -7,7 +7,7 @@ from flashinfer.decode import CUDAGraphBatchDecodeWithPagedKVCacheWrapper from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper - import vllm.attention.backends.flash_attn # noqa + from vllm.vllm_flash_attn import flash_attn_varlen_func FLASHINFER_WORKSPACE_BUFFER_SIZE = 256 * 1024 * 1024 except ImportError: BatchDecodeWithPagedKVCacheWrapper = None @@ -595,7 +595,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], device = self.runner.device use_captured_graph = cuda_graph_pad_size != -1 - max_query_len = max(query_lens) max_prefill_seq_len = max(self.prefill_seq_lens, default=0) num_decode_tokens = self.num_decode_tokens @@ -634,7 +633,6 @@ def build(self, seq_lens: List[int], query_lens: List[int], dtype=torch.int, device=device, ) - assert max_query_len > 0, ("query_lens: {}".format(query_lens)) assert device is not None seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device, @@ -801,7 +799,7 @@ def forward( # This happens when vllm runs the profiling to # determine the number of blocks. if kv_cache.numel() == 0: - output = torch.ops.vllm.flash_attn_varlen_func( + output = flash_attn_varlen_func( q=query, k=key, v=value, diff --git a/vllm/attention/backends/openvino.py b/vllm/attention/backends/openvino.py index 7992c70f52659..8b36230730380 100644 --- a/vllm/attention/backends/openvino.py +++ b/vllm/attention/backends/openvino.py @@ -9,6 +9,31 @@ from vllm.attention.backends.utils import CommonAttentionState +def copy_cache_block(src_tensor: ov.Tensor, dst_tensor: ov.Tensor, + src_offset: int, dst_offset: int) -> None: + + def create_roi_tensor( + tensor: ov.Tensor, + block_number: int, + ) -> ov.Tensor: + roi_begin = ov.runtime.Coordinate([0, 0, 0, 0]) + roi_end = ov.runtime.Coordinate(tensor.get_shape()) + + roi_begin[0] = block_number + roi_end[0] = block_number + 1 + + if isinstance(tensor, ov.Tensor): + return ov.Tensor(tensor, roi_begin, roi_end) + else: + return ov.RemoteTensor(tensor, roi_begin, roi_end) + + src_roi_tensor = \ + create_roi_tensor(src_tensor, src_offset) + dst_roi_tensor = \ + create_roi_tensor(dst_tensor, dst_offset) + src_roi_tensor.copy_to(dst_roi_tensor) + + class OpenVINOAttentionBackend(AttentionBackend): @staticmethod @@ -44,13 +69,12 @@ def get_kv_cache_shape( @staticmethod def swap_blocks( - src_kv_cache: ov.Tensor, - dst_kv_cache: ov.Tensor, - src_to_dst: torch.Tensor, + src_tensor: ov.Tensor, + dst_tensor: ov.Tensor, + src_to_dists: List[Tuple[int, int]], ) -> None: - # OpenVINO currently supports only CPU, which does not require - # swap of KV cache blocks - raise NotImplementedError + for src, dst in src_to_dists: + copy_cache_block(src_tensor, dst_tensor, src, dst) @staticmethod def copy_blocks( @@ -59,8 +83,8 @@ def copy_blocks( ) -> None: for src, dst in src_to_dists: for key_cache, value_cache in kv_caches: - key_cache.data[dst, :] = key_cache.data[src, :] - value_cache.data[dst, :] = value_cache.data[src, :] + copy_cache_block(key_cache, key_cache, src, dst) + copy_cache_block(value_cache, value_cache, src, dst) @dataclass diff --git a/vllm/attention/backends/rocm_flash_attn.py b/vllm/attention/backends/rocm_flash_attn.py index 5ee3c3b69cf36..fb5cd11ec033a 100644 --- a/vllm/attention/backends/rocm_flash_attn.py +++ b/vllm/attention/backends/rocm_flash_attn.py @@ -116,9 +116,17 @@ class ROCmFlashAttentionMetadata(AttentionMetadata, PagedAttentionMetadata): # Cuda-graph is currently enabled for decoding only. # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention. use_cuda_graph: bool + # (batch_size,) A tensor of context lengths (tokens that are computed # so far). context_lens_tensor: Optional[torch.Tensor] + + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = None + _cached_prefill_metadata: Optional["ROCmFlashAttentionMetadata"] = None _cached_decode_metadata: Optional["ROCmFlashAttentionMetadata"] = None diff --git a/vllm/attention/backends/utils.py b/vllm/attention/backends/utils.py index 49fbb25f4547b..2b8c373178ab3 100644 --- a/vllm/attention/backends/utils.py +++ b/vllm/attention/backends/utils.py @@ -312,7 +312,8 @@ def graph_capture_get_metadata_for_batch( slot_mapping=self._graph_slot_mapping[:batch_size], seq_lens=None, seq_lens_tensor=self._graph_seq_lens[:batch_size], - max_query_len=None, + max_query_len=1, + decode_query_len=1, max_prefill_seq_len=0, max_decode_seq_len=self.runner.max_seq_len_to_capture, query_start_loc=None, diff --git a/vllm/attention/backends/xformers.py b/vllm/attention/backends/xformers.py index 143fa6ee7dea4..a3f9ff64f8b8b 100644 --- a/vllm/attention/backends/xformers.py +++ b/vllm/attention/backends/xformers.py @@ -118,6 +118,12 @@ class XFormersMetadata(AttentionMetadata, PagedAttentionMetadata): # Maximum query length in the batch. None for decoding. max_query_len: Optional[int] = None + # Number of query tokens for each request in the batch. + # Currently, we require that all requests have the same number of query + # tokens during the decoding phase. When speculavie decoding is enabled, + # decode_query_len might be greater than 1. In all other cases, it is 1. + decode_query_len: Optional[int] = None + # (batch_size + 1,). The cumulative subquery lengths of the sequences in # the batch, used to index into subquery. E.g., if the subquery length # is [4, 6], it is [0, 4, 10]. diff --git a/vllm/config.py b/vllm/config.py index 61999713f5212..786ed1586a3ea 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -31,28 +31,7 @@ logger = init_logger(__name__) _EMBEDDING_MODEL_MAX_NUM_BATCHED_TOKENS = 32768 -_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 4096 - -_PP_SUPPORTED_MODELS = [ - "AquilaForCausalLM", - "AquilaModel", - "DeepseekV2ForCausalLM", - "GPT2LMHeadModel", - "InternLM2ForCausalLM", - "InternLMForCausalLM", - "InternVLChatModel", - "JAISLMHeadModel", - "LlamaForCausalLM", - "LLaMAForCausalLM", - "MistralForCausalLM", - "MixtralForCausalLM", - "NemotronForCausalLM", - "Phi3ForCausalLM", - "Qwen2ForCausalLM", - "Qwen2MoeForCausalLM", - "QWenLMHeadModel", - "Qwen2VLForConditionalGeneration", -] +_MULTIMODAL_MODEL_MAX_NUM_BATCHED_TOKENS = 5120 class ModelConfig: @@ -228,16 +207,14 @@ def _init_multimodal_config( self, limit_mm_per_prompt: Optional[Mapping[str, int]] ) -> Optional["MultiModalConfig"]: architectures = getattr(self.hf_config, "architectures", []) - if any( - ModelRegistry.is_multimodal_model(arch) - for arch in architectures): + if ModelRegistry.is_multimodal_model(architectures): return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {}) - else: - if limit_mm_per_prompt: - raise ValueError( - "limit_mm_per_prompt is only supported for multimodal " - "models.") - return None + + if limit_mm_per_prompt: + raise ValueError("`limit_mm_per_prompt` is only supported for " + "multimodal models.") + + return None def _verify_tokenizer_mode(self) -> None: tokenizer_mode = self.tokenizer_mode.lower() @@ -249,8 +226,7 @@ def _verify_tokenizer_mode(self) -> None: def _verify_embedding_mode(self) -> None: architectures = getattr(self.hf_config, "architectures", []) - self.embedding_mode = any( - ModelRegistry.is_embedding_model(arch) for arch in architectures) + self.embedding_mode = ModelRegistry.is_embedding_model(architectures) def _parse_quant_hf_config(self): quant_cfg = getattr(self.hf_config, "quantization_config", None) @@ -418,17 +394,17 @@ def verify_with_parallel_config( f"({tensor_parallel_size}).") pipeline_parallel_size = parallel_config.pipeline_parallel_size - architectures = getattr(self.hf_config, "architectures", []) - if not all(arch in _PP_SUPPORTED_MODELS - for arch in architectures) and pipeline_parallel_size > 1: - raise NotImplementedError( - "Pipeline parallelism is only supported for the following " - f" architectures: {_PP_SUPPORTED_MODELS}.") + if pipeline_parallel_size > 1: + architectures = getattr(self.hf_config, "architectures", []) + if not ModelRegistry.is_pp_supported_model(architectures): + raise NotImplementedError( + "Pipeline parallelism is not supported for this model. " + "Supported models implement the `SupportsPP` interface.") - if pipeline_parallel_size > 1 and self.use_async_output_proc: - logger.warning("Async output processor is not supported with " - "pipeline parallelism currently. Disabling it.") - self.use_async_output_proc = False + if self.use_async_output_proc: + logger.warning("Async output processor is not supported with " + "pipeline parallelism currently. Disabling it.") + self.use_async_output_proc = False def get_hf_config_sliding_window(self) -> Optional[int]: """Get the sliding window size, or None if disabled.""" @@ -981,7 +957,7 @@ def __init__(self, max_num_batched_tokens: Optional[int], max_num_seqs: int, max_model_len: int, - use_v2_block_manager: bool = False, + use_v2_block_manager: bool = True, num_lookahead_slots: int = 0, delay_factor: float = 0.0, enable_chunked_prefill: bool = False, @@ -1129,6 +1105,7 @@ def maybe_create_spec_config( speculative_model_quantization: Optional[str], speculative_draft_tensor_parallel_size: Optional[int], num_speculative_tokens: Optional[int], + speculative_disable_mqa_scorer: Optional[bool], speculative_max_model_len: Optional[int], enable_chunked_prefill: bool, use_v2_block_manager: bool, @@ -1163,6 +1140,9 @@ def maybe_create_spec_config( num_speculative_tokens (Optional[int]): The number of speculative tokens, if provided. Will default to the number in the draft model config if present, otherwise is required. + speculative_disable_mqa_scorer (Optional[bool]): Disable the MQA + scorer for the speculative model and fall back to batch + expansion for scoring. speculative_max_model_len (Optional[int]): The maximum model len of the speculative model. Used when testing the ability to skip speculation for some sequences. @@ -1317,6 +1297,7 @@ def maybe_create_spec_config( draft_model_config, draft_parallel_config, num_speculative_tokens, + speculative_disable_mqa_scorer, speculative_disable_by_batch_size, ngram_prompt_lookup_max, ngram_prompt_lookup_min, @@ -1413,6 +1394,7 @@ def __init__( draft_model_config: ModelConfig, draft_parallel_config: ParallelConfig, num_speculative_tokens: int, + speculative_disable_mqa_scorer: Optional[bool], speculative_disable_by_batch_size: Optional[int], ngram_prompt_lookup_max: Optional[int], ngram_prompt_lookup_min: Optional[int], @@ -1459,6 +1441,7 @@ def __init__( self.draft_model_config = draft_model_config self.draft_parallel_config = draft_parallel_config self.num_speculative_tokens = num_speculative_tokens + self.speculative_disable_mqa_scorer = speculative_disable_mqa_scorer self.speculative_disable_by_batch_size = \ speculative_disable_by_batch_size self.ngram_prompt_lookup_max = ngram_prompt_lookup_max or 0 diff --git a/vllm/core/scheduler.py b/vllm/core/scheduler.py index 5b7587d150843..f3a5016d0e62a 100644 --- a/vllm/core/scheduler.py +++ b/vllm/core/scheduler.py @@ -1607,10 +1607,29 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, # in a decode phase. Do not chunk. if enable_chunking and len(seqs) == 1: remaining_token_budget = budget.remaining_token_budget() - if self.cache_config.enable_prefix_caching: + if self.scheduler_config.is_multi_step: + # The current multi-step + chunked prefill capability does + # not actually support chunking prompts. + # + # Therefore, `num_new_tokens` is computed in the same fashion + # for both multi-step+chunked-prefill & + # multi-step+chunked-prefill+APC + # + # Prompts with more tokens than the current remaining budget + # are postponed to future scheduler steps + if num_new_tokens > self._get_prompt_limit(seq_group): + # If the seq_group is in prompt-stage, pass the + # num_new_tokens as-is so the caller can ignore + # the sequence. + pass + else: + num_new_tokens = 0 \ + if num_new_tokens > remaining_token_budget \ + else num_new_tokens + elif self.cache_config.enable_prefix_caching: # When prefix caching is enabled, we always allocate - # the number of new tokens that is dividable by the block size - # to avoid partial block matching. + # the number of new tokens that is dividable by the block + # size to avoid partial block matching. block_size = self.cache_config.block_size remainder = budget.token_budget % block_size if remainder != 0: @@ -1623,16 +1642,6 @@ def _get_num_new_tokens(self, seq_group: SequenceGroup, if remaining_token_budget < num_new_tokens: num_new_tokens = (remaining_token_budget // block_size) * block_size - elif self.scheduler_config.is_multi_step: - if num_new_tokens > self._get_prompt_limit(seq_group): - # If the seq_group is in prompt-stage, pass the - # num_new_tokens as-is so the caller can ignore - # the sequence. - pass - else: - num_new_tokens = 0 \ - if num_new_tokens > remaining_token_budget \ - else num_new_tokens else: num_new_tokens = min(num_new_tokens, remaining_token_budget) return num_new_tokens diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index b6f125d450e68..898a8d4c6eeaa 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -112,7 +112,7 @@ class EngineArgs: block_size: int = 16 if not current_platform.is_hpu() else 128 enable_prefix_caching: bool = False disable_sliding_window: bool = False - use_v2_block_manager: bool = False + use_v2_block_manager: bool = True swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 @@ -150,7 +150,7 @@ class EngineArgs: max_cpu_loras: Optional[int] = None device: str = 'auto' num_scheduler_steps: int = 1 - multi_step_stream_outputs: bool = False + multi_step_stream_outputs: bool = True ray_workers_use_nsight: bool = False num_gpu_blocks_override: Optional[int] = None num_lookahead_slots: int = 0 @@ -167,6 +167,7 @@ class EngineArgs: speculative_model_quantization: Optional[str] = None speculative_draft_tensor_parallel_size: Optional[int] = None num_speculative_tokens: Optional[int] = None + speculative_disable_mqa_scorer: Optional[bool] = False speculative_max_model_len: Optional[int] = None speculative_disable_by_batch_size: Optional[int] = None ngram_prompt_lookup_max: Optional[int] = None @@ -380,9 +381,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action='store_true', help='Disables sliding window, ' 'capping to sliding window size') - parser.add_argument('--use-v2-block-manager', - action='store_true', - help='Use BlockSpaceMangerV2.') + parser.add_argument( + '--use-v2-block-manager', + default=EngineArgs.use_v2_block_manager, + action='store_true', + help='Use BlockSpaceMangerV2. By default this is set to True. ' + 'Set to False to use BlockSpaceManagerV1') parser.add_argument( '--num-lookahead-slots', type=int, @@ -611,13 +615,17 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--multi-step-stream-outputs', - action='store_true', - help='If True, then multi-step will stream outputs for every step') + action=StoreBoolean, + default=EngineArgs.multi_step_stream_outputs, + nargs="?", + const="True", + help='If False, then multi-step will stream outputs at the end ' + 'of all steps') parser.add_argument( '--scheduler-delay-factor', type=float, default=EngineArgs.scheduler_delay_factor, - help='Apply a delay (of delay factor multiplied by previous' + help='Apply a delay (of delay factor multiplied by previous ' 'prompt latency) before scheduling next prompt.') parser.add_argument( '--enable-chunked-prefill', @@ -640,7 +648,7 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=nullable_str, choices=[*QUANTIZATION_METHODS, None], default=EngineArgs.speculative_model_quantization, - help='Method used to quantize the weights of speculative model.' + help='Method used to quantize the weights of speculative model. ' 'If None, we first check the `quantization_config` ' 'attribute in the model config file. If that is ' 'None, we assume the model weights are not ' @@ -652,6 +660,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: default=EngineArgs.num_speculative_tokens, help='The number of speculative tokens to sample from ' 'the draft model in speculative decoding.') + parser.add_argument( + '--speculative-disable-mqa-scorer', + action='store_true', + help= + 'If set to True, the MQA scorer will be disabled in speculative ' + ' and fall back to batch expansion') parser.add_argument( '--speculative-draft-tensor-parallel-size', '-spec-draft-tp', @@ -802,13 +816,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "lower performance.") parser.add_argument( '--override-neuron-config', - type=lambda configs: { - str(key): value - for key, value in - (config.split(':') for config in configs.split(',')) - }, + type=json.loads, default=None, - help="override or set neuron device configuration.") + help="Override or set neuron device configuration. " + "e.g. {\"cast_logits_dtype\": \"bloat16\"}.'") parser.add_argument( '--scheduling-policy', @@ -985,6 +996,7 @@ def create_engine_config(self) -> EngineConfig: speculative_draft_tensor_parallel_size = \ self.speculative_draft_tensor_parallel_size, num_speculative_tokens=self.num_speculative_tokens, + speculative_disable_mqa_scorer=self.speculative_disable_mqa_scorer, speculative_disable_by_batch_size=self. speculative_disable_by_batch_size, speculative_max_model_len=self.speculative_max_model_len, @@ -1006,10 +1018,6 @@ def create_engine_config(self) -> EngineConfig: if speculative_config is not None: raise ValueError("Speculative decoding is not supported with " "multi-step (--num-scheduler-steps > 1)") - if self.enable_chunked_prefill and self.enable_prefix_caching: - raise ValueError("Multi-Step is not supported with " - "both Chunked-Prefill and Prefix-Caching " - "enabled together.") if self.enable_chunked_prefill and self.pipeline_parallel_size > 1: raise ValueError("Multi-Step Chunked-Prefill is not supported " "for pipeline-parallel-size > 1") diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index cb489084f48de..6f3b73dbeee20 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -16,7 +16,7 @@ from vllm.engine.metrics_types import StatLoggerBase from vllm.executor.executor_base import ExecutorAsyncBase from vllm.executor.gpu_executor import GPUExecutorAsync -from vllm.executor.habana_executor import HabanaExecutorAsync +from vllm.executor.hpu_executor import HPUExecutorAsync from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import PromptType from vllm.logger import init_logger @@ -620,12 +620,11 @@ def _get_executor_cls( elif engine_config.device_config.device_type == "hpu": if distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_habana_executor import ( - RayHabanaExecutorAsync) - executor_class = RayHabanaExecutorAsync + from vllm.executor.ray_hpu_executor import RayHPUExecutorAsync + executor_class = RayHPUExecutorAsync else: - from vllm.executor.habana_executor import HabanaExecutorAsync - executor_class = HabanaExecutorAsync + from vllm.executor.hpu_executor import HPUExecutorAsync + executor_class = HPUExecutorAsync elif engine_config.device_config.device_type == "openvino": assert distributed_executor_backend is None, ( "Distributed execution is not supported with " @@ -1206,7 +1205,7 @@ async def start_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes if type(self.engine.model_executor) == GPUExecutorAsync or \ - type(self.engine.model_executor) == HabanaExecutorAsync: # noqa: E721 + type(self.engine.model_executor) == HPUExecutorAsync: # noqa: E721 self.engine.model_executor.start_profile() else: self.engine.model_executor._run_workers("start_profile") @@ -1215,7 +1214,7 @@ async def stop_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes if type(self.engine.model_executor) == GPUExecutorAsync or \ - type(self.engine.model_executor) == HabanaExecutorAsync: # noqa: E721 + type(self.engine.model_executor) == HPUExecutorAsync: # noqa: E721 self.engine.model_executor.stop_profile() else: self.engine.model_executor._run_workers("stop_profile") diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index f41d074ad536c..af0f010781040 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -28,7 +28,7 @@ from vllm.entrypoints.openai.logits_processors import get_logits_processors from vllm.executor.executor_base import ExecutorBase from vllm.executor.gpu_executor import GPUExecutor -from vllm.executor.habana_executor import HabanaExecutor +from vllm.executor.hpu_executor import HPUExecutor from vllm.executor.ray_utils import initialize_ray_cluster from vllm.inputs import (INPUT_REGISTRY, EncoderDecoderLLMInputs, InputRegistry, LLMInputs, PromptType) @@ -533,11 +533,11 @@ def _get_executor_cls(cls, elif engine_config.device_config.device_type == "hpu": if distributed_executor_backend == "ray": initialize_ray_cluster(engine_config.parallel_config) - from vllm.executor.ray_habana_executor import RayHabanaExecutor - executor_class = RayHabanaExecutor + from vllm.executor.ray_hpu_executor import RayHPUExecutor + executor_class = RayHPUExecutor else: - from vllm.executor.habana_executor import HabanaExecutor - executor_class = HabanaExecutor + from vllm.executor.hpu_executor import HPUExecutor + executor_class = HPUExecutor elif engine_config.device_config.device_type == "openvino": from vllm.executor.openvino_executor import OpenVINOExecutor executor_class = OpenVINOExecutor @@ -1120,6 +1120,8 @@ def update_prefill_num_computed_tokens( update_prefill_num_computed_tokens(seq_group, seq_group_meta, len(output), is_first_step_output) + elif not is_async: + seq_group.update_num_computed_tokens(1) if outputs: for o in outputs: @@ -1143,8 +1145,16 @@ def update_prefill_num_computed_tokens( else: self.output_processor.process_prompt_logprob(seq_group, output) if seq_group_meta.do_sample: - self.output_processor.process_outputs( + output_token_num = self.output_processor.process_outputs( seq_group, output, is_async) + if self.speculative_config: + # We -1 here because we always + # (w/o speculative decoding) add the number of + # computed tokens by one in the decoding phase. + # Therefore, we remove that one token that + # is already added. + seq_group.update_num_computed_tokens(output_token_num - + 1) if seq_group.is_finished(): finished_now.append(i) @@ -1261,11 +1271,12 @@ def _advance_to_next_step( # decodes after the very first step. Therefore, # we skip the update to the num_computed_tokens # here. - pass + seq_group.update_num_computed_tokens(1) else: seq_group.update_num_computed_tokens( seq_group_metadata.token_chunk_size) - + else: + seq_group.update_num_computed_tokens(1) if seq_group_metadata.do_sample: assert len(sequence_group_outputs.samples) == 1, ( "Async output processor expects a single sample" @@ -1276,7 +1287,6 @@ def _advance_to_next_step( assert len(seq_group.seqs) == 1 seq = seq_group.seqs[0] seq.append_token_id(sample.output_token, sample.logprobs) - seq_group.update_num_computed_tokens(1) def finish_measurements(self): self.model_executor.finish_measurements() @@ -1796,7 +1806,7 @@ def start_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes (MultiprocessingGPUExecutor) if type(self.model_executor) == GPUExecutor or \ - type(self.model_executor) == HabanaExecutor: # noqa: E721 + type(self.model_executor) == HPUExecutor: # noqa: E721 self.model_executor.start_profile() else: self.model_executor._run_workers("start_profile") @@ -1805,7 +1815,7 @@ def stop_profile(self) -> None: # using type instead of isinstance to check to avoid capturing # inherited classes (MultiprocessingGPUExecutor) if type(self.model_executor) == GPUExecutor or \ - type(self.model_executor) == HabanaExecutor: # noqa: E721 + type(self.model_executor) == HPUExecutor: # noqa: E721 self.model_executor.stop_profile() else: self.model_executor._run_workers("stop_profile") diff --git a/vllm/engine/multiprocessing/engine.py b/vllm/engine/multiprocessing/engine.py index 49500099fbcaf..3501f12c065cf 100644 --- a/vllm/engine/multiprocessing/engine.py +++ b/vllm/engine/multiprocessing/engine.py @@ -23,7 +23,7 @@ # yapf: enable from vllm.envs import VLLM_RPC_TIMEOUT from vllm.executor.gpu_executor import GPUExecutor -from vllm.executor.habana_executor import HabanaExecutor +from vllm.executor.hpu_executor import HPUExecutor from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.usage.usage_lib import UsageContext @@ -366,14 +366,14 @@ def _alive(self): def start_profile(self) -> None: if type(self.engine.model_executor) is GPUExecutor or \ - type(self.engine.model_executor) is HabanaExecutor: + type(self.engine.model_executor) is HPUExecutor: self.engine.model_executor.start_profile() else: self.engine.model_executor._run_workers("start_profile") def stop_profile(self) -> None: if type(self.engine.model_executor) is GPUExecutor or \ - type(self.engine.model_executor) is HabanaExecutor: + type(self.engine.model_executor) is HPUExecutor: self.engine.model_executor.stop_profile() else: self.engine.model_executor._run_workers("stop_profile") diff --git a/vllm/engine/output_processor/interfaces.py b/vllm/engine/output_processor/interfaces.py index 50adaf4e59188..554880a3cc438 100644 --- a/vllm/engine/output_processor/interfaces.py +++ b/vllm/engine/output_processor/interfaces.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Callable, List +from typing import Callable, List, Optional from vllm.config import SchedulerConfig from vllm.core.scheduler import Scheduler @@ -58,10 +58,14 @@ def create_output_processor( @abstractmethod def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput], - is_async: bool) -> None: + is_async: bool) -> Optional[int]: """Process new token ids for the sequence group. Handles logic such as detokenization, stop checking, and freeing/forking sequences in the scheduler. + + Return the number of new tokens generated in the sequence group. + The returned value is optional because it is only used for + speculative decoding mqa scorer. """ pass diff --git a/vllm/engine/output_processor/multi_step.py b/vllm/engine/output_processor/multi_step.py index 6dac3619580bb..f35b1ba9c2bdd 100644 --- a/vllm/engine/output_processor/multi_step.py +++ b/vllm/engine/output_processor/multi_step.py @@ -1,5 +1,5 @@ import functools -from typing import Callable, List +from typing import Callable, List, Optional from vllm.core.scheduler import Scheduler from vllm.engine.output_processor.interfaces import ( @@ -69,7 +69,7 @@ def _log_prompt_logprob_unsupported_warning_once(): def process_outputs(self, sequence_group: SequenceGroup, outputs: List[SequenceGroupOutput], - is_async: bool = False) -> None: + is_async: bool = False) -> Optional[int]: """Append new tokens in the outputs to sequences in the sequence group. This only supports sequence groups of size 1. It supports greater than @@ -84,6 +84,10 @@ def process_outputs(self, tokens from the previous step. If this is true, then no tokens need to be appended since it is already done externally (before the next schedule() call) + + Returns: + The number of tokens appended to the sequence. This is optional + because only speculative decode uses this return value. """ # Sequences can be in RUNNING or FINISHED_ABORTED state # once scheduled, as a sequence is moved to FINSIHED_ABORTED @@ -106,6 +110,7 @@ def process_outputs(self, # was already appended, so we only need to do the rest of the # postprocessor: Detokenization + stopping logic self._process_decode_and_stop(seq, sequence_group.sampling_params) + return None else: # Standard multi-step case @@ -121,8 +126,8 @@ def process_outputs(self, ] assert valid_samples - self._process_seq_outputs(seq, valid_samples, - sequence_group.sampling_params) + return self._process_seq_outputs(seq, valid_samples, + sequence_group.sampling_params) def _process_decode_and_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None: @@ -140,7 +145,7 @@ def _process_decode_and_stop(self, seq: Sequence, def _process_seq_outputs(self, seq: Sequence, valid_samples: List[SequenceOutput], - sampling_params: SamplingParams) -> None: + sampling_params: SamplingParams) -> int: output_token_ids = [sample.output_token for sample in valid_samples] output_logprobs = [sample.logprobs for sample in valid_samples] @@ -148,7 +153,6 @@ def _process_seq_outputs(self, seq: Sequence, remaining_tokens = sampling_params.max_tokens - (seq.get_output_len() + len(output_token_ids)) if remaining_tokens < 0: - valid_samples = valid_samples[:remaining_tokens] output_token_ids = output_token_ids[:remaining_tokens] # Truncate any tokens after EOS. This is required as spec decode @@ -162,7 +166,6 @@ def _process_seq_outputs(self, seq: Sequence, for i in range(len(output_token_ids)): if output_token_ids[i] == eos_token_id: output_token_ids = output_token_ids[:i + 1] - valid_samples = valid_samples[:i + 1] break # Incrementally append tokens to the sequence, as if we had only one new @@ -173,9 +176,9 @@ def _process_seq_outputs(self, seq: Sequence, token_id=output_token_id, logprobs=output_logprob, ) - seq.data.update_num_computed_tokens(1) self._process_decode_and_stop(seq, sampling_params) if seq.is_finished(): break + return len(output_token_ids) diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index 5078a2654eb22..bf367482cd80c 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -53,6 +53,7 @@ from vllm.entrypoints.openai.serving_engine import BaseModelPath from vllm.entrypoints.openai.serving_tokenization import ( OpenAIServingTokenization) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.logger import init_logger from vllm.usage.usage_lib import UsageContext from vllm.utils import FlexibleArgumentParser, get_open_zmq_ipc_path @@ -526,6 +527,15 @@ async def run_server(args, **uvicorn_kwargs) -> None: logger.info("vLLM API server version %s", VLLM_VERSION) logger.info("args: %s", args) + if args.tool_parser_plugin and len(args.tool_parser_plugin) > 3: + ToolParserManager.import_tool_parser(args.tool_parser_plugin) + + valide_tool_parses = ToolParserManager.tool_parsers.keys() + if args.enable_auto_tool_choice \ + and args.tool_call_parser not in valide_tool_parses: + raise KeyError(f"invalid tool call parser: {args.tool_call_parser} " + f"(chose from {{ {','.join(valide_tool_parses)} }})") + temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) temp_socket.bind(("", args.port)) diff --git a/vllm/entrypoints/openai/cli_args.py b/vllm/entrypoints/openai/cli_args.py index 446769a277f58..f59ba4e30accd 100644 --- a/vllm/entrypoints/openai/cli_args.py +++ b/vllm/entrypoints/openai/cli_args.py @@ -12,6 +12,7 @@ from vllm.engine.arg_utils import AsyncEngineArgs, nullable_str from vllm.entrypoints.openai.serving_engine import (LoRAModulePath, PromptAdapterPath) +from vllm.entrypoints.openai.tool_parsers import ToolParserManager from vllm.utils import FlexibleArgumentParser @@ -190,16 +191,27 @@ def make_arg_parser(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "Enable auto tool choice for supported models. Use --tool-call-parser" "to specify which parser to use") + valid_tool_parsers = ToolParserManager.tool_parsers.keys() parser.add_argument( "--tool-call-parser", type=str, - choices=["mistral", "hermes", "llama3_json"], + metavar="{" + ",".join(valid_tool_parsers) + "} or name registered in " + "--tool-parser-plugin", default=None, help= "Select the tool call parser depending on the model that you're using." " This is used to parse the model-generated tool call into OpenAI API " "format. Required for --enable-auto-tool-choice.") + parser.add_argument( + "--tool-parser-plugin", + type=str, + default="", + help= + "Special the tool parser plugin write to parse the model-generated tool" + " into OpenAI API format, the name register in this plugin can be used " + "in --tool-call-parser.") + parser = AsyncEngineArgs.add_cli_args(parser) parser.add_argument('--max-log-len', diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 41f131f56b51f..ce529f6f0ff58 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -29,10 +29,7 @@ OpenAIServing, PromptAdapterPath, TextTokensPrompt) -from vllm.entrypoints.openai.tool_parsers import (Hermes2ProToolParser, - Llama3JsonToolParser, - MistralToolParser, - ToolParser) +from vllm.entrypoints.openai.tool_parsers import ToolParser, ToolParserManager from vllm.inputs import TokensPrompt from vllm.logger import init_logger from vllm.outputs import CompletionOutput, RequestOutput @@ -82,15 +79,13 @@ def __init__(self, self.tool_parser: Optional[Callable[[AnyTokenizer], ToolParser]] = None if self.enable_auto_tools: - if tool_parser == "mistral": - self.tool_parser = MistralToolParser - elif tool_parser == "hermes": - self.tool_parser = Hermes2ProToolParser - elif tool_parser == "llama3_json": - self.tool_parser = Llama3JsonToolParser - else: + try: + self.tool_parser = ToolParserManager.get_tool_parser( + tool_parser) + except Exception as e: raise TypeError("Error: --enable-auto-tool-choice requires " - "--tool-call-parser") + f"tool_parser:'{tool_parser}' which has not " + "been registered") from e async def create_chat_completion( self, @@ -187,6 +182,10 @@ async def create_chat_completion( raw_request.state.request_metadata = request_metadata try: + if self.enable_auto_tools and self.tool_parser: + request = self.tool_parser(tokenizer).adjust_request( + request=request) + if isinstance(prompt, str): prompt_inputs = self._tokenize_prompt_input( request, @@ -282,11 +281,11 @@ async def chat_completion_stream_generator( num_choices = 1 if request.n is None else request.n previous_num_tokens = [0] * num_choices finish_reason_sent = [False] * num_choices - num_prompt_tokens = 0 - tool_parser: Optional[ToolParser] = self.tool_parser( - tokenizer) if self.tool_parser else None + tool_parsers: List[Optional[ToolParser]] = [ + self.tool_parser(tokenizer) if self.tool_parser else None + ] * num_choices if isinstance(request.tool_choice, ChatCompletionNamedToolChoiceParam): tool_choice_function_name = request.tool_choice.function.name @@ -324,7 +323,7 @@ async def chat_completion_stream_generator( # NOTE num_choices defaults to 1 so this usually executes # once per request for i in range(num_choices): - + tool_parser = tool_parsers[i] choice_data = ChatCompletionResponseStreamChoice( index=i, delta=DeltaMessage( @@ -399,6 +398,7 @@ async def chat_completion_stream_generator( for output in res.outputs: i = output.index + tool_parser = tool_parsers[i] if finish_reason_sent[i]: continue @@ -446,7 +446,8 @@ async def chat_completion_stream_generator( delta_text=delta_text, previous_token_ids=previous_token_ids, current_token_ids=current_token_ids, - delta_token_ids=output.token_ids)) + delta_token_ids=output.token_ids, + request=request)) # update the previous values for the next iteration previous_texts[i] = current_text @@ -685,7 +686,8 @@ async def chat_completion_full_generator( and self.tool_parser: tool_parser = self.tool_parser(tokenizer) - tool_call_info = tool_parser.extract_tool_calls(output.text) + tool_call_info = tool_parser.extract_tool_calls( + output.text, request=request) tools_called = tool_call_info.tools_called if tool_call_info.tools_called: message = ChatMessage(role=role, diff --git a/vllm/entrypoints/openai/tool_parsers/__init__.py b/vllm/entrypoints/openai/tool_parsers/__init__.py index 0069a2b8044b7..309d9bede489b 100644 --- a/vllm/entrypoints/openai/tool_parsers/__init__.py +++ b/vllm/entrypoints/openai/tool_parsers/__init__.py @@ -1,9 +1,10 @@ -from .abstract_tool_parser import ToolParser +from .abstract_tool_parser import ToolParser, ToolParserManager from .hermes_tool_parser import Hermes2ProToolParser +from .internlm2_tool_parser import Internlm2ToolParser from .llama_tool_parser import Llama3JsonToolParser from .mistral_tool_parser import MistralToolParser __all__ = [ - "ToolParser", "Hermes2ProToolParser", "MistralToolParser", - "Llama3JsonToolParser" + "ToolParser", "ToolParserManager", "Hermes2ProToolParser", + "MistralToolParser", "Internlm2ToolParser", "Llama3JsonToolParser" ] diff --git a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py index 873f615d43257..7e55532bc7297 100644 --- a/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py @@ -1,9 +1,14 @@ -from typing import Dict, List, Sequence, Union +import importlib +import importlib.util +import os +from typing import Callable, Dict, List, Optional, Sequence, Type, Union -from vllm.entrypoints.openai.protocol import (DeltaMessage, +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaMessage, ExtractedToolCallInformation) from vllm.logger import init_logger from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import is_list_of logger = init_logger(__name__) @@ -24,8 +29,16 @@ def __init__(self, tokenizer: AnyTokenizer): self.model_tokenizer = tokenizer - def extract_tool_calls(self, - model_output: str) -> ExtractedToolCallInformation: + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + """ + Static method that used to adjust the request parameters. + """ + return request + + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: """ Static method that should be implemented for extracting tool calls from a complete model-generated string. @@ -44,6 +57,7 @@ def extract_tool_calls_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: """ Instance method that should be implemented for extracting tool calls @@ -55,3 +69,86 @@ def extract_tool_calls_streaming( raise NotImplementedError( "AbstractToolParser.extract_tool_calls_streaming has not been " "implemented!") + + +class ToolParserManager: + tool_parsers: Dict[str, Type] = {} + + @classmethod + def get_tool_parser(cls, name) -> Type: + """ + Get tool parser by name which is registered by `register_module`. + + Raise a KeyError exception if the name is not registered. + """ + if name in cls.tool_parsers: + return cls.tool_parsers[name] + + raise KeyError(f"tool helper: '{name}' not found in tool_parsers") + + @classmethod + def _register_module(cls, + module: Type, + module_name: Optional[Union[str, List[str]]] = None, + force: bool = True) -> None: + if not issubclass(module, ToolParser): + raise TypeError( + f'module must be subclass of ToolParser, but got {type(module)}' + ) + if module_name is None: + module_name = module.__name__ + if isinstance(module_name, str): + module_name = [module_name] + for name in module_name: + if not force and name in cls.tool_parsers: + existed_module = cls.tool_parsers[name] + raise KeyError(f'{name} is already registered ' + f'at {existed_module.__module__}') + cls.tool_parsers[name] = module + + @classmethod + def register_module( + cls, + name: Optional[Union[str, List[str]]] = None, + force: bool = True, + module: Union[Type, None] = None) -> Union[type, Callable]: + """ + Register module with the given name or name list. it can be used as a + decoder(with module as None) or normal function(with module as not + None). + """ + if not isinstance(force, bool): + raise TypeError(f'force must be a boolean, but got {type(force)}') + + # raise the error ahead of time + if not (name is None or isinstance(name, str) + or is_list_of(name, str)): + raise TypeError( + 'name must be None, an instance of str, or a sequence of str, ' + f'but got {type(name)}') + + # use it as a normal method: x.register_module(module=SomeClass) + if module is not None: + cls._register_module(module=module, module_name=name, force=force) + return module + + # use it as a decorator: @x.register_module() + def _register(module): + cls._register_module(module=module, module_name=name, force=force) + return module + + return _register + + @classmethod + def import_tool_parser(cls, plugin_path: str) -> None: + """ + Import a user defined tool parser by the path of the tool parser define + file. + """ + module_name = os.path.splitext(os.path.basename(plugin_path))[0] + spec = importlib.util.spec_from_file_location(module_name, plugin_path) + if spec is None or spec.loader is None: + logger.error("load %s from %s failed.", module_name, plugin_path) + return + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) diff --git a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py index ad6f536838a88..40f041767190b 100644 --- a/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py @@ -5,12 +5,13 @@ import partial_json_parser from partial_json_parser.core.options import Allow -from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser) + ToolParser, ToolParserManager) from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger @@ -20,6 +21,7 @@ logger = init_logger(__name__) +@ToolParserManager.register_module("hermes") class Hermes2ProToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): @@ -57,8 +59,11 @@ def __init__(self, tokenizer: AnyTokenizer): "Hermes 2 Pro Tool parser could not locate tool call start/end " "tokens in the tokenizer!") - def extract_tool_calls(self, - model_output: str) -> ExtractedToolCallInformation: + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: # sanity check; avoid unnecessary processing if self.tool_call_start_token not in model_output: @@ -114,6 +119,7 @@ def extract_tool_calls_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: logger.debug("delta_text: %s", delta_text) diff --git a/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py new file mode 100644 index 0000000000000..905ab7db3d04c --- /dev/null +++ b/vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py @@ -0,0 +1,208 @@ +import json +from typing import Dict, Sequence, Union + +import partial_json_parser +from partial_json_parser.core.options import Allow + +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, + DeltaToolCall, + ExtractedToolCallInformation, + FunctionCall, ToolCall) +from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( + ToolParser, ToolParserManager) +from vllm.entrypoints.openai.tool_parsers.utils import ( + extract_intermediate_diff) +from vllm.logger import init_logger +from vllm.transformers_utils.tokenizer import AnyTokenizer +from vllm.utils import random_uuid + +logger = init_logger(__name__) + + +@ToolParserManager.register_module(["internlm"]) +class Internlm2ToolParser(ToolParser): + + def __init__(self, tokenizer: AnyTokenizer): + super().__init__(tokenizer) + self.position = 0 + + def adjust_request( + self, request: ChatCompletionRequest) -> ChatCompletionRequest: + if request.tools and request.tool_choice != 'none': + # do not skip special tokens because internlm use the special + # tokens to indicated the start and end of the tool calls + # information. + request.skip_special_tokens = False + return request + + def get_argments(self, obj): + if "parameters" in obj: + return obj.get("parameters") + elif "arguments" in obj: + return obj.get("arguments") + return None + + def extract_tool_calls_streaming( + self, + previous_text: str, + current_text: str, + delta_text: str, + previous_token_ids: Sequence[int], + current_token_ids: Sequence[int], + delta_token_ids: Sequence[int], + request: ChatCompletionRequest, + ) -> Union[DeltaMessage, None]: + if '<|action_start|>' not in current_text: + self.position = len(current_text) + return DeltaMessage(content=delta_text) + # if the tool call is sended, return a empty delta message + # to make sure the finish_reason will be send correctly. + if self.current_tool_id > 0: + return DeltaMessage(content='') + + last_pos = self.position + if '<|action_start|><|plugin|>' not in current_text[last_pos:]: + return None + + new_delta = current_text[last_pos:] + text, action = new_delta.split('<|action_start|><|plugin|>') + + if len(text) > 0: + self.position = self.position + len(text) + return DeltaMessage(content=text) + + action = action.strip() + action = action.split('<|action_end|>'.strip())[0] + + # bit mask flags for partial JSON parsing. If the name hasn't been + # sent yet, don't allow sending + # an incomplete string since OpenAI only ever (as far as I have + # seen) allows sending the entire tool/ function name at once. + flags = Allow.ALL if self.current_tool_name_sent \ + else Allow.ALL & ~Allow.STR + + try: + parsable_arr = action + + # tool calls are generated in an object in inernlm2 + # it's not support parallel tool calls + try: + tool_call_arr: Dict = partial_json_parser.loads( + parsable_arr, flags) + except partial_json_parser.core.exceptions.MalformedJSON: + logger.debug('not enough tokens to parse into JSON yet') + return None + + # if the current tool name hasn't been sent, send if available + # - otherwise send nothing + if not self.current_tool_name_sent: + function_name = tool_call_arr.get("name") + if function_name: + self.current_tool_id = self.current_tool_id + 1 + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + type="function", + id=f"chatcmpl-tool-{random_uuid()}", + function=DeltaFunctionCall( + name=function_name).model_dump( + exclude_none=True)) + ]) + self.current_tool_name_sent = True + self.streamed_args_for_tool.append("") + else: + delta = None + # now we know we're on the same tool call and we're streaming + # arguments + else: + prev_arguments = self.get_argments( + self.prev_tool_call_arr[self.current_tool_id]) + cur_arguments = self.get_argments(tool_call_arr) + + # not arguments generated + if not cur_arguments and not prev_arguments: + delta = None + # will never happen + elif not cur_arguments and prev_arguments: + logger.error( + "INVARIANT - impossible to have arguments reset " + "mid-arguments") + delta = None + # first time to get parameters + elif cur_arguments and not prev_arguments: + cur_arguments_json = json.dumps(cur_arguments) + + arguments_delta = cur_arguments_json[:cur_arguments_json. + index(delta_text) + + len(delta_text)] + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=arguments_delta). + model_dump(exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += arguments_delta + # both prev and cur parameters, send the increase parameters + elif cur_arguments and prev_arguments: + cur_args_json = json.dumps(cur_arguments) + prev_args_json = json.dumps(prev_arguments) + + argument_diff = extract_intermediate_diff( + cur_args_json, prev_args_json) + + delta = DeltaMessage(tool_calls=[ + DeltaToolCall(index=self.current_tool_id, + function=DeltaFunctionCall( + arguments=argument_diff).model_dump( + exclude_none=True)) + ]) + self.streamed_args_for_tool[ + self.current_tool_id] += argument_diff + + # check to see if the name is defined and has been sent. if so, + # stream the name - otherwise keep waiting + # finish by setting old and returning None as base case + tool_call_arr["arguments"] = self.get_argments(tool_call_arr) + self.prev_tool_call_arr = [tool_call_arr] + return delta + except Exception as e: + logger.error("Error trying to handle streaming tool call: %s", e) + logger.debug( + "Skipping chunk as a result of tool streaming extraction " + "error") + return None + + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: + text = model_output + tools = request.tools + if '<|action_start|><|plugin|>' in text: + text, action = text.split('<|action_start|><|plugin|>') + action = action.split('<|action_end|>'.strip())[0] + action = action[action.find('{'):] + action_dict = json.loads(action) + name, parameters = action_dict['name'], json.dumps( + action_dict.get('parameters', action_dict.get('arguments', + {}))) + + if not tools or name not in [t.function.name for t in tools]: + ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) + + tool_calls = [ + ToolCall( + function=FunctionCall(name=name, arguments=parameters)) + ] + return ExtractedToolCallInformation( + tools_called=True, + tool_calls=tool_calls, + content=text if len(text) > 0 else None) + + return ExtractedToolCallInformation(tools_called=False, + tool_calls=[], + content=text) diff --git a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py index f98dca16674d5..3cf34bc4928a5 100644 --- a/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py @@ -7,12 +7,13 @@ from partial_json_parser.core.options import Allow from transformers import PreTrainedTokenizerBase -from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser) + ToolParser, ToolParserManager) from vllm.entrypoints.openai.tool_parsers.utils import find_common_prefix from vllm.logger import init_logger from vllm.utils import random_uuid @@ -41,6 +42,7 @@ def is_complete_json(input_str): return False +@ToolParserManager.register_module("llama3_json") class Llama3JsonToolParser(ToolParser): """ Tool call parser for Llama 3.1 models intended for use with the @@ -64,8 +66,9 @@ def __init__(self, tokenizer: PreTrainedTokenizerBase): add_special_tokens=False)[0] self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) - def extract_tool_calls(self, - model_output: str) -> ExtractedToolCallInformation: + def extract_tool_calls( + self, model_output: str, + request: ChatCompletionRequest) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. """ @@ -125,6 +128,7 @@ def extract_tool_calls_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: if not (current_text.startswith(self.bot_token) diff --git a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py index 4b0e1c91df97c..1db30797ac6fc 100644 --- a/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +++ b/vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py @@ -1,16 +1,20 @@ import json import re +from random import choices +from string import ascii_letters, digits from typing import Dict, List, Sequence, Union import partial_json_parser from partial_json_parser.core.options import Allow +from pydantic import Field -from vllm.entrypoints.openai.protocol import (DeltaFunctionCall, DeltaMessage, +from vllm.entrypoints.openai.protocol import (ChatCompletionRequest, + DeltaFunctionCall, DeltaMessage, DeltaToolCall, ExtractedToolCallInformation, FunctionCall, ToolCall) from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ( - ToolParser) + ToolParser, ToolParserManager) from vllm.entrypoints.openai.tool_parsers.utils import ( extract_intermediate_diff) from vllm.logger import init_logger @@ -19,7 +23,21 @@ logger = init_logger(__name__) +ALPHANUMERIC = ascii_letters + digits + +class MistralToolCall(ToolCall): + id: str = Field( + default_factory=lambda: MistralToolCall.generate_random_id()) + + @staticmethod + def generate_random_id(): + # Mistral Tool Call Ids must be alphanumeric with a maximum length of 9. + # https://github.com/mistralai/mistral-common/blob/21ee9f6cee3441e9bb1e6ed2d10173f90bd9b94b/src/mistral_common/protocol/instruct/validator.py#L299 + return "".join(choices(ALPHANUMERIC, k=9)) + + +@ToolParserManager.register_module("mistral") class MistralToolParser(ToolParser): """ Tool call parser for Mistral 7B Instruct v0.3, intended for use with the @@ -31,9 +49,7 @@ class MistralToolParser(ToolParser): def __init__(self, tokenizer: AnyTokenizer): super().__init__(tokenizer) - if isinstance(self.model_tokenizer, MistralTokenizer): - self.model_tokenizer = self.model_tokenizer.tokenizer - else: + if not isinstance(self.model_tokenizer, MistralTokenizer): logger.info("Non-Mistral tokenizer detected when using a Mistral " "model...") @@ -45,11 +61,14 @@ def __init__(self, tokenizer: AnyTokenizer): self.streamed_args_for_tool: List[str] = [ ] # map what has been streamed for each tool so far to a list self.bot_token = "[TOOL_CALLS]" - self.bot_token_id = self.model_tokenizer.vocab[self.bot_token] + self.bot_token_id = self.model_tokenizer.get_vocab()[self.bot_token] self.tool_call_regex = re.compile(r"\[{.*?}\]", re.DOTALL) - def extract_tool_calls(self, - model_output: str) -> ExtractedToolCallInformation: + def extract_tool_calls( + self, + model_output: str, + request: ChatCompletionRequest, + ) -> ExtractedToolCallInformation: """ Extract the tool calls from a complete model response. Requires find-and-replacing single quotes with double quotes for JSON parsing, @@ -71,8 +90,8 @@ def extract_tool_calls(self, # load the JSON, and then use it to build the Function and # Tool Call function_call_arr = json.loads(raw_tool_call) - tool_calls: List[ToolCall] = [ - ToolCall( + tool_calls: List[MistralToolCall] = [ + MistralToolCall( type="function", function=FunctionCall( name=raw_function_call["name"], @@ -103,6 +122,7 @@ def extract_tool_calls_streaming( previous_token_ids: Sequence[int], current_token_ids: Sequence[int], delta_token_ids: Sequence[int], + request: ChatCompletionRequest, ) -> Union[DeltaMessage, None]: # if the tool call token is not in the tokens generated so far, append diff --git a/vllm/envs.py b/vllm/envs.py index 7cbffc83a6251..0f46ac4f61fdf 100644 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -35,6 +35,7 @@ VLLM_PP_LAYER_PARTITION: Optional[str] = None VLLM_CPU_KVCACHE_SPACE: int = 0 VLLM_CPU_OMP_THREADS_BIND: str = "" + VLLM_OPENVINO_DEVICE: str = "CPU" VLLM_OPENVINO_KVCACHE_SPACE: int = 0 VLLM_OPENVINO_CPU_KV_CACHE_PRECISION: Optional[str] = None VLLM_OPENVINO_ENABLE_QUANTIZED_WEIGHTS: bool = False @@ -302,6 +303,11 @@ def get_default_config_root(): "VLLM_CPU_OMP_THREADS_BIND": lambda: os.getenv("VLLM_CPU_OMP_THREADS_BIND", "all"), + # OpenVINO device selection + # default is CPU + "VLLM_OPENVINO_DEVICE": + lambda: os.getenv("VLLM_OPENVINO_DEVICE", "CPU").upper(), + # OpenVINO key-value cache space # default is 4GB "VLLM_OPENVINO_KVCACHE_SPACE": diff --git a/vllm/executor/habana_executor.py b/vllm/executor/hpu_executor.py similarity index 97% rename from vllm/executor/habana_executor.py rename to vllm/executor/hpu_executor.py index e6d0fbc0d431d..cc5609ebe5c8e 100644 --- a/vllm/executor/habana_executor.py +++ b/vllm/executor/hpu_executor.py @@ -21,7 +21,7 @@ logger = init_logger(__name__) -class HabanaExecutor(ExecutorBase): +class HPUExecutor(ExecutorBase): uses_ray: bool = False @@ -57,8 +57,8 @@ def _create_worker(self, rank: int = 0, distributed_init_method: Optional[str] = None): wrapper = WorkerWrapperBase( - worker_module_name="vllm.worker.habana_worker", - worker_class_name="HabanaWorker", + worker_module_name="vllm.worker.hpu_worker", + worker_class_name="HPUWorker", ) wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank, distributed_init_method)) @@ -202,7 +202,7 @@ def shutdown(self) -> None: self.driver_worker.shutdown_inc() -class HabanaExecutorAsync(HabanaExecutor, ExecutorAsyncBase): +class HPUExecutorAsync(HPUExecutor, ExecutorAsyncBase): async def execute_model_async( self, diff --git a/vllm/executor/openvino_executor.py b/vllm/executor/openvino_executor.py index 78606e223aa7b..4a39839a03199 100644 --- a/vllm/executor/openvino_executor.py +++ b/vllm/executor/openvino_executor.py @@ -17,6 +17,14 @@ logger = init_logger(__name__) +def is_openvino_cpu() -> bool: + return "CPU" in envs.VLLM_OPENVINO_DEVICE + + +def is_openvino_gpu() -> bool: + return "GPU" in envs.VLLM_OPENVINO_DEVICE + + class OpenVINOExecutor(ExecutorBase): uses_ray: bool = False @@ -24,8 +32,13 @@ class OpenVINOExecutor(ExecutorBase): def _init_executor(self) -> None: assert self.device_config.device_type == "openvino" assert self.lora_config is None, "OpenVINO backend doesn't support LoRA" + assert is_openvino_cpu() or is_openvino_gpu(), \ + "OpenVINO backend supports only CPU and GPU devices" + + self.ov_core = ov.Core() self.model_config = _verify_and_get_model_config(self.model_config) - self.cache_config = _verify_and_get_cache_config(self.cache_config) + self.cache_config = _verify_and_get_cache_config( + self.ov_core, self.cache_config) # Instantiate the worker and load the model to CPU. self._init_worker() @@ -40,6 +53,7 @@ def _init_worker(self): distributed_init_method = get_distributed_init_method( get_ip(), get_open_port()) self.driver_worker = OpenVINOWorker( + ov_core=self.ov_core, model_config=self.model_config, parallel_config=self.parallel_config, scheduler_config=self.scheduler_config, @@ -68,10 +82,13 @@ def initialize_cache(self, num_gpu_blocks: int, # NOTE: We log here to avoid multiple logs when number of workers is # greater than one. We could log in the engine, but not all executors # have GPUs. - # NOTE: `cpu block` for OpenVINO backend is located on CPU memory but is - # referred as `gpu block`. Because we want to reuse the existing block - # management procedure. - logger.info("# CPU blocks: %d", num_gpu_blocks) + # NOTE: In case of a CPU device, `cpu block` for OpenVINO backend + # is located on CPU memory but is referred as `gpu block`. + # Because we want to reuse the existing block management procedure. + device_blocks = num_gpu_blocks + swap_blocks = num_cpu_blocks + logger.info("OpenVINO %s: # device blocks: %d; # swap blocks: %d", + envs.VLLM_OPENVINO_DEVICE, device_blocks, swap_blocks) self.driver_worker.initialize_cache(num_gpu_blocks, num_cpu_blocks) def execute_model( @@ -143,29 +160,45 @@ def _verify_and_get_model_config(config: ModelConfig) -> ModelConfig: return config -def _verify_and_get_cache_config(config: CacheConfig) -> CacheConfig: +def _verify_and_get_cache_config(ov_core: ov.Core, + config: CacheConfig) -> CacheConfig: if envs.VLLM_OPENVINO_CPU_KV_CACHE_PRECISION == "u8": - logger.info("KV cache type is overried to u8 via " - "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.") - config.cache_dtype = ov.Type.u8 + if not is_openvino_cpu(): + logger.info("VLLM_OPENVINO_CPU_KV_CACHE_PRECISION is" + "ignored for GPU, f16 data type will be used.") + config.cache_dtype = ov.Type.f16 + else: + logger.info("KV cache type is overridden to u8 via " + "VLLM_OPENVINO_CPU_KV_CACHE_PRECISION env var.") + config.cache_dtype = ov.Type.u8 else: - core = ov.Core() - inference_precision = core.get_property("CPU", - hints.inference_precision) - if inference_precision == ov.Type.bf16: - config.cache_dtype = ov.Type.bf16 + if is_openvino_cpu(): + ov_device = envs.VLLM_OPENVINO_DEVICE + inference_precision = ov_core.get_property( + ov_device, hints.inference_precision) + if inference_precision == ov.Type.bf16: + config.cache_dtype = ov.Type.bf16 + else: + config.cache_dtype = ov.Type.f16 else: config.cache_dtype = ov.Type.f16 - if config.block_size != 32: - logger.info( - f"OpenVINO optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501 - ) - config.block_size = 32 + if is_openvino_cpu(): + if config.block_size != 32: + logger.info( + f"OpenVINO CPU optimal block size is 32, overriding currently set {config.block_size}" # noqa: G004, E501 + ) + config.block_size = 32 + else: + if config.block_size != 16: + logger.info( + f"OpenVINO GPU optimal block size is 16, overriding currently set {config.block_size}" # noqa: G004, E501 + ) + config.block_size = 16 kv_cache_space = envs.VLLM_OPENVINO_KVCACHE_SPACE if kv_cache_space >= 0: - if kv_cache_space == 0: + if kv_cache_space == 0 and is_openvino_cpu(): config.openvino_kvcache_space_bytes = 4 * GiB_bytes # type: ignore logger.warning( "Environment variable VLLM_OPENVINO_KVCACHE_SPACE (GB) " diff --git a/vllm/executor/ray_habana_executor.py b/vllm/executor/ray_hpu_executor.py similarity index 99% rename from vllm/executor/ray_habana_executor.py rename to vllm/executor/ray_hpu_executor.py index 645bceb1af446..343fa43b0eda1 100644 --- a/vllm/executor/ray_habana_executor.py +++ b/vllm/executor/ray_hpu_executor.py @@ -29,7 +29,7 @@ logger = init_logger(__name__) -class RayHabanaExecutor(DistributedGPUExecutor): +class RayHPUExecutor(DistributedGPUExecutor): uses_ray: bool = True @@ -90,8 +90,8 @@ def _get_worker_module_and_class( raise NotImplementedError( "Speculative decoding is not implemented for HPU") else: - worker_module_name = "vllm.worker.habana_worker" - worker_class_name = "HabanaWorker" + worker_module_name = "vllm.worker.hpu_worker" + worker_class_name = "HPUWorker" return (worker_module_name, worker_class_name, worker_class_fn) def _get_worker_wrapper_args(self) -> Dict[str, Any]: @@ -479,7 +479,7 @@ def __del__(self): self.shutdown() -class RayHabanaExecutorAsync(RayHabanaExecutor, DistributedGPUExecutorAsync): +class RayHPUExecutorAsync(RayHPUExecutor, DistributedGPUExecutorAsync): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) diff --git a/vllm/forward_context.py b/vllm/forward_context.py new file mode 100644 index 0000000000000..777747505e14a --- /dev/null +++ b/vllm/forward_context.py @@ -0,0 +1,22 @@ +from contextlib import contextmanager +from typing import Any + +_forward_context: Any = None + + +def get_forward_context() -> Any: + """Get the current forward context.""" + return _forward_context + + +@contextmanager +def set_forward_context(context: Any): + """A context manager that stores the current forward context, + can be attention metadata, etc.""" + global _forward_context + prev_context = _forward_context + _forward_context = context + try: + yield + finally: + _forward_context = prev_context diff --git a/vllm/lora/models.py b/vllm/lora/models.py index 6b37ff573b5ec..2abd59d2e025c 100644 --- a/vllm/lora/models.py +++ b/vllm/lora/models.py @@ -24,8 +24,7 @@ from vllm.lora.punica import PunicaWrapper from vllm.lora.utils import (from_layer, from_layer_logits_processor, parse_fine_tuned_lora_name, replace_submodule) -from vllm.model_executor.models.interfaces import (SupportsLoRA, - supports_multimodal) +from vllm.model_executor.models import SupportsLoRA, supports_multimodal from vllm.model_executor.models.module_mapping import MultiModelKeys from vllm.model_executor.models.utils import PPMissingLayer from vllm.platforms import current_platform diff --git a/vllm/model_executor/layers/fused_moe/fused_moe.py b/vllm/model_executor/layers/fused_moe/fused_moe.py index cf17f1e240e47..52f748675f752 100644 --- a/vllm/model_executor/layers/fused_moe/fused_moe.py +++ b/vllm/model_executor/layers/fused_moe/fused_moe.py @@ -324,6 +324,9 @@ def get_moe_configs(E: int, N: int, # If no optimized configuration is available, we will use the default # configuration + logger.warning( + ("Using default MoE config. Performance might be sub-optimal! " + "Config file not found at %s"), config_file_path) return None diff --git a/vllm/model_executor/layers/sampler.py b/vllm/model_executor/layers/sampler.py index 2b3ec3ae24e34..e336232a713ef 100755 --- a/vllm/model_executor/layers/sampler.py +++ b/vllm/model_executor/layers/sampler.py @@ -1020,7 +1020,7 @@ def get_logprobs( sampling_metadata: SamplingMetadata, sample_results: SampleResultType, ) -> Tuple[List[Optional[PromptLogprobs]], List[SampleLogprobs]]: - """Return sample lobprobs and prompt logprobs. + """Return sample logprobs and prompt logprobs. The logic consists of 3 parts. - Select indices to compute logprob from, ranks of token ids, and diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 7485a8de57992..876da67c02436 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -41,9 +41,8 @@ get_gguf_extra_tensor_names, get_quant_config, gguf_quant_weights_iterator, initialize_dummy_weights, np_cache_weights_iterator, pt_weights_iterator, safetensors_weights_iterator) -from vllm.model_executor.models.interfaces import (has_inner_state, - supports_lora, - supports_multimodal) +from vllm.model_executor.models import (has_inner_state, supports_lora, + supports_multimodal) from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.utils import is_fake_hpu, is_pin_memory_available diff --git a/vllm/model_executor/model_loader/openvino.py b/vllm/model_executor/model_loader/openvino.py index 3c1f6fa769894..88b7ac46e5541 100644 --- a/vllm/model_executor/model_loader/openvino.py +++ b/vllm/model_executor/model_loader/openvino.py @@ -12,6 +12,7 @@ import vllm.envs as envs from vllm.attention.backends.openvino import OpenVINOAttentionMetadata from vllm.config import DeviceConfig, ModelConfig +from vllm.executor.openvino_executor import is_openvino_cpu from vllm.logger import init_logger from vllm.model_executor.layers.logits_processor import (LogitsProcessor, _prune_hidden_states) @@ -51,25 +52,15 @@ def _modify_cache_parameters(model: ov.Model, kv_cache_dtype: ov.Type, shape = parameter.get_partial_shape() # use real block size if available, just a placeholder # to provide the expected rank - x_size = 1 num_blocks = ov.Dimension() block_size = ov.Dimension() head_size = ov.Dimension() - # TODO: Negotiate required layout with plugins (CPU is ~OK, GPU is TBD), - # pass more parameters to this function to set more static dimensions if input_name.startswith("key_cache."): cpu_shape = [num_blocks, shape[1], block_size, head_size] - gpu_shape = [ - num_blocks, - shape[1], - shape[2].get_length() // - x_size if shape[2].is_static else ov.Dimension(), - block_size, - x_size, - ] + gpu_shape = [num_blocks, shape[1], shape[2], block_size] elif input_name.startswith("value_cache."): cpu_shape = [num_blocks, shape[1], block_size, head_size] - gpu_shape = [num_blocks, shape[1], shape[2], block_size] + gpu_shape = [num_blocks, shape[1], block_size, shape[2]] else: continue parameter.set_partial_shape( @@ -108,6 +99,7 @@ class OpenVINOCasualLM(nn.Module): def __init__( self, + ov_core: ov.Core, model_config: ModelConfig, device_config: DeviceConfig, kv_cache_dtype: ov.Type, @@ -141,12 +133,12 @@ def __init__( trust_remote_code=model_config.trust_remote_code, ) + ov_device = envs.VLLM_OPENVINO_DEVICE paged_attention_transformation(pt_model.model) _modify_cache_parameters(pt_model.model, kv_cache_dtype, - device_config.device.type == "cpu") + is_openvino_cpu()) - core = ov.Core() - ov_compiled = core.compile_model(pt_model.model, "CPU") + ov_compiled = ov_core.compile_model(pt_model.model, ov_device) self.ov_request = ov_compiled.create_infer_request() def forward( @@ -199,6 +191,7 @@ def get_model( **kwargs, ) -> torch.nn.Module: lora_config = kwargs.get("lora_config", None) + ov_core = kwargs.get("ov_core") if lora_config: raise ValueError( "OpenVINO modeling does not support LoRA, " @@ -206,4 +199,5 @@ def get_model( "be added in the future. If this is important to you, " "please open an issue on github.") - return OpenVINOCasualLM(model_config, device_config, kv_cache_dtype) + return OpenVINOCasualLM(ov_core, model_config, device_config, + kv_cache_dtype) diff --git a/vllm/model_executor/models/__init__.py b/vllm/model_executor/models/__init__.py index 682a2e71a1dbf..51054a147a06f 100644 --- a/vllm/model_executor/models/__init__.py +++ b/vllm/model_executor/models/__init__.py @@ -1,220 +1,16 @@ -import functools -import importlib -from typing import Dict, List, Optional, Tuple, Type - -import torch.nn as nn - -from vllm.logger import init_logger -from vllm.utils import is_hip - -logger = init_logger(__name__) - -_GENERATION_MODELS = { - "AquilaModel": ("llama", "LlamaForCausalLM"), - "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 - "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b - "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b - "BloomForCausalLM": ("bloom", "BloomForCausalLM"), - "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), - "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), - "CohereForCausalLM": ("commandr", "CohereForCausalLM"), - "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), - "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), - "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), - "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), - "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), - "FalconForCausalLM": ("falcon", "FalconForCausalLM"), - "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), - "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), - "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), - "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), - "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), - "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), - "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), - "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), - "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), - "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), - # For decapoda-research/llama-* - "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), - "MistralForCausalLM": ("llama", "LlamaForCausalLM"), - "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), - "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), - # transformers's mpt class has lower case - "MptForCausalLM": ("mpt", "MPTForCausalLM"), - "MPTForCausalLM": ("mpt", "MPTForCausalLM"), - "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), - "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), - "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), - "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), - "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), - "OPTForCausalLM": ("opt", "OPTForCausalLM"), - "OrionForCausalLM": ("orion", "OrionForCausalLM"), - "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), - "PhiForCausalLM": ("phi", "PhiForCausalLM"), - "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), - "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), - "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), - "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), - "Qwen2VLForConditionalGeneration": - ("qwen2_vl", "Qwen2VLForConditionalGeneration"), - "RWForCausalLM": ("falcon", "FalconForCausalLM"), - "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), - "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), - "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), - "SolarForCausalLM": ("solar", "SolarForCausalLM"), - "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), - "XverseForCausalLM": ("xverse", "XverseForCausalLM"), - "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), - "MedusaModel": ("medusa", "Medusa"), - "EAGLEModel": ("eagle", "EAGLE"), - "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), - "JambaForCausalLM": ("jamba", "JambaForCausalLM"), - "GraniteForCausalLM": ("granite", "GraniteForCausalLM") -} - -_EMBEDDING_MODELS = { - "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), - "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), -} - -_MULTIMODAL_MODELS = { - "Blip2ForConditionalGeneration": - ("blip2", "Blip2ForConditionalGeneration"), - "ChameleonForConditionalGeneration": - ("chameleon", "ChameleonForConditionalGeneration"), - "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), - "InternVLChatModel": ("internvl", "InternVLChatModel"), - "LlavaForConditionalGeneration": ("llava", - "LlavaForConditionalGeneration"), - "LlavaNextForConditionalGeneration": ("llava_next", - "LlavaNextForConditionalGeneration"), - "LlavaNextVideoForConditionalGeneration": - ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), - "LlavaOnevisionForConditionalGeneration": - ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), - "MiniCPMV": ("minicpmv", "MiniCPMV"), - "PaliGemmaForConditionalGeneration": ("paligemma", - "PaliGemmaForConditionalGeneration"), - "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), - "PixtralForConditionalGeneration": ("pixtral", - "PixtralForConditionalGeneration"), - "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), - "Qwen2VLForConditionalGeneration": ("qwen2_vl", - "Qwen2VLForConditionalGeneration"), - "UltravoxModel": ("ultravox", "UltravoxModel"), - "MllamaForConditionalGeneration": ("mllama", - "MllamaForConditionalGeneration"), -} -_CONDITIONAL_GENERATION_MODELS = { - "BartModel": ("bart", "BartForConditionalGeneration"), - "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), -} - -_MODELS = { - **_GENERATION_MODELS, - **_EMBEDDING_MODELS, - **_MULTIMODAL_MODELS, - **_CONDITIONAL_GENERATION_MODELS, -} - -# Architecture -> type. -# out of tree models -_OOT_MODELS: Dict[str, Type[nn.Module]] = {} - -# Models not supported by ROCm. -_ROCM_UNSUPPORTED_MODELS: List[str] = [] - -# Models partially supported by ROCm. -# Architecture -> Reason. -_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " - "Triton flash attention. For half-precision SWA support, " - "please use CK flash attention by setting " - "`VLLM_USE_TRITON_FLASH_ATTN=0`") -_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { - "Qwen2ForCausalLM": - _ROCM_SWA_REASON, - "MistralForCausalLM": - _ROCM_SWA_REASON, - "MixtralForCausalLM": - _ROCM_SWA_REASON, - "PaliGemmaForConditionalGeneration": - ("ROCm flash attention does not yet " - "fully support 32-bit precision on PaliGemma"), - "Phi3VForCausalLM": - ("ROCm Triton flash attention may run into compilation errors due to " - "excessive use of shared memory. If this happens, disable Triton FA " - "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") -} - - -class ModelRegistry: - - @staticmethod - @functools.lru_cache(maxsize=128) - def _get_model(model_arch: str): - module_name, model_cls_name = _MODELS[model_arch] - module = importlib.import_module( - f"vllm.model_executor.models.{module_name}") - return getattr(module, model_cls_name, None) - - @staticmethod - def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: - if model_arch in _OOT_MODELS: - return _OOT_MODELS[model_arch] - if model_arch not in _MODELS: - return None - if is_hip(): - if model_arch in _ROCM_UNSUPPORTED_MODELS: - raise ValueError( - f"Model architecture {model_arch} is not supported by " - "ROCm for now.") - if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: - logger.warning( - "Model architecture %s is partially supported by ROCm: %s", - model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) - - return ModelRegistry._get_model(model_arch) - - @staticmethod - def resolve_model_cls( - architectures: List[str]) -> Tuple[Type[nn.Module], str]: - for arch in architectures: - model_cls = ModelRegistry._try_load_model_cls(arch) - if model_cls is not None: - return (model_cls, arch) - - raise ValueError( - f"Model architectures {architectures} are not supported for now. " - f"Supported architectures: {ModelRegistry.get_supported_archs()}") - - @staticmethod - def get_supported_archs() -> List[str]: - return list(_MODELS.keys()) + list(_OOT_MODELS.keys()) - - @staticmethod - def register_model(model_arch: str, model_cls: Type[nn.Module]): - if model_arch in _MODELS: - logger.warning( - "Model architecture %s is already registered, and will be " - "overwritten by the new model class %s.", model_arch, - model_cls.__name__) - global _OOT_MODELS - _OOT_MODELS[model_arch] = model_cls - - @staticmethod - def is_embedding_model(model_arch: str) -> bool: - return model_arch in _EMBEDDING_MODELS - - @staticmethod - def is_multimodal_model(model_arch: str) -> bool: - - # TODO: find a way to avoid initializing CUDA prematurely to - # use `supports_multimodal` to determine if a model is multimodal - # model_cls = ModelRegistry._try_load_model_cls(model_arch) - # from vllm.model_executor.models.interfaces import supports_multimodal - return model_arch in _MULTIMODAL_MODELS - +from .interfaces import (HasInnerState, SupportsLoRA, SupportsMultiModal, + SupportsPP, has_inner_state, supports_lora, + supports_multimodal, supports_pp) +from .registry import ModelRegistry __all__ = [ "ModelRegistry", + "HasInnerState", + "has_inner_state", + "SupportsLoRA", + "supports_lora", + "SupportsMultiModal", + "supports_multimodal", + "SupportsPP", + "supports_pp", ] diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 3ae9003dfa3b7..2fb9dbe7c261f 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -1,12 +1,12 @@ """Inference-only Snowflake Arctic model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.logger import init_logger @@ -18,8 +18,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.deepspeedfp import ( DeepSpeedFPConfig, DeepSpeedFPParameter) from vllm.model_executor.layers.rotary_embedding import get_rope @@ -32,6 +31,10 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.arctic import ArcticConfig +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + logger = init_logger(__name__) @@ -362,6 +365,7 @@ def __init__( config: ArcticConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -370,15 +374,16 @@ def __init__( self.vocab_size, config.hidden_size, org_num_embeddings=self.vocab_size) - self.layers = nn.ModuleList([ - ArcticDecoderLayer(config, - layer_idx, - cache_config, - quant_config=quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: ArcticDecoderLayer(config, int( + prefix.split(".")[-1]), cache_config, quant_config), + prefix=f"{prefix}.layers") self._attn_implementation = config._attn_implementation self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -386,17 +391,25 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states = layer(positions, hidden_states, kv_caches[i], + hidden_states = layer(positions, hidden_states, + kv_caches[i - self.start_layer], attn_metadata) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm(hidden_states) return hidden_states -class ArcticForCausalLM(nn.Module): +class ArcticForCausalLM(nn.Module, SupportsPP): def __init__(self, config: ArcticConfig, @@ -420,6 +433,8 @@ def __init__(self, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -428,9 +443,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -501,6 +516,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -510,6 +527,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -520,6 +539,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -530,6 +551,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index bdd76b11384c2..54ed548ba8bc7 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -19,7 +19,7 @@ # limitations under the License. """Inference-only BaiChuan model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -27,7 +27,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -35,8 +35,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -45,7 +44,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: @@ -255,7 +256,8 @@ def __init__(self, config: PretrainedConfig, position_embedding: str, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -265,12 +267,16 @@ def __init__(self, config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - BaiChuanDecoderLayer(config, position_embedding, cache_config, - quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: BaiChuanDecoderLayer(config, position_embedding, + cache_config, quant_config), + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -278,23 +284,34 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual, + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA): +class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "W_pack": ["W_pack"], "gate_up_proj": [ @@ -335,6 +352,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -343,9 +362,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -394,6 +413,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -402,6 +423,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) @@ -413,7 +436,7 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, @@ -431,7 +454,7 @@ class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index b28d7699afa01..ca0cbef5cbf48 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -1,3 +1,4 @@ +from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -11,7 +12,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -19,7 +20,7 @@ from .blip import (BlipVisionModel, dummy_image_for_blip, get_max_blip_image_tokens) -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .utils import (group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) @@ -475,7 +476,7 @@ def input_processor_for_blip2(ctx: InputContext, llm_inputs: LLMInputs): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_blip2_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_blip2) @INPUT_REGISTRY.register_input_processor(input_processor_for_blip2) -class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal): +class Blip2ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, config: Blip2Config, @@ -508,6 +509,16 @@ def __init__(self, self.language_model = init_vllm_registered_model( config.text_config, cache_config, quant_config) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -600,7 +611,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[SamplerOutput, IntermediateTensors]: """Run forward pass for BLIP-2. One key thing to understand is the `input_ids` already accounts for the @@ -631,26 +642,32 @@ def forward( See also: :class:`Blip2ImageInputs` """ - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - BLIP2_IMAGE_TOKEN_ID) - + if intermediate_tensors is not None: input_ids = None - else: inputs_embeds = None - - hidden_states = self.language_model.model(input_ids, - positions, - kv_caches, - attn_metadata, - inputs_embeds=inputs_embeds) + else: + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + BLIP2_IMAGE_TOKEN_ID) + + input_ids = None + else: + inputs_embeds = None + + hidden_states = self.language_model.model( + input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors=intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 831b3f20457a9..b2c9e221690b3 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -17,7 +17,7 @@ # limitations under the License. """Inference-only BLOOM model compatible with HuggingFace weights.""" import math -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -25,15 +25,14 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -41,6 +40,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) @@ -222,6 +225,7 @@ def __init__( config: BloomConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.embed_dim = config.hidden_size @@ -235,13 +239,16 @@ def __init__( self.embed_dim, eps=config.layer_norm_epsilon) # Transformer blocks - self.h = nn.ModuleList([ - BloomBlock(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: BloomBlock(config, cache_config, quant_config), + prefix=f"{prefix}.h") # Final Layer Norm self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -249,22 +256,29 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.word_embeddings(input_ids) - hidden_states = self.word_embeddings_layernorm(hidden_states) - for i in range(len(self.h)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.word_embeddings(input_ids) + hidden_states = self.word_embeddings_layernorm(hidden_states) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.h[i] hidden_states = layer( position_ids, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states -class BloomForCausalLM(nn.Module): +class BloomForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -284,6 +298,8 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -292,9 +308,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -321,6 +337,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue if not name.startswith("transformer."): name = "transformer." + name + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] if "query_key_value" in name: diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 973e47f5f0ccd..03c7419f6f6af 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -1,6 +1,6 @@ from functools import cached_property from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, - Tuple, TypedDict) + Tuple, TypedDict, Union) import torch import torch.nn.functional as F @@ -10,7 +10,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm @@ -33,7 +33,9 @@ from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import print_warning_once -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) # These configs are not part of the model config but the preprocessor # and processor files, so we hardcode them in the model file for now. @@ -822,6 +824,7 @@ def __init__( config: ChameleonConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -835,14 +838,20 @@ def __init__( config.vocabulary_map) decoder_layer = ChameleonDecoderLayer if not self.config.swin_norm \ else ChameleonSwinDecoderLayer - self.layers = nn.ModuleList([ - decoder_layer(config=config, - cache_config=cache_config, - quant_config=quant_config) - for _ in range(config.num_hidden_layers) - ]) + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: decoder_layer(config=config, + cache_config=cache_config, + quant_config=quant_config), + prefix=f"{prefix}.layers", + ) + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.vqmodel = ChameleonVQVAE(config.vq_config) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -865,22 +874,33 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None - for i in range(len(self.layers)): + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states @@ -889,7 +909,8 @@ def forward( @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_chameleon_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_chameleon) @INPUT_REGISTRY.register_input_processor(input_processor_for_chameleon) -class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal): +class ChameleonForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__( self, @@ -914,6 +935,8 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, logit_scale) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: @@ -956,22 +979,26 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs, - ) -> torch.Tensor: - - image_input = self._parse_and_validate_image_input(**kwargs) + ) -> Union[torch.Tensor, IntermediateTensors]: - if image_input is not None: - assert self.model.vqmodel is not None - image_tokens = self.model.get_image_tokens(image_input["data"].to( - self.config.torch_dtype)) - image_token_id = self.model.vocabulary_mapping.image_token_id - special_image_mask = input_ids == image_token_id - image_tokens = image_tokens.to(input_ids.device, input_ids.dtype) - input_ids = input_ids.masked_scatter(special_image_mask, - image_tokens) + if intermediate_tensors is not None: + input_ids = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + assert self.model.vqmodel is not None + image_tokens = self.model.get_image_tokens( + image_input["data"].to(self.config.torch_dtype)) + image_token_id = self.model.vocabulary_mapping.image_token_id + special_image_mask = input_ids == image_token_id + image_tokens = image_tokens.to(input_ids.device, + input_ids.dtype) + input_ids = input_ids.masked_scatter(special_image_mask, + image_tokens) hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -1039,6 +1066,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -1060,11 +1089,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue else: name = remapped_kv_scale_name + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) if use_default_weight_loading and name in params_dict: + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 35f1ed5ef5d33..879795c0d5955 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -2,7 +2,7 @@ # Adapted from # https://github.com/THUDM/ChatGLM2-6B """Inference-only ChatGLM model compatible with THUDM weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -10,15 +10,14 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -28,14 +27,16 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import ChatGLMConfig -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class GLMAttention(nn.Module): def __init__( self, - config, + config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -126,7 +127,7 @@ class GLMMLP(nn.Module): def __init__( self, - config, + config: ChatGLMConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -169,7 +170,7 @@ class GLMBlock(nn.Module): def __init__( self, - config, + config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -240,9 +241,10 @@ class GLMTransformer(nn.Module): def __init__( self, - config, + config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.post_layer_norm = config.post_layer_norm @@ -251,10 +253,11 @@ def __init__( self.num_layers = config.num_layers # Transformer layers. - self.layers = nn.ModuleList([ - GLMBlock(config, cache_config, quant_config) - for i in range(self.num_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + self.num_layers, + lambda prefix: GLMBlock(config, cache_config, quant_config), + prefix=f"{prefix}.layers", + ) if self.post_layer_norm: layer_norm_func = RMSNorm if config.rmsnorm else LayerNorm @@ -269,16 +272,16 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, ) -> torch.Tensor: - for i in range(self.num_layers): + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states = layer( hidden_states=hidden_states, position_ids=position_ids, - kv_cache=kv_caches[i], + kv_cache=kv_caches[i - self.start_layer], attn_metadata=attn_metadata, ) # Final layer norm. - if self.post_layer_norm: + if get_pp_group().is_last_rank and self.post_layer_norm: hidden_states = self.final_layernorm(hidden_states) return hidden_states @@ -288,7 +291,7 @@ class ChatGLMModel(nn.Module): def __init__( self, - config, + config: ChatGLMConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -305,6 +308,9 @@ def __init__( self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size, quant_config=quant_config) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -312,8 +318,12 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - inputs_embeds = self.embedding(input_ids) + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + inputs_embeds = self.embedding(input_ids) + else: + inputs_embeds = intermediate_tensors["hidden_states"] # Run encoder. hidden_states = self.encoder( @@ -322,10 +332,13 @@ def forward( kv_caches=kv_caches, attn_metadata=attn_metadata, ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) return hidden_states -class ChatGLMForCausalLM(nn.Module, SupportsLoRA): +class ChatGLMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "query_key_value": ["query_key_value"], "dense_h_to_4h": ["dense_h_to_4h"] @@ -362,6 +375,8 @@ def __init__( self.lm_head = self.transformer.output_layer self.logits_processor = LogitsProcessor(config.padded_vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -370,9 +385,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -402,6 +417,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 649dc798d22dc..a0b8ff3a85c98 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -20,7 +20,7 @@ # This file is based on the LLama model definition file in transformers """PyTorch Cohere model.""" -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, List, Optional, Set, Tuple, Union import torch import torch.utils.checkpoint @@ -29,14 +29,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -47,7 +46,9 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) @torch.compile @@ -82,7 +83,7 @@ class CohereMLP(nn.Module): def __init__( self, - config, + config: CohereConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -256,6 +257,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -265,12 +267,16 @@ def __init__( self.org_vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([ - CohereDecoderLayer(config, cache_config, quant_config=quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: CohereDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") self.norm = LayerNorm(param_shape=(config.hidden_size), eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -278,23 +284,34 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class CohereForCausalLM(nn.Module, SupportsLoRA): +class CohereForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -337,6 +354,8 @@ def __init__( quant_config, lora_config=lora_config) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) @torch.no_grad() def forward( @@ -346,9 +365,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -393,6 +412,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -405,6 +426,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index 397a46a486f72..b0b07e9c03a9d 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -1,20 +1,19 @@ # coding=utf-8 -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch import torch.nn as nn from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -24,6 +23,10 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.dbrx import DbrxConfig +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class DbrxRouter(nn.Module): """A Router implementation for DBRX that returns logits for each expert @@ -296,22 +299,27 @@ def __init__( config: DbrxConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.wte = VocabParallelEmbedding( config.vocab_size, config.d_model, ) - self.blocks = nn.ModuleList([ - DbrxBlock(config, cache_config, quant_config) - for _ in range(config.n_layers) - ]) + self.start_layer, self.end_layer, self.blocks = make_layers( + config.n_layers, + lambda prefix: DbrxBlock(config, cache_config, quant_config), + prefix=f"{prefix}.blocks", + ) self.norm_f = nn.LayerNorm(config.d_model, eps=1e-5) for module in self.modules(): if hasattr(module, "bias") and isinstance(module.bias, nn.Parameter): # Remove the bias term in Linear and LayerNorm. module.register_parameter("bias", None) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.d_model)) def forward( self, @@ -319,21 +327,28 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.wte(input_ids) - for i in range(len(self.blocks)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.wte(input_ids) + else: + assert intermediate_tensors + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): block = self.blocks[i] hidden_states = block( position_ids, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm_f(hidden_states) return hidden_states -class DbrxForCausalLM(nn.Module): +class DbrxForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -359,6 +374,8 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -367,9 +384,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -401,11 +418,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, weight_name) break else: + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py index 65b409a2a15a0..7ed2b96e65c49 100644 --- a/vllm/model_executor/models/decilm.py +++ b/vllm/model_executor/models/decilm.py @@ -29,11 +29,12 @@ from transformers import LlamaConfig from vllm.config import CacheConfig, LoRAConfig -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaForCausalLM +from .utils import is_pp_missing_parameter + class DeciLMForCausalLM(LlamaForCausalLM): """ @@ -91,6 +92,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -99,6 +102,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 61cc917ab6207..5b4db8f258711 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Deepseek model.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -29,7 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul @@ -40,8 +40,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -50,6 +49,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class DeepseekMLP(nn.Module): @@ -329,6 +332,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -338,14 +342,17 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - DeepseekDecoderLayer(config, - layer_idx, - cache_config, - quant_config=quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: DeepseekDecoderLayer(config, + int(prefix.split(".")[-1]), + cache_config, + quant_config=quant_config), + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -353,19 +360,29 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class DeepseekForCausalLM(nn.Module): +class DeepseekForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -384,6 +401,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -392,9 +411,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -439,6 +458,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if (("mlp.experts." in name or "mlp.shared_experts." in name) and name not in params_dict): continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -451,6 +472,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if (("mlp.experts." in name or "mlp.shared_experts." in name) and name not in params_dict): continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 8cbd9435ec7ca..702be7b7f5ed9 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only DeepseekV2 model.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -40,8 +40,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -50,7 +49,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .interfaces import SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class DeepseekV2MLP(nn.Module): @@ -439,6 +440,9 @@ def __init__( self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -447,7 +451,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: hidden_states = self.embed_tokens(input_ids) residual = None @@ -472,7 +476,7 @@ def forward( return hidden_states -class DeepseekV2ForCausalLM(nn.Module): +class DeepseekV2ForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -492,6 +496,8 @@ def __init__( quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -500,7 +506,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py index 4a1c367de3f62..dfb8fe55d2fb8 100644 --- a/vllm/model_executor/models/exaone.py +++ b/vllm/model_executor/models/exaone.py @@ -38,8 +38,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope @@ -53,8 +52,9 @@ from vllm.transformers_utils.configs.exaone import ExaoneConfig from vllm.utils import is_hip -from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class ExaoneGatedMLP(nn.Module): @@ -354,6 +354,10 @@ def __init__( else: self.ln_f = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.wte(input_ids) @@ -397,7 +401,7 @@ def forward( return hidden_states -class ExaoneForCausalLM(nn.Module, SupportsLoRA): +class ExaoneForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -477,6 +481,9 @@ def __init__( else: self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) + def forward( self, input_ids: torch.Tensor, @@ -506,24 +513,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros( - (batch_size, self.config.hidden_size), - dtype=dtype, - device=device, - ), - "residual": - torch.zeros( - (batch_size, self.config.hidden_size), - dtype=dtype, - device=device, - ), - }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index b474d35baf89d..a20dd93cee18c 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -28,7 +28,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import get_act_fn @@ -36,8 +36,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -47,6 +46,10 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import RWConfig +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + FalconConfig = Union[HF_FalconConfig, RWConfig] @@ -333,6 +336,7 @@ def __init__( config: FalconConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -347,35 +351,45 @@ def __init__( ) # Transformer blocks - self.h = nn.ModuleList([ - FalconDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: FalconDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.h") # Final Layer Norm self.ln_f = LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, - input_ids: torch.LongTensor, + input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.word_embeddings(input_ids) - for i in range(len(self.h)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.word_embeddings(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.h[i] hidden_states = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states -class FalconForCausalLM(nn.Module): +class FalconForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -403,6 +417,8 @@ def __init__( ) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -412,12 +428,8 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, ) -> torch.Tensor: - hidden_states = self.transformer( - input_ids, - positions, - kv_caches, - attn_metadata, - ) + hidden_states = self.transformer(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -454,6 +466,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] if "query_key_value" in name: output_dim = getattr(param, "output_dim", None) diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index 87b88da0dc05c..835931746fd4b 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -41,8 +41,9 @@ from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, SequenceData) -from .interfaces import SupportsMultiModal -from .utils import flatten_bn, merge_multimodal_embeddings +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (flatten_bn, group_weights_with_prefix, + merge_multimodal_embeddings) # Cannot find the following 2 numbers from hf config. _IMAGE_TOKEN_ID = 71011 @@ -217,7 +218,7 @@ def input_mapper_for_fuyu(ctx: InputContext, data: object): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_fuyu_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_fuyu) @INPUT_REGISTRY.register_input_processor(input_processor_for_fuyu) -class FuyuForCausalLM(nn.Module, SupportsMultiModal): +class FuyuForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, config: FuyuConfig, @@ -242,6 +243,12 @@ def __init__(self, self.language_model = PersimmonForCausalLM(config.text_config, cache_config=cache_config, quant_config=quant_config) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @property + def sampler(self): + return self.language_model.sampler def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: @@ -297,23 +304,29 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, ): - image_input = self._parse_and_validate_image_input(**kwargs) + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.embed_tokens(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.image_token_id) + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.embed_tokens( + input_ids) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.image_token_id) - else: - inputs_embeds = None + else: + inputs_embeds = None hidden_states = self.language_model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states @@ -336,34 +349,16 @@ def sample( return next_tokens def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - params_dict = dict(self.named_parameters(remove_duplicate=False)) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): - # Models trained using ColossalAI may include these tensors in - # the checkpoint. Skip them. - continue - param = params_dict[name] - - if "query_key_value" in name: - # copy from vllm/model_executor/models/bloom.py - # NOTE: Fuyu's fused QKV's output_dim has the shape of - # (num_heads * 3 * head_size), while the - # required shape is (3 * num_heads * head_size). - # Thus, we need weight conversion. - output_dim = getattr(param, "output_dim", None) - num_heads = self.config.num_attention_heads - if output_dim is not None: - loaded_weight_shape = loaded_weight.shape - loaded_weight = loaded_weight.view( - loaded_weight_shape[:output_dim] + (num_heads, 3, -1) + - loaded_weight_shape[output_dim + 1:]) - loaded_weight = loaded_weight.transpose( - output_dim, output_dim + 1) - loaded_weight = loaded_weight.reshape(loaded_weight_shape) + # prepare weight iterators for components + weights_group = group_weights_with_prefix(weights) + # load vision embeddings + vision_params_dict = dict(self.vision_embed_tokens.named_parameters()) + for name, loaded_weight in weights_group["vision_embed_tokens"]: + param = vision_params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + # load llm backbone + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 36fd389831282..ca419891f69db 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -15,7 +15,7 @@ # limitations under the License. """Inference-only Gemma model compatible with HuggingFace weights.""" from functools import lru_cache -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -23,7 +23,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import GemmaRMSNorm @@ -31,8 +31,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -41,7 +40,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) logger = init_logger(__name__) @@ -245,6 +246,7 @@ def __init__( config: GemmaConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -253,10 +255,11 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - GemmaDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GemmaDecoderLayer(config, cache_config, quant_config + ), + prefix=f"{prefix}.layers") self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Normalize the embedding by sqrt(hidden_size) @@ -265,6 +268,9 @@ def __init__( # See https://github.com/huggingface/transformers/pull/29402 normalizer = self.config.hidden_size**0.5 self.register_buffer("normalizer", torch.tensor(normalizer)) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -275,29 +281,38 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, + intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + hidden_states *= self.normalizer + residual = None else: - hidden_states = self.get_input_embeddings(input_ids) - hidden_states *= self.normalizer - residual = None - for i in range(len(self.layers)): + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class GemmaForCausalLM(nn.Module, SupportsLoRA): +class GemmaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -339,6 +354,8 @@ def __init__( self.model = GemmaModel(config, cache_config, quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -347,9 +364,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -388,6 +405,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -400,6 +419,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index f9d9f9e7567c8..9fddaac3a0837 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -14,7 +14,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Iterable, List, Optional, Set, Tuple +from typing import Iterable, List, Optional, Set, Tuple, Union import torch from torch import nn @@ -22,7 +22,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.logger import init_logger from vllm.model_executor.layers.activation import GeluAndMul from vllm.model_executor.layers.layernorm import GemmaRMSNorm @@ -30,8 +30,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -40,7 +39,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) logger = init_logger(__name__) @@ -244,6 +245,7 @@ def __init__( config: Gemma2Config, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -252,10 +254,11 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - Gemma2DecoderLayer(layer_idx, config, cache_config, quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Gemma2DecoderLayer(int(prefix.split(".")[ + -1]), config, cache_config, quant_config), + prefix=f"{prefix}.layers") self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) # Normalize the embedding by sqrt(hidden_size) @@ -264,6 +267,9 @@ def __init__( # See https://github.com/huggingface/transformers/pull/29402 normalizer = self.config.hidden_size**0.5 self.register_buffer("normalizer", torch.tensor(normalizer)) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -271,25 +277,36 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - hidden_states *= self.normalizer + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + hidden_states *= self.normalizer - residual = None - for i in range(len(self.layers)): + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class Gemma2ForCausalLM(nn.Module, SupportsLoRA): +class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -338,6 +355,8 @@ def __init__( self.logits_processor = LogitsProcessor( config.vocab_size, soft_cap=config.final_logit_softcapping) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -346,9 +365,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -387,6 +406,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -399,6 +420,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index fb5a297661ddc..975502340e5f9 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -32,8 +32,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -41,7 +40,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .utils import is_pp_missing_parameter, make_layers +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class GPT2Attention(nn.Module): @@ -204,6 +205,9 @@ def __init__( config, cache_config, quant_config, prefix=prefix), prefix=f"{prefix}.h") self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) def forward( self, @@ -234,7 +238,7 @@ def forward( return hidden_states -class GPT2LMHeadModel(nn.Module): +class GPT2LMHeadModel(nn.Module, SupportsPP): def __init__( self, @@ -256,6 +260,8 @@ def __init__( self.config.hidden_size) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -264,7 +270,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states @@ -286,16 +292,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index a8567f32958be..5a96c334c3d40 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPTBigCode model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -26,14 +26,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -42,7 +41,9 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) is_hpu = current_platform.is_hpu() @@ -197,6 +198,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -210,11 +212,15 @@ def __init__( self.embed_dim, org_num_embeddings=config.vocab_size) self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim) - self.h = nn.ModuleList([ - GPTBigCodeBlock(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.h = make_layers( + config.num_hidden_layers, + lambda prefix: GPTBigCodeBlock(config, cache_config, quant_config), + prefix=f"{prefix}.h", + ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) def forward( self, @@ -222,25 +228,32 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - inputs_embeds = self.wte(input_ids) - position_embeds = self.wpe(position_ids) - hidden_states = inputs_embeds + position_embeds + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + inputs_embeds = self.wte(input_ids) + position_embeds = self.wpe(position_ids) + hidden_states = inputs_embeds + position_embeds + else: + hidden_states = intermediate_tensors["hidden_states"] if is_hpu: import habana_frameworks.torch as htorch htorch.core.mark_step() - for i in range(len(self.h)): + for i in range(self.start_layer, self.end_layer): layer = self.h[i] - hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) + hidden_states = layer(hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) if is_hpu: htorch.core.mark_step() - + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states -class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA): +class GPTBigCodeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = {"c_attn": ["c_attn"]} supported_lora_modules = ["c_fc", "c_proj", "wte", "c_attn"] @@ -280,6 +293,8 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -288,9 +303,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -319,6 +334,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip attention mask. # NOTE: "c_attn.bias" should not be skipped. continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 664d775c8ba40..d40bf8c88ee19 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-J model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -24,14 +24,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -40,6 +39,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class GPTJAttention(nn.Module): @@ -178,6 +181,7 @@ def __init__( config: GPTJConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -186,11 +190,15 @@ def __init__( config.vocab_size, self.embed_dim, ) - self.h = nn.ModuleList([ - GPTJBlock(config, cache_config, quant_config) - for _ in range(config.n_layer) - ]) + self.start_layer, self.end_layer, self.h = make_layers( + config.n_layer, + lambda prefix: GPTJBlock(config, cache_config, quant_config), + prefix=f"{prefix}.h", + ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) def forward( self, @@ -198,21 +206,27 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.wte(input_ids) - for i in range(len(self.h)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.wte(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.h[i] hidden_states = layer( position_ids, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.ln_f(hidden_states) return hidden_states -class GPTJForCausalLM(nn.Module): +class GPTJForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -233,6 +247,8 @@ def __init__( ) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -241,9 +257,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -283,6 +299,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -291,6 +309,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index 5f6f1e3880547..23a1ca06cc69e 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -16,7 +16,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only GPT-NeoX model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -24,14 +24,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -40,6 +39,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class GPTNeoXAttention(nn.Module): @@ -191,6 +194,7 @@ def __init__( config: GPTNeoXConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -199,12 +203,16 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - GPTNeoXLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GPTNeoXLayer(config, cache_config, quant_config), + prefix=f"{prefix}.layers", + ) self.final_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -212,21 +220,27 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_in(input_ids) - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_in(input_ids) + else: + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states = layer( position_ids, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.final_layer_norm(hidden_states) return hidden_states -class GPTNeoXForCausalLM(nn.Module): +class GPTNeoXForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -247,6 +261,8 @@ def __init__( self.embed_out.weight = self.gpt_neox.embed_in.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.gpt_neox.make_empty_intermediate_tensors) def forward( self, @@ -255,9 +271,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.gpt_neox(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -288,6 +304,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Models trained using OpenRLHF may include # these tensors in the checkpoint. Skip them. continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] if "query_key_value" in name: diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py index d4853fd790098..dcf4f5b27704a 100644 --- a/vllm/model_executor/models/granite.py +++ b/vllm/model_executor/models/granite.py @@ -51,7 +51,7 @@ from vllm.sequence import IntermediateTensors from vllm.utils import is_hip -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers @@ -311,13 +311,13 @@ def forward( else: hidden_states = self.get_input_embeddings(input_ids) residual = None + + hidden_states *= self.config.embedding_multiplier else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] - hidden_states *= self.config.embedding_multiplier - for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states = layer( @@ -337,7 +337,7 @@ def forward( return hidden_states -class GraniteForCausalLM(nn.Module, SupportsLoRA): +class GraniteForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -404,9 +404,12 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight logit_scale = getattr(config, "logit_scale", 1.0) + + if hasattr(config, "logits_scaling"): + logit_scale /= config.logits_scaling self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, - logit_scale) + scale=logit_scale) self.sampler = Sampler() else: self.lm_head = PPMissingLayer() @@ -428,8 +431,6 @@ def compute_logits( sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) - if logits is not None: - logits /= self.config.logits_scaling return logits def sample( diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py new file mode 100644 index 0000000000000..5266951794a80 --- /dev/null +++ b/vllm/model_executor/models/granitemoe.py @@ -0,0 +1,448 @@ +# coding=utf-8 +# Adapted from +# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py +# Copyright 2023 The vLLM team. +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Inference-only GraniteMoe model.""" +from typing import Iterable, List, Optional, Tuple + +import torch +from torch import nn +from transformers.models.granitemoe import GraniteMoeConfig + +from vllm.attention import Attention, AttentionMetadata +from vllm.config import CacheConfig, LoRAConfig +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size +from vllm.model_executor.layers.fused_moe import FusedMoE +from vllm.model_executor.layers.layernorm import RMSNorm +from vllm.model_executor.layers.linear import (QKVParallelLinear, + ReplicatedLinear, + RowParallelLinear) +from vllm.model_executor.layers.logits_processor import LogitsProcessor +from vllm.model_executor.layers.quantization.base_config import ( + QuantizationConfig) +from vllm.model_executor.layers.rotary_embedding import get_rope +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.vocab_parallel_embedding import ( + DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) +from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.sequence import IntermediateTensors + +from . import mixtral +from .interfaces import SupportsLoRA, SupportsPP +from .utils import make_layers + + +class GraniteMoeMoE(nn.Module): + """A tensor-parallel MoE implementation for GraniteMoe that shards each + expert across all ranks. + Each expert's weights are sharded across all ranks and a fused MoE + kernel is used for the forward pass, and finally we reduce the outputs + across ranks. + """ + + def __init__(self, + num_experts: int, + top_k: int, + hidden_size: int, + intermediate_size: int, + params_dtype: Optional[torch.dtype] = None, + quant_config: Optional[QuantizationConfig] = None, + tp_size: Optional[int] = None, + prefix: str = ""): + super().__init__() + self.hidden_size = hidden_size + + # Gate always runs at half / full precision for now. + self.gate = ReplicatedLinear(hidden_size, + num_experts, + bias=False, + params_dtype=params_dtype, + quant_config=None, + prefix=f"{prefix}.gate") + + self.experts = FusedMoE(num_experts=num_experts, + top_k=top_k, + hidden_size=hidden_size, + intermediate_size=intermediate_size, + params_dtype=params_dtype, + reduce_results=True, + renormalize=True, + quant_config=quant_config, + tp_size=tp_size, + prefix=f"{prefix}.experts") + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + # NOTE: hidden_states can have either 1D or 2D shape. + orig_shape = hidden_states.shape + hidden_states = hidden_states.view(-1, self.hidden_size) + # router_logits: (num_tokens, n_experts) + router_logits, _ = self.gate(hidden_states) + final_hidden_states = self.experts(hidden_states, router_logits) + return final_hidden_states.view(orig_shape) + + +class GraniteMoeAttention(nn.Module): + + def __init__( + self, + hidden_size: int, + num_heads: int, + num_kv_heads: int, + max_position: int = 4096 * 32, + rope_theta: float = 10000, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + attention_multiplier: Optional[float] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = hidden_size + tp_size = get_tensor_model_parallel_world_size() + self.total_num_heads = num_heads + assert self.total_num_heads % tp_size == 0 + self.num_heads = self.total_num_heads // tp_size + self.total_num_kv_heads = num_kv_heads + if self.total_num_kv_heads >= tp_size: + # Number of KV heads is greater than TP size, so we partition + # the KV heads across multiple tensor parallel GPUs. + assert self.total_num_kv_heads % tp_size == 0 + else: + # Number of KV heads is less than TP size, so we replicate + # the KV heads across multiple tensor parallel GPUs. + assert tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + self.head_dim = hidden_size // self.total_num_heads + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_kv_heads * self.head_dim + self.scaling = (attention_multiplier if attention_multiplier + is not None else self.head_dim**-1) + self.rope_theta = rope_theta + + self.qkv_proj = QKVParallelLinear( + hidden_size, + self.head_dim, + self.total_num_heads, + self.total_num_kv_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + ) + self.o_proj = RowParallelLinear( + self.total_num_heads * self.head_dim, + hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.o_proj", + ) + self.rotary_emb = get_rope( + self.head_dim, + rotary_dim=self.head_dim, + max_position=max_position, + base=int(self.rope_theta), + is_neox_style=True, + ) + self.attn = Attention(self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config) + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + q, k = self.rotary_emb(positions, q, k) + attn_output = self.attn(q, k, v, kv_cache, attn_metadata) + output, _ = self.o_proj(attn_output) + return output + + +class GraniteMoeDecoderLayer(nn.Module): + + def __init__( + self, + config: GraniteMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.hidden_size = config.hidden_size + # Requires transformers > 4.32.0 + rope_theta = getattr(config, "rope_theta", 10000) + self.self_attn = GraniteMoeAttention( + hidden_size=self.hidden_size, + num_heads=config.num_attention_heads, + max_position=config.max_position_embeddings, + num_kv_heads=config.num_key_value_heads, + rope_theta=rope_theta, + cache_config=cache_config, + quant_config=quant_config, + prefix=f"{prefix}.self_attn", + attention_multiplier=config.attention_multiplier) + self.block_sparse_moe = GraniteMoeMoE( + num_experts=config.num_local_experts, + top_k=config.num_experts_per_tok, + hidden_size=config.hidden_size, + intermediate_size=config.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.block_sparse_moe") + + self.input_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, + eps=config.rms_norm_eps) + + self.residual_multiplier = config.residual_multiplier + + def forward( + self, + positions: torch.Tensor, + hidden_states: torch.Tensor, + kv_cache: torch.Tensor, + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states = self.self_attn( + positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata, + ) + hidden_states = residual + hidden_states * self.residual_multiplier + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.block_sparse_moe(hidden_states) + hidden_states = residual + hidden_states * self.residual_multiplier + + return hidden_states + + +class GraniteMoeModel(nn.Module): + + def __init__( + self, + config: GraniteMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + prefix: str = "", + ) -> None: + super().__init__() + self.padding_idx = config.pad_token_id + lora_vocab = (lora_config.lora_extra_vocab_size * + (lora_config.max_loras or 1)) if lora_config else 0 + self.vocab_size = config.vocab_size + lora_vocab + self.org_vocab_size = config.vocab_size + + self.embed_tokens = VocabParallelEmbedding( + self.vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + ) + self.embedding_multiplier = config.embedding_multiplier + + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: GraniteMoeDecoderLayer( + config, cache_config, quant_config=quant_config, prefix=prefix + ), + prefix=f"{prefix}.layers") + + self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> torch.Tensor: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + hidden_states *= self.embedding_multiplier + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): + layer = self.layers[i] + hidden_states = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states = self.norm(hidden_states) + return hidden_states + + +class GraniteMoeForCausalLM(nn.Module, SupportsLoRA, SupportsPP): + fall_back_to_pt_during_load = False + + packed_modules_mapping = { + "qkv_proj": [ + "q_proj", + "k_proj", + "v_proj", + ], + } + + # LoRA specific attributes + supported_lora_modules = [ + "qkv_proj", + "o_proj", + "embed_tokens", + "lm_head", + ] + embedding_modules = { + "embed_tokens": "input_embeddings", + "lm_head": "output_embeddings", + } + embedding_padding_modules = ["lm_head"] + + def __init__( + self, + config: GraniteMoeConfig, + cache_config: Optional[CacheConfig] = None, + quant_config: Optional[QuantizationConfig] = None, + lora_config: Optional[LoRAConfig] = None, + ) -> None: + super().__init__() + + self.config = config + self.lora_config = lora_config + + self.model = GraniteMoeModel(config, + cache_config, + quant_config, + lora_config=lora_config, + prefix="model") + self.unpadded_vocab_size = config.vocab_size + if lora_config: + self.unpadded_vocab_size += lora_config.lora_extra_vocab_size + self.lm_head = ParallelLMHead( + self.unpadded_vocab_size, + config.hidden_size, + org_num_embeddings=config.vocab_size, + padding_size=DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else lora_config.lora_vocab_padding_size, + quant_config=quant_config, + ) + if config.tie_word_embeddings: + self.lm_head.weight = self.model.embed_tokens.weight + + self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, + config.vocab_size, + scale=1 / + self.config.logits_scaling) + + self.sampler = Sampler() + + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + ) -> torch.Tensor: + hidden_states = self.model(input_ids, positions, kv_caches, + attn_metadata, intermediate_tensors) + return hidden_states + + def compute_logits( + self, hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[torch.Tensor]: + logits = self.logits_processor(self.lm_head, hidden_states, + sampling_metadata) + return logits + + def make_empty_intermediate_tensors( + self, batch_size: int, dtype: torch.dtype, + device: torch.device) -> IntermediateTensors: + return IntermediateTensors({ + "hidden_states": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + "residual": + torch.zeros((batch_size, self.config.hidden_size), + dtype=dtype, + device=device), + }) + + def sample( + self, + logits: Optional[torch.Tensor], + sampling_metadata: SamplingMetadata, + ) -> Optional[SamplerOutput]: + next_tokens = self.sampler(logits, sampling_metadata) + return next_tokens + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + new_weights = {} + for n, p in weights: + if n.endswith('.block_sparse_moe.input_linear.weight'): + for e in range(p.size(0)): + w1_name = n.replace( + '.block_sparse_moe.input_linear.weight', + ".block_sparse_moe.experts.%d.w1.weight" % e) + w3_name = n.replace( + '.block_sparse_moe.input_linear.weight', + ".block_sparse_moe.experts.%d.w3.weight" % e) + w1_param, w3_param = p[e].chunk(2, dim=0) + assert w1_name not in new_weights + assert w3_name not in new_weights + new_weights[w1_name] = w1_param + new_weights[w3_name] = w3_param + elif n.endswith('.block_sparse_moe.output_linear.weight'): + for e in range(p.size(0)): + w2_name = n.replace( + '.block_sparse_moe.output_linear.weight', + ".block_sparse_moe.experts.%d.w2.weight" % e) + w2_param = p[e] + assert w2_name not in new_weights + new_weights[w2_name] = w2_param + elif n.endswith('.block_sparse_moe.router.layer.weight'): + gate_name = n.replace('.block_sparse_moe.router.layer.weight', + ".block_sparse_moe.gate.weight") + assert gate_name not in new_weights + new_weights[gate_name] = p + elif n == 'lm_head.weight' and self.config.tie_word_embeddings: + pass + else: + new_weights[n] = p + mixtral.MixtralForCausalLM.load_weights(self, new_weights.items()) diff --git a/vllm/model_executor/models/interfaces.py b/vllm/model_executor/models/interfaces.py index 069948f812253..298174fa05965 100644 --- a/vllm/model_executor/models/interfaces.py +++ b/vllm/model_executor/models/interfaces.py @@ -1,11 +1,17 @@ -from typing import (ClassVar, Dict, List, Literal, Optional, Protocol, Type, - Union, overload, runtime_checkable) +import inspect +from typing import (TYPE_CHECKING, ClassVar, Dict, List, Literal, Optional, + Protocol, Type, Union, overload, runtime_checkable) +import torch from typing_extensions import TypeIs -from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig from vllm.logger import init_logger +if TYPE_CHECKING: + from vllm.attention import AttentionMetadata + from vllm.config import LoRAConfig, MultiModalConfig, SchedulerConfig + from vllm.sequence import IntermediateTensors + logger = init_logger(__name__) @@ -22,7 +28,7 @@ class SupportsMultiModal(Protocol): MRO of your model class. """ - def __init__(self, *, multimodal_config: MultiModalConfig) -> None: + def __init__(self, *, multimodal_config: "MultiModalConfig") -> None: ... @@ -32,7 +38,7 @@ def __init__(self, *, multimodal_config: MultiModalConfig) -> None: class _SupportsMultiModalType(Protocol): supports_multimodal: Literal[True] - def __call__(self, *, multimodal_config: MultiModalConfig) -> None: + def __call__(self, *, multimodal_config: "MultiModalConfig") -> None: ... @@ -75,7 +81,7 @@ class SupportsLoRA(Protocol): embedding_padding_modules: ClassVar[List[str]] # lora_config is None when LoRA is not enabled - def __init__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: + def __init__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None: ... @@ -90,7 +96,7 @@ class _SupportsLoRAType(Protocol): embedding_modules: Dict[str, str] embedding_padding_modules: List[str] - def __call__(self, *, lora_config: Optional[LoRAConfig] = None) -> None: + def __call__(self, *, lora_config: Optional["LoRAConfig"] = None) -> None: ... @@ -145,6 +151,132 @@ def _supports_lora( return isinstance(model, SupportsLoRA) +@runtime_checkable +class SupportsPP(Protocol): + """The interface required for all models that support pipeline parallel.""" + + supports_pp: ClassVar[Literal[True]] = True + """ + A flag that indicates this model supports pipeline parallel. + + Note: + There is no need to redefine this flag if this class is in the + MRO of your model class. + """ + + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> "IntermediateTensors": + """Called when PP rank > 0 for profiling purposes.""" + ... + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: "AttentionMetadata", + intermediate_tensors: Optional["IntermediateTensors"], + ) -> Union[torch.Tensor, "IntermediateTensors"]: + """ + Accept :class:`IntermediateTensors` when PP rank > 0. + + Return :class:`IntermediateTensors` only for the last PP rank. + """ + ... + + +# We can't use runtime_checkable with ClassVar for issubclass checks +# so we need to treat the class as an instance and use isinstance instead +@runtime_checkable +class _SupportsPPType(Protocol): + supports_pp: Literal[True] + + def make_empty_intermediate_tensors( + self, + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> "IntermediateTensors": + ... + + def forward( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: "AttentionMetadata", + intermediate_tensors: Optional["IntermediateTensors"], + ) -> Union[torch.Tensor, "IntermediateTensors"]: + ... + + +@overload +def supports_pp(model: Type[object]) -> TypeIs[Type[SupportsPP]]: + ... + + +@overload +def supports_pp(model: object) -> TypeIs[SupportsPP]: + ... + + +def supports_pp( + model: Union[Type[object], object], +) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: + supports_attributes = _supports_pp_attributes(model) + supports_inspect = _supports_pp_inspect(model) + + if supports_attributes and not supports_inspect: + logger.warning( + "The model (%s) sets `supports_pp=True`, but does not accept " + "`intermediate_tensors` in its `forward` method", model) + + if not supports_attributes: + pp_attrs = ("make_empty_intermediate_tensors", ) + missing_attrs = tuple(attr for attr in pp_attrs + if not hasattr(model, attr)) + + if getattr(model, "supports_pp", False): + if missing_attrs: + logger.warning( + "The model (%s) sets `supports_pp=True`, " + "but is missing PP-specific attributes: %s", + model, + missing_attrs, + ) + else: + if not missing_attrs: + logger.warning( + "The model (%s) contains all PP-specific attributes, " + "but does not set `supports_pp=True`.", model) + + return supports_attributes and supports_inspect + + +def _supports_pp_attributes( + model: Union[Type[object], object], +) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: + if isinstance(model, type): + return isinstance(model, _SupportsPPType) + + return isinstance(model, SupportsPP) + + +def _supports_pp_inspect( + model: Union[Type[object], object], +) -> Union[bool, TypeIs[Type[SupportsPP]], TypeIs[SupportsPP]]: + model_forward = getattr(model, "forward", None) + if not callable(model_forward): + return False + + forward_params = inspect.signature(model_forward).parameters + return "intermediate_tensors" in forward_params + + @runtime_checkable class HasInnerState(Protocol): """The interface required for all models that has inner state.""" @@ -158,7 +290,7 @@ class HasInnerState(Protocol): def __init__(self, *, - scheduler_config: Optional[SchedulerConfig] = None) -> None: + scheduler_config: Optional["SchedulerConfig"] = None) -> None: ... @@ -168,7 +300,7 @@ class _HasInnerStateType(Protocol): def __init__(self, *, - scheduler_config: Optional[SchedulerConfig] = None) -> None: + scheduler_config: Optional["SchedulerConfig"] = None) -> None: ... diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 11a8431a5e7f7..f6cde44e9d83d 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -18,8 +18,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -28,6 +27,7 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP from .utils import (is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers) @@ -266,7 +266,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: IntermediateTensors = None, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: @@ -297,7 +297,7 @@ def forward( return hidden_states -class InternLM2ForCausalLM(nn.Module): +class InternLM2ForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -325,7 +325,7 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - intermediate_tensors: IntermediateTensors, + intermediate_tensors: Optional[IntermediateTensors], ) -> torch.Tensor: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index e84990a2ab109..816e93818f2ee 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -5,9 +5,9 @@ # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- import re -from functools import partial -from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, - Tuple, TypedDict, Union) +from functools import cached_property, partial +from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, + TypedDict, Union) import torch import torch.nn as nn @@ -17,7 +17,6 @@ from vllm.attention import AttentionMetadata from vllm.config import CacheConfig, MultiModalConfig -from vllm.distributed import get_pp_group from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput @@ -32,7 +31,7 @@ from .clip import (dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_num_patches) -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .utils import (flatten_bn, group_weights_with_prefix, init_vllm_registered_model, merge_multimodal_embeddings) @@ -123,7 +122,7 @@ def calculate_num_blocks(orig_width: int, orig_height: int, min_num: int, return blocks, target_width, target_height -def calculate_num_blocks_wrapper(hf_config: Dict[str, Any], +def calculate_num_blocks_wrapper(hf_config: PretrainedConfig, max_dynamic_patch: Optional[int] = None): if max_dynamic_patch is None: max_dynamic_patch = hf_config.max_dynamic_patch @@ -183,7 +182,7 @@ def image_to_pixel_values(image: Image.Image, input_size: int, min_num: int, return pixel_values -def image_to_pixel_values_wrapper(hf_config: Dict[str, Any], +def image_to_pixel_values_wrapper(hf_config: PretrainedConfig, max_dynamic_patch: Optional[int] = None): image_size = hf_config.vision_config.image_size min_num = hf_config.min_dynamic_patch @@ -197,7 +196,7 @@ def image_to_pixel_values_wrapper(hf_config: Dict[str, Any], use_thumbnail=use_thumbnail) -def get_internvl_num_patches(hf_config: Dict[str, Any]): +def get_internvl_num_patches(hf_config: PretrainedConfig): vision_config = hf_config.vision_config downsample_ratio = hf_config.downsample_ratio image_size = vision_config.image_size @@ -362,7 +361,7 @@ def dummy_data_for_internvl(ctx: InputContext, @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_internvl_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_internvl) @INPUT_REGISTRY.register_input_processor(input_processor_for_internvl) -class InternVLChatModel(nn.Module, SupportsMultiModal): +class InternVLChatModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, config: PretrainedConfig, @@ -408,10 +407,12 @@ def __init__(self, self.make_empty_intermediate_tensors = ( self.language_model.make_empty_intermediate_tensors) + @cached_property + def sampler(self): if hasattr(self.language_model, "sampler"): - self.sampler = self.language_model.sampler - else: - self.sampler = Sampler() + return self.language_model.sampler + + return Sampler() def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() @@ -515,18 +516,22 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: - image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None and get_pp_group().is_first_rank: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.img_context_token_id) + ) -> Union[SamplerOutput, IntermediateTensors]: + if intermediate_tensors is not None: input_ids = None - else: inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) + if image_input is not None: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.img_context_token_id) + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index b0fbb7e9829e0..c5e5393442e30 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -33,8 +33,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -43,7 +42,9 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import JAISConfig -from .utils import is_pp_missing_parameter, make_layers +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class SwiGLUActivation(nn.Module): @@ -244,6 +245,9 @@ def __init__( ) self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.n_embd)) def forward( self, @@ -279,7 +283,7 @@ def forward( return hidden_states -class JAISLMHeadModel(nn.Module): +class JAISLMHeadModel(nn.Module, SupportsPP): def __init__( self, @@ -304,6 +308,8 @@ def __init__( self.logits_processor = LogitsProcessor(vocab_size=config.vocab_size, scale=self.output_logits_scale) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -326,16 +332,6 @@ def compute_logits( sampling_metadata) return logits - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def sample( self, logits: torch.Tensor, diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 330a2b6e3fd7f..06ec324b3e108 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -25,20 +25,18 @@ causal_conv1d_fn, causal_conv1d_update) from vllm.model_executor.layers.mamba.ops.mamba_ssm import ( selective_scan_fn, selective_state_update) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import HasInnerState from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, _get_graph_batch_size) -from .interfaces import SupportsLoRA +from .interfaces import HasInnerState, SupportsLoRA KVCache = Tuple[torch.Tensor, torch.Tensor] diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index eba607b93d634..f4a91298f7a15 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,8 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope @@ -52,8 +51,9 @@ from vllm.sequence import IntermediateTensors from vllm.utils import is_hip -from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) is_hpu = current_platform.is_hpu() @@ -75,12 +75,15 @@ def __init__( output_sizes=[intermediate_size] * 2, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj") - self.down_proj = RowParallelLinear(input_size=intermediate_size, - output_size=hidden_size, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.down_proj") + prefix=f"{prefix}.gate_up_proj", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.down_proj", + ) if hidden_act != "silu": raise ValueError(f"Unsupported activation: {hidden_act}. " "Only silu is supported for now.") @@ -164,12 +167,14 @@ def __init__( rope_scaling=rope_scaling, is_neox_style=is_neox_style, ) - self.attn = Attention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads, - cache_config=cache_config, - quant_config=quant_config) + self.attn = Attention( + self.num_heads, + self.head_dim, + self.scaling, + num_kv_heads=self.num_kv_heads, + cache_config=cache_config, + quant_config=quant_config, + ) def forward( self, @@ -251,12 +256,10 @@ def forward( else: hidden_states, residual = self.input_layernorm( hidden_states, residual) - hidden_states = self.self_attn( - positions=positions, - hidden_states=hidden_states, - kv_cache=kv_cache, - attn_metadata=attn_metadata, - ) + hidden_states = self.self_attn(positions=positions, + hidden_states=hidden_states, + kv_cache=kv_cache, + attn_metadata=attn_metadata) # Fully Connected hidden_states, residual = self.post_attention_layernorm( @@ -298,12 +301,17 @@ def __init__( cache_config=cache_config, quant_config=quant_config, prefix=prefix), - prefix=f"{prefix}.layers") + prefix=f"{prefix}.layers", + ) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -332,16 +340,11 @@ def forward( htorch.core.mark_step() for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer( - positions, - hidden_states, - kv_caches[i - self.start_layer], - attn_metadata, - residual, - ) + hidden_states, residual = layer(positions, hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, residual) if is_hpu: htorch.core.mark_step() - if not get_pp_group().is_last_rank: return IntermediateTensors({ "hidden_states": hidden_states, @@ -352,17 +355,10 @@ def forward( return hidden_states -class LlamaForCausalLM(nn.Module, SupportsLoRA): +class LlamaForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { - "qkv_proj": [ - "q_proj", - "k_proj", - "v_proj", - ], - "gate_up_proj": [ - "gate_proj", - "up_proj", - ], + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"] } # LoRA specific attributes @@ -372,7 +368,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ] embedding_modules = { "embed_tokens": "input_embeddings", - "lm_head": "output_embeddings", + "lm_head": "output_embeddings" } embedding_padding_modules = ["lm_head"] bitsandbytes_stacked_params_mapping = { @@ -428,10 +424,12 @@ def __init__( self.unpadded_vocab_size, config.hidden_size, org_num_embeddings=config.vocab_size, - padding_size=DEFAULT_VOCAB_PADDING_SIZE - # We need bigger padding if using lora for kernel - # compatibility - if not lora_config else lora_config.lora_vocab_padding_size, + padding_size=( + DEFAULT_VOCAB_PADDING_SIZE + # We need bigger padding if using lora for kernel + # compatibility + if not lora_config else + lora_config.lora_vocab_padding_size), quant_config=quant_config, ) if config.tie_word_embeddings: @@ -444,6 +442,8 @@ def __init__( self.sampler = Sampler() else: self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -466,28 +466,11 @@ def compute_logits( sampling_metadata) return logits - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: + def sample(self, logits: torch.Tensor, + sampling_metadata: SamplingMetadata) -> Optional[SamplerOutput]: next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -521,7 +504,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): loaded_weight = loaded_weight[0] weight_loader(param, loaded_weight) continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) diff --git a/vllm/model_executor/models/llama_embedding.py b/vllm/model_executor/models/llama_embedding.py index 8f1c77da50d96..ce05d8e3911bf 100644 --- a/vllm/model_executor/models/llama_embedding.py +++ b/vllm/model_executor/models/llama_embedding.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -8,10 +8,13 @@ from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.pooling_metadata import PoolingMetadata -from vllm.sequence import PoolerOutput +from vllm.sequence import IntermediateTensors, PoolerOutput +from .interfaces import SupportsPP +from .utils import is_pp_missing_parameter -class LlamaEmbeddingModel(nn.Module): + +class LlamaEmbeddingModel(nn.Module, SupportsPP): """A model that uses Llama with additional embedding functionalities. This class encapsulates the LlamaModel and provides an interface for @@ -29,6 +32,8 @@ def __init__( super().__init__() self.model = LlamaModel(**kwargs) self._pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -36,10 +41,12 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: return self.model.forward(input_ids, positions, kv_caches, - attn_metadata, inputs_embeds) + attn_metadata, intermediate_tensors, + inputs_embeds) def pooler( self, @@ -73,6 +80,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -81,6 +90,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 69eb177a7dea8..a62231b628cb9 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -1,3 +1,4 @@ +from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -11,7 +12,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -21,7 +22,7 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_max_clip_image_tokens, input_processor_for_clip) -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens, input_processor_for_siglip) @@ -198,7 +199,7 @@ def _init_vision_tower(hf_config: LlavaConfig): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava) -class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal): +class LlavaForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, config: LlavaConfig, @@ -220,6 +221,16 @@ def __init__(self, self.language_model = init_vllm_registered_model( config.text_config, cache_config, quant_config) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size expected_dims = (3, h, w) @@ -315,7 +326,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LLaVA-1.5. One key thing to understand is the `input_ids` already accounts for the @@ -351,26 +362,30 @@ def forward( See also: :class:`LlavaImageInputs` """ - image_input = self._parse_and_validate_image_input(**kwargs) + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) - input_ids = None - else: - inputs_embeds = None + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index 4341cc38bdd28..efad800d7d760 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -1,3 +1,4 @@ +from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -13,7 +14,7 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -23,7 +24,7 @@ from .clip import (CLIPVisionModel, dummy_image_for_clip, dummy_seq_data_for_clip, get_clip_image_feature_size, get_clip_patch_grid_length, input_processor_for_clip) -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .llava import LlavaMultiModalProjector from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_siglip_image_feature_size, @@ -286,7 +287,8 @@ def _init_vision_tower(hf_config: LlavaNextConfig): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_llava_next_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next) -class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal): +class LlavaNextForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__(self, config: LlavaNextConfig, @@ -300,6 +302,8 @@ def __init__(self, # TODO: Optionally initializes this for supporting embeddings. self.vision_tower = _init_vision_tower(config) + self.image_newline = nn.Parameter( + torch.empty(config.text_config.hidden_size)) self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, @@ -308,8 +312,15 @@ def __init__(self, self.language_model = init_vllm_registered_model( config.text_config, cache_config, quant_config) - self.image_newline = nn.Parameter( - torch.empty(config.text_config.hidden_size)) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) @@ -542,7 +553,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LlaVA-NeXT. One key thing to understand is the `input_ids` already accounts for the @@ -587,26 +598,30 @@ def forward( See also: :class:`LlavaNextImageInputs` """ - image_input = self._parse_and_validate_image_input(**kwargs) + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) - input_ids = None - else: - inputs_embeds = None + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py index 397a6cce5af2c..44b3073b46358 100644 --- a/vllm/model_executor/models/llava_next_video.py +++ b/vllm/model_executor/models/llava_next_video.py @@ -1,4 +1,5 @@ import math +from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -12,9 +13,8 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.clip import CLIPVisionModel from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -25,7 +25,7 @@ from vllm.utils import is_list_of from .clip import dummy_image_for_clip, dummy_seq_data_for_clip -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip) from .utils import (group_weights_with_prefix, init_vllm_registered_model, @@ -267,7 +267,8 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: "video", get_max_llava_next_video_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_next_video) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_next_video) -class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal): +class LlavaNextVideoForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__(self, config: LlavaNextVideoConfig, @@ -281,13 +282,23 @@ def __init__(self, # Initialize the vision tower only up to the required feature layer self.vision_tower = _init_vision_tower(config) + self.vision_resampler = LlavaNextVideoPooler(config) self.multi_modal_projector = LlavaNextMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, projector_hidden_act=config.projector_hidden_act) self.language_model = init_vllm_registered_model( config.text_config, cache_config, quant_config) - self.vision_resampler = LlavaNextVideoPooler(config) + + self.make_empty_intermediate_tensors = ( + self.language_model.model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() def _validate_video_pixel_values( self, data: Union[torch.Tensor, List[torch.Tensor]] @@ -397,34 +408,36 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LlaVA-NeXT-Video. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. pixel_values_videos: Pixels in each frames for each input videos. """ - video_input = self._parse_and_validate_video_input(**kwargs) - - # merge video embeddings into input embeddings - if video_input is not None: - video_embeddings = self._process_video_pixels(video_input) - inputs_embeds = self.language_model \ - .model.get_input_embeddings(input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, video_embeddings, - self.config.video_token_index) - + if intermediate_tensors is not None: input_ids = None - else: inputs_embeds = None + else: + video_input = self._parse_and_validate_video_input(**kwargs) + if video_input is not None: + video_embeddings = self._process_video_pixels(video_input) + inputs_embeds = self.language_model \ + .model.get_input_embeddings(input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, video_embeddings, + self.config.video_token_index) + + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py index 9099d4f88222d..af957e35d8089 100644 --- a/vllm/model_executor/models/llava_onevision.py +++ b/vllm/model_executor/models/llava_onevision.py @@ -1,4 +1,5 @@ import math +from functools import cached_property from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -17,9 +18,8 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger from vllm.model_executor.layers.activation import get_act_fn -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY @@ -31,7 +31,7 @@ from .clip import (CLIPVisionModel, dummy_seq_data_for_clip, dummy_video_for_clip, get_clip_image_feature_size, get_clip_patch_grid_length, input_processor_for_clip) -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .siglip import (SiglipVisionModel, dummy_seq_data_for_siglip, dummy_video_for_siglip, get_siglip_image_feature_size, get_siglip_patch_grid_length, input_processor_for_siglip) @@ -414,7 +414,8 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: "video", get_max_llava_onevision_video_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_llava_onevision) @INPUT_REGISTRY.register_input_processor(input_processor_for_llava_onevision) -class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal): +class LlavaOnevisionForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__(self, config: LlavaOnevisionConfig, @@ -434,6 +435,16 @@ def __init__(self, self.image_newline = nn.Parameter( torch.empty(config.text_config.hidden_size)) + self.make_empty_intermediate_tensors = ( + self.language_model.model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) @@ -805,39 +816,42 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for LlaVA-Onevision. Args: input_ids: Flattened (concatenated) input_ids corresponding to a batch. pixel_values_videos: Pixels in each frames for each input videos. """ - modalities = self._parse_and_validate_multimodal_inputs(**kwargs) - # merge video embeddings into input embeddings - if modalities: - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - if "images" in modalities: - image_input = modalities["images"] - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) - if "videos" in modalities: - video_input = modalities["videos"] - video_embeddings = self._process_video_pixels(video_input) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, video_embeddings, - self.config.video_token_index) + if intermediate_tensors is not None: input_ids = None - else: inputs_embeds = None + else: + modalities = self._parse_and_validate_multimodal_inputs(**kwargs) + if modalities: + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + if "images" in modalities: + image_input = modalities["images"] + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) + if "videos" in modalities: + video_input = modalities["videos"] + video_embeddings = self._process_video_pixels(video_input) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, video_embeddings, + self.config.video_token_index) + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 963ad7553fe1d..6bba1594c270f 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -22,7 +22,7 @@ # limitations under the License. """Inference-only MiniCPM model compatible with HuggingFace weights.""" import math -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -30,7 +30,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.activation import SiluAndMul @@ -41,8 +41,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -52,7 +51,9 @@ from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class MiniCPMMoE(nn.Module): @@ -264,7 +265,7 @@ class MiniCPMDecoderLayer(nn.Module): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ) -> None: @@ -346,10 +347,11 @@ class MiniCPMModel(nn.Module): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -365,15 +367,24 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) - self._init_layers() + self._init_layers(prefix, config, cache_config, quant_config) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], self.config.hidden_size)) - def _init_layers(self): - self.layers = nn.ModuleList([ - MiniCPMDecoderLayer(self.config, self.cache_config, - self.quant_config) - for _ in range(self.config.num_hidden_layers) - ]) + def _init_layers( + self, + prefix: str, + config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], + ): + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MiniCPMDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: embedding = self.embed_tokens(input_ids) @@ -387,27 +398,36 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.get_input_embeddings(input_ids) + residual = None else: - hidden_states = self.get_input_embeddings(input_ids) - residual = None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] - for i in range(len(self.layers)): + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states = self.norm(hidden_states) return hidden_states -class MiniCPMForCausalLM(nn.Module, SupportsLoRA): +class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -470,6 +490,8 @@ def __init__( self.logits_processor = LogitsProcessor(unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def _init_model(self): self.model = MiniCPMModel(config=self.config, @@ -484,7 +506,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states @@ -548,6 +570,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -557,6 +581,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -568,6 +594,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/minicpm3.py b/vllm/model_executor/models/minicpm3.py index a048a3dba0415..c37bc5ad7c38f 100644 --- a/vllm/model_executor/models/minicpm3.py +++ b/vllm/model_executor/models/minicpm3.py @@ -26,6 +26,7 @@ import torch from torch import nn +from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig @@ -34,19 +35,20 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear, ReplicatedLinear, RowParallelLinear) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.models.minicpm import (MiniCPMDecoderLayer, MiniCPMForCausalLM, MiniCPMModel) +from .utils import make_layers + class MiniCPM3Attention(nn.Module): def __init__( self, - config, + config: PretrainedConfig, hidden_size: int, num_heads: int, qk_nope_head_dim: int, @@ -199,12 +201,18 @@ def _init_attn_block(self): class MiniCPM3Model(MiniCPMModel): - def _init_layers(self): - self.layers = nn.ModuleList([ - MiniCPM3DecoderLayer(self.config, self.cache_config, - self.quant_config) - for _ in range(self.config.num_hidden_layers) - ]) + def _init_layers( + self, + prefix: str, + config: PretrainedConfig, + cache_config: Optional[CacheConfig], + quant_config: Optional[QuantizationConfig], + ): + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MiniCPM3DecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") class MiniCPM3ForCausalLM(MiniCPMForCausalLM): diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 0e0e86f2fe503..6d0fa34f299ad 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -45,7 +45,6 @@ from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.utils import set_default_torch_dtype from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.llama import LlamaModel from vllm.model_executor.models.minicpm import MiniCPMModel from vllm.model_executor.models.module_mapping import MultiModelKeys @@ -59,7 +58,8 @@ from vllm.sequence import IntermediateTensors, SequenceData from .idefics2_vision_model import Idefics2VisionTransformer -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsMultiModal, SupportsPP +from .utils import is_pp_missing_parameter _KEYS_TO_MODIFY_MAPPING = { "llm.lm_head": "lm_head", @@ -337,7 +337,7 @@ def input_mapper_for_minicpmv(ctx: InputContext, data: object): return MultiModalInputs(batch_data) -class MiniCPMVBaseModel(nn.Module, SupportsMultiModal): +class MiniCPMVBaseModel(nn.Module, SupportsMultiModal, SupportsPP): """ The abstract class of MiniCPMV can only be inherited, but cannot be instantiated. @@ -374,6 +374,9 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.llm.make_empty_intermediate_tensors) + def get_embedding( self, input_ids: torch.Tensor, @@ -498,9 +501,12 @@ def forward( intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: Any, ) -> torch.Tensor: - image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) + if intermediate_tensors is not None: + vlm_embeddings = None + else: + image_inputs = self._parse_and_validate_inputs(input_ids, **kwargs) - vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) + vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) output = self.llm( input_ids=None, @@ -557,6 +563,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue + if is_pp_missing_parameter( + name.replace(weight_name, param_name), self): + continue param = params_dict[name.replace(weight_name, param_name)] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -564,6 +573,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: use_default_weight_loading = True if use_default_weight_loading: + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 7a075162d579f..27096e5d0e814 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -36,8 +36,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -48,8 +47,9 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA -from .utils import is_pp_missing_parameter, make_layers +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class MixtralMoE(nn.Module): @@ -277,6 +277,9 @@ def __init__( prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -285,7 +288,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: hidden_states = self.embed_tokens(input_ids) residual = None @@ -307,7 +310,7 @@ def forward( return hidden_states -class MixtralForCausalLM(nn.Module, SupportsLoRA): +class MixtralForCausalLM(nn.Module, SupportsLoRA, SupportsPP): fall_back_to_pt_during_load = False packed_modules_mapping = { @@ -366,6 +369,8 @@ def __init__( self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -374,7 +379,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states @@ -388,20 +393,6 @@ def compute_logits( sampling_metadata) return logits - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def sample( self, logits: Optional[torch.Tensor], diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 68471f6ac77d1..63e2c60a84271 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Mixtral model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import numpy as np import torch @@ -31,7 +31,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, tensor_model_parallel_all_reduce) from vllm.model_executor.layers.layernorm import RMSNorm @@ -39,8 +39,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -49,6 +48,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class MixtralMLP(nn.Module): @@ -296,6 +299,7 @@ def __init__( config: MixtralConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -305,13 +309,15 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - MixtralDecoderLayer(config, - cache_config, - quant_config=quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: MixtralDecoderLayer( + config, cache_config, quant_config=quant_config), + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -319,19 +325,30 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) + kv_caches[i - self.start_layer], + attn_metadata, residual) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class MixtralForCausalLM(nn.Module): +class MixtralForCausalLM(nn.Module, SupportsPP): fall_back_to_pt_during_load = False def __init__( @@ -351,6 +368,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -359,9 +378,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -400,6 +419,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -412,6 +433,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if ("block_sparse_moe.experts." in name and name not in params_dict): continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 0fcbf06e1a060..e3d3937b13fa0 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -1,22 +1,21 @@ # coding=utf-8 # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main import math -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch import torch.nn as nn from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) @@ -25,6 +24,10 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs.mpt import MPTConfig +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + def _get_alibi_slopes( total_num_heads: int, @@ -208,6 +211,7 @@ def __init__( config: MPTConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() assert config.embedding_fraction == 1.0 @@ -217,10 +221,10 @@ def __init__( config.vocab_size, config.d_model, ) - self.blocks = nn.ModuleList([ - MPTBlock(config, cache_config, quant_config) - for _ in range(config.n_layers) - ]) + self.start_layer, self.end_layer, self.blocks = make_layers( + config.n_layers, + lambda prefix: MPTBlock(config, cache_config, quant_config), + prefix=f"{prefix}.blocks") self.norm_f = nn.LayerNorm(config.d_model) if config.no_bias: for module in self.modules(): @@ -228,6 +232,9 @@ def __init__( module.bias, nn.Parameter): # Remove the bias term in Linear and LayerNorm. module.register_parameter("bias", None) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.d_model)) def forward( self, @@ -235,21 +242,29 @@ def forward( position_ids: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.wte(input_ids) - for i in range(len(self.blocks)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.wte(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for i in range(self.start_layer, self.end_layer): block = self.blocks[i] hidden_states = block( position_ids, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm_f(hidden_states) return hidden_states -class MPTForCausalLM(nn.Module): +class MPTForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -266,6 +281,8 @@ def __init__( self.lm_head = self.transformer.wte self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def forward( self, @@ -274,9 +291,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.transformer(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -302,6 +319,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index e9ff12de2094e..14515e16e34ac 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -34,8 +34,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -46,8 +45,9 @@ from vllm.sequence import IntermediateTensors from vllm.transformers_utils.configs import NemotronConfig -from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) # The architecture is pretty similar to Llama, with these changes: # - There is no gate_proj, just up_proj @@ -328,6 +328,9 @@ def __init__( eps=config.norm_eps) else: self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -372,7 +375,7 @@ def forward( return hidden_states -class NemotronForCausalLM(nn.Module, SupportsLoRA): +class NemotronForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -440,6 +443,8 @@ def __init__( self.sampler = Sampler() else: self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -470,20 +475,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 97749725dd132..5ca7c66f5407d 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMo model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -29,14 +29,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -45,6 +44,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class OlmoAttention(nn.Module): """ @@ -223,19 +226,24 @@ class OlmoModel(nn.Module): def __init__(self, config: OlmoConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.config = config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([ - OlmoDecoderLayer(config, cache_config, quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OlmoDecoderLayer(config, cache_config, quant_config + ), + prefix=f"{prefix}.layers") self.norm = nn.LayerNorm(config.hidden_size, elementwise_affine=False, bias=False) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -243,34 +251,41 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: """ :param input_ids: A tensor of shape `(batch_size, seq_len)`. """ - # Get embeddings of input. - # shape: (batch_size, seq_len, d_model) - inputs_embeds = self.embed_tokens(input_ids) + if get_pp_group().is_first_rank: + # Get embeddings of input. + # shape: (batch_size, seq_len, d_model) + inputs_embeds = self.embed_tokens(input_ids) - # embed positions - hidden_states = inputs_embeds + # embed positions + hidden_states = inputs_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] # Apply blocks one-by-one. - for layer_idx, decoder_layer in enumerate(self.layers): + for i in range(self.start_layer, self.end_layer): # shape: (batch_size, seq_len, d_model) - hidden_states = decoder_layer( + hidden_states = self.layers[i]( positions, hidden_states, - kv_caches[layer_idx], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) # Apply final layer norm. # shape: (batch_size, seq_len or 1, d_model) hidden_states = self.norm(hidden_states) return hidden_states -class OlmoForCausalLM(nn.Module): +class OlmoForCausalLM(nn.Module, SupportsPP): """ Extremely barebones HF model wrapper. """ @@ -294,6 +309,8 @@ def __init__(self, ) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -302,12 +319,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, ) return hidden_states @@ -358,6 +376,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -366,6 +386,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py index c76e5e86c89d8..a1ba80e0d7108 100644 --- a/vllm/model_executor/models/olmoe.py +++ b/vllm/model_executor/models/olmoe.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OLMoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -18,15 +18,14 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -36,6 +35,10 @@ from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class OlmoeMoE(nn.Module): """A tensor-parallel MoE implementation for Olmoe that shards each expert @@ -243,6 +246,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -252,34 +256,54 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - OlmoeDecoderLayer(config, - layer_idx, - cache_config, - quant_config=quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OlmoeDecoderLayer(config, int( + prefix.split(".")[-1]), cache_config, quant_config), + prefix=f"{prefix}.layers") self.norm = RMSNorm(config.hidden_size, eps=1e-5) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class OlmoeForCausalLM(nn.Module): +class OlmoeForCausalLM(nn.Module, SupportsPP): fall_back_to_pt_during_load = False @@ -299,6 +323,9 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + def forward( self, input_ids: torch.Tensor, @@ -306,9 +333,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -363,6 +390,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue if name not in params_dict: continue @@ -376,6 +406,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, @@ -388,6 +421,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index 47ec718a43420..727dd65acc749 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -17,7 +17,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only OPT model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -25,15 +25,14 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) @@ -41,6 +40,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class OPTLearnedPositionalEmbedding(nn.Embedding): @@ -189,6 +192,7 @@ def __init__( config: OPTConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config @@ -232,10 +236,10 @@ def __init__( else: self.final_layer_norm = None - self.layers = nn.ModuleList([ - OPTDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OPTDecoderLayer(config, cache_config, quant_config), + prefix=f"{prefix}.layers") def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -246,18 +250,28 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is None: - inputs_embeds = self.get_input_embeddings(input_ids) - pos_embeds = self.embed_positions(positions) - if self.project_in is not None: - inputs_embeds, _ = self.project_in(inputs_embeds) - hidden_states = inputs_embeds + pos_embeds - for i in range(len(self.layers)): + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is None: + inputs_embeds = self.get_input_embeddings(input_ids) + pos_embeds = self.embed_positions(positions) + if self.project_in is not None: + inputs_embeds, _ = self.project_in(inputs_embeds) + hidden_states = inputs_embeds + pos_embeds + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states = layer(hidden_states, kv_caches[i], attn_metadata) + hidden_states = layer(hidden_states, + kv_caches[i - self.start_layer], + attn_metadata) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) if self.final_layer_norm is not None: hidden_states = self.final_layer_norm(hidden_states) if self.project_out is not None: @@ -275,6 +289,9 @@ def __init__( ): super().__init__() self.decoder = OPTDecoder(config, cache_config, quant_config) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.decoder.get_input_embeddings(input_ids) @@ -285,20 +302,22 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: return self.decoder(input_ids, positions, kv_caches, attn_metadata, + intermediate_tensors, inputs_embeds=inputs_embeds) -class OPTForCausalLM(nn.Module): +class OPTForCausalLM(nn.Module, SupportsPP): def __init__( self, - config, + config: OPTConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, ): @@ -313,6 +332,8 @@ def __init__( config.word_embed_proj_dim) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -321,9 +342,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -352,7 +373,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): ] params_dict = dict(self.named_parameters(remove_duplicate=False)) for name, loaded_weight in weights: - if "lm_head.weight" in name: + if "lm_head.weight" in name and self.config.tie_word_embeddings: continue if name.startswith("decoder."): name = "model." + name @@ -364,6 +385,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -372,6 +395,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index b01ce87adfa46..0913193f73a48 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -4,7 +4,7 @@ # Copyright (c) OrionStar Inc. # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE """Inference-only Orion-14B model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -12,14 +12,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -28,6 +27,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class OrionMLP(nn.Module): @@ -210,6 +213,7 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -219,11 +223,18 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - OrionDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: OrionDecoderLayer( + config, + cache_config, + quant_config, + ), + prefix=f"{prefix}.layers") self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -231,23 +242,34 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states = self.norm(hidden_states) return hidden_states -class OrionForCausalLM(nn.Module): +class OrionForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -266,6 +288,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -274,9 +298,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -321,6 +345,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -329,6 +355,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index 8130eb54753ea..93032b4095917 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -9,9 +9,8 @@ from vllm.config import CacheConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import Sampler, SamplerOutput +from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.gemma import GemmaForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -19,7 +18,7 @@ from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .siglip import (SiglipVisionModel, dummy_image_for_siglip, dummy_seq_data_for_siglip, get_max_siglip_image_tokens) from .utils import group_weights_with_prefix, merge_multimodal_embeddings @@ -129,7 +128,8 @@ def forward(self, image_features: torch.Tensor) -> torch.Tensor: @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_paligemma_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_paligemma) @INPUT_REGISTRY.register_input_processor(input_processor_for_paligemma) -class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal): +class PaliGemmaForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__(self, config: PaliGemmaConfig, @@ -149,12 +149,15 @@ def __init__(self, self.quant_config = quant_config self.language_model = GemmaForCausalLM(config.text_config, cache_config, quant_config) - self.unpadded_vocab_size = config.text_config.vocab_size logit_scale = getattr(config, "logit_scale", 1.0) - self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, - config.text_config.vocab_size, - logit_scale) - self.sampler = Sampler() + self.language_model.logits_processor.scale *= logit_scale + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @property + def sampler(self): + return self.language_model.sampler def _validate_pixel_values(self, data: torch.Tensor) -> torch.Tensor: h = w = self.config.vision_config.image_size @@ -239,32 +242,36 @@ def forward(self, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - **kwargs: object) -> SamplerOutput: - - parsed_image_input = self._parse_and_validate_image_input(**kwargs) + **kwargs: object) -> Union[SamplerOutput, IntermediateTensors]: + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + parsed_image_input = self._parse_and_validate_image_input(**kwargs) - if parsed_image_input is not None: - vision_embeddings = self._process_image_input(parsed_image_input) - # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa - vision_embeddings = vision_embeddings * (self.config.hidden_size** - -0.5) + if parsed_image_input is not None: + vision_embeddings = self._process_image_input( + parsed_image_input) + # https://github.com/huggingface/transformers/blob/main/src/transformers/models/paligemma/modeling_paligemma.py#L294 # noqa + vision_embeddings = vision_embeddings * ( + self.config.hidden_size**-0.5) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.config.image_token_index) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.config.image_token_index) - input_ids = None - else: - inputs_embeds = None + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index fda0602110a0b..b625d19f6447d 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only persimmon model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -28,14 +28,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -44,6 +43,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class PersimmonMLP(nn.Module): @@ -211,20 +214,23 @@ class PersimmonModel(nn.Module): def __init__(self, config: PersimmonConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.vocab_size = config.vocab_size self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([ - PersimmonDecoderLayer(config, - cache_config=cache_config, - quant_config=quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PersimmonDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -232,24 +238,31 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - if inputs_embeds is not None: - hidden_states = inputs_embeds + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + if inputs_embeds is not None: + hidden_states = inputs_embeds + else: + hidden_states = self.embed_tokens(input_ids) else: - hidden_states = self.embed_tokens(input_ids) - for i in range(len(self.layers)): + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): hidden_states = self.layers[i]( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.final_layernorm(hidden_states) return hidden_states -class PersimmonForCausalLM(nn.Module): +class PersimmonForCausalLM(nn.Module, SupportsPP): def __init__(self, config: PersimmonConfig, @@ -266,6 +279,8 @@ def __init__(self, bias=False) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -281,6 +296,7 @@ def forward( positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, inputs_embeds=inputs_embeds, ) return hidden_states @@ -312,6 +328,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] if "query_key_value" in name: diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index 15c21cfa2d8a8..c90fe2e0ab9ea 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -35,7 +35,7 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """Inference-only Phi-1.5 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -43,14 +43,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -59,7 +58,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class PhiAttention(nn.Module): @@ -196,18 +197,22 @@ class PhiModel(nn.Module): def __init__(self, config: PhiConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.config = config self.quant_config = quant_config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([ - PhiLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PhiLayer(config, cache_config, quant_config), + prefix=f"{prefix}.layers") self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -215,23 +220,31 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - for i in range(self.config.num_hidden_layers): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) + hidden_states = self.final_layernorm(hidden_states) return hidden_states -class PhiForCausalLM(nn.Module, SupportsLoRA): +class PhiForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -274,6 +287,8 @@ def __init__( quant_config=quant_config) self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -282,9 +297,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states @@ -325,6 +340,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -335,6 +352,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue # pylint: disable=E1136 + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index afc6fe9844ad6..4cfeb3bb3496f 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -1,5 +1,5 @@ import math -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -7,14 +7,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, +from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size) from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -23,6 +22,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + def load_column_parallel_weight(param: torch.nn.Parameter, loaded_weight: torch.Tensor): @@ -301,20 +304,25 @@ def __init__( config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", ): super().__init__() self.config = config self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) self.mup_embedding_multiplier = config.mup_embedding_multiplier - self.layers = nn.ModuleList([ - Phi3SmallDecoderLayer(config, layer_idx, cache_config, - quant_config) - for layer_idx in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Phi3SmallDecoderLayer(config, + int(prefix.split('.')[-1]), + cache_config, quant_config), + prefix=f"{prefix}.layers") self.final_layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def get_input_embeddings(self): return self.embed_tokens @@ -327,30 +335,37 @@ def forward( input_ids: torch.LongTensor, positions: Optional[torch.LongTensor], kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata = None, - ): - hidden_states = self.embed_tokens(input_ids) - if (self.mup_embedding_multiplier is not None - and self.mup_embedding_multiplier > 0.0): - hidden_states = hidden_states * self.mup_embedding_multiplier - for i in range(len(self.layers)): + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + if (self.mup_embedding_multiplier is not None + and self.mup_embedding_multiplier > 0.0): + hidden_states = hidden_states * self.mup_embedding_multiplier + else: + assert intermediate_tensors + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.final_layernorm(hidden_states) return hidden_states -class Phi3SmallForCausalLM(nn.Module): +class Phi3SmallForCausalLM(nn.Module, SupportsPP): _tied_weights_keys = ["lm_head.weight"] def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, @@ -372,6 +387,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) # tokens in tiktoken but not used if hasattr(config, 'dummy_token_indices'): @@ -419,12 +436,13 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: output_hidden_states = self.model( input_ids=input_ids, positions=positions, kv_caches=kv_caches, attn_metadata=attn_metadata, + intermediate_tensors=intermediate_tensors, ) output_hidden_states = output_hidden_states return output_hidden_states @@ -447,6 +465,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): continue if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index 245381518a7f8..ebfffb25360cd 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -15,7 +15,7 @@ # limitations under the License. import itertools import re -from functools import lru_cache +from functools import cached_property, lru_cache from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) @@ -29,13 +29,11 @@ from vllm.config import CacheConfig, ModelConfig, MultiModalConfig from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.logger import init_logger -from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import Sampler, SamplerOutput -from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.clip import CLIPVisionModel -from vllm.model_executor.models.llama import LlamaModel +from vllm.model_executor.models.llama import LlamaForCausalLM from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.utils import cached_get_tokenizer, repeat_and_pad_token @@ -43,8 +41,9 @@ from vllm.utils import is_list_of from .clip import dummy_image_for_clip, dummy_seq_data_for_clip -from .interfaces import SupportsMultiModal -from .utils import flatten_bn, merge_multimodal_embeddings +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (flatten_bn, group_weights_with_prefix, + merge_multimodal_embeddings) logger = init_logger(__name__) @@ -295,6 +294,37 @@ def add_image_newline(self, image_features_hd): dim=2).reshape(num_images, -1, hid_dim) return image_features_hd_newline + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + # prepare weight iterators for components + weights_group = group_weights_with_prefix(weights) + + # load vision encoder + self.img_processor.load_weights(weights_group["img_processor"]) + + # load glb_GN + for name, loaded_weight in weights_group["glb_GN"]: + assert name == "" + param = self.glb_GN + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load sub_GN + for name, loaded_weight in weights_group["sub_GN"]: + assert name == "" + param = self.sub_GN + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + + # load mlp projector + mlp_params_dict = dict(self.img_projection.named_parameters()) + for name, loaded_weight in weights_group["img_projection"]: + param = mlp_params_dict[name] + weight_loader = getattr(param, "weight_loader", + default_weight_loader) + weight_loader(param, loaded_weight) + # Based on https://huggingface.co/microsoft/Phi-3-vision-128k-instruct/blob/main/image_processing_phi3_v.py#L57 def _calc_padded_size(*, width: int, height: int, padding_unit: int = 336): @@ -508,7 +538,7 @@ def input_processor_for_phi3v(ctx: InputContext, @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_phi3v_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_phi3v) @INPUT_REGISTRY.register_input_processor(input_processor_for_phi3v) -class Phi3VForCausalLM(nn.Module, SupportsMultiModal): +class Phi3VForCausalLM(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, config: PretrainedConfig, @@ -521,17 +551,21 @@ def __init__(self, self.multimodal_config = multimodal_config self.image_token_id = _IMAGE_TOKEN_ID - self.model = LlamaModel(config, cache_config, quant_config) - # TODO: Optionally initializes this for supporting embeddings. self.vision_embed_tokens = Phi3HDImageEmbedding(config) - self.lm_head = ParallelLMHead(config.vocab_size, - config.hidden_size, - quant_config=quant_config) - if self.config.tie_word_embeddings: - self.lm_head.weight = self.model.embed_tokens.weight - self.logits_processor = LogitsProcessor(config.vocab_size) - self.sampler = Sampler() + + self.language_model = LlamaForCausalLM(config, cache_config, + quant_config) + + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() def _validate_image_sizes(self, data: torch.Tensor) -> torch.Tensor: expected_dims = (2, ) @@ -631,24 +665,29 @@ def forward(self, attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object): - image_input = self._parse_and_validate_image_input(**kwargs) - - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.model.get_input_embeddings(input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.image_token_id) + if intermediate_tensors is not None: input_ids = None - else: inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) + + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.image_token_id) + input_ids = None + else: + inputs_embeds = None - hidden_states = self.model(input_ids, - positions, - kv_caches, - attn_metadata, - intermediate_tensors, - inputs_embeds=inputs_embeds) + hidden_states = self.language_model.model(input_ids, + positions, + kv_caches, + attn_metadata, + intermediate_tensors, + inputs_embeds=inputs_embeds) return hidden_states @@ -657,66 +696,38 @@ def compute_logits( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[torch.Tensor]: - logits = self.logits_processor(self.lm_head, hidden_states, - sampling_metadata) - return logits + return self.language_model.compute_logits(hidden_states, + sampling_metadata) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens + return self.language_model.sample(logits, sampling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): - stacked_params_mapping = [ - # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), - ] + hf_to_vllm_mapping = { + "model.vision_embed_tokens.": "vision_embed_tokens.", + "lm_head.": "language_model.lm_head.", + "model.": "language_model.model.", + } - # TODO(ChristopherCho): This is a temporary fix to load - # the vision weights with CLIPVisionModel.load_weights() - vision_weights = [] - params_dict = dict(self.named_parameters()) - for name, loaded_weight in weights: - if "rotary_emb.inv_freq" in name: - continue - # Skip loading the img_processor weights since they are - # loaded separately. - if "vision_embed_tokens.img_processor" in name: - vision_weights.append((name, loaded_weight)) - continue - - for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items(): - if key_to_modify in name: - name = name.replace(key_to_modify, new_key) - for (param_name, weight_name, shard_id) in stacked_params_mapping: - if weight_name not in name: - continue - - param = params_dict[name.replace(weight_name, param_name)] - weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) - break - else: - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - if name in params_dict: - param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) - weight_loader(param, loaded_weight) - - # We use regex to extract the sub-module name - # from "model.vision_embed_tokens.img_processor.*" - vision_weights = [ - (re.search(r"vision_embed_tokens\.img_processor\.(.*)", - n).group(1), w) for n, w in vision_weights - ] - self.vision_embed_tokens.img_processor.load_weights(vision_weights) + def hf_to_vllm_name(key: str) -> str: + for hf_name, vllm_name in hf_to_vllm_mapping.items(): + if key.startswith(hf_name): + return key.replace(hf_name, vllm_name, 1) + + return key + + vllm_weights = {hf_to_vllm_name(k): v for k, v in weights} + + # prepare weight iterators for components + weights_group = group_weights_with_prefix(vllm_weights.items()) + + # load vision embeddings and encoder + self.vision_embed_tokens.load_weights( + weights_group["vision_embed_tokens"]) + + # load llm backbone + self.language_model.load_weights(weights_group["language_model"]) diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py index 487d9fc2f4337..a9c815916ed59 100644 --- a/vllm/model_executor/models/phimoe.py +++ b/vllm/model_executor/models/phimoe.py @@ -21,7 +21,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only PhiMoE model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -29,7 +29,7 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.linear import (QKVParallelLinear, ReplicatedLinear, @@ -46,7 +46,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class PhiMoEConfig(PretrainedConfig): @@ -435,6 +437,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.padding_idx = config.pad_token_id @@ -448,33 +451,56 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) - self.layers = nn.ModuleList([ - PhiMoEDecoderLayer(config, cache_config, quant_config=quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: PhiMoEDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers") self.norm = nn.LayerNorm(config.hidden_size, eps=config.rms_norm_eps, elementwise_affine=True) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + residual = intermediate_tensors["residual"] + + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states, residual = layer(positions, hidden_states, - kv_caches[i], attn_metadata, - residual) + hidden_states, residual = layer( + positions, + hidden_states, + kv_caches[i - self.start_layer], + attn_metadata, + residual, + ) + + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) + hidden_states = self.norm(hidden_states) return hidden_states -class PhiMoEForCausalLM(nn.Module, SupportsLoRA): +class PhiMoEForCausalLM(nn.Module, SupportsLoRA, SupportsPP): fall_back_to_pt_during_load = False packed_modules_mapping = { @@ -537,6 +563,9 @@ def __init__( config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + def forward( self, input_ids: torch.Tensor, @@ -544,9 +573,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits(self, hidden_states: torch.Tensor, @@ -589,6 +618,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -599,6 +631,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader( @@ -613,6 +648,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + # Skip layers on other devices. + if is_pp_missing_parameter(name, self): + continue # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index aa92e62a30d3f..c8957dcae6b16 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -1,4 +1,5 @@ from dataclasses import dataclass, fields +from functools import cached_property from itertools import tee from typing import Iterable, List, Mapping, Optional, Tuple, Union @@ -16,7 +17,7 @@ from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.quantization import QuantizationConfig -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import merge_multimodal_embeddings from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -25,7 +26,7 @@ from vllm.multimodal.utils import cached_get_tokenizer from vllm.sequence import IntermediateTensors, SequenceData -from .interfaces import SupportsMultiModal +from .interfaces import SupportsMultiModal, SupportsPP from .utils import init_vllm_registered_model @@ -126,7 +127,8 @@ def input_processor_for_pixtral(ctx: InputContext, llm_inputs: LLMInputs): @MULTIMODAL_REGISTRY.register_max_image_tokens(get_max_pixtral_image_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_pixtral) @INPUT_REGISTRY.register_input_processor(input_processor_for_pixtral) -class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal): +class PixtralForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__(self, config: PretrainedConfig, @@ -155,6 +157,16 @@ def __init__(self, self.vision_language_adapter = VisionLanguageAdapter( self.vision_args, dim=config.text_config.hidden_size) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + def forward( self, input_ids: torch.Tensor, @@ -163,32 +175,36 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for pixtral. TODO """ - image_input = self._parse_and_validate_image_input(**kwargs) + if intermediate_tensors is not None: + input_ids = None + inputs_embeds = None + else: + image_input = self._parse_and_validate_image_input(**kwargs) - if image_input is not None: - vision_embeddings = self._process_image_input(image_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) + if image_input is not None: + vision_embeddings = self._process_image_input(image_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, vision_embeddings, - self.vision_args.image_token_id) + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, vision_embeddings, + self.vision_args.image_token_id) - input_ids = None - else: - inputs_embeds = None + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model(input_ids, positions, kv_caches, attn_metadata, - None, + intermediate_tensors, inputs_embeds=inputs_embeds) return hidden_states diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index 761c1370b9776..fd8a27eec3b9a 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -31,15 +31,13 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.resampler import Resampler2, get_abs_pos from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.base import MultiModalInputs @@ -47,7 +45,9 @@ from vllm.sequence import IntermediateTensors, SequenceData from vllm.utils import is_list_of -from .utils import flatten_bn, is_pp_missing_parameter, make_layers +from .interfaces import SupportsMultiModal, SupportsPP +from .utils import (flatten_bn, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) logger = init_logger(__name__) @@ -568,6 +568,9 @@ def __init__( lambda prefix: QWenBlock(config, cache_config, quant_config), prefix=f"{prefix}.h") self.ln_f = RMSNorm(config.hidden_size, eps=config.layer_norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) self.visual = VisionTransformer(**config.visual, quant_config=quant_config) if hasattr( config, "visual") else None @@ -580,7 +583,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors], pixel_values: Optional[QwenImageInputs], - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: img_pos = None # If pixel / visual embeddings are provided, this is a visual model if pixel_values is not None and self.visual is not None: @@ -860,7 +863,7 @@ def dummy_data_for_qwen( @MULTIMODAL_REGISTRY.register_max_image_tokens(MAX_QWEN_IMG_TOKENS) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen) -class QWenLMHeadModel(nn.Module, SupportsMultiModal): +class QWenLMHeadModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__( self, @@ -881,6 +884,8 @@ def __init__( self.lm_head.weight = self.transformer.wte.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.transformer.make_empty_intermediate_tensors) def _get_image_input_type( self, @@ -912,33 +917,26 @@ def _get_image_input_type( ) return None - def forward(self, - input_ids: torch.Tensor, - positions: torch.Tensor, - kv_caches: List[torch.Tensor], - attn_metadata: AttentionMetadata, - intermediate_tensors: Optional[IntermediateTensors] = None, - pixel_values: Optional[torch.Tensor] = None) -> torch.Tensor: - pixel_values = self._get_image_input_type(pixel_values) + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + kv_caches: List[torch.Tensor], + attn_metadata: AttentionMetadata, + intermediate_tensors: Optional[IntermediateTensors] = None, + pixel_values: Optional[torch.Tensor] = None + ) -> Union[torch.Tensor, IntermediateTensors]: + if intermediate_tensors is not None: + input_ids = None + pixel_values = None + else: + pixel_values = self._get_image_input_type(pixel_values) + hidden_states = self.transformer(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors, pixel_values) return hidden_states - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def compute_logits( self, hidden_states: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 7fd90b2e8b282..fe842d2d0ef9a 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2 model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -37,8 +37,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -49,8 +48,9 @@ from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA -from .utils import PPMissingLayer, is_pp_missing_parameter, make_layers +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class Qwen2MLP(nn.Module): @@ -254,6 +254,9 @@ def __init__( prefix=f"{prefix}.layers", ) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) if get_pp_group().is_last_rank: self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) else: @@ -270,7 +273,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, inputs_embeds: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds @@ -306,7 +309,7 @@ def forward( return hidden_states -class Qwen2ForCausalLM(nn.Module, SupportsLoRA): +class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -365,6 +368,8 @@ def __init__( self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -373,7 +378,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states @@ -387,20 +392,6 @@ def compute_logits( sampling_metadata) return logits - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def sample( self, logits: torch.Tensor, diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index d80064601d993..d4475b7ca27af 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -22,7 +22,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen2MoE model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch import torch.nn.functional as F @@ -42,8 +42,7 @@ ReplicatedLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -53,7 +52,9 @@ from vllm.sequence import IntermediateTensors from vllm.utils import print_warning_once -from .utils import is_pp_missing_parameter, make_layers +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class Qwen2MoeMLP(nn.Module): @@ -338,6 +339,9 @@ def __init__( prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -346,7 +350,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: if get_pp_group().is_first_rank: hidden_states = self.embed_tokens(input_ids) residual = None @@ -368,7 +372,7 @@ def forward( return hidden_states -class Qwen2MoeForCausalLM(nn.Module): +class Qwen2MoeForCausalLM(nn.Module, SupportsPP): fall_back_to_pt_during_load = False @@ -389,6 +393,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -397,7 +403,7 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, attn_metadata, intermediate_tensors) return hidden_states @@ -411,20 +417,6 @@ def compute_logits( sampling_metadata) return logits - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - "residual": - torch.zeros((batch_size, self.config.hidden_size), - dtype=dtype, - device=device), - }) - def sample( self, logits: Optional[torch.Tensor], diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py index c82e8ed6ed1e0..24fd5152ecd09 100644 --- a/vllm/model_executor/models/qwen2_vl.py +++ b/vllm/model_executor/models/qwen2_vl.py @@ -55,7 +55,6 @@ from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.qwen2 import Qwen2Model from vllm.multimodal import (MULTIMODAL_REGISTRY, MultiModalDataDict, MultiModalInputs) @@ -68,6 +67,7 @@ from vllm.transformers_utils.processor import get_processor from vllm.utils import is_cpu +from .interfaces import SupportsMultiModal, SupportsPP from .utils import (PPMissingLayer, is_pp_missing_parameter, make_empty_intermediate_tensors_factory) @@ -883,7 +883,8 @@ def input_processor_for_qwen2_vl(ctx: InputContext, "video", get_max_qwen2_vl_video_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_qwen2_vl) @INPUT_REGISTRY.register_input_processor(input_processor_for_qwen2_vl) -class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal): +class Qwen2VLForConditionalGeneration(nn.Module, SupportsMultiModal, + SupportsPP): def __init__(self, config: Qwen2VLConfig, @@ -966,6 +967,9 @@ def _parse_and_validate_image_input( image_grid_thw=image_grid_thw) if image_embeds is not None: + image_embeds = self._validate_and_reshape_mm_tensor( + image_embeds, "image embeds") + if not isinstance(image_embeds, torch.Tensor): raise ValueError("Incorrect type of image embeddings. " f"Got type: {type(image_embeds)}") @@ -1027,7 +1031,7 @@ def forward( attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, **kwargs: object, - ) -> SamplerOutput: + ) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Qwen2-VL. Args: @@ -1047,41 +1051,43 @@ def forward( video_grid_thw: Tensor `(n_videos, 3)` of video 3D grid in LLM. `None` if no videos are passed. """ - - image_input = self._parse_and_validate_image_input(**kwargs) - video_input = self._parse_and_validate_video_input(**kwargs) - - if (image_input is None - and video_input is None) or not get_pp_group().is_first_rank: + if intermediate_tensors is not None: + input_ids = None inputs_embeds = None else: - if getattr(self.config, "rope_scaling", {}).get("type", - None) == "mrope": - assert positions.ndim == 2 and positions.size(0) == 3, ( - "multimodal section rotary embedding requires " - f"(3, seq_len) positions, but got {positions.size()}") - - inputs_embeds = self.model.embed_tokens(input_ids) - - if image_input is not None: - image_embeds = self._process_image_input(image_input) - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - image_embeds, - placeholder_token_id=self.config.image_token_id, - ) + image_input = self._parse_and_validate_image_input(**kwargs) + video_input = self._parse_and_validate_video_input(**kwargs) - if video_input is not None: - video_embeds = self._process_video_input(video_input) - inputs_embeds = self._merge_multimodal_embeddings( - input_ids, - inputs_embeds, - video_embeds, - placeholder_token_id=self.config.video_token_id, - ) - - input_ids = None + if image_input is None and video_input is None: + inputs_embeds = None + else: + rope_scaling = getattr(self.config, "rope_scaling", {}) + if rope_scaling.get("type", None) == "mrope": + assert positions.ndim == 2 and positions.size(0) == 3, ( + "multimodal section rotary embedding requires " + f"(3, seq_len) positions, but got {positions.size()}") + + inputs_embeds = self.model.embed_tokens(input_ids) + + if image_input is not None: + image_embeds = self._process_image_input(image_input) + inputs_embeds = self._merge_multimodal_embeddings( + input_ids, + inputs_embeds, + image_embeds, + placeholder_token_id=self.config.image_token_id, + ) + + if video_input is not None: + video_embeds = self._process_video_input(video_input) + inputs_embeds = self._merge_multimodal_embeddings( + input_ids, + inputs_embeds, + video_embeds, + placeholder_token_id=self.config.video_token_id, + ) + + input_ids = None hidden_states = self.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/registry.py b/vllm/model_executor/models/registry.py new file mode 100644 index 0000000000000..aa5736e7cd517 --- /dev/null +++ b/vllm/model_executor/models/registry.py @@ -0,0 +1,320 @@ +import importlib +import string +import subprocess +import sys +import uuid +from functools import lru_cache, partial +from typing import Callable, Dict, List, Optional, Tuple, Type, Union + +import torch.nn as nn + +from vllm.logger import init_logger +from vllm.utils import is_hip + +from .interfaces import supports_multimodal, supports_pp + +logger = init_logger(__name__) + +_GENERATION_MODELS = { + "AquilaModel": ("llama", "LlamaForCausalLM"), + "AquilaForCausalLM": ("llama", "LlamaForCausalLM"), # AquilaChat2 + "ArcticForCausalLM": ("arctic", "ArcticForCausalLM"), + "BaiChuanForCausalLM": ("baichuan", "BaiChuanForCausalLM"), # baichuan-7b + "BaichuanForCausalLM": ("baichuan", "BaichuanForCausalLM"), # baichuan-13b + "BloomForCausalLM": ("bloom", "BloomForCausalLM"), + "ChatGLMModel": ("chatglm", "ChatGLMForCausalLM"), + "ChatGLMForConditionalGeneration": ("chatglm", "ChatGLMForCausalLM"), + "CohereForCausalLM": ("commandr", "CohereForCausalLM"), + "DbrxForCausalLM": ("dbrx", "DbrxForCausalLM"), + "DeciLMForCausalLM": ("decilm", "DeciLMForCausalLM"), + "DeepseekForCausalLM": ("deepseek", "DeepseekForCausalLM"), + "DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"), + "ExaoneForCausalLM": ("exaone", "ExaoneForCausalLM"), + "FalconForCausalLM": ("falcon", "FalconForCausalLM"), + "GemmaForCausalLM": ("gemma", "GemmaForCausalLM"), + "Gemma2ForCausalLM": ("gemma2", "Gemma2ForCausalLM"), + "GPT2LMHeadModel": ("gpt2", "GPT2LMHeadModel"), + "GPTBigCodeForCausalLM": ("gpt_bigcode", "GPTBigCodeForCausalLM"), + "GPTJForCausalLM": ("gpt_j", "GPTJForCausalLM"), + "GPTNeoXForCausalLM": ("gpt_neox", "GPTNeoXForCausalLM"), + "GraniteForCausalLM": ("granite", "GraniteForCausalLM"), + "GraniteMoeForCausalLM": ("granitemoe", "GraniteMoeForCausalLM"), + "InternLMForCausalLM": ("llama", "LlamaForCausalLM"), + "InternLM2ForCausalLM": ("internlm2", "InternLM2ForCausalLM"), + "JAISLMHeadModel": ("jais", "JAISLMHeadModel"), + "JambaForCausalLM": ("jamba", "JambaForCausalLM"), + "LlamaForCausalLM": ("llama", "LlamaForCausalLM"), + # For decapoda-research/llama-* + "LLaMAForCausalLM": ("llama", "LlamaForCausalLM"), + "MistralForCausalLM": ("llama", "LlamaForCausalLM"), + "MixtralForCausalLM": ("mixtral", "MixtralForCausalLM"), + "QuantMixtralForCausalLM": ("mixtral_quant", "MixtralForCausalLM"), + # transformers's mpt class has lower case + "MptForCausalLM": ("mpt", "MPTForCausalLM"), + "MPTForCausalLM": ("mpt", "MPTForCausalLM"), + "MiniCPMForCausalLM": ("minicpm", "MiniCPMForCausalLM"), + "MiniCPM3ForCausalLM": ("minicpm3", "MiniCPM3ForCausalLM"), + "NemotronForCausalLM": ("nemotron", "NemotronForCausalLM"), + "OlmoForCausalLM": ("olmo", "OlmoForCausalLM"), + "OlmoeForCausalLM": ("olmoe", "OlmoeForCausalLM"), + "OPTForCausalLM": ("opt", "OPTForCausalLM"), + "OrionForCausalLM": ("orion", "OrionForCausalLM"), + "PersimmonForCausalLM": ("persimmon", "PersimmonForCausalLM"), + "PhiForCausalLM": ("phi", "PhiForCausalLM"), + "Phi3ForCausalLM": ("phi3", "Phi3ForCausalLM"), + "Phi3SmallForCausalLM": ("phi3_small", "Phi3SmallForCausalLM"), + "PhiMoEForCausalLM": ("phimoe", "PhiMoEForCausalLM"), + "Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"), + "Qwen2MoeForCausalLM": ("qwen2_moe", "Qwen2MoeForCausalLM"), + "Qwen2VLForConditionalGeneration": + ("qwen2_vl", "Qwen2VLForConditionalGeneration"), + "RWForCausalLM": ("falcon", "FalconForCausalLM"), + "StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"), + "StableLmForCausalLM": ("stablelm", "StablelmForCausalLM"), + "Starcoder2ForCausalLM": ("starcoder2", "Starcoder2ForCausalLM"), + "SolarForCausalLM": ("solar", "SolarForCausalLM"), + "XverseForCausalLM": ("xverse", "XverseForCausalLM"), + # NOTE: The below models are for speculative decoding only + "MedusaModel": ("medusa", "Medusa"), + "EAGLEModel": ("eagle", "EAGLE"), + "MLPSpeculatorPreTrainedModel": ("mlp_speculator", "MLPSpeculator"), +} + +_EMBEDDING_MODELS = { + "MistralModel": ("llama_embedding", "LlamaEmbeddingModel"), + "Qwen2ForRewardModel": ("qwen2_rm", "Qwen2ForRewardModel"), +} + +_MULTIMODAL_MODELS = { + "Blip2ForConditionalGeneration": + ("blip2", "Blip2ForConditionalGeneration"), + "ChameleonForConditionalGeneration": + ("chameleon", "ChameleonForConditionalGeneration"), + "FuyuForCausalLM": ("fuyu", "FuyuForCausalLM"), + "InternVLChatModel": ("internvl", "InternVLChatModel"), + "LlavaForConditionalGeneration": ("llava", + "LlavaForConditionalGeneration"), + "LlavaNextForConditionalGeneration": ("llava_next", + "LlavaNextForConditionalGeneration"), + "LlavaNextVideoForConditionalGeneration": + ("llava_next_video", "LlavaNextVideoForConditionalGeneration"), + "LlavaOnevisionForConditionalGeneration": + ("llava_onevision", "LlavaOnevisionForConditionalGeneration"), + "MiniCPMV": ("minicpmv", "MiniCPMV"), + "PaliGemmaForConditionalGeneration": ("paligemma", + "PaliGemmaForConditionalGeneration"), + "Phi3VForCausalLM": ("phi3v", "Phi3VForCausalLM"), + "PixtralForConditionalGeneration": ("pixtral", + "PixtralForConditionalGeneration"), + "QWenLMHeadModel": ("qwen", "QWenLMHeadModel"), + "Qwen2VLForConditionalGeneration": ("qwen2_vl", + "Qwen2VLForConditionalGeneration"), + "UltravoxModel": ("ultravox", "UltravoxModel"), + "MllamaForConditionalGeneration": ("mllama", + "MllamaForConditionalGeneration"), +} +_CONDITIONAL_GENERATION_MODELS = { + "BartModel": ("bart", "BartForConditionalGeneration"), + "BartForConditionalGeneration": ("bart", "BartForConditionalGeneration"), +} + +_MODELS = { + **_GENERATION_MODELS, + **_EMBEDDING_MODELS, + **_MULTIMODAL_MODELS, + **_CONDITIONAL_GENERATION_MODELS, +} + +# Architecture -> type. +# out of tree models +_OOT_MODELS: Dict[str, Type[nn.Module]] = {} + +# Models not supported by ROCm. +_ROCM_UNSUPPORTED_MODELS: List[str] = [] + +# Models partially supported by ROCm. +# Architecture -> Reason. +_ROCM_SWA_REASON = ("Sliding window attention (SWA) is not yet supported in " + "Triton flash attention. For half-precision SWA support, " + "please use CK flash attention by setting " + "`VLLM_USE_TRITON_FLASH_ATTN=0`") +_ROCM_PARTIALLY_SUPPORTED_MODELS: Dict[str, str] = { + "Qwen2ForCausalLM": + _ROCM_SWA_REASON, + "MistralForCausalLM": + _ROCM_SWA_REASON, + "MixtralForCausalLM": + _ROCM_SWA_REASON, + "PaliGemmaForConditionalGeneration": + ("ROCm flash attention does not yet " + "fully support 32-bit precision on PaliGemma"), + "Phi3VForCausalLM": + ("ROCm Triton flash attention may run into compilation errors due to " + "excessive use of shared memory. If this happens, disable Triton FA " + "by setting `VLLM_USE_TRITON_FLASH_ATTN=0`") +} + + +class ModelRegistry: + + @staticmethod + def _get_module_cls_name(model_arch: str) -> Tuple[str, str]: + module_relname, cls_name = _MODELS[model_arch] + return f"vllm.model_executor.models.{module_relname}", cls_name + + @staticmethod + @lru_cache(maxsize=128) + def _try_get_model_stateful(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch not in _MODELS: + return None + + module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) + module = importlib.import_module(module_name) + return getattr(module, cls_name, None) + + @staticmethod + def _try_get_model_stateless(model_arch: str) -> Optional[Type[nn.Module]]: + if model_arch in _OOT_MODELS: + return _OOT_MODELS[model_arch] + + if is_hip(): + if model_arch in _ROCM_UNSUPPORTED_MODELS: + raise ValueError( + f"Model architecture {model_arch} is not supported by " + "ROCm for now.") + if model_arch in _ROCM_PARTIALLY_SUPPORTED_MODELS: + logger.warning( + "Model architecture %s is partially supported by ROCm: %s", + model_arch, _ROCM_PARTIALLY_SUPPORTED_MODELS[model_arch]) + + return None + + @staticmethod + def _try_load_model_cls(model_arch: str) -> Optional[Type[nn.Module]]: + model = ModelRegistry._try_get_model_stateless(model_arch) + if model is not None: + return model + + return ModelRegistry._try_get_model_stateful(model_arch) + + @staticmethod + def resolve_model_cls( + architectures: Union[str, List[str]], ) -> Tuple[Type[nn.Module], str]: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + for arch in architectures: + model_cls = ModelRegistry._try_load_model_cls(arch) + if model_cls is not None: + return (model_cls, arch) + + raise ValueError( + f"Model architectures {architectures} are not supported for now. " + f"Supported architectures: {ModelRegistry.get_supported_archs()}") + + @staticmethod + def get_supported_archs() -> List[str]: + return list(_MODELS.keys()) + list(_OOT_MODELS.keys()) + + @staticmethod + def register_model(model_arch: str, model_cls: Type[nn.Module]): + if model_arch in _MODELS: + logger.warning( + "Model architecture %s is already registered, and will be " + "overwritten by the new model class %s.", model_arch, + model_cls.__name__) + + _OOT_MODELS[model_arch] = model_cls + + @staticmethod + @lru_cache(maxsize=128) + def _check_stateless( + func: Callable[[Type[nn.Module]], bool], + model_arch: str, + *, + default: Optional[bool] = None, + ) -> bool: + """ + Run a boolean function against a model and return the result. + + If the model is not found, returns the provided default value. + + If the model is not already imported, the function is run inside a + subprocess to avoid initializing CUDA for the main program. + """ + model = ModelRegistry._try_get_model_stateless(model_arch) + if model is not None: + return func(model) + + if model_arch not in _MODELS and default is not None: + return default + + module_name, cls_name = ModelRegistry._get_module_cls_name(model_arch) + + valid_name_characters = string.ascii_letters + string.digits + "._" + if any(s not in valid_name_characters for s in module_name): + raise ValueError(f"Unsafe module name detected for {model_arch}") + if any(s not in valid_name_characters for s in cls_name): + raise ValueError(f"Unsafe class name detected for {model_arch}") + if any(s not in valid_name_characters for s in func.__module__): + raise ValueError(f"Unsafe module name detected for {func}") + if any(s not in valid_name_characters for s in func.__name__): + raise ValueError(f"Unsafe class name detected for {func}") + + err_id = uuid.uuid4() + + stmts = ";".join([ + f"from {module_name} import {cls_name}", + f"from {func.__module__} import {func.__name__}", + f"assert {func.__name__}({cls_name}), '{err_id}'", + ]) + + result = subprocess.run([sys.executable, "-c", stmts], + capture_output=True) + + if result.returncode != 0: + err_lines = [line.decode() for line in result.stderr.splitlines()] + if err_lines and err_lines[-1] != f"AssertionError: {err_id}": + err_str = "\n".join(err_lines) + raise RuntimeError( + "An unexpected error occurred while importing the model in " + f"another process. Error log:\n{err_str}") + + return result.returncode == 0 + + @staticmethod + def is_embedding_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + return any(arch in _EMBEDDING_MODELS for arch in architectures) + + @staticmethod + def is_multimodal_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + is_mm = partial(ModelRegistry._check_stateless, + supports_multimodal, + default=False) + + return any(is_mm(arch) for arch in architectures) + + @staticmethod + def is_pp_supported_model(architectures: Union[str, List[str]]) -> bool: + if isinstance(architectures, str): + architectures = [architectures] + if not architectures: + logger.warning("No model architectures are specified") + + is_pp = partial(ModelRegistry._check_stateless, + supports_pp, + default=False) + + return any(is_pp(arch) for arch in architectures) diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py index cd99538378412..743a81f8f9e95 100644 --- a/vllm/model_executor/models/siglip.py +++ b/vllm/model_executor/models/siglip.py @@ -246,7 +246,7 @@ class SiglipParallelAttention(nn.Module): def __init__( self, - config, + config: SiglipVisionConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() @@ -312,7 +312,7 @@ class SiglipMLP(nn.Module): def __init__( self, - config, + config: SiglipVisionConfig, quant_config: Optional[QuantizationConfig] = None, ): super().__init__() diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py index 16e576d0ac29c..b9298ed031144 100644 --- a/vllm/model_executor/models/solar.py +++ b/vllm/model_executor/models/solar.py @@ -26,6 +26,7 @@ import torch from torch import nn +from transformers import PretrainedConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig @@ -37,8 +38,7 @@ QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.quantization.compressed_tensors.utils import ( get_compressed_tensors_cache_scale) from vllm.model_executor.layers.rotary_embedding import get_rope @@ -47,14 +47,14 @@ DEFAULT_VOCAB_PADDING_SIZE, ParallelLMHead, VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, kv_cache_scales_loader, maybe_remap_kv_scale_name) -from vllm.model_executor.models.interfaces import SupportsLoRA -from vllm.model_executor.models.utils import (PPMissingLayer, - is_pp_missing_parameter, - make_layers) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors from vllm.utils import is_hip +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (PPMissingLayer, is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class SolarMLP(nn.Module): @@ -98,7 +98,7 @@ class SolarAttention(nn.Module): def __init__( self, - config, + config: PretrainedConfig, hidden_size: int, num_heads: int, num_kv_heads: int, @@ -187,7 +187,7 @@ class SolarDecoderLayer(nn.Module): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", @@ -267,7 +267,7 @@ class SolarModel(nn.Module): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, @@ -304,6 +304,10 @@ def __init__( else: self.norm = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -368,7 +372,7 @@ def forward( return hidden_states -class SolarForCausalLM(nn.Module, SupportsLoRA): +class SolarForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -406,7 +410,7 @@ class SolarForCausalLM(nn.Module, SupportsLoRA): def __init__( self, - config, + config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, @@ -448,6 +452,9 @@ def __init__( else: self.lm_head = PPMissingLayer() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) + def forward( self, input_ids: torch.Tensor, @@ -474,24 +481,6 @@ def sample( next_tokens = self.sampler(logits, sampling_metadata) return next_tokens - def make_empty_intermediate_tensors( - self, batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: - return IntermediateTensors({ - "hidden_states": - torch.zeros( - (batch_size, self.config.hidden_size), - dtype=dtype, - device=device, - ), - "residual": - torch.zeros( - (batch_size, self.config.hidden_size), - dtype=dtype, - device=device, - ), - }) - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 6236426dcd4e1..083a48588d01a 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -19,7 +19,7 @@ # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json """Inference-only StabeLM (https://github.com/Stability-AI/StableLM) model compatible with HuggingFace weights.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -27,14 +27,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -43,6 +42,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class StablelmMLP(nn.Module): @@ -194,19 +197,25 @@ class StableLMEpochModel(nn.Module): def __init__(self, config: PretrainedConfig, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None) -> None: + quant_config: Optional[QuantizationConfig] = None, + prefix: str = '') -> None: super().__init__() self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - StablelmDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: StablelmDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers", + ) norm_eps = getattr(config, "norm_eps", getattr(config, "layer_norm_eps", 1e-05)) self.norm = nn.LayerNorm(config.hidden_size, eps=norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -214,21 +223,28 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm(hidden_states) return hidden_states -class StablelmForCausalLM(nn.Module): +class StablelmForCausalLM(nn.Module, SupportsPP): def __init__( self, @@ -247,6 +263,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -255,9 +273,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -302,6 +320,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -310,6 +330,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index d3a3a83c8437f..81dd7c4daa5e9 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -18,7 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """ PyTorch Starcoder2 model.""" -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -26,14 +26,13 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.linear import (ColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -42,6 +41,10 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors +from .interfaces import SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) + class Starcoder2Attention(nn.Module): @@ -195,7 +198,8 @@ class Starcoder2Model(nn.Module): def __init__(self, config: Starcoder2Config, cache_config: Optional[CacheConfig] = None, - quant_config: Optional[QuantizationConfig] = None): + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() self.config = config self.padding_idx = config.pad_token_id @@ -204,13 +208,16 @@ def __init__(self, # TODO: consider padding_idx (currently removed) self.embed_tokens = VocabParallelEmbedding(config.vocab_size, config.hidden_size) - self.layers = nn.ModuleList([ - Starcoder2DecoderLayer(config, - cache_config, - quant_config=quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: Starcoder2DecoderLayer( + config, cache_config, quant_config=quant_config), + prefix=f"{prefix}.layers", + ) self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_epsilon) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory(["hidden_states"], + config.hidden_size)) def forward( self, @@ -218,17 +225,25 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + else: + assert intermediate_tensors is not None + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] - hidden_states = layer(positions, hidden_states, kv_caches[i], + hidden_states = layer(positions, hidden_states, + kv_caches[i - self.start_layer], attn_metadata) + if not get_pp_group().is_last_rank: + return IntermediateTensors({"hidden_states": hidden_states}) hidden_states = self.norm(hidden_states) return hidden_states -class Starcoder2ForCausalLM(nn.Module): +class Starcoder2ForCausalLM(nn.Module, SupportsPP): def __init__(self, config: Starcoder2Config, @@ -255,6 +270,8 @@ def __init__(self, self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -263,9 +280,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -302,6 +319,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): if weight_name not in name: continue name = name.replace(weight_name, param_name) + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -309,6 +328,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): else: if self.config.tie_word_embeddings and "lm_head.weight" in name: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py index 71808eb4c2719..daa6e72dd1002 100644 --- a/vllm/model_executor/models/ultravox.py +++ b/vllm/model_executor/models/ultravox.py @@ -3,7 +3,7 @@ import math from array import array -from functools import lru_cache +from functools import cached_property, lru_cache from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union, cast) @@ -22,12 +22,10 @@ from vllm.inputs.registry import InputContext from vllm.model_executor.layers.activation import SiluAndMul, get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) -from vllm.model_executor.layers.sampler import SamplerOutput +from vllm.model_executor.layers.quantization import QuantizationConfig +from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.model_loader.loader import DefaultModelLoader from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsMultiModal from vllm.model_executor.models.utils import (flatten_bn, group_weights_with_prefix, init_vllm_registered_model, @@ -37,9 +35,12 @@ from vllm.multimodal.base import MultiModalInputs, NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) -from vllm.sequence import VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData +from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, IntermediateTensors, + SequenceData) from vllm.transformers_utils.configs.ultravox import UltravoxConfig +from .interfaces import SupportsMultiModal, SupportsPP + _AUDIO_PLACEHOLDER_TOKEN = 128002 _AUDIO_TOKENS_PER_SECOND = 6.25 @@ -323,7 +324,7 @@ def forward( "audio", get_ultravox_max_audio_tokens) @INPUT_REGISTRY.register_dummy_data(dummy_data_for_ultravox) @INPUT_REGISTRY.register_input_processor(input_processor_for_ultravox) -class UltravoxModel(nn.Module, SupportsMultiModal): +class UltravoxModel(nn.Module, SupportsMultiModal, SupportsPP): def __init__(self, config: UltravoxConfig, @@ -353,6 +354,16 @@ def __init__(self, revision=None, prefix="language_model.")) + self.make_empty_intermediate_tensors = ( + self.language_model.make_empty_intermediate_tensors) + + @cached_property + def sampler(self): + if hasattr(self.language_model, "sampler"): + return self.language_model.sampler + + return Sampler() + def _audio_features_to_embeddings( self, input_features: torch.Tensor) -> torch.Tensor: audio_input = input_features.to(self.audio_tower.dtype) @@ -425,7 +436,7 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[torch.Tensor], - **kwargs) -> SamplerOutput: + **kwargs) -> Union[torch.Tensor, IntermediateTensors]: """Run forward pass for Ultravox One key thing to understand is the `input_ids` already accounts for the @@ -438,18 +449,22 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor, Args: audio_features: A batch of audio inputs [B, N, 80, M]. """ - audio_input = self._parse_and_validate_audio_input(**kwargs) - if audio_input is not None: - audio_embeddings = self._process_audio_input(audio_input) - inputs_embeds = self.language_model.model.get_input_embeddings( - input_ids) - - inputs_embeds = merge_multimodal_embeddings( - input_ids, inputs_embeds, audio_embeddings, - _AUDIO_PLACEHOLDER_TOKEN) + if intermediate_tensors is not None: input_ids = None - else: inputs_embeds = None + else: + audio_input = self._parse_and_validate_audio_input(**kwargs) + if audio_input is not None: + audio_embeddings = self._process_audio_input(audio_input) + inputs_embeds = self.language_model.model.get_input_embeddings( + input_ids) + + inputs_embeds = merge_multimodal_embeddings( + input_ids, inputs_embeds, audio_embeddings, + _AUDIO_PLACEHOLDER_TOKEN) + input_ids = None + else: + inputs_embeds = None hidden_states = self.language_model.model( input_ids=input_ids, diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py index f6218bad4ef1e..761f0406b1333 100644 --- a/vllm/model_executor/models/utils.py +++ b/vllm/model_executor/models/utils.py @@ -24,7 +24,7 @@ class WeightsGroup(UserDict): when attempting to access a weight component that does not exist. """ - def __getitem__(self, key: str) -> int: + def __getitem__(self, key: str) -> Iterable[Tuple[str, torch.Tensor]]: try: return super().__getitem__(key) except KeyError as exc: @@ -49,8 +49,7 @@ def filter_weights(weights: Iterable[Tuple[str, torch.Tensor]], def group_weights_with_prefix( - weights: Iterable[Tuple[str, torch.Tensor]] -) -> Dict[str, Iterable[Tuple[str, torch.Tensor]]]: + weights: Iterable[Tuple[str, torch.Tensor]], ) -> WeightsGroup: """ Helper function to group weights with prefix """ @@ -183,10 +182,7 @@ def merge_multimodal_embeddings(input_ids: torch.Tensor, class LayerFn(Protocol): - def __call__( - self, - prefix="", - ) -> torch.nn.Module: + def __call__(self, prefix: str) -> torch.nn.Module: ... @@ -319,8 +315,10 @@ def is_pp_missing_parameter(name: str, model: torch.nn.Module) -> bool: def make_empty_intermediate_tensors_factory(keys: List[str], hidden_size: int): def make_empty_intermediate_tensors( - batch_size: int, dtype: torch.dtype, - device: torch.device) -> IntermediateTensors: + batch_size: int, + dtype: torch.dtype, + device: torch.device, + ) -> IntermediateTensors: return IntermediateTensors({ key: torch.zeros((batch_size, hidden_size), dtype=dtype, @@ -342,8 +340,14 @@ def __init__(self, llm: nn.Module, name: str) -> None: self.model_name = name setattr(self, name, llm) - def forward(self, *args, **kwargs) -> Any: - return getattr(self, self.model_name)(*args, **kwargs) + def __getattr__(self, key: str): + llm = super().__getattr__(self.model_name) + if key == self.model_name: + return llm - def embed_tokens(self, *args, **kwargs) -> Any: - return getattr(self, self.model_name).embed_tokens(*args, **kwargs) + return getattr(llm, key) + + # We need to explicitly override this + def __call__(self, *args: Any, **kwargs: Any) -> Any: + llm = super().__getattr__(self.model_name) + return llm(*args, **kwargs) diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 24cc3728f85e4..3bded82033c08 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Xverse model compatible with HuggingFace weights.""" -from typing import Any, Dict, Iterable, List, Optional, Tuple +from typing import Any, Dict, Iterable, List, Optional, Tuple, Union import torch from torch import nn @@ -28,15 +28,14 @@ from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import get_tensor_model_parallel_world_size +from vllm.distributed import get_pp_group, get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler, SamplerOutput from vllm.model_executor.layers.vocab_parallel_embedding import ( @@ -45,7 +44,9 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from .interfaces import SupportsLoRA +from .interfaces import SupportsLoRA, SupportsPP +from .utils import (is_pp_missing_parameter, + make_empty_intermediate_tensors_factory, make_layers) class XverseMLP(nn.Module): @@ -227,6 +228,7 @@ def __init__( cache_config: Optional[CacheConfig] = None, quant_config: Optional[QuantizationConfig] = None, lora_config: Optional[LoRAConfig] = None, + prefix: str = "", ) -> None: super().__init__() self.config = config @@ -240,11 +242,16 @@ def __init__( config.hidden_size, org_num_embeddings=config.vocab_size, ) - self.layers = nn.ModuleList([ - XverseDecoderLayer(config, cache_config, quant_config) - for _ in range(config.num_hidden_layers) - ]) + self.start_layer, self.end_layer, self.layers = make_layers( + config.num_hidden_layers, + lambda prefix: XverseDecoderLayer(config, cache_config, + quant_config), + prefix=f"{prefix}.layers", + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.make_empty_intermediate_tensors = ( + make_empty_intermediate_tensors_factory( + ["hidden_states", "residual"], config.hidden_size)) def forward( self, @@ -252,23 +259,32 @@ def forward( positions: torch.Tensor, kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, - ) -> torch.Tensor: - hidden_states = self.embed_tokens(input_ids) - residual = None - for i in range(len(self.layers)): + intermediate_tensors: Optional[IntermediateTensors], + ) -> Union[torch.Tensor, IntermediateTensors]: + if get_pp_group().is_first_rank: + hidden_states = self.embed_tokens(input_ids) + residual = None + else: + hidden_states = intermediate_tensors["hidden_states"] + for i in range(self.start_layer, self.end_layer): layer = self.layers[i] hidden_states, residual = layer( positions, hidden_states, - kv_caches[i], + kv_caches[i - self.start_layer], attn_metadata, residual, ) + if not get_pp_group().is_last_rank: + return IntermediateTensors({ + "hidden_states": hidden_states, + "residual": residual + }) hidden_states, _ = self.norm(hidden_states, residual) return hidden_states -class XverseForCausalLM(nn.Module, SupportsLoRA): +class XverseForCausalLM(nn.Module, SupportsLoRA, SupportsPP): packed_modules_mapping = { "qkv_proj": [ "q_proj", @@ -317,6 +333,8 @@ def __init__( self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.sampler = Sampler() + self.make_empty_intermediate_tensors = ( + self.model.make_empty_intermediate_tensors) def forward( self, @@ -325,9 +343,9 @@ def forward( kv_caches: List[torch.Tensor], attn_metadata: AttentionMetadata, intermediate_tensors: Optional[IntermediateTensors] = None, - ) -> torch.Tensor: + ) -> Union[torch.Tensor, IntermediateTensors]: hidden_states = self.model(input_ids, positions, kv_caches, - attn_metadata) + attn_metadata, intermediate_tensors) return hidden_states def compute_logits( @@ -368,6 +386,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) @@ -376,6 +396,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: continue + if is_pp_missing_parameter(name, self): + continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) diff --git a/vllm/model_executor/sampling_metadata.py b/vllm/model_executor/sampling_metadata.py index 97d36d31f2b11..84f35f75a0c32 100644 --- a/vllm/model_executor/sampling_metadata.py +++ b/vllm/model_executor/sampling_metadata.py @@ -146,7 +146,7 @@ def __init__( def prepare( seq_group_metadata_list: List[SequenceGroupMetadata], seq_lens: List[int], - query_lens: Optional[List[int]], + query_lens: List[int], device: str, pin_memory: bool, generators: Optional[Dict[str, torch.Generator]] = None, @@ -194,7 +194,7 @@ def __repr__(self) -> str: def _prepare_seq_groups( seq_group_metadata_list: List[SequenceGroupMetadata], seq_lens: List[int], - query_lens: Optional[List[int]], + query_lens: List[int], device: str, generators: Optional[Dict[str, torch.Generator]] = None, cache: Optional[SamplingMetadataCache] = None, @@ -284,7 +284,9 @@ def _prepare_seq_groups( else: # Decode prompt_logprob_len = 0 - sample_len = len(seq_ids) if do_sample else 0 + query_len = query_lens[i] if query_lens is not None and len( + query_lens) > 0 else 1 + sample_len = len(seq_ids) * query_len if do_sample else 0 if sampling_params.seed is not None and generators is not None: generator = generators.get(seq_group_metadata.request_id) @@ -440,14 +442,14 @@ def from_sampling_metadata( if seq_group.do_sample: sample_lens = len(seq_group.sample_indices) - assert sample_lens == len(seq_ids) - temperatures += [temperature] * len(seq_ids) - top_ps += [top_p] * len(seq_ids) - top_ks += [top_k] * len(seq_ids) - min_ps += [min_p] * len(seq_ids) - presence_penalties += [p] * len(seq_ids) - frequency_penalties += [f] * len(seq_ids) - repetition_penalties += [r] * len(seq_ids) + assert sample_lens >= len(seq_ids) + temperatures += [temperature] * sample_lens + top_ps += [top_p] * sample_lens + top_ks += [top_k] * sample_lens + min_ps += [min_p] * sample_lens + presence_penalties += [p] * sample_lens + frequency_penalties += [f] * sample_lens + repetition_penalties += [r] * sample_lens if do_penalties: for seq_group in sampling_metadata.seq_groups: diff --git a/vllm/multimodal/image.py b/vllm/multimodal/image.py index d3a230e40477e..7ca64152e481a 100644 --- a/vllm/multimodal/image.py +++ b/vllm/multimodal/image.py @@ -56,7 +56,12 @@ def _default_input_mapper( .preprocess(data, return_tensors="pt") \ .data except Exception: - logger.error("Failed to process image (%s)", data) + logger.error( + "Failed to process image (%s) with the default mapper. " + "This is most likely an edge-case with this model's image " + "processor in transformers (type: %s), and not vLLM.", + data, + type(image_processor).__name__) raise return MultiModalInputs(batch_data) diff --git a/vllm/spec_decode/batch_expansion.py b/vllm/spec_decode/batch_expansion.py index 9eb8bbfc54076..59e71cc8deb48 100644 --- a/vllm/spec_decode/batch_expansion.py +++ b/vllm/spec_decode/batch_expansion.py @@ -12,7 +12,6 @@ from vllm.spec_decode.interfaces import (SpeculativeProposals, SpeculativeScorer, SpeculativeScores) from vllm.spec_decode.util import nvtx_range, split_batch_by_proposal_len -from vllm.worker.worker_base import WorkerBase SeqId = int TargetSeqId = int @@ -36,12 +35,6 @@ class BatchExpansionTop1Scorer(SpeculativeScorer): of topk/tree. """ - def __init__(self, scorer_worker: WorkerBase, device: str, - vocab_size: int): - self._scorer_worker = scorer_worker - self._device = device - self._vocab_size = vocab_size - @nvtx_range("BatchExpansionTop1Scorer.score_proposals") def score_proposals( self, diff --git a/vllm/spec_decode/draft_model_runner.py b/vllm/spec_decode/draft_model_runner.py index cf64af72a14a5..984747c53c6c0 100644 --- a/vllm/spec_decode/draft_model_runner.py +++ b/vllm/spec_decode/draft_model_runner.py @@ -2,6 +2,7 @@ import torch +from vllm.forward_context import set_forward_context from vllm.model_executor.layers.sampler import SamplerOutput try: @@ -94,8 +95,6 @@ def _update_sampling_metadata(self, sampling_metadata, num_seqs, assert seq_group.is_prompt is False # No prompt assert seq_group.prompt_logprob_indices == [] # No prompt assert seq_group.sample_indices == [i] # Simple - assert seq_group.seq_len is None # Decode - assert seq_group.query_len is None # Decode def _gpu_advance_step( self, model_input: ModelInputForGPUWithSamplingMetadata, @@ -293,16 +292,17 @@ def execute_model( if previous_hidden_states is not None else {} # Run model - hidden_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **kwargs, - ) + with set_forward_context(model_input.attn_metadata): + hidden_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **kwargs, + ) # Compute the logits. logits = self.model.compute_logits(hidden_states, diff --git a/vllm/spec_decode/interfaces.py b/vllm/spec_decode/interfaces.py index 11ab09f10c1f5..029f56460f5c1 100644 --- a/vllm/spec_decode/interfaces.py +++ b/vllm/spec_decode/interfaces.py @@ -5,6 +5,7 @@ import torch from vllm.sequence import ExecuteModelRequest +from vllm.worker.worker_base import WorkerBase @dataclass @@ -74,6 +75,12 @@ def get_spec_proposals( class SpeculativeScorer(ABC): + def __init__(self, scorer_worker: WorkerBase, device: str, + vocab_size: int): + self._scorer_worker = scorer_worker + self._device = device + self._vocab_size = vocab_size + @abstractmethod def score_proposals( self, diff --git a/vllm/spec_decode/mqa_scorer.py b/vllm/spec_decode/mqa_scorer.py new file mode 100644 index 0000000000000..59f2a4191a8b2 --- /dev/null +++ b/vllm/spec_decode/mqa_scorer.py @@ -0,0 +1,80 @@ +from vllm.sequence import (ExecuteModelRequest, SequenceData, + SequenceGroupMetadata, get_all_seq_ids) +from vllm.spec_decode.interfaces import (SpeculativeProposals, + SpeculativeScorer, SpeculativeScores) + +SeqId = int +TargetSeqId = int + + +class MQAScorer(SpeculativeScorer): + + def score_proposals( + self, + execute_model_req: ExecuteModelRequest, + proposals: SpeculativeProposals, + ) -> SpeculativeScores: + target_seq_group_metadata_list = [] + target_seq_id_start = max( + get_all_seq_ids(execute_model_req.seq_group_metadata_list)) + 1 + all_proposal_tokens = proposals.proposal_token_ids.tolist() + for i, seq_group_metadata in enumerate( + execute_model_req.seq_group_metadata_list): + seq_data_dict = seq_group_metadata.seq_data + assert len(seq_data_dict) == 1 + seq_id = next(iter(seq_data_dict.keys())) + + seq_data: SequenceData = seq_data_dict[seq_id] + prompt_token_ids = seq_data.get_prompt_token_ids() + output_token_ids = seq_data.get_output_token_ids() + proposal_token_ids = all_proposal_tokens[i] + new_output_token_ids = [*output_token_ids, *proposal_token_ids] + + target_seq_id = target_seq_id_start + i + new_seq_data = SequenceData.from_seqs( + prompt_token_ids=prompt_token_ids, + output_token_ids=new_output_token_ids, + ) + new_seq_data.update_num_computed_tokens( + len(prompt_token_ids) + len(output_token_ids) - 1) + + # Ensure that the new sequence has at least one token + # because we only use mqa scorer in the decoding stage. + assert len(output_token_ids) >= 1 + new_seq_data_dict = {target_seq_id: new_seq_data} + + new_seq_group_metadata = SequenceGroupMetadata( + request_id=seq_group_metadata.request_id, + is_prompt=seq_group_metadata.is_prompt, + seq_data=new_seq_data_dict, + sampling_params=seq_group_metadata.sampling_params, + block_tables={ + target_seq_id: seq_group_metadata.block_tables[seq_id], + }, + lora_request=None, + token_chunk_size=1, + ) + target_seq_group_metadata_list.append(new_seq_group_metadata) + + target_sampler_output = self._scorer_worker.execute_model( + execute_model_req=execute_model_req.clone( + seq_group_metadata_list=target_seq_group_metadata_list)) + + target_sampler_output = target_sampler_output[0] + + bs, k = proposals.proposal_token_ids.shape + all_tokens = target_sampler_output.sampled_token_ids.reshape(bs, k + 1) + + all_probs = target_sampler_output.sampled_token_probs.reshape( + bs, k + 1, self._vocab_size) + all_logprobs = target_sampler_output.logprobs.reshape( + bs, k + 1, self._vocab_size) + + hidden_states = None + if target_sampler_output.hidden_states is not None: + hidden_states = target_sampler_output.hidden_states.reshape( + bs, (k + 1), -1) + return SpeculativeScores(probs=all_probs, + token_ids=all_tokens, + logprobs=all_logprobs, + hidden_states=hidden_states) diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index dbf880a8f475c..a67715290a515 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -1,6 +1,6 @@ from collections import defaultdict from functools import cached_property -from typing import Any, Dict, List, Optional, Set, Tuple +from typing import Any, Dict, List, Optional, Set, Tuple, Type import torch @@ -24,6 +24,7 @@ from vllm.spec_decode.medusa_worker import MedusaWorker from vllm.spec_decode.metrics import AsyncMetricsCollector from vllm.spec_decode.mlp_speculator_worker import MLPSpeculatorWorker +from vllm.spec_decode.mqa_scorer import MQAScorer from vllm.spec_decode.multi_step_worker import MultiStepWorker from vllm.spec_decode.ngram_worker import NGramWorker from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase @@ -70,6 +71,7 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": spec_decode_worker = SpecDecodeWorker.create_worker( scorer_worker=target_worker, draft_worker_kwargs=draft_worker_kwargs, + disable_mqa_scorer=speculative_config.speculative_disable_mqa_scorer, disable_by_batch_size=speculative_config. speculative_disable_by_batch_size, draft_token_acceptance_method=speculative_config. @@ -116,6 +118,7 @@ def create_worker( cls, scorer_worker: Worker, draft_worker_kwargs: Dict[str, Any], + disable_mqa_scorer: bool, disable_by_batch_size: Optional[int], draft_token_acceptance_method: str, typical_acceptance_sampler_posterior_threshold: float, @@ -173,12 +176,43 @@ def create_worker( typical_acceptance_sampler_posterior_threshold, posterior_alpha=typical_acceptance_sampler_posterior_alpha, ) - logger.info("Configuring SpecDecodeWorker with sampler=%s", - type(spec_decode_sampler)) + logger.info( + "[Speculative Decoding] Configuring" + " SpecDecodeWorker with sampler=%s", type(spec_decode_sampler)) + + if not disable_mqa_scorer: + if scorer_worker.model_runner.attn_backend.get_name( + ) != "flash-attn": + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "MQA is only available with flash attn backend.") + + if ngram_prompt_lookup_max > 0: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "NGramWorker does not support MQA scorer.") + + if "model_config" in draft_worker_kwargs and \ + draft_worker_kwargs["model_config"].max_model_len < \ + scorer_worker.model_config.max_model_len: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "draft model max_model_len is smaller than the target " + "model max_model_len.") + + if not scorer_worker.model_runner.model_config.enforce_eager: + disable_mqa_scorer = True + logger.info( + "[Speculative Decoding] Disabling MQA scorer as the " + "target model is not running in eager mode.") return SpecDecodeWorker( proposer_worker, scorer_worker, + disable_mqa_scorer=disable_mqa_scorer, disable_logprobs=disable_logprobs, disable_log_stats=disable_log_stats, disable_by_batch_size=disable_by_batch_size, @@ -190,6 +224,7 @@ def __init__( proposer_worker: ProposerWorkerBase, scorer_worker: WorkerBase, spec_decode_sampler: SpecDecodeBaseSampler, + disable_mqa_scorer: bool = False, disable_logprobs: bool = False, disable_log_stats: bool = False, metrics_collector: Optional[AsyncMetricsCollector] = None, @@ -211,6 +246,8 @@ def __init__( types of sampler namely RejectionSampler and TypicalAcceptanceSampler. 'spec_decode_sampler' is either an instance of RejectionSampler or TypicalAcceptanceSampler. + disable_mqa_scorer: If set to True, disable the MQA scorer and use + the BatchExpansionTop1Scorer instead. disable_logprobs: If set to True, token log probabilities will not be output in both the draft worker and the target worker. If set to False, log probabilities will be output by both. @@ -248,6 +285,7 @@ def __init__( self.token_id_dtype = self.spec_decode_sampler.token_id_dtype # Lazy initialization. self.scorer: SpeculativeScorer + self.disable_mqa_scorer = disable_mqa_scorer # Hidden states from target model to pass to proposer # in the subsequent step. @@ -270,10 +308,19 @@ def init_device(self) -> None: self._metrics.init_gpu_tensors(self.rank) self.spec_decode_sampler.init_gpu_tensors(self.rank) - self.scorer = BatchExpansionTop1Scorer( - scorer_worker=self.scorer_worker, - device=self.device, - vocab_size=self._vocab_size) + scorer_cls: Type[SpeculativeScorer] + if self.disable_mqa_scorer: + scorer_cls = BatchExpansionTop1Scorer + logger.info("[Speculative Decoding] Use batch " + "expansion for scoring proposals.") + else: + scorer_cls = MQAScorer + logger.info( + "[Speculative Decoding] Use MQA scorer for scoring proposals.") + + self.scorer = scorer_cls(scorer_worker=self.scorer_worker, + device=self.device, + vocab_size=self._vocab_size) self._configure_model_sampler_for_spec_decode() diff --git a/vllm/worker/embedding_model_runner.py b/vllm/worker/embedding_model_runner.py index 1ccf10f1a60da..1fd37eac6b851 100644 --- a/vllm/worker/embedding_model_runner.py +++ b/vllm/worker/embedding_model_runner.py @@ -6,6 +6,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) +from vllm.forward_context import set_forward_context from vllm.logger import init_logger from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.multimodal import MultiModalInputs @@ -119,7 +120,8 @@ def execute_model( device=self.device), } - hidden_states = model_executable(**execute_model_kwargs) + with set_forward_context(model_input.attn_metadata): + hidden_states = model_executable(**execute_model_kwargs) # Only perform pooling in the driver worker. if not self.is_driver_worker: diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index 90dfad62e0286..59b4b8c4ddf38 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -14,6 +14,7 @@ from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, ObservabilityConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig) +from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.model_executor import SamplingMetadata @@ -198,17 +199,18 @@ def execute_model( } if self.has_seqlen_agnostic else {} multi_modal_kwargs = model_input.multi_modal_kwargs or {} - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - encoder_input_ids=model_input.encoder_input_tokens, - encoder_positions=model_input.encoder_input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + with set_forward_context(model_input.attn_metadata): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + encoder_input_ids=model_input.encoder_input_tokens, + encoder_positions=model_input.encoder_input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) logits = self.model.compute_logits(hidden_or_intermediate_states, model_input.sampling_metadata) diff --git a/vllm/worker/habana_model_runner.py b/vllm/worker/hpu_model_runner.py similarity index 99% rename from vllm/worker/habana_model_runner.py rename to vllm/worker/hpu_model_runner.py index 2d72be5690664..b1b62e6bde7f6 100644 --- a/vllm/worker/habana_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -489,7 +489,7 @@ def from_broadcasted_tensor_dict( return cls(**tensor_dict) -class HabanaModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): +class HPUModelRunnerBase(ModelRunnerBase[TModelInputForHPU]): """ Helper class for shared methods between GPU model runners. """ @@ -1730,8 +1730,7 @@ def unwrap_model(model): return modules -class HabanaModelRunner( - HabanaModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): +class HPUModelRunner(HPUModelRunnerBase[ModelInputForHPUWithSamplingMetadata]): """ GPU model runner with sampling step. """ @@ -1872,7 +1871,7 @@ def execute_model( ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: if num_steps > 1: raise ValueError( - "num_steps > 1 is not supported in HabanaModelRunner") + "num_steps > 1 is not supported in HPUModelRunner") if self.lora_config: assert model_input.lora_requests is not None diff --git a/vllm/worker/habana_worker.py b/vllm/worker/hpu_worker.py similarity index 99% rename from vllm/worker/habana_worker.py rename to vllm/worker/hpu_worker.py index 7fc1e48b8c960..59a5adf65ebc1 100644 --- a/vllm/worker/habana_worker.py +++ b/vllm/worker/hpu_worker.py @@ -25,14 +25,14 @@ from vllm.sequence import ExecuteModelRequest from vllm.utils import hpu_backend_string, hpu_device_string, is_fake_hpu from vllm.worker.cache_engine import CacheEngine -from vllm.worker.habana_model_runner import HabanaModelRunner +from vllm.worker.hpu_model_runner import HPUModelRunner from vllm.worker.model_runner_base import ModelRunnerBase from vllm.worker.worker_base import LocalOrDistributedWorkerBase, WorkerInput logger = init_logger(__name__) -class HabanaWorker(LocalOrDistributedWorkerBase): +class HPUWorker(LocalOrDistributedWorkerBase): """A worker class that executes (a partition of) the model on a HPU. Each worker is associated with a single HPU. The worker is responsible for @@ -79,7 +79,7 @@ def __init__( from vllm.utils import init_cached_hf_modules init_cached_hf_modules() - self.model_runner: HabanaModelRunner = HabanaModelRunner( + self.model_runner: HPUModelRunner = HPUModelRunner( model_config, parallel_config, scheduler_config, diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index 40c0f5d0d99dc..9784438841980 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -24,6 +24,7 @@ from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_pp_group from vllm.distributed.parallel_state import graph_capture +from vllm.forward_context import set_forward_context from vllm.inputs import INPUT_REGISTRY, InputRegistry from vllm.logger import init_logger from vllm.lora.layers import LoRAMapping @@ -34,8 +35,7 @@ from vllm.model_executor.layers.sampler import SamplerOutput from vllm.model_executor.model_loader import get_model from vllm.model_executor.model_loader.tensorizer import TensorizerConfig -from vllm.model_executor.models.interfaces import (supports_lora, - supports_multimodal) +from vllm.model_executor.models import supports_lora, supports_multimodal from vllm.model_executor.models.utils import set_cpu_offload_max_bytes from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs, MultiModalRegistry) @@ -468,43 +468,26 @@ def _compute_lens(self, inter_data: InterDataForSeqGroup, seq_idx: int, # Compute context length (the number of tokens that are # already computed) and sequence length (total number of tokens). + seq_len = seq_data.get_len() if inter_data.is_prompt: context_len = seq_data.get_num_computed_tokens() - else: - # get_num_computed_tokens is incorrect for spec decoding. - # So, we should have a special logic here. - # TODO(sang): Fix it. + seq_len = min(seq_len, context_len + token_chunk_size) + elif self.runner.scheduler_config.is_multi_step or \ + self.runner.model_config.is_encoder_decoder_model: context_len = seq_len - 1 - seq_len = min(seq_len, context_len + token_chunk_size) + else: + context_len = seq_data.get_num_computed_tokens() # Compute tokens. - if inter_data.is_prompt: - tokens = seq_data.get_token_ids() - if context_len != 0 or seq_len < len(tokens): - tokens = tokens[context_len:seq_len] - else: - # Optimization. get_token_ids requires the entire copy of - # tokens. - tokens = seq_data.get_last_token_id() + tokens = seq_data.get_token_ids()[context_len:seq_len] inter_data.seq_lens[seq_idx] = seq_len inter_data.orig_seq_lens[seq_idx] = seq_len inter_data.context_lens[seq_idx] = context_len - - if isinstance(tokens, list): - inter_data.input_tokens[seq_idx].extend(tokens) - else: - inter_data.input_tokens[seq_idx].append(tokens) - - if (seq_len - context_len) == 1: - inter_data.input_positions[seq_idx].append(seq_len - 1) - else: - inter_data.input_positions[seq_idx].extend( - range(context_len, seq_len)) - - inter_data.query_lens[ - seq_idx] = seq_len - context_len if inter_data.is_prompt else 1 + inter_data.input_tokens[seq_idx].extend(tokens) + inter_data.input_positions[seq_idx].extend(range(context_len, seq_len)) + inter_data.query_lens[seq_idx] = seq_len - context_len if seq_data.mrope_position_delta is not None: if inter_data.mrope_input_positions is None: @@ -729,14 +712,62 @@ def add_seq_group(self, seq_group_metadata: SequenceGroupMetadata): def _use_captured_graph(self, batch_size: int, + decode_only: bool, max_decode_seq_len: int, max_encoder_seq_len: int = 0) -> bool: - return (self.decode_only and not self.runner.model_config.enforce_eager + return (decode_only and not self.runner.model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_decode_seq_len <= self.runner.max_seq_len_to_capture and max_encoder_seq_len <= self.runner.max_seq_len_to_capture and batch_size <= self.runner.max_batchsize_to_capture) + def _get_cuda_graph_pad_size(self, + num_seqs: int, + max_decode_seq_len: int, + max_encoder_seq_len: int = 0) -> int: + """ + Determine the number of padding sequences required for running in + CUDA graph mode. Returns -1 if CUDA graphs cannot be used. + + In the multi-step + chunked-prefill case, only the first step + has Prefills (if any). The rest of the steps are guaranteed to be all + decodes. In this case, we set up the padding as if all the sequences + are decodes so we may run all steps except the first step in CUDA graph + mode. The padding is accounted for in the multi-step `advance_step` + family of functions. + + Args: + num_seqs (int): Number of sequences scheduled to run. + max_decode_seq_len (int): Greatest of all the decode sequence + lengths. Used only in checking the viablility of using + CUDA graphs. + max_encoder_seq_len (int, optional): Greatest of all the encode + sequence lengths. Defaults to 0. Used only in checking the + viability of using CUDA graphs. + Returns: + int: Returns the determined number of padding sequences. If + CUDA graphs is not viable, returns -1. + """ + is_mscp: bool = self.runner.scheduler_config.is_multi_step and \ + self.runner.scheduler_config.chunked_prefill_enabled + decode_only = self.decode_only or is_mscp + if not decode_only: + # Early exit so we can treat num_seqs as the batch_size below. + return -1 + + # batch_size out of this function refers to the number of input + # tokens being scheduled. This conflation of num_seqs as batch_size + # is valid as this is a decode-only case. + batch_size = num_seqs + if not self._use_captured_graph(batch_size, decode_only, + max_decode_seq_len, + max_encoder_seq_len): + return -1 + + graph_batch_size = _get_graph_batch_size(batch_size) + assert graph_batch_size >= batch_size + return graph_batch_size - batch_size + def build(self) -> ModelInputForGPU: """Finalize the builder intermediate data and create on-device tensors. @@ -795,21 +826,17 @@ def build(self) -> ModelInputForGPU: for data in self.inter_data_list } - batch_size = len(input_tokens) - use_captured_graph = self._use_captured_graph( - batch_size, - max_decode_seq_len, + cuda_graph_pad_size = self._get_cuda_graph_pad_size( + num_seqs=len(seq_lens), + max_decode_seq_len=max_encoder_seq_len, max_encoder_seq_len=max_encoder_seq_len) - # If cuda graph can be used, pad tensors accordingly. - # See `capture_model` API for more details. - # vLLM uses cuda graph only for decoding requests. - cuda_graph_pad_size = -1 - if use_captured_graph: - graph_batch_size = _get_graph_batch_size(batch_size) - assert graph_batch_size >= batch_size - cuda_graph_pad_size = graph_batch_size - batch_size - batch_size = graph_batch_size + batch_size = len(input_tokens) + if cuda_graph_pad_size != -1: + # If cuda graph can be used, pad tensors accordingly. + # See `capture_model` API for more details. + # vLLM uses cuda graph only for decoding requests. + batch_size += cuda_graph_pad_size # Tokens and positions. if cuda_graph_pad_size: @@ -1472,7 +1499,8 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: self._update_inputs_to_capture_for_enc_dec_model( capture_inputs) - graph_runner.capture(**capture_inputs) + with set_forward_context(attn_metadata): + graph_runner.capture(**capture_inputs) self.graph_memory_pool = graph_runner.graph.pool() self.graph_runners[virtual_engine][batch_size] = ( graph_runner) @@ -1614,15 +1642,16 @@ def execute_model( model_forward_end = torch.cuda.Event(enable_timing=True) model_forward_start.record() - hidden_or_intermediate_states = model_executable( - input_ids=model_input.input_tokens, - positions=model_input.input_positions, - kv_caches=kv_caches, - attn_metadata=model_input.attn_metadata, - intermediate_tensors=intermediate_tensors, - **MultiModalInputs.as_kwargs(multi_modal_kwargs, - device=self.device), - **seqlen_agnostic_kwargs) + with set_forward_context(model_input.attn_metadata): + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) if (self.observability_config is not None and self.observability_config.collect_model_forward_time): diff --git a/vllm/worker/openvino_model_runner.py b/vllm/worker/openvino_model_runner.py index f335e4e32efd4..77ee2eadf29a2 100644 --- a/vllm/worker/openvino_model_runner.py +++ b/vllm/worker/openvino_model_runner.py @@ -42,6 +42,7 @@ class OpenVINOModelRunner: def __init__( self, + ov_core: ov.Core, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, @@ -55,6 +56,7 @@ def __init__( *args, **kwargs, ): + self.ov_core = ov_core self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config @@ -89,11 +91,10 @@ def __init__( self.model: nn.Module # Set after init_Model def load_model(self) -> None: - self.model = get_model( - model_config=self.model_config, - device_config=self.device_config, - kv_cache_dtype=self.kv_cache_dtype, - ) + self.model = get_model(model_config=self.model_config, + device_config=self.device_config, + kv_cache_dtype=self.kv_cache_dtype, + ov_core=self.ov_core) def _prepare_model_input( self, diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index 36339e175d7bb..6b818186779b6 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -5,6 +5,7 @@ import torch import torch.distributed +import vllm.envs as envs from vllm.attention import get_attn_backend from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig, ModelConfig, MultiModalConfig, ParallelConfig, @@ -12,10 +13,14 @@ from vllm.distributed import (broadcast_tensor_dict, ensure_model_parallel_initialized, init_distributed_environment) +from vllm.executor.openvino_executor import is_openvino_cpu +from vllm.inputs import INPUT_REGISTRY from vllm.logger import init_logger from vllm.model_executor import set_random_seed from vllm.model_executor.layers.sampler import SamplerOutput -from vllm.sequence import ExecuteModelRequest +from vllm.multimodal import MULTIMODAL_REGISTRY +from vllm.sampling_params import SamplingParams +from vllm.sequence import ExecuteModelRequest, SequenceGroupMetadata from vllm.worker.openvino_model_runner import OpenVINOModelRunner from vllm.worker.worker_base import LoraNotSupportedWorkerBase @@ -36,6 +41,8 @@ def __init__( model_config: ModelConfig, parallel_config: ParallelConfig, device_config: DeviceConfig, + ov_core: ov.Core, + ov_device: str, ) -> None: assert device_config.device_type == "openvino" self.cache_config = cache_config @@ -56,9 +63,10 @@ def __init__( self.block_size = cache_config.block_size # Note: In CacheConfig, num_gpu_blocks actual is num_cpu_blocks - # for OpenVINO backend, because we want to reuse KV cache management - # in the scheduler. - self.num_cpu_blocks = cache_config.num_gpu_blocks + # for OpenVINO backend with a CPU target device, because we want + # to reuse KV cache management in the scheduler. + self.num_device_blocks = cache_config.num_gpu_blocks + self.num_swap_blocks = cache_config.num_cpu_blocks # Get attention backend. self.attn_backend = get_attn_backend( @@ -74,34 +82,100 @@ def __init__( # Initialize the cache. self.kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = self._allocate_kv_cache( - self.num_cpu_blocks) + self.num_device_blocks, ov_core, + ov_device) + + # Initialize the swap. + self.swap_cache: List[Tuple[ov.Tensor, + ov.Tensor]] = self._allocate_swap_cache( + self.num_swap_blocks, ov_device) def _allocate_kv_cache( self, num_blocks: int, + ov_core: ov.Core, + ov_device: str, ) -> List[Tuple[ov.Tensor, ov.Tensor]]: """Allocates KV cache.""" k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape( num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:] kv_cache: List[Tuple[ov.Tensor, ov.Tensor]] = [] + + if is_openvino_cpu(): + for _ in range(self.num_layers): + key_blocks = ov.Tensor(self.cache_config.cache_dtype, + k_block_shape) + value_blocks = ov.Tensor(self.cache_config.cache_dtype, + v_block_shape) + kv_cache.append((key_blocks, value_blocks)) + else: + # Update key_cache shape: + k_block_shape = (v_block_shape[0], v_block_shape[1], + v_block_shape[3], v_block_shape[2]) + + remote_context = ov_core.get_default_context(ov_device) + + for _ in range(self.num_layers): + key_blocks = \ + remote_context.create_tensor(self.cache_config.cache_dtype, + ov.Shape(k_block_shape), + {}) + + value_blocks = \ + remote_context.create_tensor(self.cache_config.cache_dtype, + ov.Shape(v_block_shape), + {}) + + kv_cache.append((key_blocks, value_blocks)) + + return kv_cache + + def _allocate_swap_cache( + self, + num_blocks: int, + ov_device: str, + ) -> List[Tuple[ov.Tensor, ov.Tensor]]: + """Allocates swap cache.""" + k_block_shape = v_block_shape = self.attn_backend.get_kv_cache_shape( + num_blocks, self.block_size, self.num_kv_heads, self.head_size)[1:] + swap_cache: List[Tuple[ov.Tensor, ov.Tensor]] = [] + + if num_blocks == 0: + return swap_cache + + assert not is_openvino_cpu(), \ + "CPU device isn't supposed to have swap cache" + + # Update key_cache shape: + k_block_shape = (v_block_shape[0], v_block_shape[1], v_block_shape[3], + v_block_shape[2]) + for _ in range(self.num_layers): key_blocks = ov.Tensor(self.cache_config.cache_dtype, k_block_shape) value_blocks = ov.Tensor(self.cache_config.cache_dtype, v_block_shape) - kv_cache.append((key_blocks, value_blocks)) - return kv_cache + swap_cache.append((key_blocks, value_blocks)) + + return swap_cache - def swap_in(self, src_to_dst: Dict[int, int]) -> None: - raise NotImplementedError( - "Swap is not supported in OpenVINOCacheEngine.") + def swap_in(self, src_to_dst: List[Tuple[int, int]]) -> None: + for i in range(self.num_layers): + for swap_tensor, kv_tensor in zip(self.swap_cache[i], + self.kv_cache[i]): + self.attn_backend.swap_blocks(swap_tensor, kv_tensor, + src_to_dst) - def swap_out(self, src_to_dst: Dict[int, int]) -> None: - raise NotImplementedError( - "Swap is not supported in OpenVINOCacheEngine.") + def swap_out(self, src_to_dst: List[Tuple[int, int]]) -> None: + for i in range(self.num_layers): + for swap_tensor, kv_tensor in zip(self.swap_cache[i], + self.kv_cache[i]): + self.attn_backend.swap_blocks(kv_tensor, swap_tensor, + src_to_dst) - def copy(self, src_to_dsts: Dict[int, List[int]]) -> None: - self.attn_backend.copy_blocks(self.kv_cache, src_to_dsts) + def copy(self, src_to_dsts: List[Tuple[int, int]]) -> None: + if (len(src_to_dsts) > 0): + self.attn_backend.copy_blocks(self.kv_cache, src_to_dsts) @staticmethod def get_cache_block_size( @@ -139,6 +213,7 @@ class OpenVINOWorker(LoraNotSupportedWorkerBase): def __init__( self, + ov_core: ov.Core, model_config: ModelConfig, parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, @@ -153,6 +228,7 @@ def __init__( kv_cache_dtype: Optional[ov.Type] = ov.Type.undefined, is_driver_worker: bool = False, ) -> None: + self.ov_core = ov_core self.model_config = model_config self.parallel_config = parallel_config self.parallel_config.rank = rank @@ -175,6 +251,7 @@ def __init__( init_cached_hf_modules() self.model_runner = OpenVINOModelRunner( + self.ov_core, model_config, parallel_config, scheduler_config, @@ -204,56 +281,69 @@ def determine_num_available_blocks(self) -> Tuple[int, int]: This determines how many KV blocks can fit into the configured KV cache space. - - Note that since vLLM assumes a block resides on GPU if it can be - modified, we return num_gpu_blocks=num_cpu_blocks and num_cpu_blocks=0. - This allows us to reuse the scheduler of vLLM without generalizing it - to different devices. """ - # For OpenVINO backend, the block number will be calculated based on the - # openvino_kvcache_space_bytes. + # For OpenVINO backend, in case of CPU device, the block number will be + # calculated based on the openvino_kvcache_space_bytes. cache_block_size = self.get_cache_block_size_bytes() - num_cpu_blocks = int(self.cache_config.openvino_kvcache_space_bytes // - cache_block_size) - num_cpu_blocks = max(num_cpu_blocks, 0) + kvcache_space_bytes = self.cache_config.openvino_kvcache_space_bytes - # Note: To reuse the cache management procedure, - # use cpu cache as 'gpu cache'. - num_gpu_blocks = num_cpu_blocks - num_cpu_blocks = 0 - return num_gpu_blocks, num_cpu_blocks + if is_openvino_cpu(): + num_device_blocks = int(kvcache_space_bytes // cache_block_size) + num_swap_blocks = 0 + else: + if kvcache_space_bytes > 0: + logger.info("KV_CACHE size was explicitly configured via " + "VLLM_OPENVINO_KVCACHE_SPACE environment " + "variable, ignoring profiling run.") + kv_cache_size = kvcache_space_bytes + else: + try: + kv_cache_size = self.profile_run() + except Exception as err: + raise RuntimeError( + "The error occurred during profile run. This might be " + "due to insufficient GPU memory. Consider decreasing " + "`max_model_len` to limit the maximum simultaneously " + "processed tokens.") from err + + num_device_blocks = int(kv_cache_size // cache_block_size) + num_swap_blocks = int(self.cache_config.swap_space_bytes // + cache_block_size) + + return num_device_blocks, num_swap_blocks def initialize_cache(self, num_gpu_blocks: int, num_cpu_blocks: int) -> None: - """Initialize the KV cache. Currently, swappable CPU memory is not - supported. + """Initialize the KV cache. Swappable CPU memory is only + supported on GPU. - Since this worker does not support GPUs, we use the num_gpu_blocks to + For CPU, we use the num_gpu_blocks to determine how many non-swappable CPU blocks to allocate. """ - assert (num_cpu_blocks == 0 - ), f"{type(self)} does not support swappable cache" - # Note: To reuse the cache management procedure, - # use cpu cache as 'gpu cache'. - num_cpu_blocks = num_gpu_blocks + num_device_blocks = num_gpu_blocks + num_swap_blocks = num_cpu_blocks + + if is_openvino_cpu(): + assert (num_swap_blocks == 0 + ), f"{type(self)} does not support swappable cache for CPU" - self._validate_num_cpu_blocks(num_cpu_blocks) - self.cache_config.num_gpu_blocks = num_cpu_blocks - self.cache_config.num_cpu_blocks = 0 + self._validate_num_blocks(num_device_blocks) + self.cache_config.num_gpu_blocks = num_device_blocks + self.cache_config.num_cpu_blocks = num_swap_blocks # Initialize the cache. self._init_cache_engine() - def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: - """Raise errors if the num_cpu_blocks is invalid.""" - if num_cpu_blocks <= 0: + def _validate_num_blocks(self, num_blocks: int) -> None: + """Raise errors if the num_blocks is invalid.""" + if num_blocks <= 0: raise ValueError( "No available memory for the cache blocks. " "Try increasing `VLLM_OPENVINO_KVCACHE_SPACE` when " "initializing the engine.") - max_seq_len = self.cache_config.block_size * num_cpu_blocks + max_seq_len = self.cache_config.block_size * num_blocks if self.model_config.max_model_len > max_seq_len: raise ValueError( f"The model's max seq len ({self.model_config.max_model_len}) " @@ -263,11 +353,14 @@ def _validate_num_cpu_blocks(self, num_cpu_blocks: int) -> None: "when initializing the engine.") def _init_cache_engine(self) -> None: + ov_device = envs.VLLM_OPENVINO_DEVICE self.cache_engine = OpenVINOCacheEngine( self.cache_config, self.model_config, self.parallel_config, self.device_config, + self.ov_core, + ov_device, ) self.kv_cache = self.cache_engine.kv_cache self.model_runner.block_size = self.cache_engine.block_size @@ -275,9 +368,16 @@ def _init_cache_engine(self) -> None: assert self.kv_cache is not None # Populate the cache to warmup the memory - for key_cache, value_cache in self.kv_cache: - key_cache.data[:] = 0 - value_cache.data[:] = 0 + if is_openvino_cpu(): + for key_cache, value_cache in self.kv_cache: + key_cache.data[:] = 0 + value_cache.data[:] = 0 + + def cache_swap_in(self, src_to_dst: List[Tuple[int, int]]) -> None: + self.cache_engine.swap_in(src_to_dst) + + def cache_swap_out(self, src_to_dst: List[Tuple[int, int]]) -> None: + self.cache_engine.swap_out(src_to_dst) def cache_copy( self, @@ -300,17 +400,28 @@ def execute_model( num_seq_groups: int = len(seq_group_metadata_list) assert execute_model_req is not None blocks_to_copy = execute_model_req.blocks_to_copy - assert len(execute_model_req.blocks_to_swap_in) == 0 - assert len(execute_model_req.blocks_to_swap_out) == 0 + blocks_to_swap_in = execute_model_req.blocks_to_swap_in + blocks_to_swap_out = execute_model_req.blocks_to_swap_out data: Dict[str, Any] = { "num_seq_groups": num_seq_groups, "blocks_to_copy": execute_model_req.blocks_to_copy, + "blocks_to_swap_in": execute_model_req.blocks_to_swap_in, + "blocks_to_swap_out": execute_model_req.blocks_to_swap_out, } broadcast_tensor_dict(data, src=0) else: data = broadcast_tensor_dict(src=0) num_seq_groups = data["num_seq_groups"] blocks_to_copy = data["blocks_to_copy"] + blocks_to_swap_in = data["blocks_to_swap_in"] + blocks_to_swap_out = data["blocks_to_swap_out"] + + if is_openvino_cpu(): + assert len(execute_model_req.blocks_to_swap_in) == 0 + assert len(execute_model_req.blocks_to_swap_out) == 0 + else: + self.cache_swap_in(blocks_to_swap_in) + self.cache_swap_out(blocks_to_swap_out) self.cache_copy(blocks_to_copy) @@ -353,3 +464,149 @@ def get_cache_block_size_bytes(self) -> int: self.model_config, self.parallel_config, ) + + def profile_run(self) -> int: + ov_device = envs.VLLM_OPENVINO_DEVICE + + assert not is_openvino_cpu(), \ + "CPU device isn't supposed to use profile run." + + import openvino.properties.device as device + import openvino.properties.intel_gpu as intel_gpu + + ov_core = self.ov_core + cache_config = self.cache_config + model_config = self.model_config + parallel_config = self.parallel_config + device_config = self.device_config + input_registry = INPUT_REGISTRY + mm_registry = MULTIMODAL_REGISTRY + mm_registry.init_mm_limits_per_prompt(model_config) + + # Execute a forward pass with dummy inputs to profile the memory usage + # of the model. + def model_profile_run(): + top_k = model_config.get_vocab_size() - 1 + sampling_params = SamplingParams(top_p=0.99, top_k=top_k) + + max_num_batched_tokens = \ + self.scheduler_config.max_num_batched_tokens + max_num_seqs = self.scheduler_config.max_num_seqs + tmp_cache_config = CacheConfig(cache_config.block_size, + cache_config.gpu_memory_utilization, + cache_config.swap_space_bytes, + "auto") + tmp_cache_config.num_gpu_blocks = 1 + tmp_cache_config.num_cpu_blocks = 0 + tmp_cache_config.cache_dtype = cache_config.cache_dtype + + profiling_cache_engine = OpenVINOCacheEngine( + tmp_cache_config, model_config, parallel_config, device_config, + ov_core, ov_device) + + # Profile memory usage with max_num_sequences sequences and the + # total # number of tokens equal to max_num_batched_tokens. + seqs: List[SequenceGroupMetadata] = [] + for group_id in range(max_num_seqs): + seq_len = (max_num_batched_tokens // max_num_seqs + + (group_id < max_num_batched_tokens % max_num_seqs)) + block_size = cache_config.block_size + seq_num_blocks = (seq_len + block_size - 1) // block_size + + seq_data, dummy_multi_modal_data = input_registry \ + .dummy_data_for_profiling(model_config, + seq_len, + mm_registry) + + block_tables = [[0] * seq_num_blocks] * max_num_seqs + seq = SequenceGroupMetadata( + request_id=str(group_id), + is_prompt=True, + seq_data={group_id: seq_data}, + sampling_params=sampling_params, + block_tables=block_tables, + lora_request=None, + multi_modal_data=dummy_multi_modal_data) + seqs.append(seq) + + self.model_runner.block_size = tmp_cache_config.block_size + + # Run the model with the dummy inputs. + self.model_runner.execute_model(seqs, + profiling_cache_engine.kv_cache) + + # explicitly delete temporary KV cache manager to free KV cache + # when real inputs will be passed to OV + del profiling_cache_engine + + logger.info( + "Start profiling run with dummy inputs to evaluate " + "memory usage for %s. It might take a while.", ov_device) + + model_profile_run() + + gpu_device_type = ov_core.get_property(ov_device, device.type) + memory_statistics = \ + ov_core.get_property(ov_device, intel_gpu.memory_statistics) + memory_utilization = cache_config.gpu_memory_utilization + + if gpu_device_type == device.Type.INTEGRATED and \ + memory_utilization >= 0.9: + logger.warning( + "iGPU is used with high gpu_memory_utilization=%f " + "value. This may cause low performance due to " + "occupying the majority of available system " + "memory. Please consider decreasing " + "gpu_memory_utilization or explicitly setting" + "`VLLM_OPENVINO_KVCACHE_SPACE` (GB) environment " + "variable.", memory_utilization) + + # sum up all used device memory + device_memory_types = ["cl_mem", "usm_device"] + used_device_mem = \ + sum(memory_statistics.get(key, 0) for key in device_memory_types) + + if gpu_device_type == device.Type.INTEGRATED: + used_device_mem += memory_statistics.get("usm_host", 0) + + # there could be unaccounted extra memory reserved by kernels, kept + # in memory pools, etc + # therefore, add a threshold to account for this + used_memory_threshold = 1.1 + used_device_mem *= used_memory_threshold + + total_device_memory = \ + ov_core.get_property(ov_device, intel_gpu.device_total_mem_size) + + def format_memory_size(size) -> str: + units = ["B", "KB", "MB", "GB"] + unit_index = 0 + + while size > 1024 and unit_index < len(units) - 1: + size /= 1024 + unit_index += 1 + + return f"{size:.2f} {units[unit_index]}" + + total_device_memory_str = \ + format(format_memory_size(total_device_memory)) + used_device_memory_str = \ + format(format_memory_size(used_device_mem)) + + logger.info( + "Total %s memory: %s. " + "Amount of memory required to run the model with " + "max_num_batched_tokens=%d: %s.", ov_device, + total_device_memory_str, + self.scheduler_config.max_num_batched_tokens, + used_device_memory_str) + + if used_device_mem >= total_device_memory: + raise RuntimeError( + f"The required memory size {used_device_memory_str} for model " + "is higher than the total available device " + "memory {total_device_memory_str}. Please consider to " + "decrease `max_num_batched_tokens` or increase " + "`gpu_memory_utilization`") + + return total_device_memory * memory_utilization - used_device_mem