init src 0.9.2
This commit is contained in:
0
vllm/v1/sample/ops/__init__.py
Normal file
0
vllm/v1/sample/ops/__init__.py
Normal file
39
vllm/v1/sample/ops/bad_words.py
Normal file
39
vllm/v1/sample/ops/bad_words.py
Normal file
@@ -0,0 +1,39 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
_SMALLEST_LOGIT = float("-inf")
|
||||
|
||||
|
||||
def _apply_bad_words_single_batch(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: list[list[int]],
|
||||
past_tokens_ids: list[int],
|
||||
) -> None:
|
||||
for bad_word_ids in bad_words_token_ids:
|
||||
if len(bad_word_ids) > len(past_tokens_ids) + 1:
|
||||
continue
|
||||
|
||||
prefix_length = len(bad_word_ids) - 1
|
||||
last_token_id = bad_word_ids[-1]
|
||||
if prefix_length > 0:
|
||||
actual_prefix = past_tokens_ids[-prefix_length:]
|
||||
else:
|
||||
actual_prefix = []
|
||||
expected_prefix = bad_word_ids[:prefix_length]
|
||||
|
||||
assert len(actual_prefix) == len(expected_prefix)
|
||||
|
||||
if actual_prefix == expected_prefix:
|
||||
logits[last_token_id] = _SMALLEST_LOGIT
|
||||
|
||||
|
||||
def apply_bad_words(
|
||||
logits: torch.Tensor,
|
||||
bad_words_token_ids: dict[int, list[list[int]]],
|
||||
past_tokens_ids: list[list[int]],
|
||||
) -> None:
|
||||
for i, bad_words_ids in bad_words_token_ids.items():
|
||||
_apply_bad_words_single_batch(logits[i], bad_words_ids,
|
||||
past_tokens_ids[i])
|
||||
43
vllm/v1/sample/ops/penalties.py
Normal file
43
vllm/v1/sample/ops/penalties.py
Normal file
@@ -0,0 +1,43 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
import torch
|
||||
|
||||
from vllm.model_executor.layers.utils import apply_penalties
|
||||
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
||||
|
||||
|
||||
def apply_all_penalties(
|
||||
logits: torch.Tensor,
|
||||
prompt_token_ids: torch.Tensor,
|
||||
presence_penalties: torch.Tensor,
|
||||
frequency_penalties: torch.Tensor,
|
||||
repetition_penalties: torch.Tensor,
|
||||
output_token_ids: list[list[int]],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Applies presence, frequency and repetition penalties to the logits.
|
||||
"""
|
||||
_, vocab_size = logits.shape
|
||||
output_tokens_t = _convert_to_tensors(output_token_ids, vocab_size,
|
||||
logits.device)
|
||||
return apply_penalties(logits, prompt_token_ids, output_tokens_t,
|
||||
presence_penalties, frequency_penalties,
|
||||
repetition_penalties)
|
||||
|
||||
|
||||
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
|
||||
device: torch.device) -> torch.Tensor:
|
||||
"""
|
||||
Convert the different list data structures to tensors.
|
||||
"""
|
||||
output_tokens_tensor = make_tensor_with_pad(
|
||||
output_token_ids,
|
||||
# Use the value of vocab_size as a pad since we don't have a
|
||||
# token_id of this value.
|
||||
pad=vocab_size,
|
||||
device="cpu",
|
||||
dtype=torch.int64,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
return output_tokens_tensor.to(device, non_blocking=True)
|
||||
296
vllm/v1/sample/ops/topk_topp_sampler.py
Normal file
296
vllm/v1/sample/ops/topk_topp_sampler.py
Normal file
@@ -0,0 +1,296 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm import envs
|
||||
from vllm.logger import init_logger
|
||||
from vllm.platforms import current_platform
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
try:
|
||||
import flashinfer.sampling
|
||||
is_flashinfer_available = True
|
||||
except ImportError:
|
||||
is_flashinfer_available = False
|
||||
|
||||
|
||||
class TopKTopPSampler(nn.Module):
|
||||
"""
|
||||
Module that performs optional top-k and top-p filtering followed by
|
||||
weighted random sampling of logits.
|
||||
|
||||
Implementations may update the logits tensor in-place.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
if current_platform.is_cuda():
|
||||
if is_flashinfer_available:
|
||||
flashinfer_version = flashinfer.__version__
|
||||
if flashinfer_version < "0.2.3":
|
||||
logger.warning(
|
||||
"FlashInfer version >= 0.2.3 required. "
|
||||
"Falling back to default sampling implementation.")
|
||||
self.forward = self.forward_native
|
||||
elif envs.VLLM_USE_FLASHINFER_SAMPLER is not False:
|
||||
# NOTE(woosuk): The V0 sampler doesn't use FlashInfer for
|
||||
# sampling unless VLLM_USE_FLASHINFER_SAMPLER=1 (i.e., by
|
||||
# default it is unused). For backward compatibility, we set
|
||||
# `VLLM_USE_FLASHINFER_SAMPLER` as None by default and
|
||||
# interpret it differently in V0 and V1 samplers: In V0,
|
||||
# None means False, while in V1, None means True. This is
|
||||
# why we use the condition
|
||||
# `envs.VLLM_USE_FLASHINFER_SAMPLER is not False` here.
|
||||
logger.info("Using FlashInfer for top-p & top-k sampling.")
|
||||
self.forward = self.forward_cuda
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer is available, but it is not enabled. "
|
||||
"Falling back to the PyTorch-native implementation of "
|
||||
"top-p & top-k sampling. For the best performance, "
|
||||
"please set VLLM_USE_FLASHINFER_SAMPLER=1.")
|
||||
self.forward = self.forward_native
|
||||
else:
|
||||
logger.warning(
|
||||
"FlashInfer is not available. Falling back to the PyTorch-"
|
||||
"native implementation of top-p & top-k sampling. For the "
|
||||
"best performance, please install FlashInfer.")
|
||||
self.forward = self.forward_native
|
||||
elif current_platform.is_tpu():
|
||||
self.forward = self.forward_tpu
|
||||
else:
|
||||
self.forward = self.forward_native
|
||||
|
||||
def forward_native(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
PyTorch-native implementation of top-k and top-p sampling.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
logits = apply_top_k_top_p(logits, k, p)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
|
||||
def forward_cuda(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""More optimized implementation for top-k and top-p sampling."""
|
||||
if k is None and p is None:
|
||||
# We prefer `random_sample` over `flashinfer_sample` when sorting is
|
||||
# not needed. This is because `random_sample` does not require
|
||||
# CPU-GPU synchronization while `flashinfer_sample` does.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
if generators:
|
||||
logger.warning("FlashInfer 0.2.3+ does not support "
|
||||
"per-request generators. Falling back to "
|
||||
"PyTorch-native implementation.")
|
||||
return self.forward_native(logits, generators, k, p)
|
||||
# flashinfer sampling functions expect contiguous logits.
|
||||
# In flex_attn/triton_attn fp32 inference, logits can be non-contiguous
|
||||
# because of slicing operation in logits_processor.
|
||||
return flashinfer_sample(logits.contiguous(), k, p, generators)
|
||||
|
||||
def forward_tpu(
|
||||
self,
|
||||
logits: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
logits = apply_top_k_top_p_tpu(logits, k, p)
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
return random_sample(probs, generators)
|
||||
|
||||
|
||||
def apply_top_k_top_p_tpu(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
p: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k and top-p optimized for TPU.
|
||||
|
||||
This algorithm avoids using torch.scatter which is extremely slow on TPU.
|
||||
This is achieved by finding a "cut-off" element in the original logit, and
|
||||
after thresholding the logit using this cut-off, the remaining elements
|
||||
shall constitute the top-p set.
|
||||
|
||||
Note: in the case of tie (i.e. multipple cut-off elements present in the
|
||||
logit), all tie elements are included in the top-p set. In other words,
|
||||
this function does not break ties. Instead, these tie tokens have equal
|
||||
chance of being chosen during final sampling, so we can consider the tie
|
||||
being broken then.
|
||||
"""
|
||||
probs = logits.softmax(dim=-1)
|
||||
probs_sort, _ = probs.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
top_k_count = probs_sort.size(1) - k.to(torch.long) # shape: (batch, )
|
||||
top_k_count = top_k_count.unsqueeze(dim=1)
|
||||
top_k_cutoff = probs_sort.gather(-1, top_k_count)
|
||||
|
||||
# Make sure the no top-k rows are no-op.
|
||||
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1)
|
||||
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf"))
|
||||
|
||||
elements_to_discard = probs < top_k_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
cumprob = torch.cumsum(probs_sort, dim=-1)
|
||||
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1)
|
||||
top_p_mask[:, -1] = False # at least one
|
||||
|
||||
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1)
|
||||
top_p_cutoff = probs_sort.gather(-1, top_p_count)
|
||||
elements_to_discard = probs < top_p_cutoff
|
||||
logits.masked_fill_(elements_to_discard, -float("inf"))
|
||||
|
||||
return logits
|
||||
|
||||
|
||||
def apply_top_k_top_p(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
) -> torch.Tensor:
|
||||
"""Apply top-k and top-p masks to the logits.
|
||||
|
||||
If a top-p is used, this function will sort the logits tensor,
|
||||
which can be slow for large batches.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
if p is None:
|
||||
if k is None:
|
||||
return logits
|
||||
|
||||
# Avoid sorting vocab for top-k only case.
|
||||
return apply_top_k_only(logits, k)
|
||||
|
||||
logits_sort, logits_idx = logits.sort(dim=-1, descending=False)
|
||||
|
||||
if k is not None:
|
||||
# Apply top-k.
|
||||
top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B
|
||||
# Get all the top_k values.
|
||||
top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1))
|
||||
top_k_mask = logits_sort < top_k_mask
|
||||
logits_sort.masked_fill_(top_k_mask, -float("inf"))
|
||||
|
||||
if p is not None:
|
||||
# Apply top-p.
|
||||
probs_sort = logits_sort.softmax(dim=-1)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort)
|
||||
top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1)
|
||||
# at least one
|
||||
top_p_mask[:, -1] = False
|
||||
logits_sort.masked_fill_(top_p_mask, -float("inf"))
|
||||
|
||||
# Re-sort the probabilities.
|
||||
logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort)
|
||||
return logits
|
||||
|
||||
|
||||
def apply_top_k_only(
|
||||
logits: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Apply top-k mask to the logits.
|
||||
|
||||
This implementation doesn't involve sorting the entire vocab.
|
||||
|
||||
The logits tensor may be updated in-place.
|
||||
"""
|
||||
no_top_k_mask = k == logits.shape[1]
|
||||
# Set non-top-k rows to 1 so that we can gather.
|
||||
k = k.masked_fill(no_top_k_mask, 1)
|
||||
max_top_k = k.max()
|
||||
# topk.values tensor has shape [batch_size, max_top_k].
|
||||
# Convert top k to 0-based index in range [0, max_top_k).
|
||||
k_index = k.sub_(1).unsqueeze(1)
|
||||
top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long())
|
||||
# Handle non-topk rows.
|
||||
top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf"))
|
||||
logits.masked_fill_(logits < top_k_mask, -float("inf"))
|
||||
return logits
|
||||
|
||||
|
||||
def random_sample(
|
||||
probs: torch.Tensor,
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Randomly sample from the probabilities.
|
||||
|
||||
We use this function instead of torch.multinomial because torch.multinomial
|
||||
causes CPU-GPU synchronization.
|
||||
"""
|
||||
q = torch.empty_like(probs)
|
||||
# NOTE(woosuk): To batch-process the requests without their own seeds,
|
||||
# which is the common case, we first assume that every request does
|
||||
# not have its own seed. Then, we overwrite the values for the requests
|
||||
# that have their own seeds.
|
||||
if len(generators) != probs.shape[0]:
|
||||
q.exponential_()
|
||||
if generators:
|
||||
# TODO(woosuk): This can be slow because we handle each request
|
||||
# one by one. Optimize this.
|
||||
for i, generator in generators.items():
|
||||
q[i].exponential_(generator=generator)
|
||||
return probs.div_(q).argmax(dim=-1).view(-1)
|
||||
|
||||
|
||||
def flashinfer_sample(
|
||||
logits: torch.Tensor,
|
||||
k: Optional[torch.Tensor],
|
||||
p: Optional[torch.Tensor],
|
||||
generators: dict[int, torch.Generator],
|
||||
) -> torch.Tensor:
|
||||
"""Sample from the logits using FlashInfer.
|
||||
|
||||
Statistically, this function is equivalent to the `random_sample` function.
|
||||
However, this function is faster because it avoids sorting the logits tensor
|
||||
via rejection sampling.
|
||||
|
||||
NOTE: The outputs of this function do not necessarily match the outputs of
|
||||
the `random_sample` function. It only guarantees that the outputs are
|
||||
statistically equivalent.
|
||||
|
||||
NOTE: This function includes CPU-GPU synchronization, while `random_sample`
|
||||
does not. Call this function at the end of the forward pass to minimize
|
||||
the synchronization overhead.
|
||||
"""
|
||||
assert not (k is None and p is None)
|
||||
if k is None:
|
||||
# Top-p only.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
next_token_ids = flashinfer.sampling.top_p_sampling_from_probs(
|
||||
probs, p, deterministic=True)
|
||||
elif p is None:
|
||||
# Top-k only.
|
||||
probs = logits.softmax(dim=-1, dtype=torch.float32)
|
||||
next_token_ids = flashinfer.sampling.top_k_sampling_from_probs(
|
||||
probs, k, deterministic=True)
|
||||
else:
|
||||
# Both top-k and top-p.
|
||||
next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits(
|
||||
logits, k, p, deterministic=True)
|
||||
|
||||
return next_token_ids.view(-1)
|
||||
Reference in New Issue
Block a user