Skip to content

Commit

Permalink
fix yaff
Browse files Browse the repository at this point in the history
  • Loading branch information
ccrhx4 committed Nov 13, 2024
1 parent 6bb5cb9 commit ec38d67
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 18 deletions.
51 changes: 34 additions & 17 deletions vllm/model_executor/sampling_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from vllm.sequence import (VLLM_TOKEN_ID_ARRAY_TYPE, SequenceData,
SequenceGroupMetadata)
from vllm.utils import (PyObjectCache, async_tensor_h2d,
is_pin_memory_available, make_tensor_with_pad,
is_pin_memory_available,
make_tensor_with_pad,
make_tensor_with_pad_align)

_SAMPLING_EPS = 1e-5
Expand Down Expand Up @@ -523,22 +524,38 @@ def from_lists(
do_penalties = prompt_tokens or output_tokens

if do_penalties:
prompt_t = make_tensor_with_pad_align(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
max_len_align=1024,
)
output_t = make_tensor_with_pad_align(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
max_len_align=1024,
)
if current_platform.is_hpu():
prompt_t = make_tensor_with_pad_align(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
max_len_align=1024,
)
output_t = make_tensor_with_pad_align(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
max_len_align=1024,
)
else:
prompt_t = make_tensor_with_pad(
prompt_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
output_t = make_tensor_with_pad(
output_tokens,
vocab_size,
device="cpu",
dtype=torch.int64,
pin_memory=pin_memory,
)
else:
empty_tensor = torch.empty(0, device=device, dtype=torch.long)
prompt_t = empty_tensor
Expand Down
3 changes: 2 additions & 1 deletion vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -877,7 +877,8 @@ def make_tensor_with_pad_align(
`max_len`.
"""
np_dtype = TORCH_DTYPE_TO_NUMPY_DTYPE[dtype]
padded_x = make_ndarray_with_pad_align(x, pad, np_dtype, max_len_align=max_len_align)
padded_x = make_ndarray_with_pad_align(x, pad, np_dtype,
max_len_align=max_len_align)

tensor = torch.from_numpy(padded_x).to(device)
if pin_memory:
Expand Down

0 comments on commit ec38d67

Please sign in to comment.