# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import time from collections import defaultdict from contextlib import contextmanager from dataclasses import dataclass from typing import TYPE_CHECKING, Any, NamedTuple, Optional, Union import torch import torch.distributed as dist import vllm.envs as envs from vllm.config import CUDAGraphMode, ParallelConfig, VllmConfig from vllm.logger import init_logger from vllm.platforms import current_platform from vllm.v1.worker.ubatch_utils import UBatchSlices, is_second_ubatch_empty if TYPE_CHECKING: from vllm.attention.backends.abstract import AttentionMetadata logger = init_logger(__name__) track_batchsize: bool = envs.VLLM_LOG_BATCHSIZE_INTERVAL >= 0 last_logging_time: float = 0 forward_start_time: float = 0 batchsize_logging_interval: float = envs.VLLM_LOG_BATCHSIZE_INTERVAL batchsize_forward_time: defaultdict = defaultdict(list) class BatchDescriptor(NamedTuple): """ Batch descriptor for cudagraph dispatching. We should keep the num of items as minimal as possible to properly and uniquely describe the padded batch for cudagraph. """ num_tokens: int uniform_decode: bool = False """ False can also be used for an uniform decode batch to dispatch to the cudagraph supporting non-uniform batches. """ @property def non_uniform(self) -> "BatchDescriptor": """ Return a non-uniform version of current batch descriptor. """ return BatchDescriptor(self.num_tokens, uniform_decode=False) def _compute_sp_num_tokens(num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int) -> list[int]: sp_tokens = ((num_tokens_across_dp_cpu + sequence_parallel_size - 1) // sequence_parallel_size) sp_tokens = sp_tokens.repeat_interleave(sequence_parallel_size) return sp_tokens.tolist() def _compute_chunked_local_num_tokens(num_tokens_across_dp_cpu: torch.Tensor, sequence_parallel_size: int, max_num_tokens: int, chunk_idx: int) -> list[int]: sp_tokens = _compute_sp_num_tokens(num_tokens_across_dp_cpu, sequence_parallel_size) sp_size = len(sp_tokens) local_size = [-1] * sp_size for i in range(sp_size): # Take into account sharding if MoE activation is sequence parallel. local_size[i] = min(max_num_tokens, sp_tokens[i] - (max_num_tokens * chunk_idx)) if local_size[i] <= 0: local_size[i] = 1 # ensure lockstep even if done return local_size @dataclass class DPMetadata: max_tokens_across_dp_cpu: torch.Tensor num_tokens_across_dp_cpu: torch.Tensor # NOTE: local_sizes should only be set by the chunked_sizes context manager local_sizes: Optional[list[int]] = None @staticmethod def num_tokens_across_dp(num_tokens: int, dp_size: int, dp_rank: int) -> torch.Tensor: """ Gather the num_tokens across all DP ranks and return results in a CPU tensor of size dp_size. """ from vllm.distributed.parallel_state import get_dp_group device = current_platform.device_type group = get_dp_group().device_group # Transfering this tensor from GPU to CPU will introduce a GPU sync # point that could adversely affect performance of vllm with asynch # scheduling. This environment variable exists to quickly disable # this optimization if we run into this case. if envs.VLLM_DISABLE_NCCL_FOR_DP_SYNCHRONIZATION: logger.info_once( "Using CPU all reduce to syncronize DP padding between ranks.") device = "cpu" group = get_dp_group().cpu_group num_tokens_across_dp = [0] * dp_size num_tokens_across_dp[dp_rank] = num_tokens num_tokens_tensor = torch.tensor(num_tokens_across_dp, device=device, dtype=torch.int32) dist.all_reduce(num_tokens_tensor, group=group) return num_tokens_tensor.cpu() # Get the cumulative tokens across sequence parallel ranks. # In this case the input to the MoEs will be distributed w.r.t both # DP and TP rank. # When sp_size==1, this is just the cummulative num tokens across DP. def cu_tokens_across_sp(self, sp_size: int) -> torch.Tensor: num_tokens_across_sp_cpu = ( (self.num_tokens_across_dp_cpu - 1 + sp_size) // sp_size) num_tokens_across_sp_cpu = ( num_tokens_across_sp_cpu.repeat_interleave(sp_size)) return torch.cumsum(num_tokens_across_sp_cpu, dim=0) @staticmethod def should_ubatch_across_dp( should_ubatch: bool, orig_num_tokens_per_ubatch: int, padded_num_tokens_per_ubatch: int, dp_size: int, dp_rank: int) -> tuple[bool, Optional[torch.Tensor]]: """ 1. Decides if each DP rank is going to microbatch. Either all ranks run with microbatching or none of them do. If this function decides not to run with microbatching. It will "abort" meaning that no padding information will be returned to the caller. It will return (False, None) 2. Determines the total number of tokens that each rank will run. All ranks will be padded out so that the run with the same number of tokens Returns: tuple[ should_ubatch: Are all DP ranks going to microbatch num_tokens_after_padding: A tensor containing the total number of tokens per-microbatch for each DP rank including padding. Will be None if should_ubatch if False ] """ device = current_platform.device_type tensor = torch.zeros(3, dp_size, device=device, dtype=torch.int32) tensor[0][dp_rank] = orig_num_tokens_per_ubatch tensor[1][dp_rank] = padded_num_tokens_per_ubatch tensor[2][dp_rank] = 1 if should_ubatch else 0 from vllm.distributed.parallel_state import get_dp_group dist.all_reduce(tensor, group=get_dp_group().device_group) result: bool = bool(torch.all(tensor[2] == 1).item()) if not result: return result, None orig_num_tokens_tensor = tensor[0, :] padded_num_tokens_tensor = tensor[1, :] orig_min_num_tokens = int(orig_num_tokens_tensor.min().item()) padded_max_num_tokens = int(padded_num_tokens_tensor.max().item()) if is_second_ubatch_empty(orig_min_num_tokens, padded_max_num_tokens): logger.debug("Aborting ubatching %s %s", orig_min_num_tokens, padded_max_num_tokens) return False, None return result, padded_num_tokens_tensor.cpu() @staticmethod def make( parallel_config: ParallelConfig, attn_metadata: Any, num_tokens: int, num_tokens_across_dp_cpu: Optional[torch.Tensor] = None ) -> "DPMetadata": assert parallel_config.data_parallel_size > 1 dp_size = parallel_config.data_parallel_size dp_rank = parallel_config.data_parallel_rank if attn_metadata is not None and hasattr(attn_metadata, "num_prefill_tokens"): # for v0 attention backends batchsize = attn_metadata.num_prefill_tokens + \ attn_metadata.num_decode_tokens else: # for v1 attention backends or no attn_metadata batchsize = num_tokens # If num_tokens_across_dp is None, it will be computed by all_reduce # Otherwise, num_tokens_across_dp[dp_rank] should be equal to batchsize assert (num_tokens_across_dp_cpu is None or num_tokens_across_dp_cpu[dp_rank] == batchsize ), f"{num_tokens_across_dp_cpu[dp_rank]} {batchsize}" if num_tokens_across_dp_cpu is None: num_tokens_across_dp_cpu = DPMetadata.num_tokens_across_dp( batchsize, dp_size, dp_rank) max_tokens_across_dp_cpu = torch.max(num_tokens_across_dp_cpu) return DPMetadata(max_tokens_across_dp_cpu, num_tokens_across_dp_cpu) @contextmanager def chunked_sizes(self, sequence_parallel_size: int, max_chunk_size_per_rank: int, chunk_idx: int): """ Context manager to compute and temporarily set the per-rank local token sizes for a specific chunk during chunked forward execution. This is necessary to ensure each DP (data parallel) rank processes its designated portion of tokens in lockstep with others, even when the token counts are uneven or some ranks have completed their input early. For chunked execution, we break up the total tokens on each rank into multiple chunks (of at most `max_chunk_size_per_rank`), and for a given `chunk_idx`, this context manager sets `self.local_sizes` to the number of tokens to process in that chunk on each rank. `self.local_sizes` is only valid inside the context. Args: sequence_parallel_size: When Attn is TP and MoE layers are EP, we use SP between the layers to avoid redundant ops. We need this value to compute the chunked sizes. max_chunk_size_per_rank: The max number of tokens each rank is allowed to process in this chunk. chunk_idx: The index of the chunk to compute sizes for. """ self.local_sizes = _compute_chunked_local_num_tokens( self.num_tokens_across_dp_cpu, sequence_parallel_size, max_chunk_size_per_rank, chunk_idx) try: yield self.local_sizes finally: self.local_sizes = None @contextmanager def sp_local_sizes(self, sequence_parallel_size: int): """ Context mamager for setting self.local_sizes. Same as self.chunked_sizes but without any chunking. """ self.local_sizes = _compute_sp_num_tokens( self.num_tokens_across_dp_cpu, sequence_parallel_size) try: yield self.local_sizes finally: self.local_sizes = None def get_chunk_sizes_across_dp_rank(self) -> Optional[list[int]]: assert self.local_sizes is not None return self.local_sizes @dataclass class ForwardContext: # copy from vllm_config.compilation_config.static_forward_context no_compile_layers: dict[str, Any] """ Type AttentionMetadata for v0, Type Dict[str, AttentionMetadata] for v1, map from layer_name of each attention layer to its attention metadata Type List[Dict[str, AttentionMetadata]] for DBO. List of size two, one for each microbatch. Set dynamically for each forward pass """ attn_metadata: Union["AttentionMetadata", dict[str, "AttentionMetadata"], list[dict[str, "AttentionMetadata"]]] # TODO: remove after making all virtual_engines share the same kv cache virtual_engine: int # set dynamically for each forward pass # set dynamically for each forward pass dp_metadata: Optional[DPMetadata] = None # determine the cudagraph style at runtime to be FULL, PIECEWISE, or NONE. # by default NONE, no cudagraph is used. cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE batch_descriptor: Optional[BatchDescriptor] = None ubatch_slices: Optional[UBatchSlices] = None def __post_init__(self): assert self.cudagraph_runtime_mode in [ CUDAGraphMode.NONE, CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], \ f"Invalid cudagraph runtime mode: {self.cudagraph_runtime_mode}" _forward_context: Optional[ForwardContext] = None def get_forward_context() -> ForwardContext: """Get the current forward context.""" assert _forward_context is not None, ( "Forward context is not set. " "Please use `set_forward_context` to set the forward context.") return _forward_context def create_forward_context( attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, dp_metadata: Optional[DPMetadata] = None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: Optional[BatchDescriptor] = None, ubatch_slices: Optional[UBatchSlices] = None): return ForwardContext(no_compile_layers=vllm_config.compilation_config. static_forward_context, virtual_engine=virtual_engine, attn_metadata=attn_metadata, dp_metadata=dp_metadata, cudagraph_runtime_mode=cudagraph_runtime_mode, batch_descriptor=batch_descriptor, ubatch_slices=ubatch_slices) @contextmanager def override_forward_context(forward_context: Optional[ForwardContext]): """A context manager that overrides the current forward context. This is used to override the forward context for a specific forward pass. """ global _forward_context prev_context = _forward_context _forward_context = forward_context try: yield finally: _forward_context = prev_context @contextmanager def set_forward_context( attn_metadata: Any, vllm_config: VllmConfig, virtual_engine: int = 0, num_tokens: Optional[int] = None, num_tokens_across_dp: Optional[torch.Tensor] = None, cudagraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE, batch_descriptor: Optional[BatchDescriptor] = None, ubatch_slices: Optional[UBatchSlices] = None): """A context manager that stores the current forward context, can be attention metadata, etc. Here we can inject common logic for every model forward pass. """ global forward_start_time need_to_track_batchsize = track_batchsize and attn_metadata is not None if need_to_track_batchsize: forward_start_time = time.perf_counter() dp_metadata: Optional[DPMetadata] = None if vllm_config.parallel_config.data_parallel_size > 1 and ( attn_metadata is not None or num_tokens is not None): dp_metadata = DPMetadata.make(vllm_config.parallel_config, attn_metadata, num_tokens or 0, num_tokens_across_dp) forward_context = create_forward_context(attn_metadata, vllm_config, virtual_engine, dp_metadata, cudagraph_runtime_mode, batch_descriptor, ubatch_slices) try: with override_forward_context(forward_context): yield finally: global last_logging_time, batchsize_logging_interval if need_to_track_batchsize: if hasattr(attn_metadata, "num_prefill_tokens"): # for v0 attention backends batchsize = attn_metadata.num_prefill_tokens + \ attn_metadata.num_decode_tokens else: # for v1 attention backends batchsize = num_tokens # we use synchronous scheduling right now, # adding a sync point here should not affect # scheduling of the next batch from vllm.platforms import current_platform synchronize = current_platform.synchronize if synchronize is not None: synchronize() now = time.perf_counter() # time measurement is in milliseconds batchsize_forward_time[batchsize].append( (now - forward_start_time) * 1000) if now - last_logging_time > batchsize_logging_interval: last_logging_time = now forward_stats = [] for bs, times in batchsize_forward_time.items(): if len(times) <= 1: # can be cudagraph / profiling run continue medium = torch.quantile(torch.tensor(times), q=0.5).item() medium = round(medium, 2) forward_stats.append((bs, len(times), medium)) forward_stats.sort(key=lambda x: x[1], reverse=True) if forward_stats: logger.info(("Batchsize forward time stats " "(batchsize, count, median_time(ms)): %s"), forward_stats)