# Copyright 2023-2024 SGLang Team # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Run the model with cuda graph and torch.compile.""" from __future__ import annotations import bisect import gc import inspect import logging import os from contextlib import contextmanager from typing import TYPE_CHECKING, Callable, Optional, Union import torch import tqdm from torch.profiler import ProfilerActivity, profile from sglang.srt.custom_op import CustomOp from sglang.srt.distributed import get_tensor_model_parallel_rank from sglang.srt.distributed.device_communicators.pynccl_allocator import ( set_graph_pool_id, ) from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture from sglang.srt.layers.dp_attention import ( DpPaddingMode, get_attention_tp_rank, get_attention_tp_size, set_dp_buffer_len, ) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, ForwardBatch, ForwardMode, PPProxyTensors, enable_num_token_non_padded, ) from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin from sglang.srt.utils import ( empty_context, get_available_gpu_memory, get_device_memory_capacity, log_info_on_rank0, require_attn_tp_gather, require_gathered_buffer, require_mlp_sync, require_mlp_tp_gather, ) logger = logging.getLogger(__name__) if TYPE_CHECKING: from sglang.srt.model_executor.model_runner import ModelRunner # Detect whether the current forward pass is in capture mode is_capture_mode = False def get_is_capture_mode(): return is_capture_mode @contextmanager def model_capture_mode(): global is_capture_mode is_capture_mode = True yield is_capture_mode = False @contextmanager def freeze_gc(enable_cudagraph_gc: bool): """ Optimize garbage collection during CUDA graph capture. Clean up, then freeze all remaining objects from being included in future collections if GC is disabled during capture. """ gc.collect() should_freeze = not enable_cudagraph_gc if should_freeze: gc.freeze() try: yield finally: if should_freeze: gc.unfreeze() def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int): for sub in model._modules.values(): if isinstance(sub, CustomOp): if reverse: sub.leave_torch_compile() else: sub.enter_torch_compile(num_tokens=num_tokens) if isinstance(sub, torch.nn.Module): _to_torch(sub, reverse, num_tokens) @contextmanager def patch_model( model: torch.nn.Module, enable_compile: bool, num_tokens: int, tp_group: GroupCoordinator, ): """Patch the model to make it compatible with with torch.compile""" backup_ca_comm = None try: if enable_compile: _to_torch(model, reverse=False, num_tokens=num_tokens) backup_ca_comm = tp_group.ca_comm # Use custom-allreduce here. # We found the custom allreduce is much faster than the built-in allreduce in torch, # even with ENABLE_INTRA_NODE_COMM=1. # tp_group.ca_comm = None yield torch.compile( torch.no_grad()(model.forward), mode=os.environ.get( "SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs" ), dynamic=False, ) else: yield model.forward finally: if enable_compile: _to_torch(model, reverse=True, num_tokens=num_tokens) tp_group.ca_comm = backup_ca_comm def set_torch_compile_config(): import torch._dynamo.config import torch._inductor.config torch._inductor.config.coordinate_descent_tuning = True torch._inductor.config.triton.unique_kernel_names = True torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future # FIXME: tmp workaround torch._dynamo.config.accumulated_cache_size_limit = 1024 if hasattr(torch._dynamo.config, "cache_size_limit"): torch._dynamo.config.cache_size_limit = 1024 monkey_patch_torch_compile() def get_batch_sizes_to_capture(model_runner: ModelRunner): server_args = model_runner.server_args capture_bs = server_args.cuda_graph_bs if capture_bs is None: if server_args.speculative_algorithm is None: if server_args.disable_cuda_graph_padding: capture_bs = list(range(1, 33)) + list(range(48, 161, 16)) else: capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8)) else: # Since speculative decoding requires more cuda graph memory, we # capture less. capture_bs = ( list(range(1, 9)) + list(range(10, 33, 2)) + list(range(40, 64, 8)) + list(range(80, 161, 16)) ) gpu_mem = get_device_memory_capacity() if gpu_mem is not None: if gpu_mem > 90 * 1024: # H200, H20 capture_bs += list(range(160, 257, 8)) if gpu_mem > 160 * 1000: # B200, MI300 capture_bs += list(range(256, 513, 16)) if max(capture_bs) > model_runner.req_to_token_pool.size: # In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests # is very small. We add more values here to make sure we capture the maximum bs. capture_bs += [model_runner.req_to_token_pool.size] mul_base = 1 if server_args.enable_two_batch_overlap: mul_base *= 2 if require_gathered_buffer(server_args): mul_base *= get_attention_tp_size() capture_bs = [bs for bs in capture_bs if bs % mul_base == 0] if server_args.cuda_graph_max_bs: capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs] if max(capture_bs) < server_args.cuda_graph_max_bs: capture_bs += list( range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16) ) capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size] capture_bs = list(sorted(set(capture_bs))) assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}" compile_bs = ( [bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs] if server_args.enable_torch_compile else [] ) return capture_bs, compile_bs # Reuse this memory pool across all cuda graph runners. global_graph_memory_pool = None def get_global_graph_memory_pool(): return global_graph_memory_pool def set_global_graph_memory_pool(val): global global_graph_memory_pool global_graph_memory_pool = val class CudaGraphRunner: """A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile.""" def __init__(self, model_runner: ModelRunner): # Parse args self.model_runner = model_runner self.device = model_runner.device self.device_module = torch.get_device_module(self.device) self.graphs = {} self.output_buffers = {} self.enable_torch_compile = model_runner.server_args.enable_torch_compile self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) self.require_mlp_sync = require_mlp_sync(model_runner.server_args) self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) self.enable_two_batch_overlap = ( model_runner.server_args.enable_two_batch_overlap ) self.speculative_algorithm = model_runner.server_args.speculative_algorithm self.enable_profile_cuda_graph = ( model_runner.server_args.enable_profile_cuda_graph ) self.tp_size = model_runner.server_args.tp_size self.dp_size = model_runner.server_args.dp_size self.pp_size = model_runner.server_args.pp_size self.attn_tp_size = get_attention_tp_size() self.attn_tp_rank = get_attention_tp_rank() # Batch sizes to capture self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}") self.capture_forward_mode = ForwardMode.DECODE self.capture_hidden_mode = CaptureHiddenMode.NULL self.num_tokens_per_bs = 1 if model_runner.spec_algorithm.is_eagle(): if self.model_runner.is_draft_worker: raise RuntimeError("This should not happen") else: self.capture_forward_mode = ForwardMode.TARGET_VERIFY self.num_tokens_per_bs = ( self.model_runner.server_args.speculative_num_draft_tokens ) # If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup if model_runner.server_args.enable_return_hidden_states: self.capture_hidden_mode = CaptureHiddenMode.FULL # Attention backend self.max_bs = max(self.capture_bs) self.max_num_token = self.max_bs * self.num_tokens_per_bs self.model_runner.attn_backend.init_cuda_graph_state( self.max_bs, self.max_num_token ) self.seq_len_fill_value = ( self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value() ) # FIXME(lsyin): leave it here for now, I don't know whether it is necessary self.encoder_len_fill_value = 0 self.seq_lens_cpu = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) if self.enable_torch_compile: set_torch_compile_config() if self.model_runner.server_args.enable_lora: self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs) # Graph inputs with torch.device(self.device): self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64) self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32) self.seq_lens = torch.full( (self.max_bs,), self.seq_len_fill_value, dtype=torch.int32 ) self.out_cache_loc = torch.zeros( (self.max_num_token,), dtype=self._cache_loc_dtype() ) self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64) self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64) self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32) self.tbo_plugin = TboCudaGraphRunnerPlugin() # pipeline parallelism if self.pp_size > 1: self.pp_proxy_tensors = { "hidden_states": torch.zeros( (self.max_bs, self.model_runner.model_config.hidden_size), dtype=torch.bfloat16, ), "residual": torch.zeros( (self.max_bs, self.model_runner.model_config.hidden_size), dtype=torch.bfloat16, ), } # Speculative_inference if model_runner.spec_algorithm.is_eagle3(): self.model_runner.model.set_eagle3_layers_to_capture() if self.is_encoder_decoder: # NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch self.encoder_lens = torch.full( (self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32 ) else: self.encoder_lens = None if self.require_gathered_buffer: if self.require_mlp_tp_gather: self.global_num_tokens_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) self.global_num_tokens_for_logprob_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) else: assert self.require_attn_tp_gather self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) self.global_num_tokens_for_logprob_gpu = torch.zeros( (1,), dtype=torch.int32 ) else: self.global_num_tokens_gpu = None self.global_num_tokens_for_logprob_gpu = None self.custom_mask = torch.ones( ( (self.seq_lens.sum().item() + self.max_num_token) * self.num_tokens_per_bs ), dtype=torch.bool, device=self.device, ) self.next_token_logits_buffer = torch.zeros( (self.max_num_token, self.model_runner.model_config.vocab_size), dtype=torch.float, device=self.device, ) # Capture try: with model_capture_mode(): self.capture() except RuntimeError as e: raise Exception( f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}" ) def _cache_loc_dtype(self): return torch.int64 def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: cuda_graph_bs = ( max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() else max(forward_batch.global_num_tokens_cpu) ) else: cuda_graph_bs = forward_batch.batch_size is_bs_supported = ( cuda_graph_bs in self.graphs if self.disable_padding else cuda_graph_bs <= self.max_bs ) if self.require_mlp_sync: is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0) # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph # because the full_text_row_masked_out_mask tensor will always be ones is_encoder_lens_supported = ( torch.all(forward_batch.encoder_lens > 0) if self.is_encoder_decoder else True ) requested_capture_hidden_mode = max( forward_batch.capture_hidden_mode, ( forward_batch.spec_info.capture_hidden_mode if getattr(forward_batch.spec_info, "capture_hidden_mode", None) is not None else CaptureHiddenMode.NULL ), ) capture_hidden_mode_matches = ( requested_capture_hidden_mode == CaptureHiddenMode.NULL or requested_capture_hidden_mode == self.capture_hidden_mode ) is_tbo_supported = ( forward_batch.can_run_tbo if self.enable_two_batch_overlap else True ) return ( is_bs_supported and is_encoder_lens_supported and is_tbo_supported and capture_hidden_mode_matches ) def capture(self) -> None: profile_context = empty_context() if self.enable_profile_cuda_graph: profile_context = profile( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, ) # Trigger CUDA graph capture for specific shapes. # Capture the large shapes first so that the smaller shapes # can reuse the memory pool allocated for the large shapes. with freeze_gc( self.model_runner.server_args.enable_cudagraph_gc ), graph_capture() as graph_capture_context: with profile_context as prof: self.stream = graph_capture_context.stream avail_mem = get_available_gpu_memory( self.model_runner.device, self.model_runner.gpu_id, empty_cache=False, ) # Reverse the order to enable better memory sharing across cuda graphs. capture_range = ( tqdm.tqdm(list(reversed(self.capture_bs))) if get_tensor_model_parallel_rank() == 0 else reversed(self.capture_bs) ) for i, bs in enumerate(capture_range): if get_tensor_model_parallel_rank() == 0: avail_mem = get_available_gpu_memory( self.model_runner.device, self.model_runner.gpu_id, empty_cache=False, ) capture_range.set_description( f"Capturing batches ({bs=} {avail_mem=:.2f} GB)" ) with patch_model( self.model_runner.model, bs in self.compile_bs, num_tokens=bs * self.num_tokens_per_bs, tp_group=self.model_runner.tp_group, ) as forward: ( graph, output_buffers, ) = self.capture_one_batch_size(bs, forward) self.graphs[bs] = graph self.output_buffers[bs] = output_buffers # Save gemlite cache after each capture save_gemlite_cache() if self.enable_profile_cuda_graph: log_message = ( "Sorted by CUDA Time:\n" + prof.key_averages(group_by_input_shape=True).table( sort_by="cuda_time_total", row_limit=10 ) + "\n\nSorted by CPU Time:\n" + prof.key_averages(group_by_input_shape=True).table( sort_by="cpu_time_total", row_limit=10 ) ) logger.info(log_message) def _capture_graph(self, graph, pool, stream, run_once_fn): with self.device_module.graph(graph, pool=pool, stream=stream): out = run_once_fn() return out def _create_device_graph(self): return torch.cuda.CUDAGraph() def capture_one_batch_size(self, bs: int, forward: Callable): graph = self._create_device_graph() stream = self.stream num_tokens = bs * self.num_tokens_per_bs # Graph inputs input_ids = self.input_ids[:num_tokens] req_pool_indices = self.req_pool_indices[:bs] seq_lens = self.seq_lens[:bs] out_cache_loc = self.out_cache_loc[:num_tokens] positions = self.positions[:num_tokens] if self.is_encoder_decoder: encoder_lens = self.encoder_lens[:bs] else: encoder_lens = None mrope_positions = self.mrope_positions[:, :bs] next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens] self.num_token_non_padded[...] = num_tokens # pipeline parallelism if self.pp_size > 1: pp_proxy_tensors = PPProxyTensors( {k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()} ) if self.require_mlp_tp_gather: self.global_num_tokens_gpu.copy_( torch.tensor( [num_tokens] * self.dp_size, dtype=torch.int32, device=input_ids.device, ) ) self.global_num_tokens_for_logprob_gpu.copy_( torch.tensor( [num_tokens] * self.dp_size, dtype=torch.int32, device=input_ids.device, ) ) global_dp_buffer_len = num_tokens * self.dp_size elif self.require_attn_tp_gather: self.global_num_tokens_gpu.copy_( torch.tensor( [num_tokens], dtype=torch.int32, device=input_ids.device, ) ) self.global_num_tokens_for_logprob_gpu.copy_( torch.tensor( [num_tokens], dtype=torch.int32, device=input_ids.device, ) ) global_dp_buffer_len = num_tokens else: global_dp_buffer_len = None spec_info = self.get_spec_info(num_tokens) if self.capture_hidden_mode != CaptureHiddenMode.FULL: self.capture_hidden_mode = ( spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL ) if self.model_runner.server_args.enable_lora: # It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever # `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization). lora_ids = [None] * bs else: lora_ids = None forward_batch = ForwardBatch( forward_mode=self.capture_forward_mode, batch_size=bs, input_ids=input_ids, req_pool_indices=req_pool_indices, seq_lens=seq_lens, next_token_logits_buffer=next_token_logits_buffer, orig_seq_lens=seq_lens, req_to_token_pool=self.model_runner.req_to_token_pool, token_to_kv_pool=self.model_runner.token_to_kv_pool, attn_backend=self.model_runner.attn_backend, out_cache_loc=out_cache_loc, seq_lens_sum=seq_lens.sum().item(), encoder_lens=encoder_lens, return_logprob=False, positions=positions, global_num_tokens_gpu=self.global_num_tokens_gpu, global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu, dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(), global_dp_buffer_len=global_dp_buffer_len, mrope_positions=mrope_positions, spec_algorithm=self.model_runner.spec_algorithm, spec_info=spec_info, capture_hidden_mode=self.capture_hidden_mode, num_token_non_padded=self.num_token_non_padded, global_forward_mode=self.capture_forward_mode, lora_ids=lora_ids, ) self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens) if lora_ids is not None: self.model_runner.lora_manager.prepare_lora_batch(forward_batch) # Attention backend self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph( bs, num_tokens, req_pool_indices, seq_lens, encoder_lens, forward_batch.forward_mode, forward_batch.spec_info, ) # Run and capture def run_once(): # Clean intermediate result cache for DP attention forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None set_dp_buffer_len(global_dp_buffer_len, num_tokens) kwargs = {} if ( self.pp_size > 1 and "pp_proxy_tensors" in inspect.signature(forward).parameters ): kwargs["pp_proxy_tensors"] = PPProxyTensors( {k: v.clone() for k, v in pp_proxy_tensors.tensors.items()} ) logits_output_or_pp_proxy_tensors = forward( input_ids, forward_batch.positions, forward_batch, **kwargs, ) return logits_output_or_pp_proxy_tensors for _ in range(2): self.device_module.synchronize() self.model_runner.tp_group.barrier() run_once() if get_global_graph_memory_pool() is None: set_global_graph_memory_pool(self.device_module.graph_pool_handle()) # Set graph pool id globally to be able to use symmetric memory set_graph_pool_id(get_global_graph_memory_pool()) out = self._capture_graph( graph, get_global_graph_memory_pool(), stream, run_once ) return graph, out def recapture_if_needed(self, forward_batch: ForwardBatch): # If the required capture_hidden_mode changes, we need to recapture the graph # These are the different factors that can influence the capture_hidden_mode capture_hidden_mode_required_by_forward_batch = ( forward_batch.capture_hidden_mode ) capture_hidden_mode_required_by_spec_info = getattr( forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL ) capture_hidden_mode_required_for_returning_hidden_states = ( CaptureHiddenMode.FULL if self.model_runner.server_args.enable_return_hidden_states else CaptureHiddenMode.NULL ) # Determine the highest capture_hidden_mode required # (If we have FULL, we can emulate LAST or NULL) # (If we have LAST, we can emulate NULL) required_capture_hidden_mode = max( capture_hidden_mode_required_by_forward_batch, capture_hidden_mode_required_by_spec_info, capture_hidden_mode_required_for_returning_hidden_states, ) # If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture if self.capture_hidden_mode != required_capture_hidden_mode: self.capture_hidden_mode = required_capture_hidden_mode self.capture() def replay_prepare( self, forward_batch: ForwardBatch, pp_proxy_tensors: Optional[PPProxyTensors] = None, ): self.recapture_if_needed(forward_batch) raw_bs = forward_batch.batch_size raw_num_token = raw_bs * self.num_tokens_per_bs # Pad if self.require_mlp_tp_gather: max_num_tokens = max(forward_batch.global_num_tokens_cpu) max_batch_size = ( max_num_tokens / self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() else max_num_tokens ) index = bisect.bisect_left(self.capture_bs, max_batch_size) else: index = bisect.bisect_left(self.capture_bs, raw_bs) bs = self.capture_bs[index] if bs != raw_bs: self.seq_lens.fill_(self.seq_len_fill_value) self.out_cache_loc.zero_() # Common inputs self.input_ids[:raw_num_token].copy_(forward_batch.input_ids) self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices) self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) self.positions[:raw_num_token].copy_(forward_batch.positions) seq_lens_cpu = None if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: self.seq_lens_cpu.fill_(self.seq_len_fill_value) self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu) seq_lens_cpu = self.seq_lens_cpu[:bs] if pp_proxy_tensors: for key in self.pp_proxy_tensors.keys(): dim = pp_proxy_tensors[key].shape[0] self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key]) if self.is_encoder_decoder: self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens) if forward_batch.mrope_positions is not None: self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) if self.require_gathered_buffer: self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) if enable_num_token_non_padded(self.model_runner.server_args): num_token_non_padded = forward_batch.num_token_non_padded if self.require_gathered_buffer: tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs num_local_token_non_padded = torch.clamp( num_token_non_padded - tokens_per_rank * self.attn_tp_rank, min=0, max=tokens_per_rank, ) self.num_token_non_padded.copy_(num_local_token_non_padded) else: self.num_token_non_padded.copy_(num_token_non_padded) if self.enable_two_batch_overlap: self.tbo_plugin.replay_prepare( forward_mode=self.capture_forward_mode, bs=bs, num_token_non_padded=len(forward_batch.input_ids), spec_info=forward_batch.spec_info, ) if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None: forward_batch.spec_info.custom_mask = self.custom_mask # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( bs, self.req_pool_indices[:bs], self.seq_lens[:bs], forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value, self.encoder_lens[:bs] if self.is_encoder_decoder else None, self.capture_forward_mode, forward_batch.spec_info, seq_lens_cpu=seq_lens_cpu, ) # Store fields self.raw_bs = raw_bs self.raw_num_token = raw_num_token self.bs = bs def replay( self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False, pp_proxy_tensors: Optional[PPProxyTensors] = None, ) -> Union[LogitsProcessorOutput, PPProxyTensors]: if not skip_attn_backend_init: self.replay_prepare(forward_batch, pp_proxy_tensors) else: # In speculative decoding, these two fields are still needed. self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids) self.positions[: self.raw_num_token].copy_(forward_batch.positions) # Replay self.graphs[self.bs].replay() output = self.output_buffers[self.bs] if isinstance(output, LogitsProcessorOutput): return LogitsProcessorOutput( next_token_logits=output.next_token_logits[: self.raw_num_token], hidden_states=( output.hidden_states[: self.raw_num_token] if output.hidden_states is not None else None ), ) else: assert isinstance(output, PPProxyTensors) return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()}) def get_spec_info(self, num_tokens: int): spec_info = None if self.model_runner.spec_algorithm.is_eagle(): from sglang.srt.speculative.eagle_utils import EagleVerifyInput if self.model_runner.is_draft_worker: raise RuntimeError("This should not happen.") else: spec_info = EagleVerifyInput( draft_token=None, custom_mask=self.custom_mask, positions=None, retrive_index=None, retrive_next_token=None, retrive_next_sibling=None, retrive_cum_len=None, spec_steps=self.model_runner.server_args.speculative_num_steps, topk=self.model_runner.server_args.speculative_eagle_topk, draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens, capture_hidden_mode=CaptureHiddenMode.FULL, seq_lens_sum=None, seq_lens_cpu=None, ) return spec_info CUDA_GRAPH_CAPTURE_FAILED_MSG = ( "Possible solutions:\n" "1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n" "2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n" "3. disable torch compile by not using --enable-torch-compile\n" "4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n" "Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n" )