# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project """A layer that samples the next tokens from the model's outputs.""" import torch import torch.nn as nn from vllm.utils import is_pin_memory_available, make_tensor_with_pad from vllm.v1.outputs import LogprobsTensors, SamplerOutput from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.sample.ops.bad_words import apply_bad_words from vllm.v1.sample.ops.penalties import apply_all_penalties from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler from .cached_pooler import BinCountTensorPooler _SAMPLING_EPS = 1e-5 def _convert_to_tensors(output_token_ids: list[list[int]], vocab_size: int, device: torch.device) -> torch.Tensor: """ Convert the different list data structures to tensors. """ output_tokens_tensor = make_tensor_with_pad( output_token_ids, # Use the value of vocab_size as a pad since we don't have a # token_id of this value. pad=vocab_size, device="cpu", dtype=torch.int32, # init with int32 pin_memory=is_pin_memory_available(), ) return output_tokens_tensor.to(device, non_blocking=True) class Sampler(nn.Module): def forward( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, is_first_calculate, req_ids: list[str], ) -> SamplerOutput: # print(sampling_metadata.generators, len(sampling_metadata.generators), sampling_metadata.temperature_cpu) # sampling_metadata.temperature = sampling_metadata.temperature.to(logits.device) # sampling_metadata.top_p = sampling_metadata.top_p.to(logits.device) # sampling_metadata.top_k = sampling_metadata.top_k.to(logits.device) # NOTE(woosuk): Use the original logits (before any penalties or # temperature scaling) for the top-k logprobs. # This is different from the V0 sampler, which uses the logits that # is used for sampling (after penalties and temperature scaling). # TODO(rob): provide option for logprobs post sampling. # See https://vllm-dev.slack.com/archives/C07UUL8E61Z/p1735907856007919 # noqa: E501 num_logprobs = sampling_metadata.max_num_logprobs if num_logprobs is not None: raw_logprobs = self.compute_logprobs(logits) # Use float32 for the logits. logits = logits.to(torch.float32) # Apply allowed token ids. logits = self.apply_allowed_token_ids(logits, sampling_metadata) # Apply bad words exclusion. logits = self.apply_bad_words(logits, sampling_metadata) # Apply logits processors which can impact greedy sampling for processor in (sampling_metadata.logitsprocs.non_argmax_invariant): logits = processor.apply(logits) # Apply penalties (e.g., min_tokens, freq_penalties). # logits = self.apply_penalties(logits, sampling_metadata) if not sampling_metadata.no_penalties: if (not hasattr(self, "bin_count_pooler")): self.bin_count_pooler = BinCountTensorPooler(logits.shape[-1], logits.device) assert sampling_metadata.prompt_token_ids is not None buf_bin_buffer = self.bin_count_pooler.request_tensors(req_ids) batch, vocab_size = logits.shape logits_res = [] for i in range(batch): output_tokens_t = _convert_to_tensors([sampling_metadata.output_token_ids[i]], vocab_size, logits.device).to(torch.int32) output_tokens_t_lastone = output_tokens_t[:, -1:] if logits.shape[0] > 1: logits_i = torch.vacc.apply_penalties(logits[i:i+1], output_tokens_t_lastone, buf_bin_buffer[i], vocab_size, output_tokens_t_lastone.shape[-1], [sampling_metadata.frequency_penalties[i]], [sampling_metadata.presence_penalties[i]], is_first_calculate) else: logits_i = torch.vacc.apply_penalties(logits, output_tokens_t_lastone, buf_bin_buffer[i], vocab_size, output_tokens_t_lastone.shape[-1], [sampling_metadata.frequency_penalties[i]], [sampling_metadata.presence_penalties[i]], is_first_calculate) logits_res.append(logits_i) if len(logits_res) > 1: logits = torch.concat(logits_res) else: logits = logits_res[0] # Sample the next token. # sampled = self.sample(logits, sampling_metadata) sampled, _ = torch.vacc.sampler_v1(logits, sampling_metadata.top_p_cpu, sampling_metadata.top_k_cpu, sampling_metadata.temperature_cpu, int(sampling_metadata.all_greedy), int(sampling_metadata.all_random), sampling_metadata.generators) # Gather the logprobs of the topk and sampled token (if requested). # Get logprobs and rank tensors (if requested) logprobs_tensors = None if num_logprobs is None else \ self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled.long()) # These are GPU tensors. sampler_output = SamplerOutput( # The sampled tokens are expanded to 2D tensor with shape # [num_requests, 1], where each row represents one generated # token per request. sampled_token_ids=sampled.unsqueeze(-1), logprobs_tensors=logprobs_tensors, ) return sampler_output def apply_temperature( self, logits: torch.Tensor, temp: torch.Tensor, ) -> torch.Tensor: # Use in-place division to avoid creating a new tensor. return logits.div_(temp.unsqueeze(dim=1)) def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor: return logits.argmax(dim=-1).view(-1) def sample( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: """Sample logits based on sampling metadata. The various logits processing functions called in this method may update the logits tensor in-place. """ assert not (sampling_metadata.all_greedy and sampling_metadata.all_random) if sampling_metadata.all_random: greedy_sampled = None else: greedy_sampled = self.greedy_sample(logits) if sampling_metadata.all_greedy: return greedy_sampled assert sampling_metadata.temperature is not None # Apply temperature. logits = self.apply_temperature(logits, sampling_metadata.temperature) # Apply logits processors that only apply to random sampling # (argmax invariant) for processor in sampling_metadata.logitsprocs.argmax_invariant: logits = processor.apply(logits) # Apply top_k and/or top_p. random_sampled = self.topk_topp_sampler( logits, sampling_metadata.generators, sampling_metadata.top_k, sampling_metadata.top_p, ) if greedy_sampled is None: return random_sampled sampled = torch.where( sampling_metadata.temperature < _SAMPLING_EPS, greedy_sampled, random_sampled, out=greedy_sampled, # Reuse tensor ) return sampled def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor: return logits.log_softmax(dim=-1, dtype=torch.float32) def gather_logprobs( self, logprobs: torch.Tensor, num_logprobs: int, token_ids: torch.Tensor, ) -> LogprobsTensors: """ Gather logprobs for topk and sampled/prompt token. Args: logprobs: (num tokens) x (vocab) tensor num_logprobs: minimum number of logprobs to retain per token token_ids: prompt tokens (if prompt logprobs) or sampled tokens (if sampled logprobs); 1D token ID tensor with (num tokens) elements Must be int64. Returns: Top-k int indices tensor, (num tokens) x (num_logprobs + 1) Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1) Sampled token rank tensor, (num tokens) """ assert token_ids.dtype == torch.int64 # Find the topK values. topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1) # Get with the logprob of the prompt or sampled token. token_ids = token_ids.unsqueeze(-1) token_logprobs = logprobs.gather(-1, token_ids) # Compute the ranks of the actual token. token_ranks = (logprobs >= token_logprobs).sum(-1) # Concatenate together with the topk. indices = torch.cat((token_ids, topk_indices), dim=1) logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1) # Use int32 to reduce the tensor size. indices = indices.to(torch.int32) return LogprobsTensors(indices, logprobs, token_ranks) def apply_penalties( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: if not sampling_metadata.no_penalties: assert sampling_metadata.prompt_token_ids is not None logits = apply_all_penalties( logits, sampling_metadata.prompt_token_ids, sampling_metadata.presence_penalties, sampling_metadata.frequency_penalties, sampling_metadata.repetition_penalties, sampling_metadata.output_token_ids, ) return logits def apply_allowed_token_ids( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: if sampling_metadata.allowed_token_ids_mask is not None: logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf")) return logits def apply_bad_words( self, logits: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> torch.Tensor: if sampling_metadata.bad_words_token_ids: apply_bad_words( logits, sampling_metadata.bad_words_token_ids, sampling_metadata.output_token_ids, ) return logits