# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch import torch.nn as nn from packaging import version from vllm import envs from vllm._aiter_ops import rocm_aiter_ops from vllm.config.model import LogprobsMode from vllm.logger import init_logger from vllm.platforms import CpuArchEnum, current_platform logger = init_logger(__name__) class TopKTopPSampler(nn.Module): """ Module that performs optional top-k and top-p filtering followed by weighted random sampling of logits. Implementations may update the logits tensor in-place. """ def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs") -> None: super().__init__() self.logprobs_mode = logprobs_mode # flashinfer optimization does not apply if intermediate # logprobs/logits after top_k/top_p need to be returned if ( logprobs_mode not in ("processed_logits", "processed_logprobs") and current_platform.is_cuda() ): if envs.VLLM_USE_FLASHINFER_SAMPLER: from vllm.v1.attention.backends.flashinfer import FlashInferBackend capability = current_platform.get_device_capability() assert capability is not None if not FlashInferBackend.supports_compute_capability(capability): capability_str = capability.as_version_str() raise RuntimeError( "FlashInfer does not support compute capability " f"{capability_str}, unset VLLM_USE_FLASHINFER_SAMPLER=1." ) # Users must opt in explicitly via VLLM_USE_FLASHINFER_SAMPLER=1. logger.info_once( "Using FlashInfer for top-p & top-k sampling.", scope="global", ) self.forward = self.forward_cuda else: logger.debug_once( "FlashInfer top-p/top-k sampling is available but disabled " "by default. Set VLLM_USE_FLASHINFER_SAMPLER=1 to opt in " "after verifying accuracy for your workloads." ) self.forward = self.forward_native elif current_platform.is_cpu(): arch = current_platform.get_cpu_architecture() # Fall back to native implementation for POWERPC and RISCV. # On PowerPC argmax produces incorrect output with torch.compile. # PR: https://github.com/vllm-project/vllm/pull/26987 if arch in (CpuArchEnum.RISCV, CpuArchEnum.POWERPC): self.forward = self.forward_native else: self.forward = self.forward_cpu elif ( logprobs_mode not in ("processed_logits", "processed_logprobs") and rocm_aiter_ops.is_enabled() ): try: import aiter.ops.sampling # noqa: F401 self.aiter_ops = torch.ops.aiter logger.info_once( "Using aiter sampler on ROCm (lazy import, sampling-only)." ) self.forward = self.forward_hip except ImportError: logger.warning_once( "aiter.ops.sampling is not available on ROCm. " "Falling back to forward_native implementation." ) self.forward = self.forward_native else: self.forward = self.forward_native self.apply_top_k_top_p = apply_top_k_top_p def forward_native( self, logits: torch.Tensor, generators: dict[int, torch.Generator], k: torch.Tensor | None, p: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ PyTorch-native implementation of top-k and top-p sampling. The logits tensor may be updated in-place. """ logits = self.apply_top_k_top_p(logits, k, p) logits_to_return = None if self.logprobs_mode == "processed_logits": logits_to_return = logits elif self.logprobs_mode == "processed_logprobs": logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) probs = logits.softmax(dim=-1, dtype=torch.float32) return random_sample(probs, generators), logits_to_return def forward_cuda( self, logits: torch.Tensor, generators: dict[int, torch.Generator], k: torch.Tensor | None, p: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """More optimized implementation for top-k and top-p sampling.""" # We prefer `random_sample` over `flashinfer_sample` when sorting is # not needed. This is because `random_sample` does not require # CPU-GPU synchronization while `flashinfer_sample` does. if (k is None and p is None) or generators: if generators: logger.debug_once( "FlashInfer 0.2.3+ does not support " "per-request generators. Falling back to " "PyTorch-native implementation." ) return self.forward_native(logits, generators, k, p) assert self.logprobs_mode not in ("processed_logits", "processed_logprobs"), ( "FlashInfer does not support returning logits/logprobs" ) # flashinfer sampling functions expect contiguous logits. # In flex_attn/triton_attn fp32 inference, logits can be non-contiguous # because of slicing operation in logits_processor. return flashinfer_sample(logits.contiguous(), k, p, generators), None def forward_cpu( self, logits: torch.Tensor, generators: dict[int, torch.Generator], k: torch.Tensor | None, p: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """ PyTorch-native implementation of top-k and top-p sampling for CPU. The logits tensor may be updated in-place. """ logits = self.apply_top_k_top_p(logits, k, p) logits_to_return = None if self.logprobs_mode == "processed_logits": logits_to_return = logits elif self.logprobs_mode == "processed_logprobs": logits_to_return = logits.log_softmax(dim=-1, dtype=torch.float32) if len(generators) != logits.shape[0]: return compiled_random_sample(logits), logits_to_return else: probs = logits.softmax(dim=-1, dtype=torch.float32) q = torch.empty_like(probs) q.exponential_() for i, generator in generators.items(): q[i].exponential_(generator=generator) return probs.div_(q).argmax(dim=-1).view(-1), logits_to_return def forward_hip( self, logits: torch.Tensor, generators: dict[int, torch.Generator], k: torch.Tensor | None, p: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor | None]: """Optimized ROCm/aiter path (same structure as forward_cuda).""" if (k is None and p is None) or generators: if generators: logger.warning_once( "aiter sampler does not support per-request generators; " "falling back to PyTorch-native." ) return self.forward_native(logits, generators, k, p) assert self.logprobs_mode not in ( "processed_logits", "processed_logprobs", ), "aiter sampler does not support returning logits/logprobs." return self.aiter_sample(logits, k, p, generators), None def aiter_sample( self, logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None, generators: dict[int, torch.Generator], ) -> torch.Tensor: """Sample from logits using aiter ops.""" use_top_k = k is not None use_top_p = p is not None # Joint k+p path if use_top_p and use_top_k: probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous() next_token_ids = self.aiter_ops.top_k_top_p_sampling_from_probs( probs, None, *_to_tensor_scalar_tuple(k), *_to_tensor_scalar_tuple(p), deterministic=True, ) return next_token_ids.view(-1) # Top-p only path elif use_top_p: probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous() next_token_ids = self.aiter_ops.top_p_sampling_from_probs( probs, None, *_to_tensor_scalar_tuple(p), deterministic=True ) return next_token_ids.view(-1) # Top-k only path elif use_top_k: probs = logits.softmax(dim=-1, dtype=torch.float32).contiguous() renorm_probs = self.aiter_ops.top_k_renorm_probs( probs, *_to_tensor_scalar_tuple(k) ) return torch.multinomial(renorm_probs, num_samples=1).view(-1) raise RuntimeError("aiter_sample was called with no active top-k or top-p.") # Note: this is a workaround for # https://github.com/pytorch/pytorch/pull/151218 @torch.compile(dynamic=True) def compiled_random_sample(logits: torch.Tensor) -> torch.Tensor: probs = logits.softmax(dim=-1, dtype=torch.float32) q = torch.empty_like(probs) q.exponential_() return probs.div(q).argmax(dim=-1).view(-1) def apply_top_k_top_p( logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None, ) -> torch.Tensor: """Apply top-k and top-p masks to the logits. If a top-p is used, this function will sort the logits tensor, which can be slow for large batches. The logits tensor may be updated in-place. """ if p is None: if k is None: return logits # Avoid sorting vocab for top-k only case. return apply_top_k_only(logits, k) logits_sort, logits_idx = logits.sort(dim=-1, descending=False) if k is not None: # Apply top-k. top_k_mask = logits_sort.size(1) - k.to(torch.long) # shape: B # Get all the top_k values. top_k_mask = logits_sort.gather(1, top_k_mask.unsqueeze(dim=1)) top_k_mask = logits_sort < top_k_mask logits_sort.masked_fill_(top_k_mask, -float("inf")) if p is not None: # Apply top-p. probs_sort = logits_sort.softmax(dim=-1) probs_sum = torch.cumsum(probs_sort, dim=-1, out=probs_sort) top_p_mask = probs_sum <= 1 - p.unsqueeze(dim=1) # at least one top_p_mask[:, -1] = False logits_sort.masked_fill_(top_p_mask, -float("inf")) # Re-sort the probabilities. logits = logits_sort.scatter(dim=-1, index=logits_idx, src=logits_sort) return logits def apply_top_k_only( logits: torch.Tensor, k: torch.Tensor, ) -> torch.Tensor: """ Apply top-k mask to the logits. This implementation doesn't involve sorting the entire vocab. The logits tensor may be updated in-place. """ no_top_k_mask = k == logits.shape[1] # Set non-top-k rows to 1 so that we can gather. k = k.masked_fill(no_top_k_mask, 1) max_top_k = k.max() # topk.values tensor has shape [batch_size, max_top_k]. # Convert top k to 0-based index in range [0, max_top_k). k_index = k.sub_(1).unsqueeze(1) top_k_mask = logits.topk(max_top_k, dim=1).values.gather(1, k_index.long()) # Handle non-topk rows. top_k_mask.masked_fill_(no_top_k_mask.unsqueeze(1), -float("inf")) logits.masked_fill_(logits < top_k_mask, -float("inf")) return logits def random_sample( probs: torch.Tensor, generators: dict[int, torch.Generator], ) -> torch.Tensor: """Randomly sample from the probabilities. We use this function instead of torch.multinomial because torch.multinomial causes CPU-GPU synchronization. """ q = torch.empty_like(probs) # NOTE(woosuk): To batch-process the requests without their own seeds, # which is the common case, we first assume that every request does # not have its own seed. Then, we overwrite the values for the requests # that have their own seeds. if len(generators) != probs.shape[0]: q.exponential_() if generators: # TODO(woosuk): This can be slow because we handle each request # one by one. Optimize this. for i, generator in generators.items(): q[i].exponential_(generator=generator) return probs.div_(q).argmax(dim=-1).view(-1) def flashinfer_sample( logits: torch.Tensor, k: torch.Tensor | None, p: torch.Tensor | None, generators: dict[int, torch.Generator], ) -> torch.Tensor: """Sample from the logits using FlashInfer. Statistically, this function is equivalent to the `random_sample` function. However, this function is faster because it avoids sorting the logits tensor via rejection sampling. NOTE: The outputs of this function do not necessarily match the outputs of the `random_sample` function. It only guarantees that the outputs are statistically equivalent. NOTE: This function includes CPU-GPU synchronization, while `random_sample` does not. Call this function at the end of the forward pass to minimize the synchronization overhead. """ import flashinfer if version.parse(flashinfer.__version__) < version.parse("0.2.3"): raise ImportError( "FlashInfer version >= 0.2.3 required for top-k and top-p sampling. " ) assert not (k is None and p is None) if k is None: # Top-p only. probs = logits.softmax(dim=-1, dtype=torch.float32) next_token_ids = flashinfer.sampling.top_p_sampling_from_probs( probs, p, deterministic=True ) elif p is None: # Top-k only. probs = logits.softmax(dim=-1, dtype=torch.float32) next_token_ids = flashinfer.sampling.top_k_sampling_from_probs( probs, k, deterministic=True ) else: # Both top-k and top-p. next_token_ids = flashinfer.sampling.top_k_top_p_sampling_from_logits( logits, k, p, deterministic=True ) return next_token_ids.view(-1) def _to_tensor_scalar_tuple(x): if isinstance(x, torch.Tensor): return (x, 0) else: return (None, x)