diff --git a/vllm/model_executor/guided_decoding/outlines_logits_processors.py b/vllm/model_executor/guided_decoding/outlines_logits_processors.py index 092c143bd59b0..e1b7c11eb00a6 100644 --- a/vllm/model_executor/guided_decoding/outlines_logits_processors.py +++ b/vllm/model_executor/guided_decoding/outlines_logits_processors.py @@ -30,11 +30,48 @@ from transformers import PreTrainedTokenizerBase +# Unfortunately we cannot use lru_cache as it breaks pickling +# so we use a simpler implementation +def _cached(fn): + cache = {} + + def cached_fn(*args): + if args in cache: + result = cache[args] + else: + result = fn(*args) + cache[args] = result + return result + + return cached_fn + + class BaseLogitsProcessor: def __init__(self, guide: Guide): self._guide: Guide = guide self._fsm_state: DefaultDict[int, int] = defaultdict(int) + self._cached_get_mask_tensor = _cached(self._get_mask_tensor) + + @staticmethod + @lru_cache(maxsize=128) + def _create_mask_tensor(allowed_tokens, vocab_size, device): + mask = torch.full((vocab_size, ), -math.inf, device=device) + mask[list(allowed_tokens)] = 0 + return mask + + def _get_mask_tensor(self, state_id, vocab_size, device): + instruction = self._guide.get_next_instruction(state=state_id) + if type(instruction) == Generate: # noqa: E721 + allowed_tokens = instruction.tokens + elif type(instruction) == Write: # noqa: E721 + # TODO: support fast forward tokens + allowed_tokens = [instruction.tokens[0]] + else: + raise TypeError( + f"Unsupported instruction type {type(instruction)}") + return BaseLogitsProcessor._create_mask_tensor(tuple(allowed_tokens), + vocab_size, device) def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor: @@ -64,23 +101,10 @@ def __call__(self, input_ids: List[int], import_paths=[grammars.GRAMMAR_PATH], ) - instruction = self._guide.get_next_instruction( - state=self._fsm_state[seq_id]) - - if type(instruction) == Generate: # noqa: E721 - allowed_tokens = instruction.tokens - elif type(instruction) == Write: # noqa: E721 - # TODO: support fast forward tokens - allowed_tokens = [instruction.tokens[0]] - else: - raise TypeError( - f"Unsupported instruction type {type(instruction)}") - - mask = torch.full((scores.shape[-1], ), - -math.inf, - device=scores.device) - mask[allowed_tokens] = 0 - scores = scores.add(mask) + state_id = self._fsm_state[seq_id] + mask = self._cached_get_mask_tensor(state_id, scores.size(-1), + scores.device) + scores.add_(mask) return scores diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index bee3d38565f4c..e0194b36652a2 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -118,12 +118,28 @@ def _prune_hidden_states( return hidden_states +def get_num_parameters(logits_processor): + """Extracts the number of parameters from the + signature and stores it for further use""" + if hasattr(logits_processor, 'num_parameters'): + return logits_processor.num_parameters + logits_processor.num_parameters = len( + inspect.signature(logits_processor).parameters) + return logits_processor.num_parameters + + def _apply_logits_processors( logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: - found_logits_processors = False logits_processed = 0 + found_logits_processors = any( + seq_group.sampling_params.logits_processors + for seq_group in sampling_metadata.seq_groups) + offload_to_cpu = current_platform.is_hpu() and found_logits_processors + if offload_to_cpu: + logits_device = logits.device + logits = logits.cpu() for seq_group in sampling_metadata.seq_groups: seq_ids = seq_group.seq_ids sampling_params = seq_group.sampling_params @@ -138,8 +154,7 @@ def _apply_logits_processors( prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids for logits_processor in logits_processors: - parameters = inspect.signature(logits_processor).parameters - if len(parameters) == 3: + if get_num_parameters(logits_processor) == 3: logits_row = logits_processor(prompt_tokens_ids, past_tokens_ids, logits_row) @@ -155,4 +170,6 @@ def _apply_logits_processors( if found_logits_processors: # verifies that no rows in logits were missed unexpectedly assert logits_processed == logits.shape[0] + if offload_to_cpu: + logits = logits.to(logits_device) return logits