Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

to make repetition penalty faster #442

Open
wants to merge 5 commits into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion vllm/model_executor/layers/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def _apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
output_tokens_tensor, vocab_size, num_seqs)

repetition_penalties = repetition_penalties[:, None].repeat(1, vocab_size)
repetition_penalties[~(prompt_mask | output_mask)] = 1.0
repetition_penalties.masked_fill_(~(prompt_mask | output_mask), 1.0)
logits = torch.where(logits > 0, logits / repetition_penalties,
logits * repetition_penalties)

Expand Down
75 changes: 53 additions & 22 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (PyObjectCache, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad)
is_pin_memory_available,
make_tensor_with_pad,
make_tensor_with_pad_align)

_SAMPLING_EPS = 1e-5

Expand Down Expand Up @@ -522,20 +524,38 @@ def from_lists(
do_penalties = prompt_tokens or output_tokens

if do_penalties:
prompt_t = make_tensor_with_pad(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
output_t = make_tensor_with_pad(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
if current_platform.is_hpu():
prompt_t = make_tensor_with_pad_align(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
max_len_align=1024,
)
output_t = make_tensor_with_pad_align(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
max_len_align=1024,
)
else:
prompt_t = make_tensor_with_pad(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
output_t = make_tensor_with_pad(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_t = empty_tensor
Expand All @@ -545,47 +565,58 @@ def from_lists(
temperatures,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
top_ps_t = torch.tensor(
top_ps,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
min_ps_t = torch.tensor(
min_ps,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
presence_penalties_t = torch.tensor(
presence_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
frequency_penalties_t = torch.tensor(
frequency_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
repetition_penalties_t = torch.tensor(
repetition_penalties,
device="cpu",
dtype=dtype,
pin_memory=pin_memory,
)
top_ks_t = torch.tensor(
top_ks,
device="cpu",
dtype=torch.int,
pin_memory=pin_memory,
)
# Because the memory is pinned, we can do non-blocking
# transfer to device.

if pin_memory:
if not current_platform.is_hpu():
temperatures_t.pin_memory()
top_ps_t.pin_memory()
min_ps_t.pin_memory()
frequency_penalties_t.pin_memory()
presence_penalties_t.pin_memory()
repetition_penalties_t.pin_memory()
top_ks_t.pin_memory()
else:
temperatures_t.pin_memory(device="hpu")
top_ps_t.pin_memory(device="hpu")
min_ps_t.pin_memory(device="hpu")
frequency_penalties_t.pin_memory(device="hpu")
presence_penalties_t.pin_memory(device="hpu")
repetition_penalties_t.pin_memory(device="hpu")
top_ks_t.pin_memory(device="hpu")

return cls(
temperatures=temperatures_t.to(device=device, non_blocking=True),
top_ps=top_ps_t.to(device=device, non_blocking=True),
Expand Down
64 changes: 59 additions & 5 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import gc
import inspect
import ipaddress
import math
import os
import socket
import subprocess
Expand Down Expand Up @@ -752,9 +753,6 @@
elif current_platform.is_neuron():
print_warning_once("Pin memory is not supported on Neuron.")
return False
elif current_platform.is_hpu():
print_warning_once("Pin memory is not supported on HPU.")
return False
elif current_platform.is_cpu() or current_platform.is_openvino():
return False
return True
Expand Down Expand Up @@ -812,6 +810,29 @@

return padded_x

def make_ndarray_with_pad_align(
x: List[List[T]],
pad: T,
dtype: npt.DTypeLike,
*,
max_len_align: Optional[int] = None,
) -> npt.NDArray:
"""
Make a padded array from 2D inputs.

The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
# Unlike for most functions, map is faster than a genexpr over `len`
max_len = max(map(len, x), default=0)
max_len_aligned = math.ceil(max_len / max_len_align) * max_len_align

Check failure on line 828 in vllm/utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Unsupported operand types for * ("int" and "None") [operator]

Check failure on line 828 in vllm/utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.9)

Unsupported operand types for / ("int" and "None") [operator]

Check failure on line 828 in vllm/utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Unsupported operand types for * ("int" and "None") [operator]

Check failure on line 828 in vllm/utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.10)

Unsupported operand types for / ("int" and "None") [operator]

Check failure on line 828 in vllm/utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Unsupported operand types for * ("int" and "None") [operator]

Check failure on line 828 in vllm/utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.11)

Unsupported operand types for / ("int" and "None") [operator]

Check failure on line 828 in vllm/utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Unsupported operand types for * ("int" and "None") [operator]

Check failure on line 828 in vllm/utils.py

View workflow job for this annotation

GitHub Actions / mypy (3.12)

Unsupported operand types for / ("int" and "None") [operator]
padded_x = np.full((len(x), max_len_aligned), pad, dtype=dtype)

for ind, blocktb in enumerate(x):
assert len(blocktb) <= max_len_aligned
padded_x[ind, :len(blocktb)] = blocktb

return padded_x

def make_tensor_with_pad(
x: List[List[T]],
Expand All @@ -833,10 +854,37 @@

tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
tensor = tensor.pin_memory()
if not current_platform.is_hpu():
tensor = tensor.pin_memory()
else:
tensor = tensor.pin_memory("hpu")
Comment on lines +857 to +860

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be removed, as it wont be called now

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi Michal, this method make_tensor_with_pad is still called from different places. It is replaced by make_tensor_with_pad_align in the repetition penaly. So I think we still need the check here.


return tensor

def make_tensor_with_pad_align(
x: List[List[T]],
pad: T,
dtype: torch.dtype,
*,
max_len_align: Optional[int] = None,
device: Optional[Union[str, torch.device]] = None,
pin_memory: bool = False,
) -> torch.Tensor:
"""
Make a padded tensor from 2D inputs.

The padding is applied to the end of each inner list until it reaches
`max_len`.
"""
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad_align(x, pad, np_dtype,
max_len_align=max_len_align)

tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if not needed, since this method is called for hpu

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

HI Michal, I removed the device check logic here and kept the pin_memory check. In this way, this method behavior is exactly the same to the un-aligned version.

tensor = tensor.pin_memory("hpu")

return tensor

def async_tensor_h2d(
data: list,
Expand All @@ -845,7 +893,13 @@
pin_memory: bool,
) -> torch.Tensor:
"""Asynchronously create a tensor and copy it from host to device."""
t = torch.tensor(data, dtype=dtype, pin_memory=pin_memory, device="cpu")
t = torch.tensor(data, dtype=dtype, device="cpu")
if pin_memory:
if not current_platform.is_hpu():
t.pin_memory()
else:
t.pin_memory(device="hpu")

return t.to(device=target_device, non_blocking=True)


Expand Down
Loading