# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from collections.abc import Callable from typing import Any import numpy as np import torch import torch.nn as nn from tqdm import tqdm from vllm.config import VllmConfig from vllm.config.compilation import CUDAGraphMode from vllm.distributed.parallel_state import graph_capture, is_global_first_rank from vllm.forward_context import BatchDescriptor, set_forward_context from vllm.model_executor.offloader.base import get_offloader from vllm.utils.math_utils import cdiv from vllm.v1.kv_cache_interface import KVCacheConfig from vllm.v1.worker.gpu.attn_utils import ( build_attn_metadata, build_slot_mappings_by_layer, ) from vllm.v1.worker.gpu.block_table import BlockTables from vllm.v1.worker.gpu.dp_utils import make_num_tokens_across_dp from vllm.v1.worker.gpu.input_batch import InputBuffers from vllm.v1.worker.utils import AttentionGroup class CudaGraphManager: def __init__( self, vllm_config: VllmConfig, uses_mrope: bool, use_aux_hidden_state_outputs: bool, device: torch.device, ): self.vllm_config = vllm_config self.scheduler_config = vllm_config.scheduler_config self.uses_mrope = uses_mrope self.use_aux_hidden_state_outputs = use_aux_hidden_state_outputs self.device = device self.max_model_len = vllm_config.model_config.max_model_len self.max_num_reqs = self.scheduler_config.max_num_seqs self.max_num_tokens = self.scheduler_config.max_num_batched_tokens self.dp_size = vllm_config.parallel_config.data_parallel_size self.uniform_decode_query_len = 1 spec_config = vllm_config.speculative_config if spec_config is not None: self.uniform_decode_query_len += spec_config.num_speculative_tokens self.compilation_config = vllm_config.compilation_config assert self.compilation_config is not None self.cudagraph_mode = self.compilation_config.cudagraph_mode use_uniform_decode_cudagraph = ( self.cudagraph_mode.decode_mode() == CUDAGraphMode.FULL and self.cudagraph_mode.separate_routine() ) self.cudagraph_sizes, self.uniform_decode_cudagraph_sizes = get_cudagraph_sizes( self.compilation_config.cudagraph_capture_sizes, self.max_num_reqs, self.max_num_tokens, self.cudagraph_mode, self.uniform_decode_query_len, use_uniform_decode_cudagraph, ) self.graphs: dict[int, torch.cuda.CUDAGraph] = {} self.pool = None if self.cudagraph_mode != CUDAGraphMode.NONE: self.pool = torch.cuda.graph_pool_handle() self.hidden_states: torch.Tensor | None = None self.aux_hidden_states: list[torch.Tensor] = [] def needs_capture(self) -> bool: return len(self.cudagraph_sizes) > 0 def get_cudagraph_size( self, num_tokens: int, uniform_decode: bool = False ) -> int | None: if uniform_decode and self.uniform_decode_cudagraph_sizes: return self.uniform_decode_cudagraph_sizes.get(num_tokens) return self.cudagraph_sizes.get(num_tokens) def capture_graph( self, num_tokens: int, capture_cg_mode: CUDAGraphMode, model: nn.Module, input_buffers: InputBuffers, mrope_positions: torch.Tensor | None, inputs_embeds: torch.Tensor | None, block_tables: BlockTables, attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, has_lora: bool = False, uniform_decode: bool = False, ) -> None: # select and check capture function assert capture_cg_mode in [CUDAGraphMode.PIECEWISE, CUDAGraphMode.FULL], ( f"Invalid capture_cudagraph_mode for capture: {capture_cg_mode}" ) if capture_cg_mode == CUDAGraphMode.PIECEWISE: capture_fn = self._capture_piecewise_graph else: capture_fn = self._capture_full_graph # prepare inputs if uniform_decode: num_reqs = min( cdiv(num_tokens, self.uniform_decode_query_len), self.max_num_reqs, ) else: num_reqs = min(num_tokens, self.max_num_reqs) input_ids = input_buffers.input_ids[:num_tokens] positions = input_buffers.positions[:num_tokens] if self.uses_mrope: assert mrope_positions is not None positions = mrope_positions[:, :num_tokens] if inputs_embeds is not None: inputs_embeds = inputs_embeds[:num_tokens] attn_metadata, slot_mappings = prepare_inputs_to_capture( num_reqs, num_tokens, input_buffers, block_tables, attn_groups, self.max_model_len, kv_cache_config, uniform_decode_query_len=( self.uniform_decode_query_len if uniform_decode else 0 ), ) num_tokens_across_dp = make_num_tokens_across_dp(self.dp_size, num_tokens) # Warm up. with set_forward_context( attn_metadata, self.vllm_config, num_tokens=num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, slot_mapping=slot_mappings, ): model_output = model( input_ids=input_ids, positions=positions, inputs_embeds=inputs_embeds, ) if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output else: hidden_states = model_output aux_hidden_states = None # Allocate output buffers if not already done. if self.hidden_states is None: self.hidden_states = torch.empty_like(hidden_states) if self.use_aux_hidden_state_outputs and not self.aux_hidden_states: self.aux_hidden_states = [torch.empty_like(x) for x in aux_hidden_states] capture_fn( num_tokens=num_tokens, num_reqs=num_reqs, model=model, input_ids=input_ids, positions=positions, inputs_embeds=inputs_embeds, num_tokens_across_dp=num_tokens_across_dp, attn_metadata=attn_metadata, slot_mappings=slot_mappings, has_lora=has_lora, ) def _capture_full_graph( self, num_tokens: int, num_reqs: int, model: nn.Module, input_ids: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor | None, num_tokens_across_dp: torch.Tensor, attn_metadata: dict[str, Any] | None, slot_mappings: dict[str, torch.Tensor] | None, has_lora: bool = False, ) -> None: assert attn_metadata is not None # Capture the graph. assert num_tokens not in self.graphs graph = torch.cuda.CUDAGraph() # Sync offloader's copy stream before capture. # Ensure any pre-capture prefetches from offloader are complete. get_offloader().sync_prev_onload() with ( set_forward_context( attn_metadata=attn_metadata, vllm_config=self.vllm_config, num_tokens=num_tokens, cudagraph_runtime_mode=CUDAGraphMode.NONE, num_tokens_across_dp=num_tokens_across_dp, slot_mapping=slot_mappings, ), torch.cuda.graph(graph, self.pool), ): model_output = model( input_ids=input_ids, positions=positions, inputs_embeds=inputs_embeds, ) # Join offloader's copy stream after forward to avoid unjoined # stream error. The last layer's start_prefetch forks copy_stream, # but wait_prefetch only happens in the next forward pass. get_offloader().join_after_forward() if self.use_aux_hidden_state_outputs: hidden_states, aux_hidden_states = model_output else: hidden_states = model_output aux_hidden_states = None # Copy outputs to the output buffers. assert self.hidden_states is not None self.hidden_states[:num_tokens] = hidden_states if self.use_aux_hidden_state_outputs: for i, aux_hidden in enumerate(aux_hidden_states): self.aux_hidden_states[i][:num_tokens] = aux_hidden self.graphs[num_tokens] = graph def _capture_piecewise_graph( self, num_tokens: int, num_reqs: int, model: nn.Module, input_ids: torch.Tensor, positions: torch.Tensor, inputs_embeds: torch.Tensor | None, num_tokens_across_dp: torch.Tensor, attn_metadata: dict[str, Any] | None, slot_mappings: dict[str, torch.Tensor] | None, has_lora: bool = False, ) -> None: # create batch descriptor for piecewise cudagraph dispatch key batch_descriptor = BatchDescriptor(num_tokens=num_tokens, has_lora=has_lora) # Capture run - CUDAGraphWrapper inside torch.compile will auto capture. with set_forward_context( attn_metadata=None, # piecewise no need attn_metadata vllm_config=self.vllm_config, num_tokens=num_tokens, cudagraph_runtime_mode=CUDAGraphMode.PIECEWISE, num_tokens_across_dp=num_tokens_across_dp, batch_descriptor=batch_descriptor, slot_mapping=slot_mappings, ): model( input_ids=input_ids, positions=positions, inputs_embeds=inputs_embeds, ) @torch.inference_mode() def capture( self, model: nn.Module, input_buffers: InputBuffers, mrope_positions: torch.Tensor | None, inputs_embeds: torch.Tensor | None, block_tables: BlockTables, attn_groups: list[list[AttentionGroup]], kv_cache_config: KVCacheConfig, has_lora: bool = False, ) -> None: common_kwargs = dict( device=self.device, capture_fn=self.capture_graph, model=model, input_buffers=input_buffers, mrope_positions=mrope_positions, inputs_embeds=inputs_embeds, block_tables=block_tables, attn_groups=attn_groups, kv_cache_config=kv_cache_config, has_lora=has_lora, ) # Phase 1: Capture for mixed prefill-decode batches if needed. mixed_mode = self.cudagraph_mode.mixed_mode() if mixed_mode != CUDAGraphMode.NONE: capture_graphs( cudagraph_sizes=self.cudagraph_sizes, capture_cudagraph_mode=mixed_mode, desc=f"Capturing CUDA graphs (mixed, {mixed_mode.name})", uniform_decode=False, **common_kwargs, ) # Phase 2: Capture FULL graphs for uniform decode batches if needed. # This is only needed if we use a separate routine for decode batches # and the decode_mode is FULL. if self.uniform_decode_cudagraph_sizes: capture_graphs( cudagraph_sizes=self.uniform_decode_cudagraph_sizes, capture_cudagraph_mode=CUDAGraphMode.FULL, desc="Capturing CUDA graphs (decode, FULL)", uniform_decode=True, **common_kwargs, ) def get_cudagraph_runtime_mode( self, num_reqs: int, num_tokens: int, max_query_len: int ) -> tuple[CUDAGraphMode, int | None]: is_uniform_decode = (max_query_len == self.uniform_decode_query_len) and ( num_tokens == max_query_len * num_reqs ) cudagraph_size = self.get_cudagraph_size(num_tokens, is_uniform_decode) if cudagraph_size is None: cudagraph_mode = CUDAGraphMode.NONE elif is_uniform_decode: cudagraph_mode = self.cudagraph_mode.decode_mode() else: cudagraph_mode = self.cudagraph_mode.mixed_mode() if ( cudagraph_mode == CUDAGraphMode.FULL and cudagraph_size is not None and cudagraph_size not in self.graphs ): # If graph wasn't captured yet, fall back to eager. # This might happen when the dummy run is called before capture. cudagraph_mode = CUDAGraphMode.NONE cudagraph_size = None return cudagraph_mode, cudagraph_size def run_fullgraph( self, num_tokens: int ) -> torch.Tensor | tuple[torch.Tensor, list[torch.Tensor]]: assert num_tokens in self.graphs, f"No cudagraph for {num_tokens} tokens" # Sync offloader before replay - needed when transitioning from # eager/piecewise to full cudagraph (e.g., prefill → decode). # The previous eager iteration's start_prefetch may have queued # H2D copies on copy_stream that the graph's captured events # cannot see. Without this, replay could overwrite static buffers # while those copies are still in flight. get_offloader().sync_prev_onload() self.graphs[num_tokens].replay() assert self.hidden_states is not None hidden_states = self.hidden_states[:num_tokens] if not self.use_aux_hidden_state_outputs: return hidden_states return hidden_states, [x[:num_tokens] for x in self.aux_hidden_states] def get_cudagraph_sizes( capture_sizes: list[int] | None, max_num_reqs: int, max_num_tokens: int, cudagraph_mode: CUDAGraphMode, uniform_decode_query_len: int = 1, uniform_decode_cudagraph: bool = False, ) -> tuple[dict[int, int], dict[int, int]]: # Support both FULL and PIECEWISE cudagraph modes if cudagraph_mode == CUDAGraphMode.NONE: return {}, {} if not capture_sizes: return {}, {} capture_sizes = sorted(capture_sizes) if not capture_sizes: return {}, {} cudagraph_sizes: dict[int, int] = {} for i in range(1, capture_sizes[-1] + 1): for x in capture_sizes: if i <= x: cudagraph_sizes[i] = x break uniform_decode_cudagraph_sizes: dict[int, int] = {} if uniform_decode_cudagraph: max_num_tokens = max_num_reqs * uniform_decode_query_len uniform_decode_cudagraph_sizes = { k: v for k, v in cudagraph_sizes.items() if v <= max_num_tokens and v >= uniform_decode_query_len } return cudagraph_sizes, uniform_decode_cudagraph_sizes def capture_graphs( cudagraph_sizes: dict[int, int], device: torch.device, capture_fn: Callable, capture_cudagraph_mode: CUDAGraphMode, desc: str = "Capturing CUDA graphs", **capture_kwargs, ) -> None: # Capture larger graphs first. sizes_to_capture = sorted(set(cudagraph_sizes.values()), reverse=True) if is_global_first_rank(): sizes_to_capture = tqdm(sizes_to_capture, desc=desc) with graph_capture(device=device): for size in sizes_to_capture: capture_fn(size, capture_cudagraph_mode, **capture_kwargs) def prepare_inputs_to_capture( num_reqs: int, num_tokens: int, input_buffers: InputBuffers, block_tables: BlockTables, attn_groups: list[list[AttentionGroup]], max_model_len: int, kv_cache_config: KVCacheConfig, uniform_decode_query_len: int = 0, ) -> tuple[dict[str, Any], dict[str, torch.Tensor]]: if uniform_decode_query_len > 0: num_tokens_per_req = uniform_decode_query_len else: num_tokens_per_req = num_tokens // num_reqs query_start_loc_np = np.arange(num_reqs + 1, dtype=np.int32) * num_tokens_per_req query_start_loc_np[-1] = num_tokens query_start_loc_cpu = torch.from_numpy(query_start_loc_np) input_buffers.query_start_loc[: num_reqs + 1] = query_start_loc_cpu input_buffers.query_start_loc[num_reqs + 1 :] = num_tokens query_start_loc = input_buffers.query_start_loc[: num_reqs + 1] # HACK(woosuk): For faster warmup, we set seq_lens (GPU) to num_tokens # rather than max_model_len. input_buffers.seq_lens[:num_reqs] = num_tokens input_buffers.seq_lens[num_reqs:] = 0 input_buffers.dcp_local_seq_lens[:num_reqs] = num_tokens input_buffers.dcp_local_seq_lens[num_reqs:] = 0 input_block_tables = [x[:num_reqs] for x in block_tables.input_block_tables] slot_mappings = block_tables.slot_mappings[:, :num_tokens] slot_mappings_by_layer = build_slot_mappings_by_layer( slot_mappings, kv_cache_config ) attn_metadata = build_attn_metadata( attn_groups=attn_groups, num_reqs=num_reqs, num_tokens=num_tokens, query_start_loc_gpu=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, max_query_len=num_tokens_per_req, seq_lens=input_buffers.seq_lens, max_seq_len=max_model_len, block_tables=input_block_tables, slot_mappings=slot_mappings, kv_cache_config=kv_cache_config, dcp_local_seq_lens=input_buffers.dcp_local_seq_lens, ) return attn_metadata, slot_mappings_by_layer