# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import numpy as np import torch import vllm.envs as envs from vllm.config.model import LogprobsMode from vllm.sampling_params import SamplingParams from vllm.v1.worker.gpu.metrics.logits import get_num_nans from vllm.v1.worker.gpu.sample.bad_words import BadWordsState from vllm.v1.worker.gpu.sample.gumbel import gumbel_sample from vllm.v1.worker.gpu.sample.logit_bias import LogitBiasState from vllm.v1.worker.gpu.sample.logprob import compute_topk_logprobs from vllm.v1.worker.gpu.sample.output import SamplerOutput from vllm.v1.worker.gpu.sample.penalties import PenaltiesState from vllm.v1.worker.gpu.sample.states import NO_LOGPROBS, SamplingStates from vllm.v1.worker.gpu.states import RequestState class Sampler: def __init__( self, max_num_reqs: int, vocab_size: int, device: torch.device, req_states: RequestState, logprobs_mode: LogprobsMode = "raw_logprobs", num_speculative_tokens: int = 1, ): if logprobs_mode not in ("processed_logprobs", "raw_logprobs"): raise NotImplementedError(f"Unsupported logprobs_mode: {logprobs_mode}") self.logprobs_mode = logprobs_mode self.compute_nans = envs.VLLM_COMPUTE_NANS_IN_LOGITS # False by default. self.sampling_states = SamplingStates(max_num_reqs, vocab_size) self.penalties_state = PenaltiesState(req_states) self.logit_bias_state = LogitBiasState(max_num_reqs, device) self.bad_words_state = BadWordsState(req_states) self.num_speculative_tokens = num_speculative_tokens def add_request( self, req_idx: int, prompt_len: int, sampling_params: SamplingParams ) -> None: self.sampling_states.add_request(req_idx, sampling_params) self.penalties_state.add_request(req_idx, sampling_params) self.logit_bias_state.add_request(req_idx, prompt_len, sampling_params) self.bad_words_state.add_request(req_idx, sampling_params) def apply_staged_writes(self) -> None: self.sampling_states.apply_staged_writes() self.penalties_state.apply_staged_writes() self.logit_bias_state.apply_staged_writes() self.bad_words_state.apply_staged_writes() def __call__( self, logits: torch.Tensor, expanded_idx_mapping: torch.Tensor, idx_mapping_np: np.ndarray, cu_num_logits_np: np.ndarray, pos: torch.Tensor, input_ids: torch.Tensor, expanded_local_pos: torch.Tensor, ) -> SamplerOutput: # NOTE(woosuk): We intentionally compute num_nans before sampling to make clear # that num_nans is computed before applying penalties and temperature. num_nans = get_num_nans(logits) if self.compute_nans else None sampled, processed_logits = self.sample( logits, expanded_idx_mapping, idx_mapping_np, pos, input_ids, expanded_local_pos, ) max_num_logprobs = self.sampling_states.max_num_logprobs(idx_mapping_np) if max_num_logprobs != NO_LOGPROBS: if self.logprobs_mode == "processed_logprobs": logits = processed_logits expanded_logits = logits.shape[0] != idx_mapping_np.shape[0] cu_num_logits = cu_num_logits_np.tolist() if expanded_logits else None logprobs_tensors = compute_topk_logprobs( logits, max_num_logprobs, sampled, cu_num_logits ) else: logprobs_tensors = None # These are GPU tensors. sampler_output = SamplerOutput( # The sampled tokens are expanded to 2D tensor with shape # [num_requests, 1], where each row represents one generated # token per request. sampled_token_ids=sampled.view(-1, 1), logprobs_tensors=logprobs_tensors, num_nans=num_nans, ) return sampler_output def sample( self, logits: torch.Tensor, expanded_idx_mapping: torch.Tensor, idx_mapping_np: np.ndarray, pos: torch.Tensor, input_ids: torch.Tensor, expanded_local_pos: torch.Tensor, ) -> tuple[torch.Tensor, torch.Tensor]: # Copy logits to a new FP32 tensor. logits = torch.empty_like(logits, dtype=torch.float32).copy_(logits) # Apply logit bias (e.g., allowed_token_ids, min_tokens) in place. self.logit_bias_state.apply_logit_bias( logits, expanded_idx_mapping, idx_mapping_np, pos ) # Apply penalties in place. self.penalties_state.apply_penalties( logits, expanded_idx_mapping, idx_mapping_np, input_ids, expanded_local_pos, self.num_speculative_tokens, ) # Apply bad words masking in place. self.bad_words_state.apply_bad_words( logits, expanded_idx_mapping, idx_mapping_np, input_ids, expanded_local_pos, ) # Apply temperature in place. self.sampling_states.apply_temperature( logits, expanded_idx_mapping, idx_mapping_np ) # Apply min_p in place. self.sampling_states.apply_min_p(logits, expanded_idx_mapping, idx_mapping_np) # Apply top_k and/or top_p. This might or might not return a new tensor. logits = self.sampling_states.apply_top_k_top_p( logits, expanded_idx_mapping, idx_mapping_np ) # Sample the next token. sampled = gumbel_sample( logits, expanded_idx_mapping, self.sampling_states.temperature.gpu, self.sampling_states.seeds.gpu, pos, apply_temperature=False, ) return sampled, logits