diff --git a/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml new file mode 100644 index 0000000000000..0ecfc01ef049f --- /dev/null +++ b/.buildkite/lm-eval-harness/configs/Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml @@ -0,0 +1,11 @@ +# bash .buildkite/lm-eval-harness/run-lm-eval-gsm-vllm-baseline.sh -m nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test -b "auto" -l 250 -f 5 -t 1 +model_name: "nm-testing/Meta-Llama-3-8B-Instruct-W8-Channel-A8-Dynamic-Asym-Per-Token-Test" +tasks: +- name: "gsm8k" + metrics: + - name: "exact_match,strict-match" + value: 0.764 + - name: "exact_match,flexible-extract" + value: 0.764 +limit: 250 +num_fewshot: 5 diff --git a/.buildkite/lm-eval-harness/configs/models-small.txt b/.buildkite/lm-eval-harness/configs/models-small.txt index 064883859218a..64a0f428587af 100644 --- a/.buildkite/lm-eval-harness/configs/models-small.txt +++ b/.buildkite/lm-eval-harness/configs/models-small.txt @@ -1,6 +1,7 @@ Meta-Llama-3-8B-Instruct.yaml Meta-Llama-3-8B-Instruct-FP8-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-INT8-compressed-tensors.yaml +Meta-Llama-3-8B-Instruct-INT8-compressed-tensors-asym.yaml Meta-Llama-3-8B-Instruct-nonuniform-compressed-tensors.yaml Meta-Llama-3-8B-Instruct-Channelwise-compressed-tensors.yaml Minitron-4B-Base-FP8.yaml diff --git a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py index aa0b1b096b9ce..afc935c1a9318 100644 --- a/.buildkite/lm-eval-harness/test_lm_eval_correctness.py +++ b/.buildkite/lm-eval-harness/test_lm_eval_correctness.py @@ -49,10 +49,15 @@ def test_lm_eval_correctness(): results = launch_lm_eval(eval_config) # Confirm scores match ground truth. + success = True for task in eval_config["tasks"]: for metric in task["metrics"]: ground_truth = metric["value"] measured_value = results["results"][task["name"]][metric["name"]] print(f'{task["name"]} | {metric["name"]}: ' f'ground_truth={ground_truth} | measured={measured_value}') - assert numpy.isclose(ground_truth, measured_value, rtol=RTOL) + success = success and numpy.isclose( + ground_truth, measured_value, rtol=RTOL) + + # Assert at the end, print all scores even on failure for debugging. + assert success diff --git a/.buildkite/release-pipeline.yaml b/.buildkite/release-pipeline.yaml index 416fe344a36ea..e72138e29dd65 100644 --- a/.buildkite/release-pipeline.yaml +++ b/.buildkite/release-pipeline.yaml @@ -8,8 +8,9 @@ steps: - "docker run --rm -v $(pwd)/artifacts:/artifacts_host vllm-ci:build-image bash -c 'cp -r dist /artifacts_host && chmod -R a+rw /artifacts_host'" # rename the files to change linux -> manylinux1 - "for f in artifacts/dist/*.whl; do mv -- \"$$f\" \"$${f/linux/manylinux1}\"; done" - - "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/$BUILDKITE_COMMIT/" - - "aws s3 cp --recursive artifacts/dist s3://vllm-wheels/nightly/" + - "mv artifacts/dist/$(ls artifacts/dist) artifacts/dist/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + - "aws s3 cp artifacts/dist/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl s3://vllm-wheels/$BUILDKITE_COMMIT/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" + - "aws s3 cp artifacts/dist/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl s3://vllm-wheels/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl" env: DOCKER_BUILDKIT: "1" diff --git a/.buildkite/run-xpu-test.sh b/.buildkite/run-xpu-test.sh index 22a7e76937a76..6ffa66d5ef3d6 100644 --- a/.buildkite/run-xpu-test.sh +++ b/.buildkite/run-xpu-test.sh @@ -11,4 +11,4 @@ trap remove_docker_container EXIT remove_docker_container # Run the image and launch offline inference -docker run --network host --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path xpu-test python3 examples/offline_inference.py +docker run --network host --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test python3 examples/offline_inference.py diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index ea8b3d46f1b3f..b12bf7b382d0f 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -9,6 +9,7 @@ # label(str): the name of the test. emoji allowed. # fast_check(bool): whether to run this on each commit on fastcheck pipeline. # fast_check_only(bool): run this test on fastcheck pipeline only +# optional(bool): never run this test by default (i.e. need to unblock manually) # command(str): the single command to run for tests. incompatible with commands. # commands(list): the list of commands to run for test. incompatbile with command. # mirror_hardwares(list): the list of hardwares to run the test on as well. currently only supports [amd] @@ -39,7 +40,7 @@ steps: # Check API reference (if it fails, you may have missing mock imports) - grep \"sig sig-object py\" build/html/dev/sampling_params.html -- label: Async Engine, Inputs, Utils, Worker Test # 15min +- label: Async Engine, Inputs, Utils, Worker Test # 24min fast_check: true source_file_dependencies: - vllm/ @@ -63,13 +64,21 @@ steps: fast_check: true source_file_dependencies: - vllm/ - - tests/basic_correctness + - tests/basic_correctness/test_basic_correctness + - tests/basic_correctness/test_cpu_offload + - tests/basic_correctness/test_preemption commands: - pytest -v -s basic_correctness/test_basic_correctness.py - pytest -v -s basic_correctness/test_cpu_offload.py + - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py + +- label: Chunked Prefill Test + source_file_dependencies: + - vllm/ + - tests/basic_correctness/test_chunked_prefill + commands: - VLLM_ATTENTION_BACKEND=XFORMERS pytest -v -s basic_correctness/test_chunked_prefill.py - VLLM_ATTENTION_BACKEND=FLASH_ATTN pytest -v -s basic_correctness/test_chunked_prefill.py - - VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest -v -s basic_correctness/test_preemption.py - label: Core Test # 10min mirror_hardwares: [amd] @@ -81,7 +90,7 @@ steps: commands: - pytest -v -s core -- label: Entrypoints Test # 20min +- label: Entrypoints Test # 40min working_dir: "/vllm-workspace/tests" fast_check: true mirror_hardwares: [amd] @@ -95,7 +104,8 @@ steps: - pytest -v -s entrypoints/llm/test_generate.py # it needs a clean process - pytest -v -s entrypoints/llm/test_generate_multiple_loras.py # it needs a clean process - pytest -v -s entrypoints/llm/test_guided_generate.py # it needs a clean process - - pytest -v -s entrypoints/openai + - pytest -v -s entrypoints/openai --ignore=entrypoints/openai/test_oot_registration.py + - pytest -v -s entrypoints/openai/test_oot_registration.py # it needs a clean process - pytest -v -s entrypoints/test_chat_utils.py - pytest -v -s entrypoints/offline_mode # Needs to avoid interference with other tests @@ -150,7 +160,7 @@ steps: # OOM in the CI unless we run this separately - pytest -v -s tokenization -- label: Examples Test # 12min +- label: Examples Test # 15min working_dir: "/vllm-workspace/examples" #mirror_hardwares: [amd] source_file_dependencies: @@ -168,7 +178,7 @@ steps: - python3 tensorize_vllm_model.py --model facebook/opt-125m serialize --serialized-directory /tmp/ --suffix v1 && python3 tensorize_vllm_model.py --model facebook/opt-125m deserialize --path-to-tensors /tmp/vllm/facebook/opt-125m/v1/model.tensors - python3 offline_inference_encoder_decoder.py -- label: Prefix Caching Test # 7min +- label: Prefix Caching Test # 9min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -176,7 +186,7 @@ steps: commands: - pytest -v -s prefix_caching -- label: Samplers Test # 18min +- label: Samplers Test # 36min source_file_dependencies: - vllm/model_executor/layers - vllm/sampling_metadata.py @@ -192,17 +202,15 @@ steps: - tests/test_logits_processor command: pytest -v -s test_logits_processor.py -- label: Speculative decoding tests # 22min +- label: Speculative decoding tests # 30min source_file_dependencies: - vllm/spec_decode - tests/spec_decode commands: - # See https://github.com/vllm-project/vllm/issues/5152 - - export VLLM_ATTENTION_BACKEND=XFORMERS - pytest -v -s spec_decode/e2e/test_multistep_correctness.py - pytest -v -s spec_decode --ignore=spec_decode/e2e/test_multistep_correctness.py -- label: LoRA Test %N # 30min each +- label: LoRA Test %N # 15min each mirror_hardwares: [amd] source_file_dependencies: - vllm/lora @@ -210,7 +218,7 @@ steps: command: pytest -v -s lora --shard-id=$$BUILDKITE_PARALLEL_JOB --num-shards=$$BUILDKITE_PARALLEL_JOB_COUNT --ignore=lora/test_long_context.py parallelism: 4 -- label: "PyTorch Fullgraph Smoke Test" +- label: "PyTorch Fullgraph Smoke Test" # 9min fast_check: true source_file_dependencies: - vllm/ @@ -218,14 +226,14 @@ steps: commands: - pytest -v -s compile/test_full_graph_smoke.py -- label: "PyTorch Fullgraph Test" +- label: "PyTorch Fullgraph Test" # 18min source_file_dependencies: - vllm/ - tests/compile commands: - pytest -v -s compile/test_full_graph.py -- label: Kernels Test %N # 30min each +- label: Kernels Test %N # 1h each mirror_hardwares: [amd] source_file_dependencies: - csrc/ @@ -255,7 +263,7 @@ steps: - pip install aiohttp - bash run-benchmarks.sh -- label: Quantization Test # 15min +- label: Quantization Test # 33min source_file_dependencies: - csrc/ - vllm/model_executor/layers/quantization @@ -299,7 +307,7 @@ steps: - pytest -v -s models/test_oot_registration.py # it needs a clean process - pytest -v -s models/*.py --ignore=models/test_oot_registration.py -- label: Decoder-only Language Models Test # 1h3min +- label: Decoder-only Language Models Test # 1h36min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -307,7 +315,7 @@ steps: commands: - pytest -v -s models/decoder_only/language -- label: Decoder-only Multi-Modal Models Test # 56min +- label: Decoder-only Multi-Modal Models Test # 1h31min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ @@ -317,15 +325,25 @@ steps: - pytest -v -s models/decoder_only/audio_language - pytest -v -s models/decoder_only/vision_language -- label: Other Models Test # 5min +- label: Other Models Test # 6min #mirror_hardwares: [amd] source_file_dependencies: - vllm/ - tests/models/embedding/language - tests/models/encoder_decoder/language + - tests/models/encoder_decoder/vision_language commands: - pytest -v -s models/embedding/language - pytest -v -s models/encoder_decoder/language + - pytest -v -s models/encoder_decoder/vision_language + +- label: Custom Models Test + #mirror_hardwares: [amd] + optional: true + commands: + # PR authors can temporarily add commands below to test individual models + # e.g. pytest -v -s models/encoder_decoder/vision_language/test_mllama.py + # *To avoid merge conflicts, remember to REMOVE (not just comment out) them before merging the PR* ##### 1 GPU test ##### ##### multi gpus test ##### @@ -358,7 +376,7 @@ steps: - # the following commands are for the second node, with ip 192.168.10.11 (ray environment already set up) - VLLM_TEST_SAME_HOST=0 torchrun --nnodes 2 --nproc-per-node=2 --rdzv_backend=c10d --rdzv_endpoint=192.168.10.10 distributed/test_same_node.py | grep -q 'Same node test passed' -- label: Distributed Tests (2 GPUs) # 28min +- label: Distributed Tests (2 GPUs) # 40min #mirror_hardwares: [amd] working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -375,14 +393,16 @@ steps: - VLLM_TEST_SAME_HOST=1 torchrun --nproc-per-node=4 distributed/test_same_node.py | grep -q 'Same node test passed' - TARGET_TEST_SUITE=L4 pytest basic_correctness/ -v -s -m distributed_2_gpus # Avoid importing model tests that cause CUDA reinitialization error - - pytest models/encoder_decoder/language/test_bart.py models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus + - pytest models/encoder_decoder/language/test_bart.py -v -s -m distributed_2_gpus + - pytest models/encoder_decoder/vision_language/test_broadcast.py -v -s -m distributed_2_gpus + - pytest models/decoder_only/vision_language/test_broadcast.py -v -s -m distributed_2_gpus - pytest -v -s spec_decode/e2e/test_integration_dist_tp2.py - pip install -e ./plugins/vllm_add_dummy_model - pytest -v -s distributed/test_distributed_oot.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s test_sharded_state_loader.py - CUDA_VISIBLE_DEVICES=0,1 pytest -v -s distributed/test_utils.py -- label: Multi-step Tests (4 GPUs) # 21min +- label: Multi-step Tests (4 GPUs) # 36min working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -400,7 +420,7 @@ steps: - pytest -v -s multi_step/test_correctness_async_llm.py - pytest -v -s multi_step/test_correctness_llm.py -- label: Pipeline Parallelism Test # 23min +- label: Pipeline Parallelism Test # 45min working_dir: "/vllm-workspace/tests" num_gpus: 4 source_file_dependencies: @@ -426,7 +446,7 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -v -s -x lora/test_long_context.py -- label: Weight Loading Multiple GPU Test +- label: Weight Loading Multiple GPU Test # 33min working_dir: "/vllm-workspace/tests" num_gpus: 2 source_file_dependencies: @@ -459,7 +479,7 @@ steps: # NOTE: don't test llama model here, it seems hf implementation is buggy # see https://github.com/vllm-project/vllm/pull/5689 for details - pytest -v -s distributed/test_custom_all_reduce.py - - TARGET_TEST_SUITE=A100 pytest -v -s distributed/test_basic_distributed_correctness.py + - TARGET_TEST_SUITE=A100 pytest basic_correctness/ -v -s -m distributed_2_gpus - pytest -v -s -x lora/test_mixtral.py - label: LM Eval Large Models # optional diff --git a/.dockerignore b/.dockerignore index 79fa088fa809c..17ed0d97c88b3 100644 --- a/.dockerignore +++ b/.dockerignore @@ -1,4 +1,6 @@ -vllm/*.so +/.github/ /.venv /build dist +Dockerfile* +vllm/*.so diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS new file mode 100644 index 0000000000000..e15f129719f8f --- /dev/null +++ b/.github/CODEOWNERS @@ -0,0 +1,19 @@ +# See https://help.github.com/articles/about-codeowners/ +# for more info about CODEOWNERS file + +/tests/async_engine @njhill @robertgshaw2-neuralmagic @simon-mo +/tests/test_inputs.py @DarkLight1337 @ywang96 +/tests/entrypoints @DarkLight1337 @robertgshaw2-neuralmagic @simon-mo +/tests/models @DarkLight1337 @ywang96 +/tests/multimodal @DarkLight1337 @ywang96 +/tests/prefix_caching @comaniac @KuntaiDu +/tests/spec_decode @njhill @LiuXiaoxuanPKU +/tests/kernels @tlrmchlsmth @WoosukKwon +/tests/quantization @mgoin @robertgshaw2-neuralmagic +/.buildkite/lm-eval-harness @mgoin @simon-mo +/tests/distributed/test_multi_node_assignment.py @youkaichao +/tests/distributed/test_pipeline_parallel.py @youkaichao +/tests/distributed/test_same_node.py @youkaichao +/tests/multi_step @alexm-neuralmagic @SolitaryThinker @comaniac +/tests/weight_loading @mgoin @youkaichao +/tests/basic_correctness/test_chunked_prefill @rkooo567 @comaniac diff --git a/.github/workflows/scripts/build.sh b/.github/workflows/scripts/build.sh index cd617e9f19fb2..cda0c28c75c2a 100644 --- a/.github/workflows/scripts/build.sh +++ b/.github/workflows/scripts/build.sh @@ -8,7 +8,7 @@ PATH=${cuda_home}/bin:$PATH LD_LIBRARY_PATH=${cuda_home}/lib64:$LD_LIBRARY_PATH # Install requirements -$python_executable -m pip install wheel packaging +$python_executable -m pip install wheel packaging 'setuptools-scm>=8' $python_executable -m pip install -r requirements-cuda.txt # Limit the number of parallel jobs to avoid OOM diff --git a/.gitignore b/.gitignore index abeaf0a82e303..5367ece834890 100644 --- a/.gitignore +++ b/.gitignore @@ -33,6 +33,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +/.deps/ # PyInstaller # Usually these files are written by a python script from a template diff --git a/CMakeLists.txt b/CMakeLists.txt index b2fa72d4775c4..e531a410ec8c8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -166,7 +166,16 @@ if(NVCC_THREADS AND VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_GPU_FLAGS "--threads=${NVCC_THREADS}") endif() + +# +# Use FetchContent for C++ dependencies that are compiled as part of vLLM's build process. +# Configure it to place files in vllm/.deps, in order to play nicely with sccache. +# include(FetchContent) +get_filename_component(PROJECT_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}" ABSOLUTE) +file(MAKE_DIRECTORY "${FETCHCONTENT_BASE_DIR}") +set(FETCHCONTENT_BASE_DIR "${PROJECT_ROOT_DIR}/.deps") +message(STATUS "FetchContent base directory: ${FETCHCONTENT_BASE_DIR}") # # Define other extension targets diff --git a/Dockerfile b/Dockerfile index 6bb4bd032c39c..872b1bc47054a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -27,6 +27,14 @@ RUN echo 'tzdata tzdata/Areas select America' | debconf-set-selections \ && curl -sS https://bootstrap.pypa.io/get-pip.py | python${PYTHON_VERSION} \ && python3 --version && python3 -m pip --version +# Upgrade to GCC 10 to avoid https://gcc.gnu.org/bugzilla/show_bug.cgi?id=92519 +# as it was causing spam when compiling the CUTLASS kernels +RUN apt-get install -y gcc-10 g++-10 +RUN update-alternatives --install /usr/bin/gcc gcc /usr/bin/gcc-10 110 --slave /usr/bin/g++ g++ /usr/bin/g++-10 +RUN < /dev/null && \ echo "deb [signed-by=/usr/share/keyrings/intel-oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main " | tee /etc/apt/sources.list.d/oneAPI.list && \ @@ -7,20 +7,49 @@ RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRO echo "deb [arch=amd64,i386 signed-by=/usr/share/keyrings/intel-graphics.gpg] https://repositories.intel.com/graphics/ubuntu jammy arc" | tee /etc/apt/sources.list.d/intel.gpu.jammy.list && \ chmod 644 /usr/share/keyrings/intel-graphics.gpg -RUN apt-get update -y && \ - apt-get install -y curl libicu70 lsb-release git wget vim numactl python3 python3-pip ffmpeg libsm6 libxext6 libgl1 - -COPY ./ /workspace/vllm +RUN apt-get update -y && \ + apt-get install -y --no-install-recommends --fix-missing \ + curl \ + ffmpeg \ + git \ + libsndfile1 \ + libsm6 \ + libxext6 \ + libgl1 \ + lsb-release \ + numactl \ + python3 \ + python3-dev \ + python3-pip \ + # vim \ + wget WORKDIR /workspace/vllm +COPY requirements-xpu.txt /workspace/vllm/requirements-xpu.txt +COPY requirements-common.txt /workspace/vllm/requirements-common.txt RUN --mount=type=cache,target=/root/.cache/pip \ - pip install -v --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ \ - cmake>=3.26 ninja packaging setuptools-scm>=8 wheel jinja2 \ - -r requirements-xpu.txt + pip install --no-cache-dir \ + --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ \ + -r requirements-xpu.txt + +COPY ./ /workspace/vllm + +ENV VLLM_TARGET_DEVICE=xpu RUN --mount=type=cache,target=/root/.cache/pip \ --mount=type=bind,source=.git,target=.git \ - VLLM_TARGET_DEVICE=xpu python3 setup.py install + python3 setup.py install CMD ["/bin/bash"] + +FROM vllm-base AS vllm-openai + +# install additional dependencies for openai api server +RUN --mount=type=cache,target=/root/.cache/pip \ + pip install accelerate hf_transfer 'modelscope!=1.15.0' + +ENV VLLM_USAGE_SOURCE production-docker-image \ + TRITON_XPU_PROFILE 1 + +ENTRYPOINT ["python3", "-m", "vllm.entrypoints.openai.api_server"] diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index a39d1cf842f06..eadf994cacd34 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -11,7 +11,7 @@ from vllm import LLM, SamplingParams from vllm.engine.arg_utils import DEVICE_OPTIONS, EngineArgs -from vllm.inputs import PromptInputs +from vllm.inputs import PromptType from vllm.model_executor.layers.quantization import QUANTIZATION_METHODS from vllm.utils import FlexibleArgumentParser @@ -61,7 +61,7 @@ def main(args: argparse.Namespace): dummy_prompt_token_ids = np.random.randint(10000, size=(args.batch_size, args.input_len)) - dummy_inputs: List[PromptInputs] = [{ + dummy_prompts: List[PromptType] = [{ "prompt_token_ids": batch } for batch in dummy_prompt_token_ids.tolist()] @@ -74,13 +74,13 @@ def run_to_completion(profile_dir: Optional[str] = None): ], on_trace_ready=torch.profiler.tensorboard_trace_handler( str(profile_dir))) as p: - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) print(p.key_averages()) else: start_time = time.perf_counter() - llm.generate(dummy_inputs, + llm.generate(dummy_prompts, sampling_params=sampling_params, use_tqdm=False) end_time = time.perf_counter() diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index a407a263120bb..996a92d2a8b3d 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -1,4 +1,4 @@ -"""Benchmark online serving throughput. +r"""Benchmark online serving throughput. On the server side, run one of the following commands: vLLM OpenAI API server @@ -89,8 +89,6 @@ def sample_sharegpt_requests( tokenizer: PreTrainedTokenizerBase, fixed_output_len: Optional[int] = None, ) -> List[Tuple[str, int, int, None]]: - if fixed_output_len is not None and fixed_output_len < 4: - raise ValueError("output_len too small") # Load the dataset. with open(dataset_path) as f: dataset = json.load(f) @@ -117,7 +115,7 @@ def sample_sharegpt_requests( prompt_len = len(prompt_token_ids) output_len = len(completion_token_ids ) if fixed_output_len is None else fixed_output_len - if prompt_len < 4 or output_len < 4: + if prompt_len < 4 or (fixed_output_len is None and output_len < 4): # Prune too short sequences. continue if prompt_len > 1024 or prompt_len + output_len > 2048: @@ -228,10 +226,11 @@ def sample_hf_requests( prompt_len = len(prompt_token_ids) output_len = len(completion_token_ids ) if fixed_output_len is None else fixed_output_len - if prompt_len < 4 or output_len < 4: + if fixed_output_len is None and (prompt_len < 4 or output_len < 4): # Prune too short sequences. continue - if prompt_len > 1024 or prompt_len + output_len > 2048: + if fixed_output_len is None and \ + (prompt_len > 1024 or prompt_len + output_len > 2048): # Prune too long sequences. continue @@ -963,4 +962,4 @@ def main(args: argparse.Namespace): ) args = parser.parse_args() - main(args) \ No newline at end of file + main(args) diff --git a/collect_env.py b/collect_env.py index c5cd8c315e749..ae7f97f355253 100644 --- a/collect_env.py +++ b/collect_env.py @@ -267,13 +267,23 @@ def get_neuron_sdk_version(run_lambda): def get_vllm_version(): + version = "" try: import vllm - return vllm.__version__ + "@" + vllm.__commit__ + version = vllm.__version__ except Exception: - # old version of vllm does not have __commit__ - return 'N/A' - + pass + commit = "" + try: + import vllm + commit = vllm.__commit__ + except Exception: + pass + if version != "" and commit != "": + return f"{version}@{commit}" + if version == "" and commit == "": + return "N/A" + return version or commit def summarize_vllm_build_flags(): # This could be a static method if the flags are constant, or dynamic if you need to check environment variables, etc. diff --git a/csrc/core/exception.hpp b/csrc/core/exception.hpp new file mode 100644 index 0000000000000..f3b2ffaef6cce --- /dev/null +++ b/csrc/core/exception.hpp @@ -0,0 +1,3 @@ +#pragma once + +#define VLLM_IMPLIES(p, q) (!(p) || (q)) diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.cu b/csrc/mamba/causal_conv1d/causal_conv1d.cu index 32261ec17d897..30831efdfa1a2 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.cu +++ b/csrc/mamba/causal_conv1d/causal_conv1d.cu @@ -39,8 +39,6 @@ template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template -void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); template void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream); @@ -55,8 +53,11 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, const at::Tensor x, const at::Tensor weight, const at::Tensor out, - void* bias_ptr, - bool silu_activation) { + const c10::optional& bias, + bool silu_activation, + const c10::optional& query_start_loc = std::nullopt, + const c10::optional& cache_indices = std::nullopt, + const c10::optional& has_initial_state = std::nullopt) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -71,26 +72,31 @@ void set_conv_params_fwd(ConvParamsBase ¶ms, // Set the pointers and strides. params.x_ptr = x.data_ptr(); params.weight_ptr = weight.data_ptr(); - params.bias_ptr = bias_ptr; + params.bias_ptr = bias.has_value() ? bias.value().data_ptr() : nullptr; params.out_ptr = out.data_ptr(); // All stride are in elements, not bytes. - params.x_batch_stride = x.stride(0); - params.x_c_stride = x.stride(1); - params.x_l_stride = x.stride(-1); + params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr; + params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr; + params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr; + const bool varlen = params.query_start_loc_ptr != nullptr; + params.x_batch_stride = x.stride(varlen ? 1 : 0); + params.x_c_stride = x.stride(varlen ? 0 : 1); + params.x_l_stride = x.stride(varlen ? 1 : -1); params.weight_c_stride = weight.stride(0); params.weight_width_stride = weight.stride(1); - params.out_batch_stride = out.stride(0); - params.out_c_stride = out.stride(1); - params.out_l_stride = out.stride(-1); + params.out_batch_stride = out.stride(varlen ? 1 : 0); + params.out_c_stride = out.stride(varlen ? 0 : 1); + params.out_l_stride = out.stride(varlen ? 1 : -1); } at::Tensor causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, const c10::optional &bias_, - const c10::optional &seq_idx_, - const c10::optional &initial_states_, - const c10::optional &final_states_out_, + const c10::optional &conv_states, + const c10::optional &query_start_loc, + const c10::optional &cache_indices, + const c10::optional &has_initial_state, bool silu_activation) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); @@ -99,24 +105,22 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, TORCH_CHECK(x.is_cuda()); TORCH_CHECK(weight.is_cuda()); - + + const bool varlen = query_start_loc.has_value() ? true : false; const auto sizes = x.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; + const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0]; + const int dim = varlen ? sizes[0] : sizes[1]; + const int seqlen = varlen ? sizes[1] : sizes[2]; const int width = weight.size(-1); - - CHECK_SHAPE(x, batch_size, dim, seqlen); + if (varlen){ + CHECK_SHAPE(x, dim, seqlen); + } + else { + CHECK_SHAPE(x, batch_size, dim, seqlen); + } CHECK_SHAPE(weight, dim, width); - TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1); - const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1; - if (is_channel_last) { - TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now"); - TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8"); - } - TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); if (bias_.has_value()) { auto bias = bias_.value(); @@ -126,56 +130,50 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, CHECK_SHAPE(bias, dim); } - if (seq_idx_.has_value()) { - TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout"); - auto seq_idx = seq_idx_.value(); - TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32); - TORCH_CHECK(seq_idx.is_cuda()); - TORCH_CHECK(seq_idx.is_contiguous()); - CHECK_SHAPE(seq_idx, batch_size, seqlen); + + if (has_initial_state.has_value()) { + auto has_initial_state_ = has_initial_state.value(); + TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool); + TORCH_CHECK(has_initial_state_.is_cuda()); + CHECK_SHAPE(has_initial_state_, batch_size); } - at::Tensor out = torch::empty_like(x); - ConvParamsBase params; - set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, - bias_.has_value() ? bias_.value().data_ptr() : nullptr, - silu_activation); - - if (seq_idx_.has_value()) { - params.seq_idx_ptr = seq_idx_.value().data_ptr(); - } else { - params.seq_idx_ptr = nullptr; + if (query_start_loc.has_value()) { + auto query_start_loc_ = query_start_loc.value(); + TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(query_start_loc_.is_cuda()); } - if (initial_states_.has_value()) { - TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout"); - auto initial_states = initial_states_.value(); - TORCH_CHECK(initial_states.scalar_type() == input_type); - TORCH_CHECK(initial_states.is_cuda()); - CHECK_SHAPE(initial_states, batch_size, dim, width - 1); - TORCH_CHECK(initial_states.stride(1) == 1); - params.initial_states_ptr = initial_states.data_ptr(); - params.initial_states_batch_stride = initial_states.stride(0); - params.initial_states_c_stride = initial_states.stride(1); - params.initial_states_l_stride = initial_states.stride(2); - } else { - params.initial_states_ptr = nullptr; + + if (cache_indices.has_value()) { + auto cache_indices_ = cache_indices.value(); + TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(cache_indices_.is_cuda()); + CHECK_SHAPE(cache_indices_, batch_size); } - if (final_states_out_.has_value()) { - TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout"); - auto final_states = final_states_out_.value(); - TORCH_CHECK(final_states.scalar_type() == input_type); - TORCH_CHECK(final_states.is_cuda()); - CHECK_SHAPE(final_states, batch_size, dim, width - 1); - TORCH_CHECK(final_states.stride(1) == 1); - params.final_states_ptr = final_states.data_ptr(); - params.final_states_batch_stride = final_states.stride(0); - params.final_states_c_stride = final_states.stride(1); - params.final_states_l_stride = final_states.stride(2); + at::Tensor out = torch::empty_like(x); + + ConvParamsBase params; + set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, + bias_, + silu_activation, + query_start_loc, + cache_indices, + has_initial_state + ); + + if (conv_states.has_value()) { + auto conv_states_ = conv_states.value(); + TORCH_CHECK(conv_states_.scalar_type() == input_type); + TORCH_CHECK(conv_states_.is_cuda()); + params.conv_states_ptr = conv_states_.data_ptr(); + params.conv_states_batch_stride = conv_states_.stride(0); + params.conv_states_c_stride = conv_states_.stride(1); + params.conv_states_l_stride = conv_states_.stride(2); } else { - params.final_states_ptr = nullptr; + params.conv_states_ptr = nullptr; } // Otherwise the kernel will be launched from cuda:0 device @@ -183,11 +181,7 @@ causal_conv1d_fwd(const at::Tensor &x, const at::Tensor &weight, at::cuda::CUDAGuard device_guard{(char)x.get_device()}; auto stream = at::cuda::getCurrentCUDAStream().stream(); DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] { - if (!is_channel_last) { - causal_conv1d_fwd_cuda(params, stream); - } else { - causal_conv1d_channellast_fwd_cuda(params, stream); - } + causal_conv1d_fwd_cuda(params, stream); }); return out; } @@ -199,6 +193,7 @@ causal_conv1d_update(const at::Tensor &x, const at::Tensor &weight, const c10::optional &bias_, bool silu_activation, + const c10::optional &cache_seqlens_, const c10::optional &conv_state_indices_) { auto input_type = x.scalar_type(); auto weight_type = weight.scalar_type(); @@ -214,9 +209,12 @@ causal_conv1d_update(const at::Tensor &x, const auto sizes = x.sizes(); const int batch_size = sizes[0]; const int dim = sizes[1]; + const int seqlen = sizes[2]; const int width = weight.size(-1); + const int conv_state_len = conv_state.size(2); + TORCH_CHECK(conv_state_len >= width - 1); - CHECK_SHAPE(x, batch_size, dim); + CHECK_SHAPE(x, batch_size, dim, seqlen); CHECK_SHAPE(weight, dim, width); TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4"); @@ -232,15 +230,27 @@ causal_conv1d_update(const at::Tensor &x, at::Tensor out = torch::empty_like(x); ConvParamsBase params; - set_conv_params_fwd(params, batch_size, dim, /*seqlen=*/1, width, x, weight, out, - bias_.has_value() ? bias_.value().data_ptr() : nullptr, + set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out, + bias_, silu_activation); params.conv_state_ptr = conv_state.data_ptr(); + params.conv_state_len = conv_state_len; // All stride are in elements, not bytes. params.conv_state_batch_stride = conv_state.stride(0); params.conv_state_c_stride = conv_state.stride(1); params.conv_state_l_stride = conv_state.stride(2); + if (cache_seqlens_.has_value()) { + auto cache_seqlens = cache_seqlens_.value(); + TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32); + TORCH_CHECK(cache_seqlens.is_cuda()); + TORCH_CHECK(cache_seqlens.stride(-1) == 1); + CHECK_SHAPE(cache_seqlens, batch_size); + params.cache_seqlens = cache_seqlens.data_ptr(); + } else { + params.cache_seqlens = nullptr; + } + if (conv_state_indices_.has_value()) { auto conv_state_indices = conv_state_indices_.value(); TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32) @@ -249,11 +259,11 @@ causal_conv1d_update(const at::Tensor &x, CHECK_SHAPE(conv_state_indices, batch_size); int conv_state_entries = conv_state.size(0); - CHECK_SHAPE(conv_state, conv_state_entries, dim, width); + CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len); params.conv_state_indices_ptr = conv_state_indices.data_ptr(); } else { - CHECK_SHAPE(conv_state, batch_size, dim, width); + CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len); params.conv_state_indices_ptr = nullptr; } @@ -296,7 +306,7 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { constexpr int kWidth = Ktraits::kWidth; constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNElts = Ktraits::kNElts; - static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; + constexpr bool kIsVecLoad = Ktraits::kIsVecLoad; using input_t = typename Ktraits::input_t; using vec_t = typename Ktraits::vec_t; using weight_t = typename Ktraits::weight_t; @@ -309,20 +319,39 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { auto& smem_store_vec = reinterpret_cast(smem_); vec_t *smem_exchange = reinterpret_cast(smem_ + Ktraits::kSmemIOSize); + const bool kVarlen = params.query_start_loc_ptr != nullptr; const int tidx = threadIdx.x; const int batch_id = blockIdx.x; const int channel_id = blockIdx.y; - input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + const int *query_start_loc = kVarlen ? reinterpret_cast(params.query_start_loc_ptr) : nullptr; + const int sequence_start_index = kVarlen ? query_start_loc[batch_id] : batch_id; + const int seqlen = kVarlen ? query_start_loc[batch_id + 1] - sequence_start_index : params.seqlen; + + input_t *x = reinterpret_cast(params.x_ptr) + sequence_start_index * params.x_batch_stride + channel_id * params.x_c_stride; weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride + channel_id * params.out_c_stride; float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + bool has_initial_state = params.has_initial_state_ptr == nullptr ? false + : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; + + int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr + : reinterpret_cast(params.cache_indices_ptr); + int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; + + input_t *conv_states = params.conv_states_ptr == nullptr ? nullptr + : reinterpret_cast(params.conv_states_ptr) + cache_index * params.conv_states_batch_stride + channel_id * params.conv_states_c_stride; + // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0. if (tidx == 0) { - input_t zeros[kNElts] = {0}; - smem_exchange[kNThreads - 1] = reinterpret_cast(zeros)[0]; + input_t initial_state[kNElts] = {0}; + if (has_initial_state) { + #pragma unroll + for (int w = 0; w < kWidth - 1; ++w){ initial_state[kNElts - 1 - (kWidth - 2) + w ] = conv_states[w]; } + } + smem_exchange[kNThreads - 1] = reinterpret_cast(initial_state)[0]; } float weight_vals[kWidth]; @@ -330,14 +359,14 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } constexpr int kChunkSize = kNThreads * kNElts; - const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize; + const int n_chunks = (seqlen + kChunkSize - 1) / kChunkSize; for (int chunk = 0; chunk < n_chunks; ++chunk) { input_t x_vals_load[2 * kNElts] = {0}; if constexpr(kIsVecLoad) { - typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts); + typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast(x), *reinterpret_cast(&x_vals_load[kNElts]), (seqlen - chunk * kChunkSize) / kNElts); } else { __syncthreads(); - typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize); + typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast(&x_vals_load[kNElts]), seqlen - chunk * kChunkSize); } x += kChunkSize; __syncthreads(); @@ -375,19 +404,57 @@ void causal_conv1d_fwd_kernel(ConvParamsBase params) { #pragma unroll for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; } if constexpr(kIsVecLoad) { - typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts); + typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast(out), reinterpret_cast(out_vals_store), (seqlen - chunk * kChunkSize) / kNElts); } else { - typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize); + typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, seqlen - chunk * kChunkSize); } out += kChunkSize; } + // Final state is stored in the smem_exchange last token slot, + // in case seqlen < kWidth, we would need to take the final state from the + // initial state which is stored in conv_states + // in case seqlen > kWidth, we would need to load the last kWidth - 1 data + // and load it into conv_state accordingly + int last_thread = ((seqlen - (kWidth - 1)) - (n_chunks - 1) * kChunkSize) / kNElts; + if (conv_states != nullptr && tidx == last_thread) { + input_t x_vals_load[kNElts * 2] = {0}; + // in case we are on the first kWidth tokens + if (last_thread == 0 && seqlen < kWidth){ + // Need to take the initial state + reinterpret_cast(x_vals_load)[0] = smem_exchange[0]; + const int offset = seqlen - (kWidth - 1); + #pragma unroll + for (int w = 0; w < kWidth - 1; ++w){ + // pad the existing state + if ((w - seqlen) >= 0 && has_initial_state) { conv_states[w - seqlen] = conv_states[w]; } + else if ((w - seqlen) >= 0 && !has_initial_state) { conv_states[w - seqlen] = input_t(0.0f); } + } + #pragma unroll + for (int w = 0; w < kWidth - 1; ++w){ + if (offset + w >= 0) + conv_states[w] = x_vals_load[offset + w ]; + } + } + else { + // in case the final state is in between the threads data + reinterpret_cast(x_vals_load)[1] = smem_exchange[last_thread + 1]; + reinterpret_cast(x_vals_load)[0] = smem_exchange[last_thread]; + const int offset = ((seqlen - (kWidth - 1)) % (kNElts)); + #pragma unroll + for (int w = 0; w < kWidth - 1; ++w){ + conv_states[w] = x_vals_load[offset + w ]; + } + } + + } } template void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8; - BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] { + const bool kVarlen = params.query_start_loc_ptr != nullptr; + BOOL_SWITCH(params.seqlen % kNElts == 0 && !kVarlen, kIsVecLoad, [&] { using Ktraits = Causal_conv1d_fwd_kernel_traits; constexpr int kSmemSize = Ktraits::kSmemSize; dim3 grid(params.batch, params.dim); @@ -422,220 +489,11 @@ void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { } } -template -struct Causal_conv1d_channellast_fwd_kernel_traits { - // The cache line is 128 bytes, and we try to read 16 bytes per thread. - // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension. - // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128 - // threads). Each each load is 16 x 32|64 elements in the L x C dimensions. - using input_t = input_t_; - using weight_t = weight_t_; - static constexpr int kNThreads = kNThreads_; - static_assert(kNThreads % 32 == 0); - static constexpr int kNWarps = kNThreads / 32; - static constexpr int kWidth = kWidth_; - static constexpr int kChunkSizeL = kChunkSizeL_; - static constexpr int kNBytes = sizeof(input_t); - static_assert(kNBytes == 2 || kNBytes == 4); - static constexpr int kNElts = kNBytes == 4 ? 4 : 8; - static constexpr int kNEltsPerRow = 128 / kNBytes; - static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now - static_assert(kNThreadsPerRow * kNBytes * kNElts == 128); - static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now - static_assert(kNColsPerWarp * kNThreadsPerRow == 32); - static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps; - static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad; - static_assert(kNLoads * kNColsPerLoad == kChunkSizeL); - static constexpr bool kIsVecLoad = kIsVecLoad_; - using vec_t = typename BytesToType::Type; - // using BlockLoadT = cub::BlockLoad; - // using BlockStoreT = cub::BlockStore; - // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage), - // sizeof(typename BlockStoreT::TempStorage)}); - // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes; -}; - -template -__global__ __launch_bounds__(Ktraits::kNThreads) -void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) { - constexpr int kWidth = Ktraits::kWidth; - constexpr int kNThreads = Ktraits::kNThreads; - constexpr int kNElts = Ktraits::kNElts; - constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow; - constexpr int kLPerLoad = Ktraits::kNColsPerLoad; - constexpr int kChunkSizeL = Ktraits::kChunkSizeL; - constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; - using input_t = typename Ktraits::input_t; - using vec_t = typename Ktraits::vec_t; - using weight_t = typename Ktraits::weight_t; - - // Shared memory. - __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts]; - - const int batch_id = blockIdx.x; - const int chunk_l_id = blockIdx.y; - const int chunk_c_id = blockIdx.z; - const int tid = threadIdx.x; - const int l_idx = tid / kNThreadsPerC; - const int c_idx = tid % kNThreadsPerC; - input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride - + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - weight_t *weight = reinterpret_cast(params.weight_ptr) - + chunk_c_id * kChunkSizeC * params.weight_c_stride; - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride - + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast(params.seq_idx_ptr) - + batch_id * params.seqlen + chunk_l_id * kChunkSizeL; - input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr - : reinterpret_cast(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - // The last L-chunk will also have enough info to write to final states, since it also contain a few x values - // from the previous L-chunk. - input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr - : reinterpret_cast(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts; - - #pragma unroll - for (int l = 0; l < Ktraits::kNLoads; ++l) { - input_t x_vals_load[kNElts] = {0}; - if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x + l * kLPerLoad * params.x_l_stride); - } - reinterpret_cast(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; - } - // Load the elements from the previous chunk that are needed for convolution. - if (l_idx < kWidth - 1) { - input_t x_vals_load[kNElts] = {0}; - if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0 - && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(x - (kWidth - 1) * params.x_l_stride); - } else if (initial_states != nullptr - && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0 - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - reinterpret_cast(x_vals_load)[0] = *reinterpret_cast(initial_states); - } - reinterpret_cast(x_smem[l_idx])[c_idx] = reinterpret_cast(x_vals_load)[0]; - } - - __syncthreads(); - - if (final_states != nullptr - && l_idx < kWidth - 1 - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1) - // So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx] - *reinterpret_cast(final_states) = reinterpret_cast(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx]; - } - - constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL); - static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC); - constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread; - static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL); - // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity - static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0); - static_assert((kLPerThread & (kLPerThread - 1)) == 0); - static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0); - static_assert(kNThreadsPerRow <= 32); - - const int row_idx = tid / kNThreadsPerRow; - const int col_idx = tid % kNThreadsPerRow; - - float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]); - float weight_vals[kWidth] = {0}; - if (chunk_c_id * kChunkSizeC + row_idx < params.dim) { - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride]; - } - } - float x_vals[kWidth - 1 + kLPerThread]; - #pragma unroll - for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { - x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]); - } - int seq_idx_thread[kWidth - 1 + kLPerThread]; - if constexpr (kHasSeqIdx) { - #pragma unroll - for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) { - seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1; - } - } - - float out_vals[kLPerThread]; - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { - out_vals[i] = bias_val; - const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1]; - #pragma unroll - for (int w = 0; w < kWidth; ++w) { - if constexpr (!kHasSeqIdx) { - out_vals[i] += weight_vals[w] * x_vals[i + w]; - } else { - out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f; - } - } - if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); } - } - - __syncthreads(); - #pragma unroll - for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; } - __syncthreads(); - - #pragma unroll - for (int l = 0; l < Ktraits::kNLoads; ++l) { - input_t out_vals_store[kNElts]; - reinterpret_cast(out_vals_store)[0] = reinterpret_cast(x_smem[l * kLPerLoad + l_idx])[c_idx]; - if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen - && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) { - *reinterpret_cast(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast(out_vals_store)[0]; - } - } - -} - -template -void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) { - BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] { - using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits; - // constexpr int kSmemSize = Ktraits::kSmemSize; - constexpr int kChunkSizeL = Ktraits::kChunkSizeL; - constexpr int kChunkSizeC = Ktraits::kNEltsPerRow; - const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL; - const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC; - dim3 grid(params.batch, n_chunks_L, n_chunks_C); - dim3 block(Ktraits::kNThreads); - auto kernel = &causal_conv1d_channellast_fwd_kernel; - // if (kSmemSize >= 48 * 1024) { - // C10_CUDA_CHECK(cudaFuncSetAttribute( - // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize)); - // } - // kernel<<>>(params); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); - }); -} - -template -void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) { - if (params.width == 2) { - causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream); - } else if (params.width == 3) { - causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream); - } else if (params.width == 4) { - causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream); - } -} template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); template void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -template void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream); -/////// - @@ -649,7 +507,7 @@ struct Causal_conv1d_update_kernel_traits { static_assert(kNBytes == 2 || kNBytes == 4); }; -template +template __global__ __launch_bounds__(Ktraits::kNThreads) void causal_conv1d_update_kernel(ConvParamsBase params) { constexpr int kWidth = Ktraits::kWidth; @@ -660,6 +518,8 @@ void causal_conv1d_update_kernel(ConvParamsBase params) { const int tidx = threadIdx.x; const int batch_id = blockIdx.x; const int channel_id = blockIdx.y * kNThreads + tidx; + if (channel_id >= params.dim) return; + input_t *x = reinterpret_cast(params.x_ptr) + batch_id * params.x_batch_stride + channel_id * params.x_c_stride; @@ -675,35 +535,70 @@ void causal_conv1d_update_kernel(ConvParamsBase params) { weight_t *weight = reinterpret_cast(params.weight_ptr) + channel_id * params.weight_c_stride; input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + channel_id * params.out_c_stride; - float bias_val = params.bias_ptr == nullptr || channel_id >= params.dim ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast(params.bias_ptr)[channel_id]); + + int state_len = params.conv_state_len; + int advance_len = params.seqlen; + int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0; + int update_idx = cache_seqlen - (kWidth - 1); + update_idx = update_idx < 0 ? update_idx + state_len : update_idx; float weight_vals[kWidth] = {0}; - if (channel_id < params.dim) { - #pragma unroll - for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } - } + #pragma unroll + for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); } float x_vals[kWidth] = {0}; - if (channel_id < params.dim) { + if constexpr (!kIsCircularBuffer) { + #pragma unroll 2 + for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) { + conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride]; + } #pragma unroll - for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = float(conv_state[(i + 1) * params.conv_state_l_stride]); } - x_vals[kWidth - 1] = float(x[0]); + for (int i = 0; i < kWidth - 1; ++i) { + input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride]; + if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) { + conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val; + } + x_vals[i] = float(state_val); + } + } else { #pragma unroll - for (int i = 0; i < kWidth; ++i) { conv_state[i * params.conv_state_l_stride] = input_t(x_vals[i]); } + for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) { + input_t state_val = conv_state[update_idx * params.conv_state_l_stride]; + x_vals[i] = float(state_val); + } + } + #pragma unroll 2 + for (int i = 0; i < params.seqlen; ++i) { + input_t x_val = x[i * params.x_l_stride]; + if constexpr (!kIsCircularBuffer) { + if (i < advance_len && state_len - advance_len + i >= 0) { + conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val; + } + } else { + conv_state[update_idx * params.conv_state_l_stride] = x_val; + ++update_idx; + update_idx = update_idx >= state_len ? update_idx - state_len : update_idx; + } + x_vals[kWidth - 1] = float(x_val); + float out_val = bias_val; + #pragma unroll + for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; } + if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } + out[i * params.out_l_stride] = input_t(out_val); + // Shift the input buffer by 1 + #pragma unroll + for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; } } - - float out_val = bias_val; - #pragma unroll - for (int i = 0; i < kWidth; ++i) { out_val += weight_vals[i] * x_vals[i]; } - if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); } - if (channel_id < params.dim) { out[0] = input_t(out_val); } } template void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) { using Ktraits = Causal_conv1d_update_kernel_traits; dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads); - auto kernel = &causal_conv1d_update_kernel; + auto kernel = params.cache_seqlens == nullptr + ? &causal_conv1d_update_kernel + : &causal_conv1d_update_kernel; kernel<<>>(params); C10_CUDA_KERNEL_LAUNCH_CHECK(); } diff --git a/csrc/mamba/causal_conv1d/causal_conv1d.h b/csrc/mamba/causal_conv1d/causal_conv1d.h index 32a7d83c09b8d..49e37ee4528be 100644 --- a/csrc/mamba/causal_conv1d/causal_conv1d.h +++ b/csrc/mamba/causal_conv1d/causal_conv1d.h @@ -24,6 +24,7 @@ struct ConvParamsBase { index_t out_c_stride; index_t out_l_stride; + int conv_state_len; index_t conv_state_batch_stride; index_t conv_state_c_stride; index_t conv_state_l_stride; @@ -35,6 +36,10 @@ struct ConvParamsBase { void *__restrict__ out_ptr; void *__restrict__ conv_state_ptr; + void *__restrict__ query_start_loc_ptr; + void *__restrict__ has_initial_state_ptr; + void *__restrict__ cache_indices_ptr; + int32_t *__restrict__ cache_seqlens; // For the continuous batching case. Makes it so that the mamba state for // the current batch doesn't need to be a contiguous tensor. @@ -52,6 +57,11 @@ struct ConvParamsBase { index_t final_states_batch_stride; index_t final_states_l_stride; index_t final_states_c_stride; + + void * conv_states_ptr; + index_t conv_states_batch_stride; + index_t conv_states_l_stride; + index_t conv_states_c_stride; }; diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 0070c92f6cd0f..580d0b2e17e74 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -54,10 +54,14 @@ struct SSMParamsBase { void *__restrict__ delta_ptr; void *__restrict__ delta_bias_ptr; void *__restrict__ out_ptr; - void *__restrict__ x_ptr; + void *__restrict__ ssm_states_ptr; void *__restrict__ z_ptr; void *__restrict__ out_z_ptr; - void *__restrict__ index_ptr; + + void *__restrict__ query_start_loc_ptr; + void *__restrict__ cache_indices_ptr; + void *__restrict__ has_initial_state_ptr; + }; @@ -201,7 +205,7 @@ inline __device__ void load_input(typename Ktraits::input_t *u, typename Ktraits::input_t (&u_vals)[Ktraits::kNItems], typename Ktraits::BlockLoadT::TempStorage &smem_load, int seqlen) { - if constexpr (Ktraits::kIsEvenLen) { + if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) { auto& smem_load_vec = reinterpret_cast(smem_load); using vec_t = typename Ktraits::vec_t; typename Ktraits::BlockLoadVecT(smem_load_vec).Load( @@ -217,21 +221,6 @@ inline __device__ void load_input(typename Ktraits::input_t *u, } } -template -inline __device__ void load_index(int *u, - int (&u_vals)[Ktraits::kNItems], - typename Ktraits::BlockLoadIndexT::TempStorage &smem_load_index, - int seqlen) { - if constexpr (Ktraits::kIsEvenLen) { - auto& smem_load_index_vec = reinterpret_cast(smem_load_index); - Ktraits::BlockLoadIndexVecT(smem_load_index_vec).Load( - reinterpret_cast(u), - reinterpret_cast(u_vals) - ); - } else { - Ktraits::BlockLoadIndexT(smem_load_index).Load(u, u_vals, seqlen, 0); - } -} template inline __device__ void load_weight(typename Ktraits::input_t *Bvar, @@ -240,7 +229,7 @@ inline __device__ void load_weight(typename Ktraits::input_t *Bvar, int seqlen) { constexpr int kNItems = Ktraits::kNItems; typename Ktraits::input_t B_vals_load[kNItems]; - if constexpr (Ktraits::kIsEvenLen) { + if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) { auto& smem_load_weight_vec = reinterpret_cast(smem_load_weight); using vec_t = typename Ktraits::vec_t; typename Ktraits::BlockLoadWeightVecT(smem_load_weight_vec).Load( @@ -263,7 +252,7 @@ inline __device__ void store_output(typename Ktraits::input_t *out, typename Ktraits::input_t write_vals[Ktraits::kNItems]; #pragma unroll for (int i = 0; i < Ktraits::kNItems; ++i) { write_vals[i] = out_vals[i]; } - if constexpr (Ktraits::kIsEvenLen) { + if constexpr (Ktraits::kIsEvenLen && !Ktraits::kVarlen) { auto& smem_store_vec = reinterpret_cast(smem_store); using vec_t = typename Ktraits::vec_t; typename Ktraits::BlockStoreVecT(smem_store_vec).Store( diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index d7829f5d583d4..6b225b41d295d 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -23,7 +23,7 @@ template + bool kHasZ_, bool kVarlen_, typename input_t_, typename weight_t_> struct Selective_Scan_fwd_kernel_traits { static_assert(kNItems_ % 4 == 0); using input_t = input_t_; @@ -38,22 +38,19 @@ struct Selective_Scan_fwd_kernel_traits { static constexpr int kNElts = kNBytes == 4 ? 4 : constexpr_min(8, kNItems); static_assert(kNItems % kNElts == 0); static constexpr int kNLoads = kNItems / kNElts; - static constexpr bool kIsEvenLen = kIsEvenLen_; + static constexpr bool kIsEvenLen = kVarlen_ ? false : kIsEvenLen_; static constexpr bool kIsVariableB = kIsVariableB_; static constexpr bool kIsVariableC = kIsVariableC_; static constexpr bool kHasZ = kHasZ_; - static constexpr bool kUseIndex = kUseIndex_; + static constexpr bool kVarlen = kVarlen_; - static constexpr bool kDirectIO = kIsEvenLen && kNLoads == 1; + static constexpr bool kDirectIO = kVarlen_ ? false : kIsEvenLen && kNLoads == 1; static constexpr int kNLoadsIndex = kNItems / 4; using vec_t = typename BytesToType::Type; using scan_t = float2; using BlockLoadT = cub::BlockLoad; using BlockLoadVecT = cub::BlockLoad; - using BlockLoadIndexT = cub::BlockLoad; - using BlockLoadIndexVecT = cub::BlockLoad; using BlockLoadWeightT = cub::BlockLoad; using BlockLoadWeightVecT = cub::BlockLoad; @@ -65,8 +62,6 @@ struct Selective_Scan_fwd_kernel_traits { using BlockScanT = cub::BlockScan; static constexpr int kSmemIOSize = custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockLoadVecT::TempStorage), - sizeof(typename BlockLoadIndexT::TempStorage), - sizeof(typename BlockLoadIndexVecT::TempStorage), (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightT::TempStorage), (int(kIsVariableB) + int(kIsVariableC)) * sizeof(typename BlockLoadWeightVecT::TempStorage), sizeof(typename BlockStoreT::TempStorage), @@ -80,7 +75,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { constexpr bool kIsVariableB = Ktraits::kIsVariableB; constexpr bool kIsVariableC = Ktraits::kIsVariableC; constexpr bool kHasZ = Ktraits::kHasZ; - constexpr bool kUseIndex = Ktraits::kUseIndex; + constexpr bool kVarlen = Ktraits::kVarlen; constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNItems = Ktraits::kNItems; constexpr int kNRows = Ktraits::kNRows; @@ -97,7 +92,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // auto& smem_load = reinterpret_cast(smem_loadstorescan); auto& smem_load = reinterpret_cast(smem_); auto& smem_load_weight = reinterpret_cast(smem_); - auto& smem_load_index = reinterpret_cast(smem_); auto& smem_load_weight1 = *reinterpret_cast(smem_ + sizeof(typename Ktraits::BlockLoadWeightT::TempStorage)); auto& smem_store = reinterpret_cast(smem_); auto& smem_scan = *reinterpret_cast(smem_ + Ktraits::kSmemIOSize); @@ -108,17 +102,29 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { const int batch_id = blockIdx.x; const int dim_id = blockIdx.y; const int group_id = dim_id / (params.dim_ngroups_ratio); - input_t *u = reinterpret_cast(params.u_ptr) + batch_id * params.u_batch_stride + int seqlen = params.seqlen; + int sequence_start_index = batch_id; + if constexpr (kVarlen){ + int *query_start_loc = reinterpret_cast(params.query_start_loc_ptr); + sequence_start_index = query_start_loc[batch_id]; + seqlen = query_start_loc[batch_id + 1] - sequence_start_index; + } + const bool has_initial_state = params.has_initial_state_ptr == nullptr ? false + : reinterpret_cast(params.has_initial_state_ptr)[batch_id]; + + const int* cache_indices = params.cache_indices_ptr == nullptr ? nullptr + : reinterpret_cast(params.cache_indices_ptr); + const int cache_index = cache_indices == nullptr ? batch_id : cache_indices[batch_id]; + input_t *u = reinterpret_cast(params.u_ptr) + sequence_start_index * params.u_batch_stride + dim_id * kNRows * params.u_d_stride; - input_t *delta = reinterpret_cast(params.delta_ptr) + batch_id * params.delta_batch_stride + input_t *delta = reinterpret_cast(params.delta_ptr) + sequence_start_index * params.delta_batch_stride + dim_id * kNRows * params.delta_d_stride; weight_t *A = reinterpret_cast(params.A_ptr) + dim_id * kNRows * params.A_d_stride; weight_t *B = reinterpret_cast(params.B_ptr) + dim_id * kNRows * params.B_d_stride; - input_t *Bvar = reinterpret_cast(params.B_ptr) + batch_id * params.B_batch_stride + group_id * params.B_group_stride; + input_t *Bvar = reinterpret_cast(params.B_ptr) + sequence_start_index * params.B_batch_stride + group_id * params.B_group_stride; weight_t *C = reinterpret_cast(params.C_ptr) + dim_id * kNRows * params.C_d_stride; - input_t *Cvar = reinterpret_cast(params.C_ptr) + batch_id * params.C_batch_stride + group_id * params.C_group_stride; - scan_t *x = reinterpret_cast(params.x_ptr) + (batch_id * params.dim + dim_id * kNRows) * params.n_chunks * params.dstate; - int *index = !kUseIndex ? nullptr :reinterpret_cast(params.index_ptr) + batch_id * params.seqlen; + input_t *Cvar = reinterpret_cast(params.C_ptr) + sequence_start_index * params.C_batch_stride + group_id * params.C_group_stride; + input_t *ssm_states = reinterpret_cast(params.ssm_states_ptr) + (cache_index * params.dim + dim_id * kNRows) * params.dstate; float D_val[kNRows] = {0}; if (params.D_ptr != nullptr) { @@ -142,9 +148,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // } constexpr int kChunkSize = kNThreads * kNItems; - for (int chunk = 0; chunk < params.n_chunks; ++chunk) { + const int n_chunks = (seqlen + 2048 - 1) / 2048; + for (int chunk = 0; chunk < n_chunks; ++chunk) { input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; - int index_vals_load[kNRows][kNItems]; __syncthreads(); #pragma unroll @@ -152,15 +158,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if constexpr (!kDirectIO) { if (r > 0) { __syncthreads(); } } - load_input(u + r * params.u_d_stride, u_vals[r], smem_load, params.seqlen - chunk * kChunkSize); + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize); if constexpr (!kDirectIO) { __syncthreads(); } - load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, params.seqlen - chunk * kChunkSize); - if constexpr (kUseIndex) { - load_index(index + r * params.delta_d_stride, index_vals_load[r], smem_load_index, params.seqlen - chunk * kChunkSize); - } - } - if constexpr (kUseIndex) { - index += kChunkSize; + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize); } u += kChunkSize; delta += kChunkSize; @@ -195,9 +195,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // If both B and C vary, this is unused. weight_t BC_val[kNRows]; weight_t B_vals[kNItems], C_vals[kNItems]; - if constexpr (kIsVariableB) { + if constexpr (kIsVariableB) { load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, - smem_load_weight, (params.seqlen - chunk * kChunkSize) * (1)); + smem_load_weight, (seqlen - chunk * kChunkSize) * (1)); if constexpr (!kIsVariableC) { #pragma unroll for (int r = 0; r < kNRows; ++r) { @@ -208,7 +208,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if constexpr (kIsVariableC) { auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (params.seqlen - chunk * kChunkSize) * (1 )); + smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1 )); if constexpr (!kIsVariableB) { #pragma unroll for (int r = 0; r < kNRows; ++r) { @@ -232,24 +232,16 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); - // Reset A bar for cumulative sequences (Real) - if constexpr (kUseIndex) { - if (index_vals_load[r][i] == 0) { - thread_data[i].x = 0.f; - } - } - - if constexpr (!Ktraits::kIsEvenLen) { // So that the last state is correct - if (threadIdx.x * kNItems + i >= params.seqlen - chunk * kChunkSize) { + if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct + if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) { thread_data[i] = make_float2(1.f, 0.f); } } } // Initialize running total - scan_t running_prefix; - // If we use WARP_SCAN then all lane 0 of all warps (not just thread 0) needs to read - running_prefix = chunk == 0 ? x[(r * params.n_chunks) * params.dstate + state_idx] : ( threadIdx.x % 32 == 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.f, 0.f)); - // running_prefix = chunk > 0 && threadIdx.x == 0 ? smem_running_prefix[state_idx] : make_float2(1.f, 0.f); + + scan_t running_prefix = chunk > 0 ? smem_running_prefix[state_idx + r * MAX_DSTATE] : make_float2(1.0, has_initial_state ? float(ssm_states[state_idx]): 0.0); + SSMScanPrefixCallbackOp prefix_op(running_prefix); typename Ktraits::BlockScanT(smem_scan).InclusiveScan( thread_data, thread_data, SSMScanOp(), prefix_op @@ -258,7 +250,9 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { // Unless there's only 1 warp, but then it's the same thread (0) reading and writing. if (threadIdx.x == 0) { smem_running_prefix[state_idx] = prefix_op.running_prefix; - x[(r * params.n_chunks + chunk) * params.dstate + state_idx] = prefix_op.running_prefix; + if (chunk == n_chunks - 1) { + ssm_states[state_idx] = input_t(prefix_op.running_prefix.y); + } } #pragma unroll for (int i = 0; i < kNItems; ++i) { @@ -270,7 +264,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } - input_t *out = reinterpret_cast(params.out_ptr) + batch_id * params.out_batch_stride + input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; __syncthreads(); #pragma unroll @@ -278,26 +272,26 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if constexpr (!kDirectIO) { if (r > 0) { __syncthreads(); } } - store_output(out + r * params.out_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize); } if constexpr (kHasZ) { - input_t *z = reinterpret_cast(params.z_ptr) + batch_id * params.z_batch_stride + input_t *z = reinterpret_cast(params.z_ptr) + sequence_start_index * params.z_batch_stride + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; - input_t *out_z = reinterpret_cast(params.out_z_ptr) + batch_id * params.out_z_batch_stride + input_t *out_z = reinterpret_cast(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; #pragma unroll for (int r = 0; r < kNRows; ++r) { input_t z_vals[kNItems]; __syncthreads(); - load_input(z + r * params.z_d_stride, z_vals, smem_load, params.seqlen - chunk * kChunkSize); + load_input(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize); #pragma unroll for (int i = 0; i < kNItems; ++i) { float z_val = z_vals[i]; out_vals[r][i] *= z_val / (1 + expf(-z_val)); } __syncthreads(); - store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, params.seqlen - chunk * kChunkSize); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize); } } @@ -316,8 +310,8 @@ void selective_scan_fwd_launch(SSMParamsBase ¶ms, cudaStream_t stream) { constexpr bool kIsVariableC = true; constexpr bool kHasZ = true; BOOL_SWITCH(params.seqlen % (kNThreads * kNItems) == 0, kIsEvenLen, [&] { - BOOL_SWITCH(params.index_ptr != nullptr , kUseIndex, [&] { - using Ktraits = Selective_Scan_fwd_kernel_traits; + BOOL_SWITCH(params.query_start_loc_ptr != nullptr , kVarlen, [&] { + using Ktraits = Selective_Scan_fwd_kernel_traits; constexpr int kSmemSize = Ktraits::kSmemSize + kNRows * MAX_DSTATE * sizeof(typename Ktraits::scan_t); dim3 grid(params.batch, params.dim / kNRows); auto kernel = &selective_scan_fwd_kernel; @@ -405,12 +399,15 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, const torch::Tensor out, const torch::Tensor z, const torch::Tensor out_z, - void* D_ptr, - void* delta_bias_ptr, - void* x_ptr, + const c10::optional& D, + const c10::optional& delta_bias, + const torch::Tensor ssm_states, bool has_z, bool delta_softplus, - void* index_ptr) { + const c10::optional& query_start_loc, + const c10::optional& cache_indices, + const c10::optional& has_initial_state, + bool varlen) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -434,55 +431,83 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.A_ptr = A.data_ptr(); params.B_ptr = B.data_ptr(); params.C_ptr = C.data_ptr(); - params.D_ptr = D_ptr; - params.delta_bias_ptr = delta_bias_ptr; + params.D_ptr = D.has_value() ? D.value().data_ptr() : nullptr; + params.delta_bias_ptr = delta_bias.has_value() ? delta_bias.value().data_ptr() : nullptr; params.out_ptr = out.data_ptr(); - params.x_ptr = x_ptr; + params.ssm_states_ptr = ssm_states.data_ptr(); params.z_ptr = has_z ? z.data_ptr() : nullptr; params.out_z_ptr = has_z ? out_z.data_ptr() : nullptr; + params.query_start_loc_ptr = query_start_loc.has_value() ? query_start_loc.value().data_ptr() : nullptr; + params.cache_indices_ptr = cache_indices.has_value() ? cache_indices.value().data_ptr() : nullptr; + params.has_initial_state_ptr = has_initial_state.has_value() ? has_initial_state.value().data_ptr() : nullptr; - params.index_ptr = index_ptr; // All stride are in elements, not bytes. params.A_d_stride = A.stride(0); params.A_dstate_stride = A.stride(1); - if (!is_variable_B) { - params.B_d_stride = B.stride(0); - } else { - params.B_batch_stride = B.stride(0); - params.B_group_stride = B.stride(1); - } - params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); - if (!is_variable_C) { - params.C_d_stride = C.stride(0); - } else { - params.C_batch_stride = C.stride(0); - params.C_group_stride = C.stride(1); + + if (varlen){ + params.B_batch_stride = B.stride(2); + params.B_group_stride = B.stride(0); + params.B_dstate_stride = B.stride(1); + params.C_batch_stride = C.stride(2); + params.C_group_stride = C.stride(0); + params.C_dstate_stride = C.stride(1); + + params.u_batch_stride = u.stride(1); + params.u_d_stride = u.stride(0); + params.delta_batch_stride = delta.stride(1); + params.delta_d_stride = delta.stride(0); + if (has_z) { + params.z_batch_stride = z.stride(1); + params.z_d_stride = z.stride(0); + params.out_z_batch_stride = out_z.stride(1); + params.out_z_d_stride = out_z.stride(0); + } + params.out_batch_stride = out.stride(1); + params.out_d_stride = out.stride(0); + } - params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); - params.u_batch_stride = u.stride(0); - params.u_d_stride = u.stride(1); - params.delta_batch_stride = delta.stride(0); - params.delta_d_stride = delta.stride(1); - if (has_z) { - params.z_batch_stride = z.stride(0); - params.z_d_stride = z.stride(1); - params.out_z_batch_stride = out_z.stride(0); - params.out_z_d_stride = out_z.stride(1); + else{ + if (!is_variable_B) { + params.B_d_stride = B.stride(0); + } else { + params.B_batch_stride = B.stride(0); + params.B_group_stride = B.stride(1); + } + params.B_dstate_stride = !is_variable_B ? B.stride(1) : B.stride(2); + if (!is_variable_C) { + params.C_d_stride = C.stride(0); + } else { + params.C_batch_stride = C.stride(0); + params.C_group_stride = C.stride(1); + } + params.C_dstate_stride = !is_variable_C ? C.stride(1) : C.stride(2); + params.u_batch_stride = u.stride(0); + params.u_d_stride = u.stride(1); + params.delta_batch_stride = delta.stride(0); + params.delta_d_stride = delta.stride(1); + if (has_z) { + params.z_batch_stride = z.stride(0); + params.z_d_stride = z.stride(1); + params.out_z_batch_stride = out_z.stride(0); + params.out_z_d_stride = out_z.stride(1); + } + params.out_batch_stride = out.stride(0); + params.out_d_stride = out.stride(1); } - params.out_batch_stride = out.stride(0); - params.out_d_stride = out.stride(1); } -std::vector -selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, +void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, const torch::Tensor &A, const torch::Tensor &B, const torch::Tensor &C, const c10::optional &D_, const c10::optional &z_, const c10::optional &delta_bias_, bool delta_softplus, - const c10::optional &index_, - const c10::optional &x) { + const c10::optional &query_start_loc, + const c10::optional &cache_indices, + const c10::optional &has_initial_state, + const torch::Tensor &ssm_states) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -505,23 +530,37 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, TORCH_CHECK(delta.stride(-1) == 1 || delta.size(-1) == 1); const auto sizes = u.sizes(); - const int batch_size = sizes[0]; - const int dim = sizes[1]; - const int seqlen = sizes[2]; + const bool varlen = query_start_loc.has_value(); + const int batch_size = varlen ? query_start_loc.value().sizes()[0] - 1 : sizes[0]; + const int dim = varlen ? sizes[0] : sizes[1]; + const int seqlen = varlen ? sizes[1] : sizes[2]; const int dstate = A.size(1); - const int n_groups = is_variable_B ? B.size(1) : 1; + const int n_groups = varlen ? B.size(0) : B.size(1); TORCH_CHECK(dstate <= 256, "selective_scan only supports state dimension <= 256"); - CHECK_SHAPE(u, batch_size, dim, seqlen); - CHECK_SHAPE(delta, batch_size, dim, seqlen); + if (varlen) { + CHECK_SHAPE(u, dim, seqlen); + CHECK_SHAPE(delta, dim, seqlen); + } else { + CHECK_SHAPE(u, batch_size, dim, seqlen); + CHECK_SHAPE(delta, batch_size, dim, seqlen); + } CHECK_SHAPE(A, dim, dstate); TORCH_CHECK(is_variable_B, "is_variable_B = False is disabled in favor of reduced binary size") - CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen ); + if (varlen) { + CHECK_SHAPE(B, n_groups, dstate, seqlen); + } else { + CHECK_SHAPE(B, batch_size, n_groups, dstate, seqlen); + } TORCH_CHECK(B.stride(-1) == 1 || B.size(-1) == 1); TORCH_CHECK(is_variable_C, "is_variable_C = False is disabled in favor of reduced binary size") - CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); + if (varlen) { + CHECK_SHAPE(C, n_groups, dstate, seqlen); + } else { + CHECK_SHAPE(C, batch_size, n_groups, dstate, seqlen); + } TORCH_CHECK(C.stride(-1) == 1 || C.size(-1) == 1); if (D_.has_value()) { @@ -539,12 +578,30 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, TORCH_CHECK(delta_bias.stride(-1) == 1 || delta_bias.size(-1) == 1); CHECK_SHAPE(delta_bias, dim); } - if (index_.has_value()) { - auto index = index_.value(); - TORCH_CHECK(index.scalar_type() == at::ScalarType::Int); - TORCH_CHECK(index.is_cuda()); - CHECK_SHAPE(index, batch_size, seqlen); + + + if (has_initial_state.has_value()) { + auto has_initial_state_ = has_initial_state.value(); + TORCH_CHECK(has_initial_state_.scalar_type() == at::ScalarType::Bool); + TORCH_CHECK(has_initial_state_.is_cuda()); + CHECK_SHAPE(has_initial_state_, batch_size); + } + + + if (query_start_loc.has_value()) { + auto query_start_loc_ = query_start_loc.value(); + TORCH_CHECK(query_start_loc_.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(query_start_loc_.is_cuda()); + } + + + if (cache_indices.has_value()) { + auto cache_indices_ = cache_indices.value(); + TORCH_CHECK(cache_indices_.scalar_type() == at::ScalarType::Int); + TORCH_CHECK(cache_indices_.is_cuda()); + CHECK_SHAPE(cache_indices_, batch_size); } + at::Tensor z, out_z; const bool has_z = z_.has_value(); @@ -553,32 +610,39 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, TORCH_CHECK(z.scalar_type() == input_type); TORCH_CHECK(z.is_cuda()); TORCH_CHECK(z.stride(-1) == 1 || z.size(-1) == 1); - CHECK_SHAPE(z, batch_size, dim, seqlen); - out_z = torch::empty_like(z); + if (varlen){ + CHECK_SHAPE(z, dim, seqlen); + } else { + CHECK_SHAPE(z, batch_size, dim, seqlen); + } + + out_z = z; const int n_chunks = (seqlen + 2048 - 1) / 2048; // const int n_chunks = (seqlen + 1024 - 1) / 1024; // at::Tensor out = torch::empty_like(u); // Right now u has BHL layout and delta has HBL layout, and we want out to have HBL layout - at::Tensor out = torch::empty_like(delta); - if (x.has_value()){ - auto _x = x.value(); - TORCH_CHECK(_x.scalar_type() == weight_type); - TORCH_CHECK(_x.is_cuda()); - TORCH_CHECK(_x.stride(-1) == 1); - CHECK_SHAPE(_x, batch_size, dim, n_chunks, dstate * 2); - } + at::Tensor out = delta; + TORCH_CHECK(ssm_states.scalar_type() == input_type); + TORCH_CHECK(ssm_states.is_cuda()); + TORCH_CHECK(ssm_states.stride(-1) == 1); + CHECK_SHAPE(ssm_states, batch_size, dim, dstate); SSMParamsBase params; set_ssm_params_fwd(params, batch_size, dim, seqlen, dstate, n_groups, n_chunks, is_variable_B, is_variable_C, u, delta, A, B, C, out, z, out_z, - D_.has_value() ? D_.value().data_ptr() : nullptr, - delta_bias_.has_value() ? delta_bias_.value().data_ptr() : nullptr, - x.value().data_ptr(), + D_, + delta_bias_, + ssm_states, has_z, delta_softplus, - index_.has_value() ? index_.value().data_ptr() : nullptr); + query_start_loc, + cache_indices, + has_initial_state, + varlen + ); + // Otherwise the kernel will be launched from cuda:0 device // Cast to char to avoid compiler warning about narrowing at::cuda::CUDAGuard device_guard{(char)u.get_device()}; @@ -586,8 +650,5 @@ selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, DISPATCH_WTYPE_ITYPE_FLOAT_AND_HALF_AND_BF16(u.scalar_type(), "selective_scan_fwd", [&] { selective_scan_fwd_cuda(params, stream); }); - std::vector result = {out}; - if (has_z) { result.push_back(out_z); } - return result; } diff --git a/csrc/moe/marlin_moe_ops.cu b/csrc/moe/marlin_moe_ops.cu index dfe0437414013..c97b5dbd2a54e 100644 --- a/csrc/moe/marlin_moe_ops.cu +++ b/csrc/moe/marlin_moe_ops.cu @@ -25,6 +25,7 @@ #include +#include "core/exception.hpp" #include "core/scalar_type.hpp" #include "marlin_kernels/marlin_moe_kernel_ku4b8.h" #include "marlin_kernels/marlin_moe_kernel_ku8b128.h" @@ -189,7 +190,7 @@ int get_scales_cache_size(thread_config_t const& th_config, int prob_m, int load_groups = tb_groups * STAGES * 2; // Chunk size is 2x pipeline over dim K load_groups = max(load_groups, 32); // We load at least 32 scale groups - return load_groups * tb_n * 2; + return load_groups * tb_n * 4; } else { int tb_scales = tb_groups * tb_n * 2; @@ -433,11 +434,7 @@ void marlin_mm_moe(const void* A, const void* B, void* C, int4* C_ptr = (int4*)C; const float* topk_weights_ptr = (const float*)topk_weights; const int* sorted_ids_ptr = (const int*)sorted_ids; - const int4* s_ptr = - (const int4*)s + - (((group_size == -1 || group_size == 0) ? 1 : prob_k / group_size) * - prob_n / 8) * - expert_idx; + const int4* s_ptr = (const int4*)s + num_groups * prob_n / 8 * expert_idx; const int* g_idx_ptr = (const int*)g_idx + prob_k * expert_idx; const int* perm_ptr = (const int*)perm + prob_k * expert_idx; int* locks = (int*)workspace; @@ -521,6 +518,9 @@ torch::Tensor marlin_gemm_moe( " is not size_n = ", size_n); num_groups = b_scales.size(1); + TORCH_CHECK(VLLM_IMPLIES(!is_k_full, has_act_order), + "if is_k_full is false, has_act_order must be true"); + if (has_act_order) { if (is_k_full) { TORCH_CHECK(num_groups > 1, "For act_order, num_groups must be > 1"); diff --git a/csrc/ops.h b/csrc/ops.h index 7ad0abd46c82a..3e31ddb286e80 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -215,25 +215,30 @@ void moe_align_block_size(torch::Tensor topk_ids, int64_t num_experts, torch::Tensor experts_ids, torch::Tensor num_tokens_post_pad); -std::vector selective_scan_fwd( - const torch::Tensor& u, const torch::Tensor& delta, const torch::Tensor& A, - const torch::Tensor& B, const torch::Tensor& C, - const c10::optional& D_, - const c10::optional& z_, - const c10::optional& delta_bias_, bool delta_softplus, - const c10::optional& index_, - const c10::optional& x); +void selective_scan_fwd(const torch::Tensor& u, const torch::Tensor& delta, + const torch::Tensor& A, const torch::Tensor& B, + const torch::Tensor& C, + const c10::optional& D_, + const c10::optional& z_, + const c10::optional& delta_bias_, + bool delta_softplus, + const c10::optional& query_start_loc, + const c10::optional& cache_indices, + const c10::optional& has_initial_state, + const torch::Tensor& ssm_states); at::Tensor causal_conv1d_update( const at::Tensor& x, const at::Tensor& conv_state, const at::Tensor& weight, - const c10::optional& bias, bool silu_activation, - const c10::optional& conv_state_indices); + const c10::optional& bias_, bool silu_activation, + const c10::optional& cache_seqlens_, + const c10::optional& conv_state_indices_); at::Tensor causal_conv1d_fwd(const at::Tensor& x, const at::Tensor& weight, const c10::optional& bias_, - const c10::optional& seq_idx_, - const c10::optional& initial_states_, - const c10::optional& final_states_out_, + const c10::optional& conv_states, + const c10::optional& query_start_loc, + const c10::optional& cache_indices, + const c10::optional& has_initial_state, bool silu_activation); #ifndef USE_ROCM diff --git a/csrc/prepare_inputs/advance_step.cu b/csrc/prepare_inputs/advance_step.cu index a9d08ca0dc14c..195eb27dee749 100644 --- a/csrc/prepare_inputs/advance_step.cu +++ b/csrc/prepare_inputs/advance_step.cu @@ -52,7 +52,7 @@ __global__ void advance_step_flashattn_kernel( slot_mapping_ptr[cur_query_id] = slot_num; } -inline void verify_tensor(std::string const& name, torch::Tensor& t, +inline void verify_tensor(std::string const& name, torch::Tensor const& t, int64_t const size_0, int64_t const size_1, c10::ScalarType const type) { bool size_0_cond = true; @@ -211,7 +211,7 @@ void advance_step_flashinfer( printf(" num_seqs = %d\n", num_seqs); printf(" num_queries = %d\n", num_queries); printf(" block_size = %d\n", block_size); - printf(" block_tables.stride(0) = %d\n", block_tables.stride(0)); + printf(" block_tables.stride(0) = %zu\n", block_tables.stride(0)); } // Verify all tensors verify_tensor("input_tokens", input_tokens, num_seqs, -1, at::kLong); @@ -303,4 +303,4 @@ void advance_step_flashinfer( num_seqs, num_queries, block_size, input_tokens, sampled_token_ids, input_positions, seq_lens, slot_mapping, block_tables, paged_kv_indices, paged_kv_indptr, paged_kv_last_page_len, block_table_bound); -} \ No newline at end of file +} diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 8ed81ea727aa3..c35dfe94c9c41 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -457,7 +457,13 @@ def generate(): )), ] - schedules = list(set([x[1] for x in default_heuristic])) + # Do not use schedules = list(set(...)) because we need to make sure + # the output list is deterministic; otherwise the generated kernel file + # will be non-deterministic and causes ccache miss. + schedules = [] + for _, schedule_config in default_heuristic: + if schedule_config not in schedules: + schedules.append(schedule_config) impl_configs = [] diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index b6ba1b2a26e10..3538f2850f915 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -273,26 +273,31 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { ops.def( "selective_scan_fwd(Tensor! u, Tensor! delta," "Tensor! A, Tensor! B, Tensor! C," - "Tensor? D_, Tensor? z_, Tensor? delta_bias_," + "Tensor? D_, Tensor!? z_, Tensor? delta_bias_," "bool delta_softplus," - "Tensor? index_, Tensor!? x) -> Tensor[]"); + "Tensor? query_start_loc," + "Tensor? cache_indices," + "Tensor? has_initial_state," + "Tensor! ssm_states) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); ops.def( "causal_conv1d_update(Tensor! x," "Tensor! conv_state," "Tensor! weight," - "Tensor? bias," + "Tensor? bias_," "bool silu_activation," + "Tensor? cache_seqlens_," "Tensor? conv_state_indices) -> Tensor"); ops.impl("causal_conv1d_update", torch::kCUDA, &causal_conv1d_update); ops.def( "causal_conv1d_fwd(Tensor! x, Tensor! weight," "Tensor? bias_," - "Tensor? seq_idx_," - "Tensor? initial_states_," - "Tensor!? final_states_out_," + "Tensor!? conv_states," + "Tensor? query_start_loc," + "Tensor? cache_indices," + "Tensor? has_initial_state," "bool silu_activation) -> Tensor"); ops.impl("causal_conv1d_fwd", torch::kCUDA, &causal_conv1d_fwd); #endif diff --git a/docs/source/dev/multimodal/multimodal_index.rst b/docs/source/dev/multimodal/multimodal_index.rst index 241b2ccd0991e..e112b43aade5e 100644 --- a/docs/source/dev/multimodal/multimodal_index.rst +++ b/docs/source/dev/multimodal/multimodal_index.rst @@ -8,7 +8,7 @@ Multi-Modality vLLM provides experimental support for multi-modal models through the :mod:`vllm.multimodal` package. Multi-modal inputs can be passed alongside text and token prompts to :ref:`supported models ` -via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptInputs`. +via the ``multi_modal_data`` field in :class:`vllm.inputs.PromptType`. Currently, vLLM only has built-in support for image data. You can extend vLLM to process additional modalities by following :ref:`this guide `. diff --git a/docs/source/dev/offline_inference/llm_inputs.rst b/docs/source/dev/offline_inference/llm_inputs.rst index 9adf82d43f3e0..0d47281db485e 100644 --- a/docs/source/dev/offline_inference/llm_inputs.rst +++ b/docs/source/dev/offline_inference/llm_inputs.rst @@ -1,7 +1,7 @@ LLM Inputs ========== -.. autodata:: vllm.inputs.PromptInputs +.. autodata:: vllm.inputs.PromptType .. autoclass:: vllm.inputs.TextPrompt :show-inheritance: diff --git a/docs/source/getting_started/installation.rst b/docs/source/getting_started/installation.rst index afae6e6556021..c6db74c18629f 100644 --- a/docs/source/getting_started/installation.rst +++ b/docs/source/getting_started/installation.rst @@ -12,8 +12,8 @@ Requirements * Python: 3.8 -- 3.12 * GPU: compute capability 7.0 or higher (e.g., V100, T4, RTX20xx, A100, L4, H100, etc.) -Install with pip ----------------- +Install released versions +-------------------------- You can install vLLM using pip: @@ -46,31 +46,75 @@ You can install vLLM using pip: Therefore, it is recommended to install vLLM with a **fresh new** conda environment. If either you have a different CUDA version or you want to use an existing PyTorch installation, you need to build vLLM from source. See below for instructions. -.. note:: +Install the latest code +---------------------------- - vLLM also publishes a subset of wheels (Python 3.10, 3.11 with CUDA 12) for every commit since v0.5.3. You can download them with the following command: +LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for Linux running on x86 platform with cuda 12 for every commit since v0.5.3. You can download and install the latest one with the following command: - .. code-block:: console +.. code-block:: console + + $ pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl + +If you want to access the wheels for previous commits, you can specify the commit hash in the URL: + +.. code-block:: console + + $ export VLLM_COMMIT=33f460b17a54acb3b6cc0b03f4a17876cff5eafd # use full commit hash from the main branch + $ pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/${VLLM_COMMIT}/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl + +Note that the wheels are built with Python 3.8 abi (see `PEP 425 `_ for more details about abi), so **they are compatible with Python 3.8 and later**. The version string in the wheel file name (``1.0.0.dev``) is just a placeholder to have a unified URL for the wheels. The actual versions of wheels are contained in the wheel metadata. + +Another way to access the latest code is to use the docker images: + +.. code-block:: console + + $ export VLLM_COMMIT=33f460b17a54acb3b6cc0b03f4a17876cff5eafd # use full commit hash from the main branch + $ docker pull public.ecr.aws/q9t5s3a7/vllm-ci-test-repo:${VLLM_COMMIT} + +These docker images are used for CI and testing only, and they are not intended for production use. They will be expired after several days. + +Latest code can contain bugs and may not be stable. Please use it with caution. + +Build from source (without compilation) +--------------------------------------- + +If you want to develop vLLM, and you only need to change the Python code, you can build vLLM without compilation. + +The first step is to follow the previous instructions to install the latest vLLM wheel: + +.. code-block:: console + + $ pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-1.0.0.dev-cp38-abi3-manylinux1_x86_64.whl + +After verifying that the installation is successful, we have a script for you to copy and link directories, so that you can edit the Python code directly: + +.. code-block:: console + + $ git clone https://github.com/vllm-project/vllm.git + $ cd vllm + $ python python_only_dev.py + +It will: - $ export VLLM_VERSION=0.6.1.post1 # vLLM's main branch version is currently set to latest released tag - $ pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/nightly/vllm-${VLLM_VERSION}-cp38-abi3-manylinux1_x86_64.whl - $ # You can also access a specific commit - $ # export VLLM_COMMIT=... - $ # pip install https://vllm-wheels.s3.us-west-2.amazonaws.com/${VLLM_COMMIT}/vllm-${VLLM_VERSION}-cp38-abi3-manylinux1_x86_64.whl +- Find the installed vLLM in the current environment. +- Copy built files to the current directory. +- Rename the installed vLLM +- Symbolically link the current directory to the installed vLLM. +This way, you can edit the Python code in the current directory, and the changes will be reflected in the installed vLLM. .. _build_from_source: -Build from source ------------------ +Build from source (with compilation) +------------------------------------ -You can also build and install vLLM from source: +If you need to touch the C++ or CUDA code, you need to build vLLM from source: .. code-block:: console $ git clone https://github.com/vllm-project/vllm.git $ cd vllm - $ pip install -e . # This may take 5-10 minutes. + $ pip install -e . # This can take a long time .. note:: diff --git a/docs/source/models/supported_models.rst b/docs/source/models/supported_models.rst index c41903f84910d..b05cba3b5d423 100644 --- a/docs/source/models/supported_models.rst +++ b/docs/source/models/supported_models.rst @@ -281,7 +281,7 @@ Multimodal Language Models - * - :code:`Qwen2VLForConditionalGeneration` - Qwen2-VL - - Image\ :sup:`+` / Video\ :sup:`+` + - Image\ :sup:`E+` / Video\ :sup:`+` - :code:`Qwen/Qwen2-VL-2B-Instruct`, :code:`Qwen/Qwen2-VL-7B-Instruct`, :code:`Qwen/Qwen2-VL-72B-Instruct`, etc. - * - :code:`UltravoxModel` diff --git a/docs/source/models/vlm.rst b/docs/source/models/vlm.rst index 08db891665044..3f4f01e3ae7ac 100644 --- a/docs/source/models/vlm.rst +++ b/docs/source/models/vlm.rst @@ -27,7 +27,7 @@ The :class:`~vllm.LLM` class can be instantiated in much the same way as languag We have removed all vision language related CLI args in the ``0.5.1`` release. **This is a breaking change**, so please update your code to follow the above snippet. Specifically, ``image_feature_size`` can no longer be specified as we now calculate that internally for each model. -To pass an image to the model, note the following in :class:`vllm.inputs.PromptInputs`: +To pass an image to the model, note the following in :class:`vllm.inputs.PromptType`: * ``prompt``: The prompt should follow the format that is documented on HuggingFace. * ``multi_modal_data``: This is a dictionary that follows the schema defined in :class:`vllm.multimodal.MultiModalDataDict`. @@ -60,7 +60,24 @@ To pass an image to the model, note the following in :class:`vllm.inputs.PromptI for o in outputs: generated_text = o.outputs[0].text print(generated_text) + + # Inference with image embeddings as input with additional parameters + # Specifically, we are conducting a trial run of Qwen2VL with the new input format, as the model utilizes additional parameters for calculating positional encoding. + image_embeds = torch.load(...) # torch.Tensor of shape (1, image_feature_size, hidden_size of LM) + image_grid_thw = torch.load(...) # torch.Tensor of shape (1, 3) + mm_data['image'] = { + "image_embeds": image_embeds, + "image_grid_thw": image_grid_thw, + } + outputs = llm.generate({ + "prompt": prompt, + "multi_modal_data": mm_data, + }) + for o in outputs: + generated_text = o.outputs[0].text + print(generated_text) + # Batch inference image_1 = PIL.Image.open(...) image_2 = PIL.Image.open(...) diff --git a/docs/source/serving/openai_compatible_server.md b/docs/source/serving/openai_compatible_server.md index eb4ea0fb5655e..e0eba7f09bd65 100644 --- a/docs/source/serving/openai_compatible_server.md +++ b/docs/source/serving/openai_compatible_server.md @@ -157,10 +157,10 @@ vLLM will use guided decoding to ensure the response matches the tool parameter To enable this feature, you should set the following flags: * `--enable-auto-tool-choice` -- **mandatory** Auto tool choice. tells vLLM that you want to enable the model to generate its own tool calls when it deems appropriate. -* `--tool-call-parser` -- select the tool parser to use - currently either `hermes` or `mistral`. Additional tool parsers +* `--tool-call-parser` -- select the tool parser to use - currently either `hermes`, `mistral` or `llama3_json`. Additional tool parsers will continue to be added in the future. * `--chat-template` -- **optional** for auto tool choice. the path to the chat template which handles `tool`-role messages and `assistant`-role messages -that contain previously generated tool calls. Hermes and Mistral models have tool-compatible chat templates in their +that contain previously generated tool calls. Hermes, Mistral and Llama models have tool-compatible chat templates in their `tokenizer_config.json` files, but you can specify a custom template. This argument can be set to `tool_use` if your model has a tool use-specific chat template configured in the `tokenizer_config.json`. In this case, it will be used per the `transformers` specification. More on this [here](https://huggingface.co/docs/transformers/en/chat_templating#why-do-some-models-have-multiple-templates) from HuggingFace; and you can find an example of this in a `tokenizer_config.json` [here](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B/blob/main/tokenizer_config.json) @@ -197,3 +197,25 @@ when tools are provided, that results in much better reliability when working wi Recommended flags: `--tool-call-parser mistral --chat-template examples/tool_chat_template_mistral_parallel.jinja` + +#### Llama Models +Supported models: +* `meta-llama/Meta-Llama-3.1-8B-Instruct` +* `meta-llama/Meta-Llama-3.1-70B-Instruct` +* `meta-llama/Meta-Llama-3.1-405B-Instruct` +* `meta-llama/Meta-Llama-3.1-405B-Instruct-FP8` + +The tool calling that is supported is the [JSON based tool calling](https://llama.meta.com/docs/model-cards-and-prompt-formats/llama3_1/#json-based-tool-calling). +Other tool calling formats like the built in python tool calling or custom tool calling are not supported. + +Known issues: +1. Parallel tool calls are not supported. +2. The model can generate parameters with a wrong format, such as generating + an array serialized as string instead of an array. + +The `tool_chat_template_llama3_json.jinja` file contains the "official" Llama chat template, but tweaked so that +it works better with vLLM. + +Recommended flags: `--tool-call-parser llama3_json --chat-template examples/tool_chat_template_llama3_json.jinja` + + diff --git a/examples/offline_inference_vision_language.py b/examples/offline_inference_vision_language.py index 6d34621a8a9bc..b94ef537d783f 100644 --- a/examples/offline_inference_vision_language.py +++ b/examples/offline_inference_vision_language.py @@ -12,6 +12,10 @@ from vllm.assets.video import VideoAsset from vllm.utils import FlexibleArgumentParser +# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on +# lower-end GPUs. +# Unless specified, these settings have been tested to work on a single L4. + # LLaVA-1.5 def run_llava(question, modality): @@ -19,7 +23,7 @@ def run_llava(question, modality): prompt = f"USER: \n{question}\nASSISTANT:" - llm = LLM(model="llava-hf/llava-1.5-7b-hf") + llm = LLM(model="llava-hf/llava-1.5-7b-hf", max_model_len=4096) stop_token_ids = None return llm, prompt, stop_token_ids @@ -57,7 +61,7 @@ def run_llava_onevision(question, modality): <|im_start|>assistant\n" llm = LLM(model="llava-hf/llava-onevision-qwen2-7b-ov-hf", - max_model_len=32768) + max_model_len=16384) stop_token_ids = None return llm, prompt, stop_token_ids @@ -67,7 +71,7 @@ def run_fuyu(question, modality): assert modality == "image" prompt = f"{question}\n" - llm = LLM(model="adept/fuyu-8b") + llm = LLM(model="adept/fuyu-8b", max_model_len=2048, max_num_seqs=2) stop_token_ids = None return llm, prompt, stop_token_ids @@ -99,7 +103,8 @@ def run_phi3v(question, modality): llm = LLM( model="microsoft/Phi-3-vision-128k-instruct", trust_remote_code=True, - max_num_seqs=5, + max_model_len=4096, + max_num_seqs=2, mm_processor_kwargs={"num_crops": 16}, ) stop_token_ids = None @@ -122,7 +127,7 @@ def run_chameleon(question, modality): assert modality == "image" prompt = f"{question}" - llm = LLM(model="facebook/chameleon-7b") + llm = LLM(model="facebook/chameleon-7b", max_model_len=4096) stop_token_ids = None return llm, prompt, stop_token_ids @@ -145,6 +150,8 @@ def run_minicpmv(question, modality): trust_remote_code=True) llm = LLM( model=model_name, + max_model_len=4096, + max_num_seqs=2, trust_remote_code=True, ) # NOTE The stop_token_ids are different for various versions of MiniCPM-V @@ -177,7 +184,7 @@ def run_internvl(question, modality): llm = LLM( model=model_name, trust_remote_code=True, - max_num_seqs=5, + max_model_len=4096, ) tokenizer = AutoTokenizer.from_pretrained(model_name, @@ -215,7 +222,8 @@ def run_qwen_vl(question, modality): llm = LLM( model="Qwen/Qwen-VL", trust_remote_code=True, - max_num_seqs=5, + max_model_len=1024, + max_num_seqs=2, ) prompt = f"{question}Picture 1: \n" @@ -229,8 +237,10 @@ def run_qwen2_vl(question, modality): model_name = "Qwen/Qwen2-VL-7B-Instruct" + # Tested on L40 llm = LLM( model=model_name, + max_model_len=8192, max_num_seqs=5, ) @@ -252,10 +262,10 @@ def run_mllama(question, modality): # max_model_len (131072) for this model may cause OOM. # You may lower either to run this example on lower-end GPUs. - # The configuration below has been confirmed to launch on a - # single H100 GPU. + # The configuration below has been confirmed to launch on a single L40 GPU. llm = LLM( model=model_name, + max_model_len=4096, max_num_seqs=16, enforce_eager=True, ) diff --git a/examples/offline_inference_vision_language_multi_image.py b/examples/offline_inference_vision_language_multi_image.py index 8c5f1a7b7af08..66936ab125b81 100644 --- a/examples/offline_inference_vision_language_multi_image.py +++ b/examples/offline_inference_vision_language_multi_image.py @@ -28,12 +28,18 @@ class ModelRequestData(NamedTuple): chat_template: Optional[str] +# NOTE: The default `max_num_seqs` and `max_model_len` may result in OOM on +# lower-end GPUs. +# Unless specified, these settings have been tested to work on a single L4. + + def load_qwenvl_chat(question: str, image_urls: List[str]) -> ModelRequestData: model_name = "Qwen/Qwen-VL-Chat" llm = LLM( model=model_name, trust_remote_code=True, - max_num_seqs=5, + max_model_len=1024, + max_num_seqs=2, limit_mm_per_prompt={"image": len(image_urls)}, ) placeholders = "".join(f"Picture {i}: \n" @@ -83,6 +89,7 @@ def load_phi3v(question: str, image_urls: List[str]) -> ModelRequestData: model="microsoft/Phi-3.5-vision-instruct", trust_remote_code=True, max_model_len=4096, + max_num_seqs=2, limit_mm_per_prompt={"image": len(image_urls)}, mm_processor_kwargs={"num_crops": 4}, ) @@ -106,9 +113,9 @@ def load_internvl(question: str, image_urls: List[str]) -> ModelRequestData: llm = LLM( model=model_name, trust_remote_code=True, - max_num_seqs=5, max_model_len=4096, limit_mm_per_prompt={"image": len(image_urls)}, + mm_processor_kwargs={"max_dynamic_patch": 4}, ) placeholders = "\n".join(f"Image-{i}: \n" @@ -148,10 +155,11 @@ def load_qwen2_vl(question, image_urls: List[str]) -> ModelRequestData: model_name = "Qwen/Qwen2-VL-7B-Instruct" + # Tested on L40 llm = LLM( model=model_name, - max_num_seqs=5, max_model_len=32768 if process_vision_info is None else 4096, + max_num_seqs=5, limit_mm_per_prompt={"image": len(image_urls)}, ) diff --git a/examples/tool_chat_template_llama3.1_json.jinja b/examples/tool_chat_template_llama3.1_json.jinja new file mode 100644 index 0000000000000..c24a7e51335ef --- /dev/null +++ b/examples/tool_chat_template_llama3.1_json.jinja @@ -0,0 +1,94 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {#- Llama 3.1 doesn't pass all tests if the tools are in the system prompt #} + {%- set tools_in_user_message = true %} +{%- endif %} +{%- if not date_string is defined %} + {%- if strftime_now is defined %} + {%- set date_string = strftime_now("%d %b %Y") %} + {%- else %} + {%- set date_string = "26 Jul 2024" %} + {%- endif %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} +{%- endif %} + +{#- System message #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} + {%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {{- "<|eot_id|>" }} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping %} + {{- message.content | tojson }} + {%- else %} + {{- { "output": message.content } | tojson }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/examples/tool_chat_template_llama3.2_json.jinja b/examples/tool_chat_template_llama3.2_json.jinja new file mode 100644 index 0000000000000..7e24777726a35 --- /dev/null +++ b/examples/tool_chat_template_llama3.2_json.jinja @@ -0,0 +1,93 @@ +{{- bos_token }} +{%- if custom_tools is defined %} + {%- set tools = custom_tools %} +{%- endif %} +{%- if not tools_in_user_message is defined %} + {%- set tools_in_user_message = false %} +{%- endif %} +{%- if not date_string is defined %} + {%- if strftime_now is defined %} + {%- set date_string = strftime_now("%d %b %Y") %} + {%- else %} + {%- set date_string = "26 Jul 2024" %} + {%- endif %} +{%- endif %} +{%- if not tools is defined %} + {%- set tools = none %} +{%- endif %} + +{#- This block extracts the system message, so we can slot it into the right place. #} +{%- if messages[0]['role'] == 'system' %} + {%- set system_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} +{%- else %} + {%- set system_message = "You are a helpful assistant with tool calling capabilities. Only reply with a tool call if the function exists in the library provided by the user. If it doesn't exist, just reply directly in natural language. When you receive a tool call response, use the output to format an answer to the original user question." %} +{%- endif %} + +{#- System message #} +{{- "<|start_header_id|>system<|end_header_id|>\n\n" }} +{%- if tools is not none %} + {{- "Environment: ipython\n" }} +{%- endif %} +{{- "Cutting Knowledge Date: December 2023\n" }} +{{- "Today Date: " + date_string + "\n\n" }} +{%- if tools is not none and not tools_in_user_message %} + {{- "You have access to the following functions. To call a function, please respond with JSON for a function call." }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} +{%- endif %} +{{- system_message }} +{{- "<|eot_id|>" }} + +{#- Custom tools are passed in a user message with some extra guidance #} +{%- if tools_in_user_message and not tools is none %} + {#- Extract the first user message so we can plug it in here #} + {%- if messages | length != 0 %} + {%- set first_user_message = messages[0]['content']|trim %} + {%- set messages = messages[1:] %} + {%- else %} + {{- raise_exception("Cannot put tools in the first user message when there's no first user message!") }} + {%- endif %} + {{- '<|start_header_id|>user<|end_header_id|>\n\n' -}} + {{- "Given the following functions, please respond with a JSON for a function call " }} + {{- "with its proper arguments that best answers the given prompt.\n\n" }} + {{- 'Respond in the format {"name": function name, "parameters": dictionary of argument name and its value}.' }} + {{- "Do not use variables.\n\n" }} + {%- for t in tools %} + {{- t | tojson(indent=4) }} + {{- "\n\n" }} + {%- endfor %} + {{- first_user_message + "<|eot_id|>"}} +{%- endif %} + +{%- for message in messages %} + {%- if not (message.role == 'ipython' or message.role == 'tool' or 'tool_calls' in message) %} + {{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n'+ message['content'] | trim + '<|eot_id|>' }} + {%- elif 'tool_calls' in message %} + {%- if not message.tool_calls|length == 1 %} + {{- raise_exception("This model only supports single tool-calls at once!") }} + {%- endif %} + {%- set tool_call = message.tool_calls[0].function %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' -}} + {{- '{"name": "' + tool_call.name + '", ' }} + {{- '"parameters": ' }} + {{- tool_call.arguments | tojson }} + {{- "}" }} + {{- "<|eot_id|>" }} + {%- elif message.role == "tool" or message.role == "ipython" %} + {{- "<|start_header_id|>ipython<|end_header_id|>\n\n" }} + {%- if message.content is mapping %} + {{- message.content | tojson }} + {%- else %} + {{- { "output": message.content } | tojson }} + {%- endif %} + {{- "<|eot_id|>" }} + {%- endif %} +{%- endfor %} +{%- if add_generation_prompt %} + {{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }} +{%- endif %} diff --git a/python_only_dev.py b/python_only_dev.py new file mode 100644 index 0000000000000..d84122280a3c2 --- /dev/null +++ b/python_only_dev.py @@ -0,0 +1,54 @@ +# enable python only development +# copy compiled files to the current directory directly + +import os +import shutil +import subprocess +import sys + +# cannot directly `import vllm` , because it will try to +# import from the current directory +output = subprocess.run([sys.executable, "-m", "pip", "show", "vllm"], + capture_output=True) + +assert output.returncode == 0, "vllm is not installed" + +text = output.stdout.decode("utf-8") + +package_path = None +for line in text.split("\n"): + if line.startswith("Location: "): + package_path = line.split(": ")[1] + break + +assert package_path is not None, "could not find package path" + +cwd = os.getcwd() + +assert cwd != package_path, "should not import from the current directory" + +files_to_copy = [ + "vllm/_C.abi3.so", + "vllm/_core_C.abi3.so", + "vllm/_moe_C.abi3.so", + "vllm/vllm_flash_attn/vllm_flash_attn_c.abi3.so", + "vllm/vllm_flash_attn/flash_attn_interface.py", + "vllm/vllm_flash_attn/__init__.py", + # "vllm/_version.py", # not available in nightly wheels yet +] + +for file in files_to_copy: + src = os.path.join(package_path, file) + dst = file + print(f"Copying {src} to {dst}") + shutil.copyfile(src, dst) + +pre_built_vllm_path = os.path.join(package_path, "vllm") +tmp_path = os.path.join(package_path, "vllm_pre_built") +current_vllm_path = os.path.join(cwd, "vllm") + +print(f"Renaming {pre_built_vllm_path} to {tmp_path}") +os.rename(pre_built_vllm_path, tmp_path) + +print(f"linking {current_vllm_path} to {pre_built_vllm_path}") +os.symlink(current_vllm_path, pre_built_vllm_path) diff --git a/requirements-common.txt b/requirements-common.txt index 2fc89c026901b..aa165ff6d6a5e 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -1,14 +1,14 @@ psutil sentencepiece # Required for LLaMA tokenizer. numpy < 2.0.0 -requests +requests >= 2.26.0 tqdm py-cpuinfo transformers >= 4.45.0 # Required for Llama 3.2. tokenizers >= 0.19.1 # Required for Llama 3. protobuf # Required by LlamaTokenizer. -fastapi < 0.113.0; python_version < '3.9' -fastapi >= 0.114.1; python_version >= '3.9' +fastapi >= 0.107.0, < 0.113.0; python_version < '3.9' +fastapi >= 0.107.0, != 0.113.*, != 0.114.0; python_version >= '3.9' aiohttp openai >= 1.40.0 # Ensure modern openai package (ensure types module present) uvicorn[standard] @@ -26,7 +26,7 @@ pyzmq msgspec gguf == 0.10.0 importlib_metadata -mistral_common >= 1.4.3 +mistral_common[opencv] >= 1.4.4 pyyaml six>=1.16.0; python_version > '3.11' # transitive dependency of pandas that needs to be the latest version for python 3.12 setuptools>=74.1.1; python_version > '3.11' # Setuptools is used by triton, we need to ensure a modern version is installed for 3.12+ so that it does not try to import distutils, which was removed in 3.12 diff --git a/requirements-xpu.txt b/requirements-xpu.txt index 9b21845e084d8..ce83a178c618f 100644 --- a/requirements-xpu.txt +++ b/requirements-xpu.txt @@ -1,9 +1,13 @@ # Common dependencies -r requirements-common.txt -setuptools < 70.0.0 # IPEX's torch have some dependency. to be removed. - ray >= 2.9 +cmake>=3.26 +ninja +packaging +setuptools-scm>=8 +wheel +jinja2 # Following pkgs retrieved from https://pytorch-extension.intel.com/release-whl/stable/xpu/us/ torch == 2.3.1+cxx11.abi intel-extension-for-pytorch == 2.3.110+xpu diff --git a/setup.py b/setup.py index 721bcbbd89ca5..00f75f1ff166e 100644 --- a/setup.py +++ b/setup.py @@ -457,6 +457,8 @@ def _read_requirements(filename: str) -> List[str]: for line in requirements: if line.startswith("-r "): resolved_requirements += _read_requirements(line.split()[1]) + elif line.startswith("--"): + continue else: resolved_requirements.append(line) return resolved_requirements @@ -554,7 +556,6 @@ def _read_requirements(filename: str) -> List[str]: ext_modules=ext_modules, extras_require={ "tensorizer": ["tensorizer>=2.9.0"], - "video": ["opencv-python"], # Required for video processing "audio": ["librosa", "soundfile"] # Required for audio processing }, cmdclass={"build_ext": cmake_build_ext} if len(ext_modules) > 0 else {}, diff --git a/tests/async_engine/test_async_llm_engine.py b/tests/async_engine/test_async_llm_engine.py index 6cae76f74603d..1903a7582dc89 100644 --- a/tests/async_engine/test_async_llm_engine.py +++ b/tests/async_engine/test_async_llm_engine.py @@ -86,17 +86,19 @@ class MockAsyncLLMEngine(AsyncLLMEngine): @pytest.mark.asyncio async def test_new_requests_event(): + params = SamplingParams() + engine = MockAsyncLLMEngine() engine.start_background_loop() await asyncio.sleep(0.01) assert engine.engine.step_calls == 0 - await engine.add_request("1", "", None) + await engine.add_request("1", "", params) await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 1 assert engine.engine.step_calls == 1 - await engine.add_request("2", "", None) + await engine.add_request("2", "", params) engine.engine.generate("2") await asyncio.sleep(0) await asyncio.sleep(0) @@ -111,7 +113,7 @@ async def test_new_requests_event(): await asyncio.sleep(0.001) assert engine.engine.step_calls == old_step_calls - await engine.add_request("3", "", None) + await engine.add_request("3", "", params) await asyncio.sleep(0.01) assert engine.engine.add_request_calls == 3 assert engine.engine.step_calls == old_step_calls + 1 diff --git a/tests/basic_correctness/test_preemption.py b/tests/basic_correctness/test_preemption.py index 00806c3e129b1..05e7859759002 100644 --- a/tests/basic_correctness/test_preemption.py +++ b/tests/basic_correctness/test_preemption.py @@ -23,8 +23,10 @@ @pytest.fixture(scope="module", autouse=True) def check_settings(): assert ENABLE_ARTIFICIAL_PREEMPT is True, ( - "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1. " - "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 pytest " + "Use an env var VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1, " + "VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1. " + "`VLLM_TEST_ENABLE_ARTIFICIAL_PREEMPT=1 " + "VLLM_ALLOW_DEPRECATED_BEAM_SEARCH=1 pytest " "tests/basic_correctness/test_preemption.py`") @@ -199,6 +201,7 @@ def test_swap( @pytest.mark.parametrize("dtype", ["float"]) @pytest.mark.parametrize("max_tokens", [96]) @pytest.mark.parametrize("beam_width", [4]) +@pytest.mark.parametrize("use_v2_block_manager", [True, False]) def test_swap_infeasible( vllm_runner, example_prompts, @@ -207,6 +210,7 @@ def test_swap_infeasible( max_tokens: int, beam_width: int, worker_use_ray: bool, + use_v2_block_manager: bool, ) -> None: """Verify infeasible swap request will be ignored.""" BLOCK_SIZE = 16 @@ -223,6 +227,7 @@ def test_swap_infeasible( num_gpu_blocks_override=prefill_blocks + decode_blocks, max_model_len=(prefill_blocks + decode_blocks) * BLOCK_SIZE, worker_use_ray=worker_use_ray, + use_v2_block_manager=use_v2_block_manager, ) as vllm_model: sampling_params = SamplingParams(n=beam_width, use_beam_search=True, diff --git a/tests/conftest.py b/tests/conftest.py index 2f73cdf1e2972..997b98a67f7d9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -246,17 +246,14 @@ def video_assets() -> _VideoAssets: class HfRunner: - def wrap_device(self, input: _T) -> _T: - if not is_cpu(): - # Check if the input is already on the GPU - if hasattr(input, 'device') and input.device.type == "cuda": - return input # Already on GPU, no need to move - return input.to("cuda") - else: - # Check if the input is already on the CPU - if hasattr(input, 'device') and input.device.type == "cpu": - return input # Already on CPU, no need to move - return input.to("cpu") + def wrap_device(self, input: _T, device: Optional[str] = None) -> _T: + if device is None: + return self.wrap_device(input, "cpu" if is_cpu() else "cuda") + + if hasattr(input, "device") and input.device.type == device: + return input + + return input.to(device) def __init__( self, @@ -333,7 +330,7 @@ def generate( inputs = self.postprocess_inputs(inputs) output_ids = self.model.generate( - **self.wrap_device(inputs), + **self.wrap_device(inputs, device=self.model.device.type), use_cache=True, **kwargs, ) @@ -406,7 +403,7 @@ def generate_greedy_logprobs( inputs = self.postprocess_inputs(inputs) output = self.model.generate( - **self.wrap_device(inputs), + **self.wrap_device(inputs, device=self.model.device.type), use_cache=True, do_sample=False, max_new_tokens=max_tokens, @@ -414,40 +411,39 @@ def generate_greedy_logprobs( return_dict_in_generate=True, **kwargs, ) - seq_logprobs: List[torch.Tensor] = [] - for hidden_states in output.hidden_states: - last_hidden_states = hidden_states[-1][0] - logits = torch.matmul( - last_hidden_states, - self.model.get_output_embeddings().weight.t(), - ) - if self.model.get_output_embeddings().bias is not None: - logits += self.model.get_output_embeddings( - ).bias.unsqueeze(0) - logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) - seq_logprobs.append(logprobs) + seq_logprobs = self._hidden_states_to_seq_logprobs( + output.hidden_states) all_logprobs.append(seq_logprobs) return all_logprobs - def _hidden_states_to_logprobs( + def _hidden_states_to_seq_logprobs( self, - hidden_states, - num_logprobs, - ) -> Tuple[List[Dict[int, float]], int]: + hidden_states: Tuple[Tuple[torch.Tensor, ...], ...], + ) -> List[torch.Tensor]: + output_embeddings = self.model.get_output_embeddings() + seq_logprobs: List[torch.Tensor] = [] - output_len = len(hidden_states) for _, hidden_state in enumerate(hidden_states): last_hidden_states = hidden_state[-1][0] logits = torch.matmul( - last_hidden_states, - self.model.get_output_embeddings().weight.t(), + last_hidden_states.to(output_embeddings.weight.device), + output_embeddings.weight.t(), ) - if getattr(self.model.get_output_embeddings(), "bias", - None) is not None: - logits += self.model.get_output_embeddings().bias.unsqueeze(0) + if getattr(output_embeddings, "bias", None) is not None: + logits += output_embeddings.bias.unsqueeze(0) logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32) seq_logprobs.append(logprobs) + return seq_logprobs + + def _hidden_states_to_logprobs( + self, + hidden_states: Tuple[Tuple[torch.Tensor, ...], ...], + num_logprobs: int, + ) -> Tuple[List[Dict[int, float]], int]: + seq_logprobs = self._hidden_states_to_seq_logprobs(hidden_states) + output_len = len(hidden_states) + # convert to dict seq_logprobs_lst: List[Dict[int, float]] = [] for tok_idx, tok_logprobs in enumerate(seq_logprobs): @@ -500,7 +496,7 @@ def generate_greedy_logprobs_limit( inputs = self.postprocess_inputs(inputs) output = self.model.generate( - **self.wrap_device(inputs), + **self.wrap_device(inputs, device=self.model.device.type), use_cache=True, do_sample=False, max_new_tokens=max_tokens, @@ -543,12 +539,20 @@ def generate_encoder_decoder_greedy_logprobs_limit( for (encoder_prompt, decoder_prompt) in to_enc_dec_tuple_list(encoder_decoder_prompts): + encoder_input_ids = self.wrap_device( - self.tokenizer(encoder_prompt, return_tensors="pt").input_ids) - decoder_input_ids = ( - None if decoder_prompt is None else self.wrap_device( + self.tokenizer(encoder_prompt, return_tensors="pt").input_ids, + device=self.model.device.type, + ) + + if decoder_prompt is None: + decoder_input_ids = None + else: + decoder_input_ids = self.wrap_device( self.tokenizer(decoder_prompt, - return_tensors="pt").input_ids)) + return_tensors="pt").input_ids, + device=self.model.device.type, + ) output = self.model.generate( encoder_input_ids, @@ -699,7 +703,6 @@ def generate_w_logprobs( if videos is not None: for i, video in enumerate(videos): inputs[i]["multi_modal_data"] = {"video": video} - print(f"[INPUTS!!!!]: {inputs}, {sampling_params}") req_outputs = self.model.generate(inputs, sampling_params=sampling_params) diff --git a/tests/core/block/test_block_manager_v2.py b/tests/core/block/test_block_manager_v2.py index 30efe4437741d..e67883367879f 100644 --- a/tests/core/block/test_block_manager_v2.py +++ b/tests/core/block/test_block_manager_v2.py @@ -373,6 +373,52 @@ def test_can_swap(block_size, num_gpu_blocks, num_lookahead_slots, seq_group, num_lookahead_slots) == AllocStatus.NEVER +@pytest.mark.parametrize("num_lookahead_slots", [0, 2, 10]) +@pytest.mark.parametrize("enable_caching", [False, True]) +def test_swap_in_infeasible(num_lookahead_slots, enable_caching): + """Verifies that swapping fails if there is not enough free blocks + to account for unseen tokens and lookahead_slots. + """ + block_size = 8 + num_cpu_blocks = 1 + num_gpu_blocks = 1 + block_manager = BlockSpaceManagerV2(block_size, + num_cpu_blocks, + num_gpu_blocks, + watermark=0, + enable_caching=enable_caching) + prompt_length = block_size - 3 + assert prompt_length > 0 + prompt, seq_group = create_dummy_prompt("1", prompt_length=prompt_length) + prompt.status = SequenceStatus.WAITING + block_manager.allocate(seq_group) + # Emulate a forward pass by appending a single token. + # The block manager then knows how many unprocessed + # tokens will be written in the next forward pass. + token_id = 0 + prompt.status = SequenceStatus.RUNNING + prompt.append_token_id(token_id, {token_id: Logprob(0.0)}) + + # Swap seq group from GPU -> CPU. + assert block_manager.can_swap_out(seq_group) + block_manager.swap_out(seq_group) + prompt.status = SequenceStatus.SWAPPED + + # Swap seq group from CPU -> GPU. + # The number of unseen tokens is 1. If the number of existing + # tokens plus the unseen ones and number of lookahead slots exceeds + # the total number of available GPU blocks then the swap + # should fail. + num_unseen_tokens = 1 + if (num_lookahead_slots + num_unseen_tokens + + prompt_length) <= (block_size * num_gpu_blocks): + assert block_manager.can_swap_in(seq_group, + num_lookahead_slots) == AllocStatus.OK + else: + assert block_manager.can_swap_in( + seq_group, num_lookahead_slots) == AllocStatus.NEVER + + # TODO(cade/kaiyang): add comprehensive tests for swapping at allocator level. @@ -400,7 +446,6 @@ def check_used(min_n, max_n=None): if max_n is None: max_n = min_n used = num_gpu_blocks - block_manager.get_num_free_gpu_blocks() - #print("check", min_n, used, max_n) assert min_n <= used assert used <= max_n diff --git a/tests/core/block/test_naive_block.py b/tests/core/block/test_naive_block.py index e2e814c278603..10d5964dcfe8a 100644 --- a/tests/core/block/test_naive_block.py +++ b/tests/core/block/test_naive_block.py @@ -104,9 +104,9 @@ def test_get_num_free_blocks(allocate_type: str, num_blocks: int, @staticmethod @pytest.mark.parametrize("num_blocks", [4]) @pytest.mark.parametrize("block_size", [8]) - def test_naive_block_get_num_blocks_touched(num_blocks, block_size): + def test_naive_block_get_num_full_blocks_touched(num_blocks, block_size): """ Verify the allocator can correctly return the number of - blocks touched, with different lookahead slots. + full blocks touched. """ allocator_src = NaiveBlockAllocator(create_block=NaiveBlock, num_blocks=num_blocks, @@ -124,7 +124,7 @@ def test_naive_block_get_num_blocks_touched(num_blocks, block_size): src_blocks = [allocate_block() for _ in range(num_blocks - 1)] # All blocks are cached - assert allocator_dst.get_num_blocks_touched( + assert allocator_dst.get_num_full_blocks_touched( src_blocks) == num_blocks - 1 # Insert one non-full block in the src @@ -136,9 +136,10 @@ def test_naive_block_get_num_blocks_touched(num_blocks, block_size): src_blocks.append(allocate_non_full_block()) src_blocks[-1].append_token_ids([0]) - assert allocator_dst.get_num_blocks_touched( - src_blocks, num_lookahead_slots=1) == num_blocks - assert allocator_dst.get_num_blocks_touched( - src_blocks, num_lookahead_slots=block_size - 1) == num_blocks - assert allocator_dst.get_num_blocks_touched( - src_blocks, num_lookahead_slots=block_size) == (num_blocks + 1) + assert allocator_dst.get_num_full_blocks_touched( + src_blocks) == num_blocks - 1 + # Fill up the last source block and then invoke + # get_num_blocks_touched + src_blocks[-1].append_token_ids([0] * (block_size - 1)) + assert allocator_dst.get_num_full_blocks_touched( + src_blocks) == num_blocks diff --git a/tests/core/block/test_prefix_caching_block.py b/tests/core/block/test_prefix_caching_block.py index 25be2dd13f8bd..1a6e17ef7b445 100644 --- a/tests/core/block/test_prefix_caching_block.py +++ b/tests/core/block/test_prefix_caching_block.py @@ -318,11 +318,10 @@ def test_get_num_free_blocks(num_blocks: int, block_size: int, seed: int): @staticmethod @pytest.mark.parametrize("num_blocks", [4]) @pytest.mark.parametrize("block_size", [8]) - def test_prefix_caching_block_get_num_blocks_touched( + def test_prefix_caching_block_get_num_full_blocks_touched( num_blocks, block_size): """ Verify the allocator can correctly return the number of - blocks touched, when there are cached prefixes and different - lookahead slots. + blocks touched, when there are cached prefixes. """ allocator_src = PrefixCachingBlockAllocator(num_blocks=num_blocks, block_size=block_size) @@ -346,28 +345,30 @@ def test_prefix_caching_block_get_num_blocks_touched( token_ids=token_ids, allocator=allocator_src, ) - # All blocks are cached - assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 0 + assert allocator_dst.get_num_full_blocks_touched( + blocks_to_swap_in) == 0 # Free the first block in the dst allocator_dst.free(cached_blocks[0]) # Now the first block becomes dangling, the swapped blocks need # to reclaim the first block in the dst - assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in) == 1 + assert allocator_dst.get_num_full_blocks_touched( + blocks_to_swap_in) == 1 # Insert one non-full block in the src non_full_block = allocator_src.allocate_mutable_block( blocks_to_swap_in[-1]) non_full_block.append_token_ids([0]) blocks_to_swap_in.append(non_full_block) - assert allocator_dst.get_num_blocks_touched(blocks_to_swap_in, - num_lookahead_slots=1) == 2 - assert allocator_dst.get_num_blocks_touched( - blocks_to_swap_in, num_lookahead_slots=block_size - 1) == 2 - assert allocator_dst.get_num_blocks_touched( - blocks_to_swap_in, num_lookahead_slots=block_size) == 3 + assert allocator_dst.get_num_full_blocks_touched( + blocks_to_swap_in) == 1 + # Fill up the last mutable block and invoke get_num_blocks_touched. + # Note: The last block is not cached so it will be touched. + non_full_block.append_token_ids([0] * (block_size - 1)) + assert allocator_dst.get_num_full_blocks_touched( + blocks_to_swap_in) == 2 @staticmethod @pytest.mark.parametrize("num_blocks", [1024]) diff --git a/tests/distributed/test_pipeline_parallel.py b/tests/distributed/test_pipeline_parallel.py index 280a8abdd13a7..2e8e83c3d271b 100644 --- a/tests/distributed/test_pipeline_parallel.py +++ b/tests/distributed/test_pipeline_parallel.py @@ -8,8 +8,6 @@ import os import pytest -from packaging import version -from transformers import __version__ as transformers_version from vllm.logger import init_logger @@ -39,7 +37,9 @@ (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-1B", "mp"), (1, 2, 1, 1, 1, "OpenGVLab/InternVL2-2B", "mp"), (1, 2, 1, 0, 1, "OpenGVLab/InternVL2-4B", "mp"), - (1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp") + (1, 2, 0, 1, 0, "Qwen/Qwen2-VL-2B-Instruct", "mp"), + # TP only models + (2, 1, 1, 0, 0, "adept/fuyu-8b", "mp"), ], ) @fork_new_process_for_each_test @@ -49,11 +49,6 @@ def test_compare_tp(TP_SIZE, PP_SIZE, EAGER_MODE, CHUNKED_PREFILL, pytest.skip("Skipping multi-node pipeline parallel test for " "multiprocessing distributed backend") - # Skip tests that require transformers>=4.45.0 - if "Qwen2-VL" in MODEL_NAME and version.parse( - transformers_version) < version.parse("4.45.0.dev0"): - pytest.skip("This test requires transformers>=4.45.0") - pp_args = [ # use half precision for speed and memory savings in CI environment "--dtype", diff --git a/tests/engine/test_custom_executor.py b/tests/engine/test_custom_executor.py index bff0fc99ed022..bbabb936e92ba 100644 --- a/tests/engine/test_custom_executor.py +++ b/tests/engine/test_custom_executor.py @@ -48,9 +48,9 @@ def test_custom_executor_type_checking(model): @pytest.mark.parametrize("model", ["facebook/opt-125m"]) -def test_custom_executor(model, tmpdir): +def test_custom_executor(model, tmp_path): cwd = os.path.abspath(".") - os.chdir(tmpdir) + os.chdir(tmp_path) try: assert not os.path.exists(".marker") @@ -68,9 +68,9 @@ def test_custom_executor(model, tmpdir): @pytest.mark.parametrize("model", ["facebook/opt-125m"]) -def test_custom_executor_async(model, tmpdir): +def test_custom_executor_async(model, tmp_path): cwd = os.path.abspath(".") - os.chdir(tmpdir) + os.chdir(tmp_path) try: assert not os.path.exists(".marker") diff --git a/tests/entrypoints/llm/test_encode.py b/tests/entrypoints/llm/test_encode.py index d1056a0490509..1885f2e168d80 100644 --- a/tests/entrypoints/llm/test_encode.py +++ b/tests/entrypoints/llm/test_encode.py @@ -49,21 +49,6 @@ def assert_outputs_equal(o1: List[EmbeddingRequestOutput], assert [o.outputs for o in o1] == [o.outputs for o in o2] -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt', PROMPTS) -def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.encode(prompts=prompt, pooling_params=pooling_params) - - v2_output = llm.encode(prompt, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.encode({"prompt": prompt}, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, @@ -79,25 +64,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, assert_outputs_equal(v1_output, v2_output) -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): - pooling_params = PoolingParams() - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.encode(prompts=PROMPTS, pooling_params=pooling_params) - - v2_output = llm.encode(PROMPTS, pooling_params=pooling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.encode( - [{ - "prompt": p - } for p in PROMPTS], - pooling_params=pooling_params, - ) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): pooling_params = PoolingParams() diff --git a/tests/entrypoints/llm/test_generate.py b/tests/entrypoints/llm/test_generate.py index cd989225e2483..6543c4bb1b58e 100644 --- a/tests/entrypoints/llm/test_generate.py +++ b/tests/entrypoints/llm/test_generate.py @@ -47,23 +47,6 @@ def assert_outputs_equal(o1: List[RequestOutput], o2: List[RequestOutput]): assert [o.outputs for o in o1] == [o.outputs for o in o2] -@pytest.mark.skip_global_cleanup -@pytest.mark.parametrize('prompt', PROMPTS) -def test_v1_v2_api_consistency_single_prompt_string(llm: LLM, prompt): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.generate(prompts=prompt, - sampling_params=sampling_params) - - v2_output = llm.generate(prompt, sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.generate({"prompt": prompt}, - sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup @pytest.mark.parametrize('prompt_token_ids', TOKEN_IDS) def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, @@ -79,26 +62,6 @@ def test_v1_v2_api_consistency_single_prompt_tokens(llm: LLM, assert_outputs_equal(v1_output, v2_output) -@pytest.mark.skip_global_cleanup -def test_v1_v2_api_consistency_multi_prompt_string(llm: LLM): - sampling_params = SamplingParams(temperature=0.0, top_p=1.0) - - with pytest.warns(DeprecationWarning, match="'prompts'"): - v1_output = llm.generate(prompts=PROMPTS, - sampling_params=sampling_params) - - v2_output = llm.generate(PROMPTS, sampling_params=sampling_params) - assert_outputs_equal(v1_output, v2_output) - - v2_output = llm.generate( - [{ - "prompt": p - } for p in PROMPTS], - sampling_params=sampling_params, - ) - assert_outputs_equal(v1_output, v2_output) - - @pytest.mark.skip_global_cleanup def test_v1_v2_api_consistency_multi_prompt_tokens(llm: LLM): sampling_params = SamplingParams(temperature=0.0, top_p=1.0) diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 873e115421257..2841dfc6bd9c2 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -7,7 +7,7 @@ from vllm.entrypoints.llm import LLM from vllm.outputs import RequestOutput -from vllm.sampling_params import SamplingParams +from vllm.sampling_params import GuidedDecodingParams, SamplingParams from ...conftest import cleanup @@ -31,14 +31,12 @@ def test_guided_regex(sample_regex, llm): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, - ) - outputs = llm.generate( - prompts=[ - f"Give an example IPv4 address with this regex: {sample_regex}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_regex=sample_regex)) + guided_decoding=GuidedDecodingParams(regex=sample_regex)) + outputs = llm.generate(prompts=[ + f"Give an example IPv4 address with this regex: {sample_regex}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) assert outputs is not None for output in outputs: @@ -57,15 +55,13 @@ def test_guided_json_completion(sample_json_schema, llm): sampling_params = SamplingParams( temperature=1.0, max_tokens=1000, - ) - outputs = llm.generate( - prompts=[ - f"Give an example JSON for an employee profile " - f"that fits this schema: {sample_json_schema}" - ] * 2, - sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_json=sample_json_schema)) + guided_decoding=GuidedDecodingParams(json=sample_json_schema)) + outputs = llm.generate(prompts=[ + f"Give an example JSON for an employee profile " + f"that fits this schema: {sample_json_schema}" + ] * 2, + sampling_params=sampling_params, + use_tqdm=True) assert outputs is not None @@ -86,12 +82,11 @@ def test_guided_choice_completion(sample_guided_choice, llm): sampling_params = SamplingParams( temperature=0.8, top_p=0.95, - ) + guided_decoding=GuidedDecodingParams(choice=sample_guided_choice)) outputs = llm.generate( prompts="The best language for type-safe systems programming is ", sampling_params=sampling_params, - use_tqdm=True, - guided_options_request=dict(guided_choice=sample_guided_choice)) + use_tqdm=True) assert outputs is not None for output in outputs: @@ -112,13 +107,13 @@ def test_guided_grammar(sample_sql_statements, llm): temperature=0.8, top_p=0.95, max_tokens=1000, - ) + guided_decoding=GuidedDecodingParams(grammar=sample_sql_statements)) outputs = llm.generate( prompts=("Generate a sql state that select col_1 from " "table_1 where it is equals to 1"), sampling_params=sampling_params, use_tqdm=True, - guided_options_request=dict(guided_grammar=sample_sql_statements)) + ) assert outputs is not None for output in outputs: @@ -140,3 +135,28 @@ def test_guided_grammar(sample_sql_statements, llm): assert generated_text.strip() == ground_truth print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + + +@pytest.mark.skip_global_cleanup +def test_guided_options_request_deprecation_warning(sample_regex, llm): + sampling_params = SamplingParams(temperature=0.8, top_p=0.95) + + with pytest.warns(DeprecationWarning, match="guided_options_request"): + llm.generate(prompts="This should fail", + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_regex=sample_regex)) + + +@pytest.mark.skip_global_cleanup +def test_validation_against_both_guided_decoding_options(sample_regex, llm): + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + guided_decoding=GuidedDecodingParams(regex=sample_regex)) + + with pytest.raises(ValueError, match="Cannot set both"): + llm.generate(prompts="This should fail", + sampling_params=sampling_params, + use_tqdm=True, + guided_options_request=dict(guided_regex=sample_regex)) diff --git a/tests/entrypoints/openai/test_chat_template.py b/tests/entrypoints/openai/test_chat_template.py index b98ab2e30d78d..e1e1dcff7475d 100644 --- a/tests/entrypoints/openai/test_chat_template.py +++ b/tests/entrypoints/openai/test_chat_template.py @@ -12,7 +12,7 @@ # Define models, templates, and their corresponding expected outputs MODEL_TEMPLATE_GENERATON_OUTPUT = [ - ("facebook/opt-125m", chatml_jinja_path, True, """<|im_start|>user + ("facebook/opt-125m", chatml_jinja_path, True, False, """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> @@ -20,12 +20,20 @@ What is the capital of<|im_end|> <|im_start|>assistant """), - ("facebook/opt-125m", chatml_jinja_path, False, """<|im_start|>user + ("facebook/opt-125m", chatml_jinja_path, False, False, """<|im_start|>user Hello<|im_end|> <|im_start|>assistant Hi there!<|im_end|> <|im_start|>user -What is the capital of""") +What is the capital of"""), + ("facebook/opt-125m", chatml_jinja_path, False, True, """<|im_start|>user +Hello<|im_end|> +<|im_start|>assistant +Hi there!<|im_end|> +<|im_start|>user +What is the capital of<|im_end|> +<|im_start|>assistant +The capital of"""), ] TEST_MESSAGES = [ @@ -42,6 +50,10 @@ 'content': 'What is the capital of' }, ] +ASSISTANT_MESSAGE_TO_CONTINUE = { + 'role': 'assistant', + 'content': 'The capital of' +} def test_load_chat_template(): @@ -73,10 +85,10 @@ def test_no_load_chat_template_literallike(): @pytest.mark.parametrize( - "model,template,add_generation_prompt,expected_output", + "model,template,add_generation_prompt,continue_final_message,expected_output", MODEL_TEMPLATE_GENERATON_OUTPUT) def test_get_gen_prompt(model, template, add_generation_prompt, - expected_output): + continue_final_message, expected_output): # Initialize the tokenizer tokenizer = get_tokenizer(tokenizer_name=model) template_content = load_chat_template(chat_template=template) @@ -84,8 +96,11 @@ def test_get_gen_prompt(model, template, add_generation_prompt, # Create a mock request object using keyword arguments mock_request = ChatCompletionRequest( model=model, - messages=TEST_MESSAGES, - add_generation_prompt=add_generation_prompt) + messages=TEST_MESSAGES + [ASSISTANT_MESSAGE_TO_CONTINUE] + if continue_final_message else TEST_MESSAGES, + add_generation_prompt=add_generation_prompt, + continue_final_message=continue_final_message, + ) # Call the function and get the result result = apply_hf_chat_template( @@ -93,6 +108,7 @@ def test_get_gen_prompt(model, template, add_generation_prompt, conversation=mock_request.messages, chat_template=mock_request.chat_template or template_content, add_generation_prompt=mock_request.add_generation_prompt, + continue_final_message=mock_request.continue_final_message, ) # Test assertion diff --git a/tests/entrypoints/openai/test_serving_chat.py b/tests/entrypoints/openai/test_serving_chat.py index db31745cc102e..ec550fe82c70f 100644 --- a/tests/entrypoints/openai/test_serving_chat.py +++ b/tests/entrypoints/openai/test_serving_chat.py @@ -15,6 +15,11 @@ BASE_MODEL_PATHS = [BaseModelPath(name=MODEL_NAME, model_path=MODEL_NAME)] +@dataclass +class MockHFConfig: + model_type: str = "any" + + @dataclass class MockModelConfig: tokenizer = MODEL_NAME @@ -24,6 +29,7 @@ class MockModelConfig: tokenizer_revision = None embedding_mode = False multimodal_config = MultiModalConfig() + hf_config = MockHFConfig() @dataclass diff --git a/tests/entrypoints/openai/test_tokenization.py b/tests/entrypoints/openai/test_tokenization.py index 316ca11b8e95a..859a676a9c777 100644 --- a/tests/entrypoints/openai/test_tokenization.py +++ b/tests/entrypoints/openai/test_tokenization.py @@ -104,28 +104,40 @@ async def test_tokenize_chat(client: openai.AsyncOpenAI, model_name: str, "role": "user", "content": "Can I ask a question? vllm1" }] - - prompt = tokenizer.apply_chat_template( - add_generation_prompt=add_generation, - conversation=conversation, - tokenize=False) - tokens = tokenizer.encode(prompt, add_special_tokens=add_special) - - response = requests.post(base_url + "/tokenize", - json={ - "add_generation_prompt": - add_generation, - "add_special_tokens": add_special, - "messages": conversation, - "model": model_name - }) - response.raise_for_status() - - assert response.json() == { - "tokens": tokens, - "count": len(tokens), - "max_model_len": 8192 - } + for continue_final in [False, True]: + if add_generation and continue_final: + continue + if continue_final: + conversation.append({ + "role": "assistant", + "content": "Sure," + }) + + prompt = tokenizer.apply_chat_template( + add_generation_prompt=add_generation, + continue_final_message=continue_final, + conversation=conversation, + tokenize=False) + tokens = tokenizer.encode(prompt, + add_special_tokens=add_special) + + response = requests.post(base_url + "/tokenize", + json={ + "add_generation_prompt": + add_generation, + "continue_final_message": + continue_final, + "add_special_tokens": add_special, + "messages": conversation, + "model": model_name + }) + response.raise_for_status() + + assert response.json() == { + "tokens": tokens, + "count": len(tokens), + "max_model_len": 8192 + } @pytest.mark.asyncio diff --git a/tests/kernels/test_causal_conv1d.py b/tests/kernels/test_causal_conv1d.py index 744e445fe6673..069020a536d0e 100644 --- a/tests/kernels/test_causal_conv1d.py +++ b/tests/kernels/test_causal_conv1d.py @@ -3,7 +3,6 @@ import pytest import torch import torch.nn.functional as F -from einops import rearrange from tests.kernels.utils import opcheck from vllm import _custom_ops as ops # noqa: F401 @@ -57,43 +56,72 @@ def causal_conv1d_ref( return (out, None) if not return_final_states else (out, final_states_out) -def causal_conv1d_update_ref(x: torch.Tensor, - conv_state: torch.Tensor, - weight: torch.Tensor, - bias: Optional[torch.Tensor] = None, - activation: Optional[str] = None): +def causal_conv1d_update_ref(x, + conv_state, + weight, + bias=None, + activation=None, + cache_seqlens=None): """ - x: (batch, dim) - conv_state: (batch, dim, width) + x: (batch, dim) or (batch, dim, seqlen) + conv_state: (batch, dim, state_len), where state_len >= width - 1 weight: (dim, width) bias: (dim,) + cache_seqlens: (batch,), dtype int32. + If not None, the conv_state is treated as a circular buffer. + The conv_state will be updated by copying x to the + conv_state starting at the index + @cache_seqlens % state_len before performing the convolution. - out: (batch, dim) + out: (batch, dim) or (batch, dim, seqlen) """ if activation not in [None, "silu", "swish"]: raise NotImplementedError("activation must be None, silu, or swish") dtype_in = x.dtype - batch, dim = x.shape + unsqueeze = x.dim() == 2 + if unsqueeze: + x = x.unsqueeze(-1) + batch, dim, seqlen = x.shape width = weight.shape[1] - assert conv_state.shape == (batch, dim, width) + state_len = conv_state.shape[-1] + assert conv_state.shape == (batch, dim, state_len) assert weight.shape == (dim, width) - conv_state.copy_(torch.roll(conv_state, shifts=-1, - dims=-1)) # Update state (B D W) - conv_state[:, :, -1] = x - out = torch.sum(conv_state * weight, dim=-1) # (B D) - if bias is not None: - out += bias + if cache_seqlens is None: + x_new = torch.cat([conv_state, x], dim=-1).to( + weight.dtype) # (batch, dim, state_len + seqlen) + conv_state.copy_(x_new[:, :, -state_len:]) + else: + width_idx = torch.arange( + -(width - 1), 0, dtype=torch.long, + device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand( + -1, dim, -1) + x_new = torch.cat([conv_state.gather(2, width_idx), x], + dim=-1).to(weight.dtype) + copy_idx = torch.arange( + seqlen, dtype=torch.long, + device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1) + copy_idx = torch.remainder(copy_idx, + state_len).unsqueeze(1).expand(-1, dim, -1) + conv_state.scatter_(2, copy_idx, x) + out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, + groups=dim)[:, :, -seqlen:] + if unsqueeze: + out = out.squeeze(-1) return (out if activation is None else F.silu(out)).to(dtype=dtype_in) +@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) +@pytest.mark.parametrize("silu_activation", [True]) +@pytest.mark.parametrize("has_bias", [True]) def causal_conv1d_opcheck_fn( x: torch.Tensor, weight: torch.Tensor, bias: Optional[torch.Tensor] = None, - seq_idx: Optional[torch.Tensor] = None, - initial_states: Optional[torch.Tensor] = None, - return_final_states: bool = False, - final_states_out=None, + cu_seq_len: Optional[torch.Tensor] = None, + cache_indices: Optional[torch.Tensor] = None, + has_initial_state: Optional[torch.Tensor] = None, + conv_states: Optional[torch.Tensor] = None, activation: Optional[str] = "silu", ): """ @@ -109,135 +137,93 @@ def causal_conv1d_opcheck_fn( """ if activation not in [None, "silu", "swish"]: raise NotImplementedError("activation must be None, silu, or swish") - if x.stride(2) != 1 and x.stride(1) != 1: + if x.stride(-1) != 1: x = x.contiguous() bias = bias.contiguous() if bias is not None else None - if seq_idx is not None: - assert (initial_states is - None), "initial_states must be None if seq_idx is not None" - assert (not return_final_states - ), "If seq_idx is not None, we don't return final_states_out" - seq_idx = seq_idx.contiguous() if seq_idx is not None else None - if initial_states is not None and (initial_states.stride(2) != 1 - and initial_states.stride(1) != 1): - initial_states = initial_states.contiguous() - if return_final_states: - assert ( - x.stride(1) == 1 - ), "Only channel-last layout support returning final_states_out" - if final_states_out is not None: - assert (final_states_out.stride(2) == 1 - or final_states_out.stride(1) == 1) - else: - batch, dim, seqlen = x.shape - width = weight.shape[1] - final_states_out = torch.empty(batch, - width - 1, - dim, - device=x.device, - dtype=x.dtype).transpose(1, 2) - else: - final_states_out = None - opcheck(torch.ops._C.causal_conv1d_fwd, - (x, weight, bias, seq_idx, initial_states, final_states_out, - activation in ["silu", "swish"])) + opcheck(torch.ops._C.causal_conv1d_fwd, ( + x, + weight, + bias, + conv_states, + cu_seq_len, + cache_indices, + has_initial_state, + activation in ["silu", "swish"], + )) -@pytest.mark.parametrize("return_final_states", [False, True]) -@pytest.mark.parametrize("has_initial_states", [False, True]) -@pytest.mark.parametrize("channel_last", [False, True]) -@pytest.mark.parametrize("itype", [torch.bfloat16]) -@pytest.mark.parametrize("silu_activation", [False, True]) -@pytest.mark.parametrize("has_bias", [False, True]) +@pytest.mark.parametrize("itype", [torch.bfloat16, torch.float]) +@pytest.mark.parametrize("silu_activation", [True]) +@pytest.mark.parametrize("has_bias", [True]) @pytest.mark.parametrize("width", [4]) -@pytest.mark.parametrize("seqlen", [128, 512, 4096]) -@pytest.mark.parametrize('dim', [64, 4096 + 32]) -@pytest.mark.parametrize('batch', [1, 2]) +@pytest.mark.parametrize( + 'seqlen', [1, 8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) +@pytest.mark.parametrize('dim', [64]) +@pytest.mark.parametrize('batch', [1]) def test_causal_conv1d(batch, dim, seqlen, width, has_bias, silu_activation, - itype, channel_last, has_initial_states, - return_final_states): - if not channel_last and (has_initial_states or return_final_states): - pytest.skip( - "Only channel_last support initial_states or return_final_states") + itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 # set seed seed_everything(0) - if not channel_last: - x = torch.randn(batch, - 4096 + dim + 64, - seqlen, - device=device, - dtype=itype)[:, 4096:4096 + dim, :] - else: - x = rearrange( - torch.randn(batch, - seqlen, - 4096 + dim + 64, - device=device, - dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s") + x = torch.randn(batch, dim, seqlen, device=device, + dtype=itype).contiguous() + weight = torch.randn(dim, width, device=device, dtype=itype) bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None - if has_initial_states: - initial_states = torch.randn(batch, - width - 1, - dim, - device=device, - dtype=itype).transpose(1, 2) - else: - initial_states = None - x_ref = x.detach().clone() - weight_ref = weight.detach().clone() - bias_ref = bias.detach().clone() if bias is not None else None - initial_states_ref = initial_states.detach().clone( + initial_states = torch.randn(batch, + dim, + width - 1, + device=device, + dtype=itype) + x_ref = x.clone() + weight_ref = weight.clone() + bias_ref = bias.clone() if bias is not None else None + initial_states_ref = initial_states.clone( ) if initial_states is not None else None activation = None if not silu_activation else "silu" - out, final_states = causal_conv1d_fn( - x, - weight, - bias, - initial_states=initial_states, - return_final_states=return_final_states, - activation=activation) + out = causal_conv1d_fn(x, + weight, + bias, + activation=activation, + conv_states=initial_states, + has_initial_state=torch.ones(batch, + dtype=torch.bool, + device=x.device)) out_ref, final_states_ref = causal_conv1d_ref( x_ref, weight_ref, bias_ref, initial_states=initial_states_ref, - return_final_states=return_final_states, + return_final_states=True, activation=activation) - - causal_conv1d_opcheck_fn(x_ref, - weight_ref, - bias_ref, - initial_states=initial_states_ref, - return_final_states=return_final_states, - activation=activation) - - if return_final_states: - assert final_states is not None and final_states_ref is not None - assert torch.allclose(final_states, - final_states_ref, - rtol=rtol, - atol=atol) - + assert initial_states is not None and final_states_ref is not None + assert torch.allclose(initial_states, + final_states_ref, + rtol=rtol, + atol=atol) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - if return_final_states: - out += F.sigmoid(final_states).sum(dim=-1, keepdim=True) - out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True) + causal_conv1d_opcheck_fn(x, + weight, + bias, + activation=activation, + conv_states=initial_states, + has_initial_state=torch.ones(batch, + dtype=torch.bool, + device=x.device)) @pytest.mark.parametrize("itype", [torch.bfloat16]) @pytest.mark.parametrize("silu_activation", [False, True]) @pytest.mark.parametrize("has_bias", [False, True]) -@pytest.mark.parametrize("width", [2, 3, 4]) +@pytest.mark.parametrize("seqlen", [1]) +@pytest.mark.parametrize("width", [4]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) -@pytest.mark.parametrize("batch", [1, 2]) -def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, +def test_causal_conv1d_update(dim, width, seqlen, has_bias, silu_activation, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) @@ -246,8 +232,9 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, # set seed seed_everything(0) batch = 2 - x = torch.randn(batch, dim, device=device, dtype=itype) - conv_state = torch.randn(batch, dim, width, device=device, dtype=itype) + x = torch.randn(batch, dim, seqlen, device=device, dtype=itype) + conv_state = torch.randn(batch, dim, width - 1, device=device, dtype=itype) + weight = torch.randn(dim, width, device=device, @@ -273,9 +260,15 @@ def test_causal_conv1d_update(batch, dim, width, has_bias, silu_activation, assert torch.equal(conv_state, conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - opcheck( - torch.ops._C.causal_conv1d_update, - (x, conv_state, weight, bias, activation in ["silu", "swish"], None)) + opcheck(torch.ops._C.causal_conv1d_update, ( + x, + conv_state, + weight, + bias, + activation in ["silu", "swish"], + None, + None, + )) @pytest.mark.parametrize("itype", @@ -292,16 +285,16 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, if itype == torch.bfloat16: rtol, atol = 1e-2, 5e-2 - # set seed - torch.random.manual_seed(0) + # set )seed + seed_everything(0) batch = 64 - x = torch.randn(batch, dim, device=device, dtype=itype) + x = torch.randn(batch, dim, 1, device=device, dtype=itype) total_entries = 10 * batch conv_state = torch.randn(total_entries, dim, - width, + width - 1, device=device, dtype=itype) conv_state_indices = torch.randperm(total_entries)[:batch].to( @@ -332,3 +325,100 @@ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_bias, assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref) assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + opcheck(torch.ops._C.causal_conv1d_update, ( + x, + conv_state, + weight, + bias, + activation in ["silu", "swish"], + None, + conv_state_indices, + )) + + +@pytest.mark.parametrize("itype", [torch.bfloat16]) +@pytest.mark.parametrize("silu_activation", [True]) +@pytest.mark.parametrize("has_bias", [True]) +@pytest.mark.parametrize("width", [4]) +@pytest.mark.parametrize('seqlen', + [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096]) +@pytest.mark.parametrize('dim', [64, 4096]) +def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, + itype): + device = "cuda" + rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 1e-2, 5e-2 + # set seed + seed_everything(0) + batch = 1 + seqlens = [] + nsplits = 3 + eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values + seqlens.append( + torch.diff( + torch.cat( + [torch.tensor([-1]), eos_pos, + torch.tensor([seqlen - 1])])).tolist()) + assert sum(seqlens[-1]) == seqlen + assert all(s > 0 for s in seqlens[-1]) + + cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) + cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], + dim=0) + x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, + dtype=itype)[:, 4096:4096 + dim, :] + weight = torch.randn(dim, width, device=device, dtype=itype) + bias = torch.randn(dim, device=device, dtype=itype) if has_bias else None + x_ref = x.clone() + weight_ref = weight.clone() + bias_ref = bias.clone() if bias is not None else None + activation = None if not silu_activation else "silu" + final_states = torch.randn(nsplits + 1, + dim, + width - 1, + device=x.device, + dtype=x.dtype) + final_states_ref = final_states.clone() + has_initial_states = torch.randint(0, + 2, (cumsum.shape[0] - 1, ), + dtype=torch.bool, + device=x.device) + cache_indices = torch.randperm(cumsum.shape[0] - 1, + dtype=torch.int32, + device=x.device) + out = causal_conv1d_fn(x.squeeze(0), weight, bias, cumsum.cuda(), + cache_indices, has_initial_states, final_states, + activation) + out_ref = [] + out_ref_b = [] + + splits = [torch.split(var, seqlens[0], dim=-1) for var in (x_ref)] + for i in range(len(seqlens[0])): + x_s = [v[i].unsqueeze(0) for v in splits][0] + out_ref_b.append( + causal_conv1d_ref( + x_s, + weight_ref, + bias_ref, + activation=activation, + return_final_states=True, + final_states_out=final_states_ref[cache_indices[i]].unsqueeze( + 0), + initial_states=final_states_ref[cache_indices[i]].unsqueeze(0) + if has_initial_states[i] else None)) + out_ref.append(torch.cat([t[0] for t in out_ref_b], dim=2)) + out_ref = torch.cat(out_ref, dim=0) + + print(f"Output max diff: {(out - out_ref).abs().max().item()}") + print(f"Output mean diff: {(out - out_ref).abs().mean().item()}") + print("Output state max diff" + f":{(final_states - final_states_ref).abs().max()}") + print("Output state mean diff" + f":{(final_states - final_states_ref).abs().mean()}") + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol) + causal_conv1d_opcheck_fn(x.squeeze(0), weight, bias, cumsum.cuda(), + cache_indices, has_initial_states, final_states, + activation) diff --git a/tests/kernels/test_encoder_decoder_attn.py b/tests/kernels/test_encoder_decoder_attn.py index b550a7fdd84f0..6b979d0558c46 100644 --- a/tests/kernels/test_encoder_decoder_attn.py +++ b/tests/kernels/test_encoder_decoder_attn.py @@ -136,7 +136,9 @@ class that Attention will automatically select when it is constructed. ) if test_pt.num_blocks is None or test_pt.num_heads is None: # Caller does not require a KV cache - return TestResources(scale, attn_backend, attn, None) + return TestResources( + scale, attn_backend, attn, + torch.tensor([], dtype=torch.float32, device=CUDA_DEVICE)) # Construct KV cache kv_cache = make_kv_cache(test_pt.num_blocks, @@ -620,7 +622,9 @@ def _run_encoder_attention_test( return attn.forward(packed_qkv.query, packed_qkv.key, packed_qkv.value, - None, + torch.tensor([], + dtype=torch.float32, + device=packed_qkv.query.device), attn_metadata, attn_type=attn_type) diff --git a/tests/kernels/test_mamba_ssm.py b/tests/kernels/test_mamba_ssm.py index 5a6149562e886..8fa55e75f6c11 100644 --- a/tests/kernels/test_mamba_ssm.py +++ b/tests/kernels/test_mamba_ssm.py @@ -98,8 +98,8 @@ def selective_scan_ref(u, delta_bias=None, delta_softplus=False, return_last_state=False, - position_indices=None, - prev_state=None): + prev_state=None, + final_state_out=None): """ u: r(B D L) delta: r(B D L) @@ -139,12 +139,8 @@ def selective_scan_ref(u, deltaB_u = torch.einsum('bdl,bdnl,bdl->bdln', delta, B, u) if is_variable_C and C.dim() == 4: C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) - last_state = None for i in range(u.shape[2]): - if position_indices is not None and position_indices[0, i] == 0: - x = deltaB_u[:, :, i] - else: - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] if not is_variable_C: y = torch.einsum('bdn,dn->bd', x, C) else: @@ -153,14 +149,17 @@ def selective_scan_ref(u, else: y = torch.einsum('bdn,bdn->bd', x, C[:, :, :, i]) if i == u.shape[2] - 1: - last_state = x + if final_state_out is None: + final_state_out = x + else: + final_state_out.copy_(x) ys.append(y) y = torch.stack(ys, dim=2) # (batch dim L) out = y if D is None else y + u * rearrange(D, "d -> d 1") if z is not None: out = out * F.silu(z) out = out.to(dtype=dtype_in) - return out if not return_last_state else (out, last_state) + return out if not return_last_state else (out, final_state_out) def selective_scan_opcheck_fn(u, @@ -172,9 +171,10 @@ def selective_scan_opcheck_fn(u, z=None, delta_bias=None, delta_softplus=False, - return_last_state=False, - position_indices=None, - prev_state=None): + cu_seq_len=None, + cache_indices=None, + has_initial_state=None, + ssm_states=None): """if return_last_state is True, returns (out, last_state) last_state has shape (batch, dim, dstate). """ @@ -190,36 +190,27 @@ def selective_scan_opcheck_fn(u, C = C.contiguous() if z is not None and z.stride(-1) != 1: z = z.contiguous() - if B.dim() == 3: + if B.dim() == 3 and cu_seq_len is None: B = B.unsqueeze(1) - if C.dim() == 3: + if B.dim() == 2 and cu_seq_len is not None: + B = B.unsqueeze(0) + if C.dim() == 3 and cu_seq_len is None: C = C.unsqueeze(1) - n_chunks = int((u.shape[-1] + 2048 - 1) / 2048) - x = torch.zeros(( - u.shape[0], - u.shape[1], - n_chunks, - int(A.shape[1] * 2), - ), - device=u.device, - dtype=torch.float32, - requires_grad=False) - x[:, :, 0, 0::2] = 1 - if prev_state is not None: - x[:, :, 0, 1::2].copy_(prev_state) + if C.dim() == 2 and cu_seq_len is not None: + C = C.unsqueeze(0) # Disable test_autograd_registration for now as it seems to trigger # a bogus error. opcheck(torch.ops._C.selective_scan_fwd, - (u, delta, A, B, C, D, z, delta_bias, delta_softplus, - position_indices, x), + (u, delta, A, B, C, D, z, delta_bias, delta_softplus, cu_seq_len, + cache_indices, has_initial_state, ssm_states), test_utils=["test_schema", "test_faketensor"]) @pytest.mark.parametrize('wtype', [torch.float32]) -@pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize('itype', + [torch.float32, torch.float16, torch.bfloat16]) @pytest.mark.parametrize('seqlen', [128, 256, 512, 1024, 2048, 4096]) -@pytest.mark.parametrize("return_last_state", [True]) @pytest.mark.parametrize('has_delta_bias', [True]) @pytest.mark.parametrize('delta_softplus', [True]) @pytest.mark.parametrize('has_z', [True]) @@ -229,8 +220,8 @@ def selective_scan_opcheck_fn(u, @pytest.mark.parametrize("is_variable_B", [True]) @pytest.mark.parametrize("scan_chunks", [1, 2, 3]) def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, - has_z, has_delta_bias, delta_softplus, - return_last_state, seqlen, itype, wtype, scan_chunks): + has_z, has_delta_bias, delta_softplus, seqlen, itype, + wtype, scan_chunks): if varBC_groups > 1 and (not is_variable_B or not is_variable_C): pytest.skip() # This config is not applicable device = 'cuda' @@ -243,10 +234,11 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, atolw = max(atolw, atol) # set seed seed_everything(0) - batch_size = 2 + batch_size = 1 dim = 4 dstate = 8 A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) + A_ref = A.clone() if not is_variable_B: B_shape = [dim, dstate] elif varBC_groups == 1: @@ -256,6 +248,7 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, B = torch.randn(B_shape, device=device, dtype=wtype if not is_variable_B else itype) + B_ref = B.clone() if not is_variable_C: C_shape = [dim, dstate] elif varBC_groups == 1: @@ -265,16 +258,25 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, C = torch.randn(C_shape, device=device, dtype=wtype if not is_variable_C else itype) + C_ref = C.clone() D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None + D_ref = D.clone() z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) if has_z else None + z_ref = z.clone() if has_z else None delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) ) if has_delta_bias else None u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) + u_ref = u.clone() delta = (0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype)) - state = None - state_ref = None + delta_ref = delta.clone() + state_shape = (batch_size, u.shape[1], int(A.shape[1])) + state = torch.randn(state_shape, + device=u.device, + dtype=itype, + requires_grad=False) + state_ref = state.clone() out = None out_ref = None outs = [] @@ -294,40 +296,40 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, if has_z: assert z is not None _z = z[..., chunk_start:chunk_end] - out, *rest = selective_scan_fn(u[..., chunk_start:chunk_end], - delta[..., chunk_start:chunk_end], - A, - _B, - _C, - D, - z=_z, - delta_bias=delta_bias, - delta_softplus=delta_softplus, - return_last_state=return_last_state, - prev_state=state if c > 0 else None) + out = selective_scan_fn( + u[..., chunk_start:chunk_end], + state, + delta[..., chunk_start:chunk_end], + A, + _B, + _C, + D, + z=_z, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + has_initial_state=torch.ones(batch_size, + device=u.device, + dtype=torch.bool) if c > 0 else None) outs.append(out) - if return_last_state: - state = rest[0] if len(outs) > 1: out = torch.cat(outs, dim=-1) - out_ref, *rest = selective_scan_ref(u, - delta, - A, - B, - C, - D, - z=z, - delta_bias=delta_bias, - delta_softplus=delta_softplus, - return_last_state=return_last_state) - if return_last_state: - state_ref = rest[0] + + out_ref, state_ref, *rest = selective_scan_ref( + u_ref, + delta_ref, + A_ref, + B_ref, + C_ref, + D_ref, + z=z_ref, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=True) assert out is not None and out_ref is not None assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) - if return_last_state: - assert state is not None and state_ref is not None - assert torch.allclose(state, state_ref, rtol=rtol, atol=atol) + assert state is not None and state_ref is not None + assert torch.allclose(state, state_ref.to(itype), rtol=rtol, atol=atol) selective_scan_opcheck_fn(u, delta, @@ -335,10 +337,10 @@ def test_selective_scan(is_variable_B, is_variable_C, varBC_groups, has_D, B, C, D, - z=z, + z, delta_bias=delta_bias, delta_softplus=delta_softplus, - return_last_state=return_last_state) + ssm_states=state) @pytest.mark.parametrize("itype", @@ -391,9 +393,131 @@ def test_selective_state_update(dim, dstate, has_z, itype): assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) +@pytest.mark.parametrize('wtype', [torch.float32]) +@pytest.mark.parametrize('itype', [torch.float32]) +@pytest.mark.parametrize('seqlen', [1, 128, 129, 256, 512, 1024, 2048, 4096]) +@pytest.mark.parametrize("return_last_state", [True]) +@pytest.mark.parametrize('has_delta_bias', [True]) +@pytest.mark.parametrize('delta_softplus', [True]) +@pytest.mark.parametrize('has_z', [True]) +@pytest.mark.parametrize('has_D', [True]) +@pytest.mark.parametrize("varBC_groups", [1, 2]) +@pytest.mark.parametrize("is_variable_C", [True]) +@pytest.mark.parametrize("is_variable_B", [True]) +def test_selective_scan_varlen(is_variable_B, is_variable_C, varBC_groups, + has_D, has_z, has_delta_bias, delta_softplus, + return_last_state, seqlen, itype, wtype): + if varBC_groups > 1 and (not is_variable_B or not is_variable_C): + pytest.skip() # This config is not applicable + device = 'cuda' + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + rtolw, atolw = (1e-3, 1e-3) + if has_z: # If we have z, the errors on the weights seem higher + rtolw = max(rtolw, rtol) + atolw = max(atolw, atol) + # set seed + torch.random.manual_seed(0) + seqlens = [] + nsplits = 3 + if seqlen < 10: + nsplits = 0 + eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values + seqlens.append( + torch.diff( + torch.cat( + [torch.tensor([-1]), eos_pos, + torch.tensor([seqlen - 1])])).tolist()) + assert sum(seqlens[-1]) == seqlen + assert all(s > 0 for s in seqlens[-1]) + + cumsum = torch.cumsum(torch.tensor(seqlens[0]), dim=0).to(torch.int32) + cumsum = torch.concat([torch.tensor([0], dtype=torch.int32), cumsum], + dim=0).cuda() + + dim = 4 + dstate = 8 + A = (-0.5 * torch.rand(dim, dstate, device=device, dtype=wtype)) + A_ref = A.clone() + B_shape = [varBC_groups, dstate, seqlen] + B = torch.randn(B_shape, + device=device, + dtype=wtype if not is_variable_B else itype) + B_ref = B.clone() + C_shape = [varBC_groups, dstate, seqlen] + C = torch.randn(C_shape, + device=device, + dtype=wtype if not is_variable_C else itype) + C_ref = C.clone() + D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None + D_ref = D.clone() + z = torch.randn(dim, seqlen, device=device, dtype=itype) + z_ref = z.clone() + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) + ) if has_delta_bias else None + u = torch.randn(dim, seqlen, device=device, dtype=itype) + u_ref = u.clone() + delta = (0.5 * torch.rand(dim, seqlen, device=device, dtype=itype)) + delta_ref = delta.clone() + out = None + out_ref = None + prev_state_shape = (cumsum.shape[0] - 1, u.shape[0], int(A.shape[1])) + prev_state = torch.randn(prev_state_shape, + device=u.device, + dtype=itype, + requires_grad=False) + prev_state_ref = prev_state.clone() + cache_indices = torch.randperm(cumsum.shape[0] - 1, + dtype=torch.int32, + device=u.device) + + has_initial_state = torch.randint(0, + 2, (cumsum.shape[0] - 1, ), + dtype=torch.bool, + device=u.device) + out = selective_scan_fn(u, prev_state, delta, A, B, C, D, z, delta_bias, + delta_softplus, cumsum, cache_indices, + has_initial_state) + outs_ref = [] + splits = [ + torch.split(var, seqlens[0], dim=-1) + for var in (u_ref, delta_ref, B_ref, C_ref, z_ref) + ] + for i in range(len(seqlens[0])): + u_s, delta_s, B_s, C_s, z_s = [v[i].unsqueeze(0) for v in splits] + out_ref_s, _ = selective_scan_ref( + u_s, + delta_s, + A_ref, + B_s, + C_s, + D_ref, + z=z_s, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + return_last_state=return_last_state, + prev_state=prev_state_ref[cache_indices[i]].unsqueeze(0) + if has_initial_state[i] else None, + final_state_out=prev_state_ref[cache_indices[i]].unsqueeze(0)) + outs_ref.append(out_ref_s) + out_ref = torch.cat(outs_ref, dim=-1) if len(outs_ref) > 1 else outs_ref[0] + + print("Output diff max", (out - out_ref[0]).max()) + print("Output diff mean", (out - out_ref[0]).mean()) + print("Output state diff max", (prev_state - prev_state_ref).max()) + print("Output state diff mean", (prev_state - prev_state_ref).mean()) + assert torch.allclose(prev_state, prev_state_ref, rtol=rtol, atol=atol) + assert torch.allclose(out, out_ref[0], rtol=rtol, atol=atol) + + selective_scan_opcheck_fn(u, delta, A, B, C, D, z, delta_bias, + delta_softplus, cumsum, cache_indices, + has_initial_state, prev_state) + + @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("has_z", [True]) @pytest.mark.parametrize("dstate", [16, 32, 64]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): @@ -405,7 +529,7 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): atol *= 2 # set seed torch.random.manual_seed(0) - batch_size = 16 + batch_size = 3 total_entries = 10 * batch_size state = torch.randn(total_entries, dim, dstate, dtype=itype, device=device) @@ -443,6 +567,11 @@ def test_selective_state_update_with_batch_indices(dim, dstate, has_z, itype): dt_bias=dt_bias, dt_softplus=True) + print("Output diff max", (out - out_ref[0]).max()) + print("Output diff mean", (out - out_ref[0]).mean()) + print("Output state diff max", (state[state_indices, :] - state_ref).max()) + print("Output state diff mean", + (state[state_indices, :] - state_ref).mean()) assert torch.allclose(state[state_indices, :], state_ref, rtol=rtol, @@ -465,7 +594,7 @@ def test_selective_state_update_with_heads_with_batch_indices( rtol, atol = 1e-1, 1e-1 # set seed torch.random.manual_seed(0) - batch_size = 16 + batch_size = 3 headdim = 64 nheads = dim // headdim diff --git a/tests/kernels/test_moe.py b/tests/kernels/test_moe.py index c6ddcc8ce79f5..cbbb5c9b79c42 100644 --- a/tests/kernels/test_moe.py +++ b/tests/kernels/test_moe.py @@ -145,6 +145,7 @@ def compute_max_diff(output, output_ref): @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("is_k_full", [True, False]) def test_fused_marlin_moe( m: int, n: int, @@ -154,6 +155,7 @@ def test_fused_marlin_moe( group_size: int, act_order: bool, num_bits: int, + is_k_full: bool, ): seed_everything(7) @@ -166,6 +168,9 @@ def test_fused_marlin_moe( return if group_size in (k, n): return + else: + if not is_k_full: + return quant_type = (scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128) @@ -246,6 +251,7 @@ def test_fused_marlin_moe( w1_scale=scales1, w2_scale=scales2, num_bits=num_bits, + is_k_full=is_k_full, ) assert compute_max_diff(marlin_output, triton_output) < 4e-2 @@ -290,6 +296,7 @@ def test_fused_marlin_moe( @pytest.mark.parametrize("group_size", [-1, 32, 64, 128]) @pytest.mark.parametrize("act_order", [True, False]) @pytest.mark.parametrize("num_bits", [4, 8]) +@pytest.mark.parametrize("is_k_full", [True, False]) def test_single_marlin_moe_multiply( m: int, n: int, @@ -299,6 +306,7 @@ def test_single_marlin_moe_multiply( group_size: int, act_order: bool, num_bits: int, + is_k_full: bool, ): if topk > e: return @@ -309,6 +317,9 @@ def test_single_marlin_moe_multiply( return if group_size == k: return + else: + if not is_k_full: + return quant_type = (scalar_types.uint4b8 if num_bits == 4 else scalar_types.uint8b128) @@ -339,15 +350,18 @@ def test_single_marlin_moe_multiply( sort_indices = stack_and_dev(sort_indices_l) score = torch.randn((m, e), device="cuda", dtype=dtype) - marlin_output = single_marlin_moe(a, - qweight, - scales, - score, - g_idx, - sort_indices, - topk, - renormalize=False, - num_bits=num_bits) + marlin_output = single_marlin_moe( + a, + qweight, + scales, + score, + g_idx, + sort_indices, + topk, + renormalize=False, + num_bits=num_bits, + is_k_full=is_k_full, + ) torch_output = torch_moe_single(a, w_ref.transpose(1, 2), score, topk) assert compute_max_diff(marlin_output, torch_output) < 1e-2 diff --git a/tests/lora/conftest.py b/tests/lora/conftest.py index 1c30f4147e8b5..65e38b2e4e6e4 100644 --- a/tests/lora/conftest.py +++ b/tests/lora/conftest.py @@ -197,6 +197,11 @@ def baichuan_zero_lora_files(): return snapshot_download(repo_id="jeeejeee/baichuan7b-zero-init") +@pytest.fixture(scope="session") +def minicpmv_lora_files(): + return snapshot_download(repo_id="jeeejeee/minicpmv25-lora-pokemon") + + @pytest.fixture(scope="session") def tinyllama_lora_files(): return snapshot_download(repo_id="jashing/tinyllama-colorist-lora") diff --git a/tests/lora/test_baichuan.py b/tests/lora/test_baichuan.py index 56cec4db89e64..cbc3668997817 100644 --- a/tests/lora/test_baichuan.py +++ b/tests/lora/test_baichuan.py @@ -63,12 +63,11 @@ def test_baichuan_lora(baichuan_lora_files): assert output2[i] == expected_lora_output[i] -@pytest.mark.skip("Requires multiple GPUs") @pytest.mark.parametrize("fully_sharded", [True, False]) -def test_baichuan_tensor_parallel_equality(baichuan_lora_files, fully_sharded): - # Cannot use as it will initialize torch.cuda too early... - # if torch.cuda.device_count() < 4: - # pytest.skip(f"Not enough GPUs for tensor parallelism {4}") +def test_baichuan_tensor_parallel_equality(baichuan_lora_files, + num_gpus_available, fully_sharded): + if num_gpus_available < 4: + pytest.skip(f"Not enough GPUs for tensor parallelism {4}") llm_tp1 = vllm.LLM(MODEL_PATH, enable_lora=True, diff --git a/tests/lora/test_minicpmv.py b/tests/lora/test_minicpmv.py new file mode 100644 index 0000000000000..81b8188e638c9 --- /dev/null +++ b/tests/lora/test_minicpmv.py @@ -0,0 +1,71 @@ +from typing import List + +import vllm +from vllm.assets.image import ImageAsset +from vllm.lora.request import LoRARequest + +MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" + +PROMPT_TEMPLATE = ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + "(./)\nWhat is in the image?<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n") + +IMAGE_ASSETS = [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), +] + +# After fine-tuning with LoRA, all generated content should start begin `A`. +EXPECTED_OUTPUT = [ + "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501 + "A pink cherry blossom tree with a blue sky in the background.", +] + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + sampling_params = vllm.SamplingParams( + temperature=0, + max_tokens=5, + stop_token_ids=[128001, 128009], # eos_id, eot_id + ) + + inputs = [{ + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": { + "image": asset.pil_image + }, + } for asset in IMAGE_ASSETS] + + outputs = llm.generate( + inputs, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None, + ) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +def test_minicpmv_lora(minicpmv_lora_files): + llm = vllm.LLM( + MODEL_PATH, + max_num_seqs=2, + enable_lora=True, + max_loras=4, + max_lora_rank=64, + trust_remote_code=True, + ) + + output1 = do_sample(llm, minicpmv_lora_files, lora_id=1) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output1[i]) + output2 = do_sample(llm, minicpmv_lora_files, lora_id=2) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output2[i]) diff --git a/tests/lora/test_minicpmv_tp.py b/tests/lora/test_minicpmv_tp.py new file mode 100644 index 0000000000000..ba29e562e58ec --- /dev/null +++ b/tests/lora/test_minicpmv_tp.py @@ -0,0 +1,95 @@ +from typing import List + +import pytest + +import vllm +from vllm.assets.image import ImageAsset +from vllm.lora.request import LoRARequest + +from ..utils import multi_gpu_test + +MODEL_PATH = "openbmb/MiniCPM-Llama3-V-2_5" + +PROMPT_TEMPLATE = ( + "<|begin_of_text|><|start_header_id|>user<|end_header_id|>\n\n" + "(./)\nWhat is in the image?<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n") + +IMAGE_ASSETS = [ + ImageAsset("stop_sign"), + ImageAsset("cherry_blossom"), +] + +# After fine-tuning with LoRA, all generated content should start begin `A`. +EXPECTED_OUTPUT = [ + "A red and white stop sign with a Chinese archway in the background featuring red lanterns and gold accents.", # noqa: E501 + "A pink cherry blossom tree with a blue sky in the background.", +] + + +def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: + sampling_params = vllm.SamplingParams( + temperature=0, + max_tokens=5, + stop_token_ids=[128001, 128009], # eos_id, eot_id + ) + + inputs = [{ + "prompt": PROMPT_TEMPLATE, + "multi_modal_data": { + "image": asset.pil_image + }, + } for asset in IMAGE_ASSETS] + + outputs = llm.generate( + inputs, + sampling_params, + lora_request=LoRARequest(str(lora_id), lora_id, lora_path) + if lora_id else None, + ) + # Print the outputs. + generated_texts: List[str] = [] + for output in outputs: + prompt = output.prompt + generated_text = output.outputs[0].text.strip() + generated_texts.append(generated_text) + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + return generated_texts + + +@multi_gpu_test(num_gpus=2) +@pytest.mark.parametrize("fully_sharded", [True, False]) +def test_minicpmv_tp2(minicpmv_lora_files, fully_sharded): + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=2, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=2, + trust_remote_code=True, + fully_sharded_loras=fully_sharded, + ) + + output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) + + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) + + +@multi_gpu_test(num_gpus=4) +@pytest.mark.parametrize("fully_sharded", [True, False]) +def test_minicpmv_tp4(minicpmv_lora_files, fully_sharded): + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=2, + max_loras=4, + max_lora_rank=64, + tensor_parallel_size=4, + trust_remote_code=True, + fully_sharded_loras=fully_sharded, + ) + output_tp = do_sample(llm, minicpmv_lora_files, lora_id=1) + for i in range(len(EXPECTED_OUTPUT)): + assert EXPECTED_OUTPUT[i].startswith(output_tp[i]) diff --git a/tests/lora/test_quant_model.py b/tests/lora/test_quant_model.py index 133e0d4514a6d..5636c96435024 100644 --- a/tests/lora/test_quant_model.py +++ b/tests/lora/test_quant_model.py @@ -71,10 +71,10 @@ def format_prompt_tuples(prompt): @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("tp_size", [1]) -def test_quant_model_lora(tinyllama_lora_files, model, tp_size): - # Cannot use as it will initialize torch.cuda too early... - # if torch.cuda.device_count() < tp_size: - # pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") +def test_quant_model_lora(tinyllama_lora_files, num_gpus_available, model, + tp_size): + if num_gpus_available < tp_size: + pytest.skip(f"Not enough GPUs for tensor parallelism {tp_size}") llm = vllm.LLM( model=model.model_path, @@ -164,11 +164,10 @@ def expect_match(output, expected_output): @pytest.mark.parametrize("model", MODELS) -@pytest.mark.skip("Requires multiple GPUs") -def test_quant_model_tp_equality(tinyllama_lora_files, model): - # Cannot use as it will initialize torch.cuda too early... - # if torch.cuda.device_count() < 2: - # pytest.skip(f"Not enough GPUs for tensor parallelism {2}") +def test_quant_model_tp_equality(tinyllama_lora_files, num_gpus_available, + model): + if num_gpus_available < 2: + pytest.skip(f"Not enough GPUs for tensor parallelism {2}") llm_tp1 = vllm.LLM( model=model.model_path, diff --git a/tests/lora/test_tokenizer_group.py b/tests/lora/test_tokenizer_group.py index 2dcad23c2b547..daa39b2a3dba1 100644 --- a/tests/lora/test_tokenizer_group.py +++ b/tests/lora/test_tokenizer_group.py @@ -41,7 +41,7 @@ async def test_tokenizer_group_lora(sql_lora_files, tokenizer_group_type): lora_request) -def test_get_lora_tokenizer(sql_lora_files, tmpdir): +def test_get_lora_tokenizer(sql_lora_files, tmp_path): lora_request = None tokenizer = get_lora_tokenizer(lora_request) assert not tokenizer @@ -50,6 +50,6 @@ def test_get_lora_tokenizer(sql_lora_files, tmpdir): tokenizer = get_lora_tokenizer(lora_request) assert tokenizer.get_added_vocab() - lora_request = LoRARequest("1", 1, str(tmpdir)) + lora_request = LoRARequest("1", 1, str(tmp_path)) tokenizer = get_lora_tokenizer(lora_request) assert not tokenizer diff --git a/tests/model_executor/conftest.py b/tests/model_executor/conftest.py new file mode 100644 index 0000000000000..10792b0a04999 --- /dev/null +++ b/tests/model_executor/conftest.py @@ -0,0 +1,49 @@ +import pytest + + +@pytest.fixture +def sample_regex(): + return (r"((25[0-5]|(2[0-4]|1\d|[1-9]|)\d)\.){3}" + r"(25[0-5]|(2[0-4]|1\d|[1-9]|)\d)") + + +@pytest.fixture +def sample_json_schema(): + return { + "type": "object", + "properties": { + "name": { + "type": "string" + }, + "age": { + "type": "integer" + }, + "skills": { + "type": "array", + "items": { + "type": "string", + "maxLength": 10 + }, + "minItems": 3 + }, + "work_history": { + "type": "array", + "items": { + "type": "object", + "properties": { + "company": { + "type": "string" + }, + "duration": { + "type": "number" + }, + "position": { + "type": "string" + } + }, + "required": ["company", "position"] + } + } + }, + "required": ["name", "age", "skills", "work_history"] + } diff --git a/tests/entrypoints/openai/test_guided_processors.py b/tests/model_executor/test_guided_processors.py similarity index 69% rename from tests/entrypoints/openai/test_guided_processors.py rename to tests/model_executor/test_guided_processors.py index 85cb4d52200c3..45fab8e96b968 100644 --- a/tests/entrypoints/openai/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -1,14 +1,12 @@ -# This unit test should be moved to a new -# tests/test_guided_decoding directory. import pytest import torch from transformers import AutoTokenizer -from vllm.entrypoints.openai.protocol import CompletionRequest from vllm.model_executor.guided_decoding import ( get_guided_decoding_logits_processor) from vllm.model_executor.guided_decoding.outlines_logits_processors import ( JSONLogitsProcessor, RegexLogitsProcessor) +from vllm.sampling_params import GuidedDecodingParams def test_guided_logits_processors(sample_regex, sample_json_schema): @@ -44,11 +42,9 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') token_ids = tokenizer.encode( f"Give an example IPv4 address with this regex: {sample_regex}") - regex_request = CompletionRequest(model='test', - prompt=token_ids, - guided_regex=sample_regex) + regex_request = GuidedDecodingParams(regex=sample_regex, backend=backend) regex_lp = await get_guided_decoding_logits_processor( - backend, regex_request, tokenizer) + regex_request, tokenizer) assert regex_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) @@ -59,14 +55,31 @@ async def test_guided_logits_processor_black_box(backend: str, sample_regex, token_ids = tokenizer.encode( f"Give an employee profile that fits this schema: {sample_json_schema}" ) - json_request = CompletionRequest(model='test', - prompt=token_ids, - guided_json=sample_json_schema) + json_request = GuidedDecodingParams(json=sample_json_schema, + backend=backend) json_lp = await get_guided_decoding_logits_processor( - backend, json_request, tokenizer) + json_request, tokenizer) assert json_lp is not None tensor = torch.rand(32000) original_tensor = torch.clone(tensor) tensor = json_lp(token_ids, tensor) assert tensor.shape == original_tensor.shape assert not torch.allclose(tensor, original_tensor) + + +def test_multiple_guided_options_not_allowed(sample_json_schema, sample_regex): + with pytest.raises(ValueError, + match="You can only use one kind of guided"): + GuidedDecodingParams(json=sample_json_schema, regex=sample_regex) + + with pytest.raises(ValueError, + match="You can only use one kind of guided"): + GuidedDecodingParams(json=sample_json_schema, json_object=True) + + with pytest.raises(ValueError, + match="You can only use one kind of guided"): + GuidedDecodingParams(json=sample_json_schema, choice=["a", "b"]) + + with pytest.raises(ValueError, + match="You can only use one kind of guided"): + GuidedDecodingParams(json=sample_json_schema, grammar="test grammar") diff --git a/tests/models/decoder_only/language/test_granite.py b/tests/models/decoder_only/language/test_granite.py index e5c5ce4a8f745..0b71f0d49c70a 100644 --- a/tests/models/decoder_only/language/test_granite.py +++ b/tests/models/decoder_only/language/test_granite.py @@ -3,7 +3,6 @@ Run `pytest tests/models/test_granite.py`. """ import pytest -import transformers from ...utils import check_logprobs_close @@ -12,9 +11,6 @@ ] -# GraniteForCausalLM will be in transformers >= 4.45 -@pytest.mark.skipif(transformers.__version__ < "4.45", - reason="granite model test requires transformers >= 4.45") @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index 36fa67a22b0f6..408d12cd5ff5c 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -1,18 +1,16 @@ import pytest +from vllm.sampling_params import SamplingParams from vllm.worker.model_runner import _get_graph_batch_size from ...utils import check_outputs_equal -MODELS = ["ai21labs/Jamba-tiny-random"] +MODELS = ["ai21labs/Jamba-tiny-dev"] -# Fails due to usage of MoE as MLP(E=1_, which is different than the HF impl -# TODO: Fix this with trained model -@pytest.mark.skip() @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [10]) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) def test_models( hf_runner, vllm_runner, @@ -22,7 +20,14 @@ def test_models( max_tokens: int, ) -> None: - with hf_runner(model, dtype=dtype) as hf_model: + with hf_runner( + model, + dtype=dtype, + model_kwargs={ + "use_mamba_kernels": + False, # mamba kernels are not installed so HF + # don't use them + }) as hf_model: hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens) with vllm_runner(model, dtype=dtype) as vllm_model: @@ -38,8 +43,8 @@ def test_models( @pytest.mark.parametrize("model", MODELS) -@pytest.mark.parametrize("dtype", ["half"]) -@pytest.mark.parametrize("max_tokens", [5]) +@pytest.mark.parametrize("dtype", ["float"]) +@pytest.mark.parametrize("max_tokens", [96]) def test_batching( vllm_runner, example_prompts, @@ -65,6 +70,107 @@ def test_batching( ) +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["float16"]) +@pytest.mark.parametrize("max_tokens", [10]) +def test_mamba_prefill_chunking_with_parallel_sampling( + hf_runner, vllm_runner, example_prompts, model: str, dtype: str, + max_tokens: int) -> None: + # Tests prefill chunking in conjunction with n>1, in this case, + # prefill is populated with decoding tokens and we test that it + # doesn't fail This test might fail if cache is not allocated + # correctly for n > 1 decoding steps inside a + # chunked prefill forward pass (where we have both prefills + # and decoding together ) + sampling_params = SamplingParams(n=3, + temperature=1, + seed=0, + max_tokens=max_tokens) + with vllm_runner( + model, + dtype=dtype, + enable_chunked_prefill=True, + max_num_batched_tokens=30, + max_num_seqs=10 # forces prefill chunks with decoding + ) as vllm_model: + vllm_model.generate(example_prompts, sampling_params) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [10]) +def test_mamba_prefill_chunking(hf_runner, vllm_runner, example_prompts, + model: str, dtype: str, + max_tokens: int) -> None: + # numeric error during prefill chucking produces different generation + # compared to w/o prefill chunking for those examples, removed them for now + example_prompts.pop(7) + example_prompts.pop(2) + example_prompts.pop(1) + + with hf_runner( + model, + dtype=dtype, + model_kwargs={ + "use_mamba_kernels": + False, # mamba kernels are not installed so HF + # don't use them + }) as hf_model: + non_chunked = hf_model.generate_greedy(example_prompts, max_tokens) + + with vllm_runner(model, + dtype=dtype, + enable_chunked_prefill=True, + max_num_batched_tokens=5, + max_num_seqs=2) as vllm_model: + chunked = vllm_model.generate_greedy(example_prompts, + max_tokens=max_tokens) + + check_outputs_equal( + outputs_0_lst=chunked, + outputs_1_lst=non_chunked, + name_0="chunked", + name_1="non_chunked", + ) + + +@pytest.mark.parametrize("model", MODELS) +@pytest.mark.parametrize("dtype", ["bfloat16"]) +@pytest.mark.parametrize("max_tokens", [15]) +def test_parallel_sampling( + vllm_runner, + example_prompts, + model: str, + dtype: str, + max_tokens: int, +) -> None: + + with vllm_runner(model, dtype=dtype) as vllm_model: + for_loop_outputs = [] + for _ in range(10): + for_loop_outputs.append( + # using example_prompts index 1 instead of 0 since with 0 the + # logprobs get really close and the test doesn't pass + vllm_model.generate_greedy([example_prompts[1]], max_tokens) + [0]) + sampling_params = SamplingParams(n=10, + temperature=0.001, + seed=0, + max_tokens=max_tokens) + n_lt_1_outputs = vllm_model.generate([example_prompts[1]], + sampling_params) + token_ids, texts = n_lt_1_outputs[0] + n_lt_1_outputs = [(token_id, text) + for token_id, text in zip(token_ids, texts)] + + check_outputs_equal( + outputs_0_lst=n_lt_1_outputs, + outputs_1_lst=for_loop_outputs, + name_0="vllm_n_lt_1_outputs", + name_1="vllm", + ) + + @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [20]) diff --git a/tests/models/decoder_only/language/test_phimoe.py b/tests/models/decoder_only/language/test_phimoe.py index dbdf5a1b934a6..89afbcf1c03ac 100644 --- a/tests/models/decoder_only/language/test_phimoe.py +++ b/tests/models/decoder_only/language/test_phimoe.py @@ -7,6 +7,7 @@ from vllm.utils import is_cpu +from ....utils import large_gpu_test from ...utils import check_logprobs_close MODELS = [ @@ -69,20 +70,10 @@ def test_phimoe_routing_function(): assert torch.equal(topk_ids, ground_truth[test_id]["topk_ids"]) -def get_gpu_memory(): - try: - props = torch.cuda.get_device_properties(torch.cuda.current_device()) - gpu_memory = props.total_memory / (1024**3) - return gpu_memory - except Exception: - return 0 - - @pytest.mark.skipif(condition=is_cpu(), reason="This test takes a lot time to run on CPU, " "and vllm CI's disk space is not enough for this model.") -@pytest.mark.skipif(condition=get_gpu_memory() < 100, - reason="Skip this test if GPU memory is insufficient.") +@large_gpu_test(min_gb=80) @pytest.mark.parametrize("model", MODELS) @pytest.mark.parametrize("dtype", ["bfloat16"]) @pytest.mark.parametrize("max_tokens", [64]) diff --git a/tests/models/decoder_only/vision_language/test_fuyu.py b/tests/models/decoder_only/vision_language/test_fuyu.py index 94b8431424db5..7827ecb19a744 100644 --- a/tests/models/decoder_only/vision_language/test_fuyu.py +++ b/tests/models/decoder_only/vision_language/test_fuyu.py @@ -65,8 +65,8 @@ def run_test( # max_model_len should be greater than image_feature_size with vllm_runner(model, - max_model_len=2560, - max_num_seqs=1, + max_model_len=2048, + max_num_seqs=2, dtype=dtype, tensor_parallel_size=tensor_parallel_size, distributed_executor_backend=distributed_executor_backend, @@ -80,8 +80,6 @@ def run_test( ] with hf_runner(model, dtype=dtype) as hf_model: - hf_model.model.get_output_embeddings = lambda: \ - hf_model.model.language_model.get_output_embeddings() eos_token_id = hf_model.processor.tokenizer.eos_token_id hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, diff --git a/tests/models/decoder_only/vision_language/test_llava_next_video.py b/tests/models/decoder_only/vision_language/test_llava_next_video.py index d477bcc713611..7b7b23c783e2a 100644 --- a/tests/models/decoder_only/vision_language/test_llava_next_video.py +++ b/tests/models/decoder_only/vision_language/test_llava_next_video.py @@ -1,7 +1,6 @@ from typing import List, Optional, Tuple, Type, overload import pytest -import transformers from transformers import AutoConfig, AutoModelForVision2Seq, AutoTokenizer from vllm.multimodal.utils import (rescale_video_size, resize_video, @@ -158,8 +157,6 @@ def run_test( ) -@pytest.mark.skipif(transformers.__version__ < "4.45", - reason="Waiting for next transformers release") @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( "size_factors", @@ -203,8 +200,6 @@ def test_models(hf_runner, vllm_runner, video_assets, model, size_factors, ) -@pytest.mark.skipif(transformers.__version__ < "4.45", - reason="Waiting for next transformers release") @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( "sizes", diff --git a/tests/models/decoder_only/vision_language/test_llava_onevision.py b/tests/models/decoder_only/vision_language/test_llava_onevision.py index d1bffddde59ab..367f25f446279 100644 --- a/tests/models/decoder_only/vision_language/test_llava_onevision.py +++ b/tests/models/decoder_only/vision_language/test_llava_onevision.py @@ -1,7 +1,6 @@ from typing import List, Optional, Tuple, Type, overload import pytest -import transformers from transformers import (AutoConfig, AutoModelForVision2Seq, AutoTokenizer, BatchEncoding) @@ -12,13 +11,13 @@ from ....conftest import (VIDEO_ASSETS, HfRunner, PromptImageInput, VllmRunner, _VideoAssets) +from ....utils import large_gpu_test from ...utils import check_logprobs_close # Video test HF_VIDEO_PROMPTS = VIDEO_ASSETS.prompts({ "sample_demo_1": - "<|im_start|>user