v1.0
This commit is contained in:
0
v1/sample/ops/__init__.py
Normal file
0
v1/sample/ops/__init__.py
Normal file
BIN
v1/sample/ops/__pycache__/__init__.cpython-312.pyc
Normal file
BIN
v1/sample/ops/__pycache__/__init__.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/sample/ops/__pycache__/bad_words.cpython-312.pyc
Normal file
BIN
v1/sample/ops/__pycache__/bad_words.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/sample/ops/__pycache__/logprobs.cpython-312.pyc
Normal file
BIN
v1/sample/ops/__pycache__/logprobs.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/sample/ops/__pycache__/penalties.cpython-312.pyc
Normal file
BIN
v1/sample/ops/__pycache__/penalties.cpython-312.pyc
Normal file
Binary file not shown.
BIN
v1/sample/ops/__pycache__/topk_topp_sampler.cpython-312.pyc
Normal file
BIN
v1/sample/ops/__pycache__/topk_topp_sampler.cpython-312.pyc
Normal file
Binary file not shown.
52
v1/sample/ops/bad_words.py
Normal file
52
v1/sample/ops/bad_words.py
Normal file
@@ -0,0 +1,52 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
_SMALLEST_LOGIT = float("-inf")
|
||||
|
||||
|
||||
def _apply_bad_words_single_batch(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: list[list[int]],
|
||||
past_tokens_ids: list[int],
|
||||
) -> None:
|
||||
for bad_word_ids in bad_words_token_ids:
|
||||
if len(bad_word_ids) > len(past_tokens_ids) + 1:
|
||||
continue
|
||||
|
||||
prefix_length = len(bad_word_ids) - 1
|
||||
last_token_id = bad_word_ids[-1]
|
||||
actual_prefix = past_tokens_ids[-prefix_length:] if prefix_length > 0 else []
|
||||
expected_prefix = bad_word_ids[:prefix_length]
|
||||
|
||||
assert len(actual_prefix) == len(expected_prefix)
|
||||
|
||||
if actual_prefix == expected_prefix:
|
||||
logits[last_token_id] = _SMALLEST_LOGIT
|
||||
|
||||
|
||||
def apply_bad_words(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: dict[int, list[list[int]]],
|
||||
past_tokens_ids: list[list[int]],
|
||||
) -> None:
|
||||
for i, bad_words_ids in bad_words_token_ids.items():
|
||||
_apply_bad_words_single_batch(logits[i], bad_words_ids, past_tokens_ids[i])
|
||||
|
||||
|
||||
def apply_bad_words_with_drafts(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: dict[int, list[list[int]]],
|
||||
past_tokens_ids: list[list[int]],
|
||||
num_draft_tokens: list[int],
|
||||
) -> None:
|
||||
start_idx = 0
|
||||
for i, bad_words_ids in bad_words_token_ids.items():
|
||||
for draft_idx in range(num_draft_tokens[i]):
|
||||
_apply_bad_words_single_batch(
|
||||
logits[start_idx + draft_idx],
|
||||
bad_words_ids,
|
||||
past_tokens_ids[start_idx + draft_idx],
|
||||
)
|
||||
start_idx += num_draft_tokens[i]
|
||||
25
v1/sample/ops/logprobs.py
Normal file
25
v1/sample/ops/logprobs.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
"""Some utilities for logprobs, including logits."""
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
|
||||
@torch.compile(dynamic=True, backend=current_platform.simple_compile_backend)
|
||||
def batched_count_greater_than(x: torch.Tensor, values: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Counts elements in each row of x that are greater than the corresponding
|
||||
value in values. Use torch.compile to generate an optimized kernel for
|
||||
this function. otherwise, it will create additional copies of the input
|
||||
tensors and cause memory issues.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): A 2D tensor of shape (batch_size, n_elements).
|
||||
values (torch.Tensor): A 2D tensor of shape (batch_size, 1).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: A 1D tensor of shape (batch_size,) with the counts.
|
||||
"""
|
||||
return (x >= values).sum(-1)
|
||||
57
v1/sample/ops/penalties.py
Normal file
57
v1/sample/ops/penalties.py
Normal file
@@ -0,0 +1,57 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import make_tensor_with_pad
|
||||
|
||||
|
||||
def apply_all_penalties(
|
||||
logits: torch.Tensor,
|
||||
prompt_token_ids: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor,
|
||||
output_token_ids: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Applies presence, frequency and repetition penalties to the logits.
|
||||
"""
|
||||
_, vocab_size = logits.shape
|
||||
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size, logits.device)
|
||||
|
||||
# In the async scheduling case, rows that won't have penalties applied may contain
|
||||
# -1 placeholder token ids. We must replace these with valid token ids so that the
|
||||
# scatter done in apply_penalties is valid.
|
||||
# NOTE(nick): The penalties implementation is currently quite inefficient and
|
||||
# will be reworked anyhow.
|
||||
output_tokens_t.masked_fill_(output_tokens_t == -1, vocab_size)
|
||||
|
||||
return apply_penalties(
|
||||
logits,
|
||||
prompt_token_ids,
|
||||
output_tokens_t,
|
||||
presence_penalties,
|
||||
frequency_penalties,
|
||||
repetition_penalties,
|
||||
)
|
||||
|
||||
|
||||
def _convert_to_tensors(
|
||||
output_token_ids: list[list[int]], vocab_size: int, device: torch.device
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Convert the different list data structures to tensors.
|
||||
"""
|
||||
output_tokens_tensor = make_tensor_with_pad(
|
||||
output_token_ids,
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
pad=vocab_size,
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
return output_tokens_tensor.to(device, non_blocking=True)
|
||||
290
v1/sample/ops/topk_topp_sampler.py
Normal file
290
v1/sample/ops/topk_topp_sampler.py
Normal file
@@ -0,0 +1,290 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from packaging import version
|
||||
|
||||
from vllm import envs
|
||||
from vllm.config.model import LogprobsMode
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import CpuArchEnum, current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
class TopKTopPSampler(nn.Module):
|
||||
"""
|
||||
Module that performs optional top-k and top-p filtering followed by
|
||||
weighted random sampling of logits.
|
||||
|
||||
Implementations may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None:
|
||||
super().__init__()
|
||||
self.logprobs_mode = logprobs_mode
|
||||
# flashinfer optimization does not apply if intermediate
|
||||
# logprobs/logits after top_k/top_p need to be returned
|
||||
if (
|
||||
logprobs_mode not in ("processed_logits", "processed_logprobs")
|
||||
and current_platform.is_cuda()
|
||||
):
|
||||
if envs.VLLM_USE_FLASHINFER_SAMPLER:
|
||||
# Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1.
|
||||
logger.info_once(
|
||||
"Using FlashInfer for top-p & top-k sampling.",
|
||||
scope="global",
|
||||
)
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
logger.debug_once(
|
||||
"FlashInfer top-p/top-k sampling is available but disabled "
|
||||
"by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in "
|
||||
"after verifying accuracy for your workloads."
|
||||
)
|
||||
self.forward = self.forward_native
|
||||
|
||||
elif current_platform.is_cpu():
|
||||
arch = current_platform.get_cpu_architecture()
|
||||
# Fall back to native implementation for POWERPC and RISCV.
|
||||
# On PowerPC argmax produces incorrect output with torch.compile.
|
||||
# PR: https://github.com/vllm-project/vllm/pull/26987
|
||||
if arch in (CpuArchEnum.RISCV, CpuArchEnum.POWERPC):
|
||||
self.forward = self.forward_native
|
||||
else:
|
||||
self.forward = self.forward_cpu
|
||||
else:
|
||||
self.forward = self.forward_native
|
||||
|
||||
self.apply_top_k_top_p = apply_top_k_top_p
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""
|
||||
PyTorch-native implementation of top-k and top-p sampling.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = self.apply_top_k_top_p(logits, k, p)
|
||||
logits_to_return = None
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
logits_to_return = logits
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators), logits_to_return
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""More optimized implementation for top-k and top-p sampling."""
|
||||
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||
# not needed. This is because `random_sample` does not require
|
||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||
if (k is None and p is None) or generators:
|
||||
if generators:
|
||||
logger.debug_once(
|
||||
"FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
"PyTorch-native implementation."
|
||||
)
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), (
|
||||
"FlashInfer does not support returning logits/logprobs"
|
||||
)
|
||||
# flashinfer sampling functions expect contiguous logits.
|
||||
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
|
||||
# because of slicing operation in logits_processor.
|
||||
return flashinfer_sample(logits.contiguous(), k, p, generators), None
|
||||
|
||||
def forward_cpu(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
) -> tuple[torch.Tensor, torch.Tensor | None]:
|
||||
"""
|
||||
PyTorch-native implementation of top-k and top-p sampling for CPU.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = self.apply_top_k_top_p(logits, k, p)
|
||||
logits_to_return = None
|
||||
if self.logprobs_mode == "processed_logits":
|
||||
logits_to_return = logits
|
||||
elif self.logprobs_mode == "processed_logprobs":
|
||||
logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32)
|
||||
|
||||
if len(generators) != logits.shape[0]:
|
||||
return compiled_random_sample(logits), logits_to_return
|
||||
else:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
|
||||
return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return
|
||||
|
||||
|
||||
# Note: this is a workaround for
|
||||
# https://github.com/pytorch/pytorch/pull/151218
|
||||
@torch.compile(dynamic=True)
|
||||
def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor:
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
q = torch.empty_like(probs)
|
||||
q.exponential_()
|
||||
return probs.div(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
) -> torch.Tensor:
|
||||
"""Apply top-k and top-p masks to the logits.
|
||||
|
||||
If a top-p is used, this function will sort the logits tensor,
|
||||
which can be slow for large batches.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
if p is None:
|
||||
if k is None:
|
||||
return logits
|
||||
|
||||
# Avoid sorting vocab for top-k only case.
|
||||
return apply_top_k_only(logits, k)
|
||||
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def apply_top_k_only(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k mask to the logits.
|
||||
|
||||
This implementation doesn't involve sorting the entire vocab.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
no_top_k_mask = k == logits.shape[1]
|
||||
# Set non-top-k rows to 1 so that we can gather.
|
||||
k = k.masked_fill(no_top_k_mask, 1)
|
||||
max_top_k = k.max()
|
||||
# topk.values tensor has shape [batch_size, max_top_k].
|
||||
# Convert top k to 0-based index in range [0, max_top_k).
|
||||
k_index = k.sub_(1).unsqueeze(1)
|
||||
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
|
||||
# Handle non-topk rows.
|
||||
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
|
||||
logits.masked_fill_(logits < top_k_mask, -float("inf"))
|
||||
return logits
|
||||
|
||||
|
||||
def random_sample(
|
||||
probs: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Randomly sample from the probabilities.
|
||||
|
||||
We use this function instead of torch.multinomial because torch.multinomial
|
||||
causes CPU-GPU synchronization.
|
||||
"""
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
if len(generators) != probs.shape[0]:
|
||||
q.exponential_()
|
||||
if generators:
|
||||
# TODO(woosuk): This can be slow because we handle each request
|
||||
# one by one. Optimize this.
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
def flashinfer_sample(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor | None,
|
||||
p: torch.Tensor | None,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Sample from the logits using FlashInfer.
|
||||
|
||||
Statistically, this function is equivalent to the `random_sample` function.
|
||||
However, this function is faster because it avoids sorting the logits tensor
|
||||
via rejection sampling.
|
||||
|
||||
NOTE: The outputs of this function do not necessarily match the outputs of
|
||||
the `random_sample` function. It only guarantees that the outputs are
|
||||
statistically equivalent.
|
||||
|
||||
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
|
||||
does not. Call this function at the end of the forward pass to minimize
|
||||
the synchronization overhead.
|
||||
"""
|
||||
import flashinfer
|
||||
|
||||
if version.parse(flashinfer.__version__) < version.parse("0.2.3"):
|
||||
raise ImportError(
|
||||
"FlashInfer version >= 0.2.3 required for top-k and top-p sampling. "
|
||||
)
|
||||
|
||||
assert not (k is None and p is None)
|
||||
if k is None:
|
||||
# Top-p only.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
|
||||
probs, p, deterministic=True
|
||||
)
|
||||
elif p is None:
|
||||
# Top-k only.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
|
||||
probs, k, deterministic=True
|
||||
)
|
||||
else:
|
||||
# Both top-k and top-p.
|
||||
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
|
||||
logits, k, p, deterministic=True
|
||||
)
|
||||
|
||||
return next_token_ids.view(-1)
|
||||
Reference in New Issue
Block a user