277 lines
11 KiB
Python
277 lines
11 KiB
Python
# SPDX-License-Identifier: Apache-2.0
|
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
|
"""A layer that samples the next tokens from the model's outputs."""
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from vllm.utils import is_pin_memory_available, make_tensor_with_pad
|
|
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
|
|
from vllm.v1.sample.metadata import SamplingMetadata
|
|
from vllm.v1.sample.ops.bad_words import apply_bad_words
|
|
from vllm.v1.sample.ops.penalties import apply_all_penalties
|
|
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
|
|
|
|
from .cached_pooler import BinCountTensorPooler
|
|
_SAMPLING_EPS = 1e-5
|
|
|
|
def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int,
|
|
device: torch.device) -> torch.Tensor:
|
|
"""
|
|
Convert the different list data structures to tensors.
|
|
"""
|
|
output_tokens_tensor = make_tensor_with_pad(
|
|
output_token_ids,
|
|
# Use the value of vocab_size as a pad since we don't have a
|
|
# token_id of this value.
|
|
pad=vocab_size,
|
|
device="cpu",
|
|
dtype=torch.int32, # init with int32
|
|
pin_memory=is_pin_memory_available(),
|
|
)
|
|
return output_tokens_tensor.to(device, non_blocking=True)
|
|
|
|
class Sampler(nn.Module):
|
|
def forward(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
is_first_calculate,
|
|
req_ids: list[str],
|
|
) -> SamplerOutput:
|
|
# print(sampling_metadata.generators, len(sampling_metadata.generators), sampling_metadata.temperature_cpu)
|
|
# sampling_metadata.temperature = sampling_metadata.temperature.to(logits.device)
|
|
# sampling_metadata.top_p = sampling_metadata.top_p.to(logits.device)
|
|
# sampling_metadata.top_k = sampling_metadata.top_k.to(logits.device)
|
|
# NOTE(woosuk): Use the original logits (before any penalties or
|
|
# temperature scaling) for the top-k logprobs.
|
|
# This is different from the V0 sampler, which uses the logits that
|
|
# is used for sampling (after penalties and temperature scaling).
|
|
# TODO(rob): provide option for logprobs post sampling.
|
|
# See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501
|
|
num_logprobs = sampling_metadata.max_num_logprobs
|
|
if num_logprobs is not None:
|
|
raw_logprobs = self.compute_logprobs(logits)
|
|
|
|
# Use float32 for the logits.
|
|
logits = logits.to(torch.float32)
|
|
# Apply allowed token ids.
|
|
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
|
|
# Apply bad words exclusion.
|
|
logits = self.apply_bad_words(logits, sampling_metadata)
|
|
|
|
# Apply logits processors which can impact greedy sampling
|
|
for processor in (sampling_metadata.logitsprocs.non_argmax_invariant):
|
|
logits = processor.apply(logits)
|
|
|
|
# Apply penalties (e.g., min_tokens, freq_penalties).
|
|
# logits = self.apply_penalties(logits, sampling_metadata)
|
|
|
|
|
|
|
|
if not sampling_metadata.no_penalties:
|
|
if (not hasattr(self, "bin_count_pooler")):
|
|
self.bin_count_pooler = BinCountTensorPooler(logits.shape[-1], logits.device)
|
|
assert sampling_metadata.prompt_token_ids is not None
|
|
buf_bin_buffer = self.bin_count_pooler.request_tensors(req_ids)
|
|
batch, vocab_size = logits.shape
|
|
logits_res = []
|
|
|
|
for i in range(batch):
|
|
output_tokens_t = _convert_to_tensors([sampling_metadata.output_token_ids[i]], vocab_size, logits.device).to(torch.int32)
|
|
output_tokens_t_lastone = output_tokens_t[:, -1:]
|
|
if logits.shape[0] > 1:
|
|
logits_i = torch.vacc.apply_penalties(logits[i:i+1],
|
|
output_tokens_t_lastone,
|
|
buf_bin_buffer[i],
|
|
vocab_size,
|
|
output_tokens_t_lastone.shape[-1],
|
|
[sampling_metadata.frequency_penalties[i]],
|
|
[sampling_metadata.presence_penalties[i]],
|
|
is_first_calculate)
|
|
else:
|
|
logits_i = torch.vacc.apply_penalties(logits,
|
|
output_tokens_t_lastone,
|
|
buf_bin_buffer[i],
|
|
vocab_size,
|
|
output_tokens_t_lastone.shape[-1],
|
|
[sampling_metadata.frequency_penalties[i]],
|
|
[sampling_metadata.presence_penalties[i]],
|
|
is_first_calculate)
|
|
logits_res.append(logits_i)
|
|
|
|
if len(logits_res) > 1:
|
|
logits = torch.concat(logits_res)
|
|
else:
|
|
logits = logits_res[0]
|
|
|
|
# Sample the next token.
|
|
|
|
# sampled = self.sample(logits, sampling_metadata)
|
|
sampled, _ = torch.vacc.sampler_v1(logits, sampling_metadata.top_p_cpu, sampling_metadata.top_k_cpu, sampling_metadata.temperature_cpu, int(sampling_metadata.all_greedy), int(sampling_metadata.all_random), sampling_metadata.generators)
|
|
|
|
# Gather the logprobs of the topk and sampled token (if requested).
|
|
# Get logprobs and rank tensors (if requested)
|
|
logprobs_tensors = None if num_logprobs is None else \
|
|
self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled.long())
|
|
|
|
|
|
# These are GPU tensors.
|
|
sampler_output = SamplerOutput(
|
|
# The sampled tokens are expanded to 2D tensor with shape
|
|
# [num_requests, 1], where each row represents one generated
|
|
# token per request.
|
|
sampled_token_ids=sampled.unsqueeze(-1),
|
|
logprobs_tensors=logprobs_tensors,
|
|
)
|
|
return sampler_output
|
|
|
|
def apply_temperature(
|
|
self,
|
|
logits: torch.Tensor,
|
|
temp: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
# Use in-place division to avoid creating a new tensor.
|
|
return logits.div_(temp.unsqueeze(dim=1))
|
|
|
|
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
|
|
return logits.argmax(dim=-1).view(-1)
|
|
|
|
def sample(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> torch.Tensor:
|
|
"""Sample logits based on sampling metadata.
|
|
|
|
The various logits processing functions called in this method
|
|
may update the logits tensor in-place.
|
|
"""
|
|
|
|
assert not (sampling_metadata.all_greedy
|
|
and sampling_metadata.all_random)
|
|
if sampling_metadata.all_random:
|
|
greedy_sampled = None
|
|
else:
|
|
greedy_sampled = self.greedy_sample(logits)
|
|
if sampling_metadata.all_greedy:
|
|
return greedy_sampled
|
|
|
|
assert sampling_metadata.temperature is not None
|
|
|
|
# Apply temperature.
|
|
logits = self.apply_temperature(logits, sampling_metadata.temperature)
|
|
|
|
# Apply logits processors that only apply to random sampling
|
|
# (argmax invariant)
|
|
for processor in sampling_metadata.logitsprocs.argmax_invariant:
|
|
logits = processor.apply(logits)
|
|
|
|
# Apply top_k and/or top_p.
|
|
random_sampled = self.topk_topp_sampler(
|
|
logits,
|
|
sampling_metadata.generators,
|
|
sampling_metadata.top_k,
|
|
sampling_metadata.top_p,
|
|
)
|
|
|
|
if greedy_sampled is None:
|
|
return random_sampled
|
|
|
|
sampled = torch.where(
|
|
sampling_metadata.temperature < _SAMPLING_EPS,
|
|
greedy_sampled,
|
|
random_sampled,
|
|
out=greedy_sampled, # Reuse tensor
|
|
)
|
|
return sampled
|
|
|
|
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
|
|
return logits.log_softmax(dim=-1, dtype=torch.float32)
|
|
|
|
def gather_logprobs(
|
|
self,
|
|
logprobs: torch.Tensor,
|
|
num_logprobs: int,
|
|
token_ids: torch.Tensor,
|
|
) -> LogprobsTensors:
|
|
"""
|
|
Gather logprobs for topk and sampled/prompt token.
|
|
|
|
Args:
|
|
logprobs: (num tokens) x (vocab) tensor
|
|
num_logprobs: minimum number of logprobs to
|
|
retain per token
|
|
token_ids: prompt tokens (if prompt logprobs)
|
|
or sampled tokens (if sampled
|
|
logprobs); 1D token ID tensor
|
|
with (num tokens) elements
|
|
Must be int64.
|
|
|
|
Returns:
|
|
Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
|
|
Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
|
|
Sampled token rank tensor, (num tokens)
|
|
"""
|
|
assert token_ids.dtype == torch.int64
|
|
# Find the topK values.
|
|
topk_logprobs, topk_indices = torch.topk(logprobs,
|
|
num_logprobs,
|
|
dim=-1)
|
|
|
|
# Get with the logprob of the prompt or sampled token.
|
|
token_ids = token_ids.unsqueeze(-1)
|
|
token_logprobs = logprobs.gather(-1, token_ids)
|
|
|
|
# Compute the ranks of the actual token.
|
|
token_ranks = (logprobs >= token_logprobs).sum(-1)
|
|
|
|
# Concatenate together with the topk.
|
|
indices = torch.cat((token_ids, topk_indices), dim=1)
|
|
logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)
|
|
|
|
# Use int32 to reduce the tensor size.
|
|
indices = indices.to(torch.int32)
|
|
|
|
return LogprobsTensors(indices, logprobs, token_ranks)
|
|
|
|
def apply_penalties(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> torch.Tensor:
|
|
if not sampling_metadata.no_penalties:
|
|
assert sampling_metadata.prompt_token_ids is not None
|
|
logits = apply_all_penalties(
|
|
logits,
|
|
sampling_metadata.prompt_token_ids,
|
|
sampling_metadata.presence_penalties,
|
|
sampling_metadata.frequency_penalties,
|
|
sampling_metadata.repetition_penalties,
|
|
sampling_metadata.output_token_ids,
|
|
)
|
|
return logits
|
|
|
|
def apply_allowed_token_ids(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> torch.Tensor:
|
|
if sampling_metadata.allowed_token_ids_mask is not None:
|
|
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
|
|
float("-inf"))
|
|
return logits
|
|
|
|
def apply_bad_words(
|
|
self,
|
|
logits: torch.Tensor,
|
|
sampling_metadata: SamplingMetadata,
|
|
) -> torch.Tensor:
|
|
if sampling_metadata.bad_words_token_ids:
|
|
apply_bad_words(
|
|
logits,
|
|
sampling_metadata.bad_words_token_ids,
|
|
sampling_metadata.output_token_ids,
|
|
)
|
|
return logits
|