Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix guided sampling with outlines #226

Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ prometheus_client >= 0.18.0
prometheus-fastapi-instrumentator >= 7.0.0
tiktoken >= 0.6.0 # Required for DBRX tokenizer
lm-format-enforcer == 0.10.3
outlines >= 0.0.43, < 0.1 # Requires torch >= 2.1.0
outlines >= 0.0.46, < 0.1 # Requires torch >= 2.1.0
typing_extensions
filelock >= 3.10.4 # filelock starts to support `mode` argument from 3.10.4
pyzmq
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase

from vllm.hpu.utils import with_mark_steps

class BaseLogitsProcessor:

def __init__(self, guide: Guide):
self._guide: Guide = guide
self._fsm_state: DefaultDict[int, int] = defaultdict(int)

@with_mark_steps
def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
"""Use the FSM to bias the logits before sampling the next token."""
Expand All @@ -57,18 +59,16 @@ def __call__(self, input_ids: List[int],
raise TypeError(
f"Unsupported instruction type {type(instruction)}")

mask = torch.full((scores.shape[-1], ),
-math.inf,
device=scores.device)
mask = torch.ones((scores.shape[-1], ), device=scores.device, dtype=torch.bool)
mask[allowed_tokens] = 0
scores.add_(mask)
scores.masked_fill_(mask, -math.inf)
return scores


class RegexLogitsProcessor(BaseLogitsProcessor):

@classmethod
@cache()
@lru_cache(maxsize=32)
def _get_guide(cls, regex_string: str,
tokenizer: PreTrainedTokenizerBase) -> Guide:
tokenizer = _adapt_tokenizer(tokenizer)
Expand Down Expand Up @@ -127,7 +127,7 @@ def __init__(self, schema: Union[str, Dict, BaseModel],
class CFGLogitsProcessor(BaseLogitsProcessor):

@classmethod
@cache()
@lru_cache(maxsize=32)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ruff static analysis found that "cache" imported at the top of file is now not used, please remove the import

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am benchmarking the performance of .add versus .masked_fill to determine which has better throughput. I'll resolve the conflict based on the benchmark and fix the ruff issue before merging. I'll let you know once it's completed. Apologies for the delayed response!

def _get_guide(cls, cfg: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
tokenizer = _adapt_tokenizer(tokenizer)
return CFGGuide(cfg, tokenizer)
Expand Down
Loading