Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

HPU: offload logits processing to CPU #358

Merged
merged 4 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
58 changes: 41 additions & 17 deletions vllm/model_executor/guided_decoding/outlines_logits_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,48 @@
from transformers import PreTrainedTokenizerBase


# Unfortunately we cannot use lru_cache as it breaks pickling
# so we use a simpler implementation
def _cached(fn):

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How many masks are expected at maximu? Since _create_mask_tensor is limited and the _cached fun is not, won't we get into a situation where we will get None result for least recently used (128+)?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can ignore my last comment, the lrucache maxsize does not have effect here, since _cached cache keeps objects.
If you reverse the condition in line 39 the else can be replaced with return cache[args]
if args not in cache:
cache[args] = fn(*args)
return cache[args]

Copy link
Author

@madamczykhabana madamczykhabana Oct 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, but at the same time we'd need to calculate the hash and access the map one additional time.

  1. if args in cache
  2. cache[args] = fn(*args)
  3. return cache[args]

In current version you access it twice:

  1. args in cache
  2. cache[args] = result

To be honest I didn't measure how big of a perf impact this might be, but I'd like to be on the safe side.

cache = {}

def cached_fn(*args):
if args in cache:
result = cache[args]
else:
result = fn(*args)
cache[args] = result
return result

return cached_fn


class BaseLogitsProcessor:

def __init__(self, guide: Guide):
self._guide: Guide = guide
self._fsm_state: DefaultDict[int, int] = defaultdict(int)
self._cached_get_mask_tensor = _cached(self._get_mask_tensor)

@staticmethod
@lru_cache(maxsize=128)
def _create_mask_tensor(allowed_tokens, vocab_size, device):
mask = torch.full((vocab_size, ), -math.inf, device=device)
mask[list(allowed_tokens)] = 0
return mask

def _get_mask_tensor(self, state_id, vocab_size, device):
instruction = self._guide.get_next_instruction(state=state_id)
if type(instruction) == Generate: # noqa: E721
allowed_tokens = instruction.tokens
elif type(instruction) == Write: # noqa: E721
# TODO: support fast forward tokens
allowed_tokens = [instruction.tokens[0]]
else:
raise TypeError(
f"Unsupported instruction type {type(instruction)}")
return BaseLogitsProcessor._create_mask_tensor(tuple(allowed_tokens),
vocab_size, device)

def __call__(self, input_ids: List[int],
scores: torch.Tensor) -> torch.Tensor:
Expand Down Expand Up @@ -64,23 +101,10 @@ def __call__(self, input_ids: List[int],
import_paths=[grammars.GRAMMAR_PATH],
)

instruction = self._guide.get_next_instruction(
state=self._fsm_state[seq_id])

if type(instruction) == Generate: # noqa: E721
allowed_tokens = instruction.tokens
elif type(instruction) == Write: # noqa: E721
# TODO: support fast forward tokens
allowed_tokens = [instruction.tokens[0]]
else:
raise TypeError(
f"Unsupported instruction type {type(instruction)}")

mask = torch.full((scores.shape[-1], ),
-math.inf,
device=scores.device)
mask[allowed_tokens] = 0
scores = scores.add(mask)
state_id = self._fsm_state[seq_id]
mask = self._cached_get_mask_tensor(state_id, scores.size(-1),
scores.device)
scores.add_(mask)
return scores


Expand Down
23 changes: 20 additions & 3 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,12 +118,28 @@ def _prune_hidden_states(
return hidden_states


def get_num_parameters(logits_processor):
"""Extracts the number of parameters from the
signature and stores it for further use"""
if hasattr(logits_processor, 'num_parameters'):
return logits_processor.num_parameters
logits_processor.num_parameters = len(
inspect.signature(logits_processor).parameters)
return logits_processor.num_parameters


def _apply_logits_processors(
logits: torch.Tensor,
sampling_metadata: SamplingMetadata,
) -> torch.Tensor:
found_logits_processors = False
logits_processed = 0
found_logits_processors = any(
seq_group.sampling_params.logits_processors
for seq_group in sampling_metadata.seq_groups)
offload_to_cpu = current_platform.is_hpu() and found_logits_processors
if offload_to_cpu:
logits_device = logits.device
logits = logits.cpu()
for seq_group in sampling_metadata.seq_groups:
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
Expand All @@ -138,8 +154,7 @@ def _apply_logits_processors(
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids

for logits_processor in logits_processors:
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
if get_num_parameters(logits_processor) == 3:
logits_row = logits_processor(prompt_tokens_ids,
past_tokens_ids,
logits_row)
Expand All @@ -155,4 +170,6 @@ def _apply_logits_processors(
if found_logits_processors:
# verifies that no rows in logits were missed unexpectedly
assert logits_processed == logits.shape[0]
if offload_to_cpu:
logits = logits.to(logits_device)
return logits
Loading