Files
sglang/python/sglang/srt/model_executor/cuda_graph_runner.py
2025-08-24 23:17:55 -07:00

861 lines
33 KiB
Python

# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Run the model with cuda graph and torch.compile."""
from __future__ import annotations
import bisect
import gc
import inspect
import logging
import os
from contextlib import contextmanager
from typing import TYPE_CHECKING, Callable, Optional, Union
import torch
import tqdm
from torch.profiler import ProfilerActivity, profile
from sglang.srt.custom_op import CustomOp
from sglang.srt.distributed import get_tensor_model_parallel_rank
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
set_graph_pool_id,
)
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
from sglang.srt.layers.dp_attention import (
DpPaddingMode,
get_attention_tp_rank,
get_attention_tp_size,
set_dp_buffer_len,
)
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
from sglang.srt.layers.torchao_utils import save_gemlite_cache
from sglang.srt.model_executor.forward_batch_info import (
CaptureHiddenMode,
ForwardBatch,
ForwardMode,
PPProxyTensors,
enable_num_token_non_padded,
)
from sglang.srt.patch_torch import monkey_patch_torch_compile
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
from sglang.srt.utils import (
empty_context,
get_available_gpu_memory,
get_device_memory_capacity,
log_info_on_rank0,
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_sync,
require_mlp_tp_gather,
)
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from sglang.srt.model_executor.model_runner import ModelRunner
# Detect whether the current forward pass is in capture mode
is_capture_mode = False
def get_is_capture_mode():
return is_capture_mode
@contextmanager
def model_capture_mode():
global is_capture_mode
is_capture_mode = True
yield
is_capture_mode = False
@contextmanager
def freeze_gc(enable_cudagraph_gc: bool):
"""
Optimize garbage collection during CUDA graph capture.
Clean up, then freeze all remaining objects from being included
in future collections if GC is disabled during capture.
"""
gc.collect()
should_freeze = not enable_cudagraph_gc
if should_freeze:
gc.freeze()
try:
yield
finally:
if should_freeze:
gc.unfreeze()
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
for sub in model._modules.values():
if isinstance(sub, CustomOp):
if reverse:
sub.leave_torch_compile()
else:
sub.enter_torch_compile(num_tokens=num_tokens)
if isinstance(sub, torch.nn.Module):
_to_torch(sub, reverse, num_tokens)
@contextmanager
def patch_model(
model: torch.nn.Module,
enable_compile: bool,
num_tokens: int,
tp_group: GroupCoordinator,
):
"""Patch the model to make it compatible with with torch.compile"""
backup_ca_comm = None
try:
if enable_compile:
_to_torch(model, reverse=False, num_tokens=num_tokens)
backup_ca_comm = tp_group.ca_comm
# Use custom-allreduce here.
# We found the custom allreduce is much faster than the built-in allreduce in torch,
# even with ENABLE_INTRA_NODE_COMM=1.
# tp_group.ca_comm = None
yield torch.compile(
torch.no_grad()(model.forward),
mode=os.environ.get(
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
),
dynamic=False,
)
else:
yield model.forward
finally:
if enable_compile:
_to_torch(model, reverse=True, num_tokens=num_tokens)
tp_group.ca_comm = backup_ca_comm
def set_torch_compile_config():
import torch._dynamo.config
import torch._inductor.config
torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
# FIXME: tmp workaround
torch._dynamo.config.accumulated_cache_size_limit = 1024
if hasattr(torch._dynamo.config, "cache_size_limit"):
torch._dynamo.config.cache_size_limit = 1024
monkey_patch_torch_compile()
def get_batch_sizes_to_capture(model_runner: ModelRunner):
server_args = model_runner.server_args
capture_bs = server_args.cuda_graph_bs
if capture_bs is None:
if server_args.speculative_algorithm is None:
if server_args.disable_cuda_graph_padding:
capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
else:
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
else:
# Since speculative decoding requires more cuda graph memory, we
# capture less.
capture_bs = (
list(range(1, 9))
+ list(range(10, 33, 2))
+ list(range(40, 64, 8))
+ list(range(80, 161, 16))
)
gpu_mem = get_device_memory_capacity()
if gpu_mem is not None:
if gpu_mem > 90 * 1024: # H200, H20
capture_bs += list(range(160, 257, 8))
if gpu_mem > 160 * 1000: # B200, MI300
capture_bs += list(range(256, 513, 16))
if max(capture_bs) > model_runner.req_to_token_pool.size:
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
# is very small. We add more values here to make sure we capture the maximum bs.
capture_bs += [model_runner.req_to_token_pool.size]
mul_base = 1
if server_args.enable_two_batch_overlap:
mul_base *= 2
if require_gathered_buffer(server_args):
mul_base *= get_attention_tp_size()
capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
if server_args.cuda_graph_max_bs:
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
if max(capture_bs) < server_args.cuda_graph_max_bs:
capture_bs += list(
range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
)
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
capture_bs = list(sorted(set(capture_bs)))
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
compile_bs = (
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
if server_args.enable_torch_compile
else []
)
return capture_bs, compile_bs
# Reuse this memory pool across all cuda graph runners.
global_graph_memory_pool = None
def get_global_graph_memory_pool():
return global_graph_memory_pool
def set_global_graph_memory_pool(val):
global global_graph_memory_pool
global_graph_memory_pool = val
class CudaGraphRunner:
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
def __init__(self, model_runner: ModelRunner):
# Parse args
self.model_runner = model_runner
self.device = model_runner.device
self.device_module = torch.get_device_module(self.device)
self.graphs = {}
self.output_buffers = {}
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.enable_two_batch_overlap = (
model_runner.server_args.enable_two_batch_overlap
)
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
self.enable_profile_cuda_graph = (
model_runner.server_args.enable_profile_cuda_graph
)
self.tp_size = model_runner.server_args.tp_size
self.dp_size = model_runner.server_args.dp_size
self.pp_size = model_runner.server_args.pp_size
self.attn_tp_size = get_attention_tp_size()
self.attn_tp_rank = get_attention_tp_rank()
# Batch sizes to capture
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
self.capture_forward_mode = ForwardMode.DECODE
self.capture_hidden_mode = CaptureHiddenMode.NULL
self.num_tokens_per_bs = 1
if model_runner.spec_algorithm.is_eagle():
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen")
else:
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
self.num_tokens_per_bs = (
self.model_runner.server_args.speculative_num_draft_tokens
)
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
if model_runner.server_args.enable_return_hidden_states:
self.capture_hidden_mode = CaptureHiddenMode.FULL
# Attention backend
self.max_bs = max(self.capture_bs)
self.max_num_token = self.max_bs * self.num_tokens_per_bs
self.model_runner.attn_backend.init_cuda_graph_state(
self.max_bs, self.max_num_token
)
self.seq_len_fill_value = (
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
)
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
self.encoder_len_fill_value = 0
self.seq_lens_cpu = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
if self.enable_torch_compile:
set_torch_compile_config()
if self.model_runner.server_args.enable_lora:
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
# Graph inputs
with torch.device(self.device):
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
self.seq_lens = torch.full(
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
)
self.out_cache_loc = torch.zeros(
(self.max_num_token,), dtype=self._cache_loc_dtype()
)
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
self.tbo_plugin = TboCudaGraphRunnerPlugin()
# pipeline parallelism
if self.pp_size > 1:
self.pp_proxy_tensors = {
"hidden_states": torch.zeros(
(self.max_bs, self.model_runner.model_config.hidden_size),
dtype=torch.bfloat16,
),
"residual": torch.zeros(
(self.max_bs, self.model_runner.model_config.hidden_size),
dtype=torch.bfloat16,
),
}
# Speculative_inference
if model_runner.spec_algorithm.is_eagle3():
self.model_runner.model.set_eagle3_layers_to_capture()
if self.is_encoder_decoder:
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
self.encoder_lens = torch.full(
(self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32
)
else:
self.encoder_lens = None
if self.require_gathered_buffer:
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(self.dp_size,), dtype=torch.int32
)
else:
assert self.require_attn_tp_gather
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
self.global_num_tokens_for_logprob_gpu = torch.zeros(
(1,), dtype=torch.int32
)
else:
self.global_num_tokens_gpu = None
self.global_num_tokens_for_logprob_gpu = None
self.custom_mask = torch.ones(
(
(self.seq_lens.sum().item() + self.max_num_token)
* self.num_tokens_per_bs
),
dtype=torch.bool,
device=self.device,
)
self.next_token_logits_buffer = torch.zeros(
(self.max_num_token, self.model_runner.model_config.vocab_size),
dtype=torch.float,
device=self.device,
)
# Capture
try:
with model_capture_mode():
self.capture()
except RuntimeError as e:
raise Exception(
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
)
def _cache_loc_dtype(self):
return torch.int64
def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather:
cuda_graph_bs = (
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else max(forward_batch.global_num_tokens_cpu)
)
else:
cuda_graph_bs = forward_batch.batch_size
is_bs_supported = (
cuda_graph_bs in self.graphs
if self.disable_padding
else cuda_graph_bs <= self.max_bs
)
if self.require_mlp_sync:
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
# because the full_text_row_masked_out_mask tensor will always be ones
is_encoder_lens_supported = (
torch.all(forward_batch.encoder_lens > 0)
if self.is_encoder_decoder
else True
)
requested_capture_hidden_mode = max(
forward_batch.capture_hidden_mode,
(
forward_batch.spec_info.capture_hidden_mode
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
is not None
else CaptureHiddenMode.NULL
),
)
capture_hidden_mode_matches = (
requested_capture_hidden_mode == CaptureHiddenMode.NULL
or requested_capture_hidden_mode == self.capture_hidden_mode
)
is_tbo_supported = (
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
)
return (
is_bs_supported
and is_encoder_lens_supported
and is_tbo_supported
and capture_hidden_mode_matches
)
def capture(self) -> None:
profile_context = empty_context()
if self.enable_profile_cuda_graph:
profile_context = profile(
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
record_shapes=True,
)
# Trigger CUDA graph capture for specific shapes.
# Capture the large shapes first so that the smaller shapes
# can reuse the memory pool allocated for the large shapes.
with freeze_gc(
self.model_runner.server_args.enable_cudagraph_gc
), graph_capture() as graph_capture_context:
with profile_context as prof:
self.stream = graph_capture_context.stream
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
empty_cache=False,
)
# Reverse the order to enable better memory sharing across cuda graphs.
capture_range = (
tqdm.tqdm(list(reversed(self.capture_bs)))
if get_tensor_model_parallel_rank() == 0
else reversed(self.capture_bs)
)
for i, bs in enumerate(capture_range):
if get_tensor_model_parallel_rank() == 0:
avail_mem = get_available_gpu_memory(
self.model_runner.device,
self.model_runner.gpu_id,
empty_cache=False,
)
capture_range.set_description(
f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
)
with patch_model(
self.model_runner.model,
bs in self.compile_bs,
num_tokens=bs * self.num_tokens_per_bs,
tp_group=self.model_runner.tp_group,
) as forward:
(
graph,
output_buffers,
) = self.capture_one_batch_size(bs, forward)
self.graphs[bs] = graph
self.output_buffers[bs] = output_buffers
# Save gemlite cache after each capture
save_gemlite_cache()
if self.enable_profile_cuda_graph:
log_message = (
"Sorted by CUDA Time:\n"
+ prof.key_averages(group_by_input_shape=True).table(
sort_by="cuda_time_total", row_limit=10
)
+ "\n\nSorted by CPU Time:\n"
+ prof.key_averages(group_by_input_shape=True).table(
sort_by="cpu_time_total", row_limit=10
)
)
logger.info(log_message)
def _capture_graph(self, graph, pool, stream, run_once_fn):
with self.device_module.graph(graph, pool=pool, stream=stream):
out = run_once_fn()
return out
def _create_device_graph(self):
return torch.cuda.CUDAGraph()
def capture_one_batch_size(self, bs: int, forward: Callable):
graph = self._create_device_graph()
stream = self.stream
num_tokens = bs * self.num_tokens_per_bs
# Graph inputs
input_ids = self.input_ids[:num_tokens]
req_pool_indices = self.req_pool_indices[:bs]
seq_lens = self.seq_lens[:bs]
out_cache_loc = self.out_cache_loc[:num_tokens]
positions = self.positions[:num_tokens]
if self.is_encoder_decoder:
encoder_lens = self.encoder_lens[:bs]
else:
encoder_lens = None
mrope_positions = self.mrope_positions[:, :bs]
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
self.num_token_non_padded[...] = num_tokens
# pipeline parallelism
if self.pp_size > 1:
pp_proxy_tensors = PPProxyTensors(
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
)
if self.require_mlp_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[num_tokens] * self.dp_size,
dtype=torch.int32,
device=input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[num_tokens] * self.dp_size,
dtype=torch.int32,
device=input_ids.device,
)
)
global_dp_buffer_len = num_tokens * self.dp_size
elif self.require_attn_tp_gather:
self.global_num_tokens_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=input_ids.device,
)
)
self.global_num_tokens_for_logprob_gpu.copy_(
torch.tensor(
[num_tokens],
dtype=torch.int32,
device=input_ids.device,
)
)
global_dp_buffer_len = num_tokens
else:
global_dp_buffer_len = None
spec_info = self.get_spec_info(num_tokens)
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
self.capture_hidden_mode = (
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
)
if self.model_runner.server_args.enable_lora:
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
lora_ids = [None] * bs
else:
lora_ids = None
forward_batch = ForwardBatch(
forward_mode=self.capture_forward_mode,
batch_size=bs,
input_ids=input_ids,
req_pool_indices=req_pool_indices,
seq_lens=seq_lens,
next_token_logits_buffer=next_token_logits_buffer,
orig_seq_lens=seq_lens,
req_to_token_pool=self.model_runner.req_to_token_pool,
token_to_kv_pool=self.model_runner.token_to_kv_pool,
attn_backend=self.model_runner.attn_backend,
out_cache_loc=out_cache_loc,
seq_lens_sum=seq_lens.sum().item(),
encoder_lens=encoder_lens,
return_logprob=False,
positions=positions,
global_num_tokens_gpu=self.global_num_tokens_gpu,
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
global_dp_buffer_len=global_dp_buffer_len,
mrope_positions=mrope_positions,
spec_algorithm=self.model_runner.spec_algorithm,
spec_info=spec_info,
capture_hidden_mode=self.capture_hidden_mode,
num_token_non_padded=self.num_token_non_padded,
global_forward_mode=self.capture_forward_mode,
lora_ids=lora_ids,
)
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
if lora_ids is not None:
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
bs,
num_tokens,
req_pool_indices,
seq_lens,
encoder_lens,
forward_batch.forward_mode,
forward_batch.spec_info,
)
# Run and capture
def run_once():
# Clean intermediate result cache for DP attention
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
kwargs = {}
if (
self.pp_size > 1
and "pp_proxy_tensors" in inspect.signature(forward).parameters
):
kwargs["pp_proxy_tensors"] = PPProxyTensors(
{k: v.clone() for k, v in pp_proxy_tensors.tensors.items()}
)
logits_output_or_pp_proxy_tensors = forward(
input_ids,
forward_batch.positions,
forward_batch,
**kwargs,
)
return logits_output_or_pp_proxy_tensors
for _ in range(2):
self.device_module.synchronize()
self.model_runner.tp_group.barrier()
run_once()
if get_global_graph_memory_pool() is None:
set_global_graph_memory_pool(self.device_module.graph_pool_handle())
# Set graph pool id globally to be able to use symmetric memory
set_graph_pool_id(get_global_graph_memory_pool())
out = self._capture_graph(
graph, get_global_graph_memory_pool(), stream, run_once
)
return graph, out
def recapture_if_needed(self, forward_batch: ForwardBatch):
# If the required capture_hidden_mode changes, we need to recapture the graph
# These are the different factors that can influence the capture_hidden_mode
capture_hidden_mode_required_by_forward_batch = (
forward_batch.capture_hidden_mode
)
capture_hidden_mode_required_by_spec_info = getattr(
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
)
capture_hidden_mode_required_for_returning_hidden_states = (
CaptureHiddenMode.FULL
if self.model_runner.server_args.enable_return_hidden_states
else CaptureHiddenMode.NULL
)
# Determine the highest capture_hidden_mode required
# (If we have FULL, we can emulate LAST or NULL)
# (If we have LAST, we can emulate NULL)
required_capture_hidden_mode = max(
capture_hidden_mode_required_by_forward_batch,
capture_hidden_mode_required_by_spec_info,
capture_hidden_mode_required_for_returning_hidden_states,
)
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
if self.capture_hidden_mode != required_capture_hidden_mode:
self.capture_hidden_mode = required_capture_hidden_mode
self.capture()
def replay_prepare(
self,
forward_batch: ForwardBatch,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
):
self.recapture_if_needed(forward_batch)
raw_bs = forward_batch.batch_size
raw_num_token = raw_bs * self.num_tokens_per_bs
# Pad
if self.require_mlp_tp_gather:
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
max_batch_size = (
max_num_tokens / self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else max_num_tokens
)
index = bisect.bisect_left(self.capture_bs, max_batch_size)
else:
index = bisect.bisect_left(self.capture_bs, raw_bs)
bs = self.capture_bs[index]
if bs != raw_bs:
self.seq_lens.fill_(self.seq_len_fill_value)
self.out_cache_loc.zero_()
# Common inputs
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
self.positions[:raw_num_token].copy_(forward_batch.positions)
seq_lens_cpu = None
if forward_batch.seq_lens_cpu is not None:
if bs != raw_bs:
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
seq_lens_cpu = self.seq_lens_cpu[:bs]
if pp_proxy_tensors:
for key in self.pp_proxy_tensors.keys():
dim = pp_proxy_tensors[key].shape[0]
self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key])
if self.is_encoder_decoder:
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
if forward_batch.mrope_positions is not None:
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
if self.require_gathered_buffer:
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
if enable_num_token_non_padded(self.model_runner.server_args):
num_token_non_padded = forward_batch.num_token_non_padded
if self.require_gathered_buffer:
tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs
num_local_token_non_padded = torch.clamp(
num_token_non_padded - tokens_per_rank * self.attn_tp_rank,
min=0,
max=tokens_per_rank,
)
self.num_token_non_padded.copy_(num_local_token_non_padded)
else:
self.num_token_non_padded.copy_(num_token_non_padded)
if self.enable_two_batch_overlap:
self.tbo_plugin.replay_prepare(
forward_mode=self.capture_forward_mode,
bs=bs,
num_token_non_padded=len(forward_batch.input_ids),
spec_info=forward_batch.spec_info,
)
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
forward_batch.spec_info.custom_mask = self.custom_mask
# Attention backend
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
bs,
self.req_pool_indices[:bs],
self.seq_lens[:bs],
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
self.capture_forward_mode,
forward_batch.spec_info,
seq_lens_cpu=seq_lens_cpu,
)
# Store fields
self.raw_bs = raw_bs
self.raw_num_token = raw_num_token
self.bs = bs
def replay(
self,
forward_batch: ForwardBatch,
skip_attn_backend_init: bool = False,
pp_proxy_tensors: Optional[PPProxyTensors] = None,
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
if not skip_attn_backend_init:
self.replay_prepare(forward_batch, pp_proxy_tensors)
else:
# In speculative decoding, these two fields are still needed.
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
# Replay
self.graphs[self.bs].replay()
output = self.output_buffers[self.bs]
if isinstance(output, LogitsProcessorOutput):
return LogitsProcessorOutput(
next_token_logits=output.next_token_logits[: self.raw_num_token],
hidden_states=(
output.hidden_states[: self.raw_num_token]
if output.hidden_states is not None
else None
),
)
else:
assert isinstance(output, PPProxyTensors)
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
def get_spec_info(self, num_tokens: int):
spec_info = None
if self.model_runner.spec_algorithm.is_eagle():
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
if self.model_runner.is_draft_worker:
raise RuntimeError("This should not happen.")
else:
spec_info = EagleVerifyInput(
draft_token=None,
custom_mask=self.custom_mask,
positions=None,
retrive_index=None,
retrive_next_token=None,
retrive_next_sibling=None,
retrive_cum_len=None,
spec_steps=self.model_runner.server_args.speculative_num_steps,
topk=self.model_runner.server_args.speculative_eagle_topk,
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
capture_hidden_mode=CaptureHiddenMode.FULL,
seq_lens_sum=None,
seq_lens_cpu=None,
)
return spec_info
CUDA_GRAPH_CAPTURE_FAILED_MSG = (
"Possible solutions:\n"
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
"3. disable torch compile by not using --enable-torch-compile\n"
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
)