From 778eb79d57e468632bfaee622d1d409f092cc226 Mon Sep 17 00:00:00 2001 From: Taesu Kim Date: Mon, 5 Aug 2024 08:32:42 +0000 Subject: [PATCH 1/4] fix guided sampling with outlines --- .../guided_decoding/outlines_logits_processors.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 1c8f6cccb3e9a..7cc70f7b4c666 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -57,11 +57,9 @@ def __call__(self, input_ids: List[int], raise TypeError( f"Unsupported instruction type {type(instruction)}") - mask = torch.full((scores.shape[-1], ), - -math.inf, - device=scores.device) + mask = torch.ones((scores.shape[-1], ), device=scores.device, dtype=torch.bool) mask[allowed_tokens] = 0 - scores.add_(mask) + scores.masked_fill_(mask, -math.inf) return scores From 4bea4e3fe5f95fa840bf069e35a5e5ad4d5fad84 Mon Sep 17 00:00:00 2001 From: Taesu Kim Date: Fri, 23 Aug 2024 05:17:36 +0000 Subject: [PATCH 2/4] add with_mark_steps for faster inference --- .../guided_decoding/outlines_logits_processors.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 7cc70f7b4c666..d894c1eee24a4 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -27,6 +27,7 @@ from pydantic import BaseModel from transformers import PreTrainedTokenizerBase +from vllm.hpu.utils import with_mark_steps class BaseLogitsProcessor: @@ -34,6 +35,7 @@ def __init__(self, guide: Guide): self._guide: Guide = guide self._fsm_state: DefaultDict[int, int] = defaultdict(int) + @with_mark_steps def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: """Use the FSM to bias the logits before sampling the next token.""" From 6d57c1882767c730fb40208a7db4551f3086c873 Mon Sep 17 00:00:00 2001 From: JonghoLee Date: Mon, 9 Sep 2024 05:28:19 +0000 Subject: [PATCH 3/4] fix: high latency due to @cache() & update outlines version --- requirements-common.txt | 2 +- .../guided_decoding/outlines_logits_processors.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/requirements-common.txt b/requirements-common.txt index 3b8d473c1fe7a..746ba8d31552f 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -18,7 +18,7 @@ prometheus_client >= 0.18.0 prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer == 0.10.3 -outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0 +outlines >= 0.0.46, < 0.1 # Requires torch >= 2.1.0 typing_extensions filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4 pyzmq diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index d894c1eee24a4..791a0ac88c890 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -68,7 +68,7 @@ def __call__(self, input_ids: List[int], class RegexLogitsProcessor(BaseLogitsProcessor): @classmethod - @cache() + @lru_cache(maxsize=32) def _get_guide(cls, regex_string: str, tokenizer: PreTrainedTokenizerBase) -> Guide: tokenizer = _adapt_tokenizer(tokenizer) @@ -127,7 +127,7 @@ def __init__(self, schema: Union[str, Dict, BaseModel], class CFGLogitsProcessor(BaseLogitsProcessor): @classmethod - @cache() + @lru_cache(maxsize=32) def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide: tokenizer = _adapt_tokenizer(tokenizer) return CFGGuide(cfg, tokenizer) From c04af23979190848f0b413a75ab4ca937b163cae Mon Sep 17 00:00:00 2001 From: Taesu Kim Date: Mon, 23 Sep 2024 06:21:06 +0000 Subject: [PATCH 4/4] remove unused cache --- .../guided_decoding/outlines_logits_processors.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 791a0ac88c890..c527f23ced1f4 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -21,7 +21,6 @@ from typing import Callable, DefaultDict, Dict, List, Union import torch -from outlines.caching import cache from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write from outlines.fsm.json_schema import build_regex_from_schema from pydantic import BaseModel @@ -59,9 +58,11 @@ def __call__(self, input_ids: List[int], raise TypeError( f"Unsupported instruction type {type(instruction)}") - mask = torch.ones((scores.shape[-1], ), device=scores.device, dtype=torch.bool) + mask = torch.full((scores.shape[-1], ), + -math.inf, + device=scores.device) mask[allowed_tokens] = 0 - scores.masked_fill_(mask, -math.inf) + scores = scores.add(mask) return scores