861 lines
33 KiB
Python
861 lines
33 KiB
Python
# Copyright 2023-2024 SGLang Team
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
# ==============================================================================
|
|
"""Run the model with cuda graph and torch.compile."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import bisect
|
|
import gc
|
|
import inspect
|
|
import logging
|
|
import os
|
|
from contextlib import contextmanager
|
|
from typing import TYPE_CHECKING, Callable, Optional, Union
|
|
|
|
import torch
|
|
import tqdm
|
|
from torch.profiler import ProfilerActivity, profile
|
|
|
|
from sglang.srt.custom_op import CustomOp
|
|
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
|
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
|
set_graph_pool_id,
|
|
)
|
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
|
from sglang.srt.layers.dp_attention import (
|
|
DpPaddingMode,
|
|
get_attention_tp_rank,
|
|
get_attention_tp_size,
|
|
set_dp_buffer_len,
|
|
)
|
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
|
from sglang.srt.model_executor.forward_batch_info import (
|
|
CaptureHiddenMode,
|
|
ForwardBatch,
|
|
ForwardMode,
|
|
PPProxyTensors,
|
|
enable_num_token_non_padded,
|
|
)
|
|
from sglang.srt.patch_torch import monkey_patch_torch_compile
|
|
from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin
|
|
from sglang.srt.utils import (
|
|
empty_context,
|
|
get_available_gpu_memory,
|
|
get_device_memory_capacity,
|
|
log_info_on_rank0,
|
|
require_attn_tp_gather,
|
|
require_gathered_buffer,
|
|
require_mlp_sync,
|
|
require_mlp_tp_gather,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
if TYPE_CHECKING:
|
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
|
|
|
# Detect whether the current forward pass is in capture mode
|
|
is_capture_mode = False
|
|
|
|
|
|
def get_is_capture_mode():
|
|
return is_capture_mode
|
|
|
|
|
|
@contextmanager
|
|
def model_capture_mode():
|
|
global is_capture_mode
|
|
is_capture_mode = True
|
|
|
|
yield
|
|
|
|
is_capture_mode = False
|
|
|
|
|
|
@contextmanager
|
|
def freeze_gc(enable_cudagraph_gc: bool):
|
|
"""
|
|
Optimize garbage collection during CUDA graph capture.
|
|
Clean up, then freeze all remaining objects from being included
|
|
in future collections if GC is disabled during capture.
|
|
"""
|
|
gc.collect()
|
|
should_freeze = not enable_cudagraph_gc
|
|
if should_freeze:
|
|
gc.freeze()
|
|
try:
|
|
yield
|
|
finally:
|
|
if should_freeze:
|
|
gc.unfreeze()
|
|
|
|
|
|
def _to_torch(model: torch.nn.Module, reverse: bool, num_tokens: int):
|
|
for sub in model._modules.values():
|
|
if isinstance(sub, CustomOp):
|
|
if reverse:
|
|
sub.leave_torch_compile()
|
|
else:
|
|
sub.enter_torch_compile(num_tokens=num_tokens)
|
|
if isinstance(sub, torch.nn.Module):
|
|
_to_torch(sub, reverse, num_tokens)
|
|
|
|
|
|
@contextmanager
|
|
def patch_model(
|
|
model: torch.nn.Module,
|
|
enable_compile: bool,
|
|
num_tokens: int,
|
|
tp_group: GroupCoordinator,
|
|
):
|
|
"""Patch the model to make it compatible with with torch.compile"""
|
|
backup_ca_comm = None
|
|
|
|
try:
|
|
if enable_compile:
|
|
_to_torch(model, reverse=False, num_tokens=num_tokens)
|
|
backup_ca_comm = tp_group.ca_comm
|
|
# Use custom-allreduce here.
|
|
# We found the custom allreduce is much faster than the built-in allreduce in torch,
|
|
# even with ENABLE_INTRA_NODE_COMM=1.
|
|
# tp_group.ca_comm = None
|
|
yield torch.compile(
|
|
torch.no_grad()(model.forward),
|
|
mode=os.environ.get(
|
|
"SGLANG_TORCH_COMPILE_MODE", "max-autotune-no-cudagraphs"
|
|
),
|
|
dynamic=False,
|
|
)
|
|
else:
|
|
yield model.forward
|
|
finally:
|
|
if enable_compile:
|
|
_to_torch(model, reverse=True, num_tokens=num_tokens)
|
|
tp_group.ca_comm = backup_ca_comm
|
|
|
|
|
|
def set_torch_compile_config():
|
|
import torch._dynamo.config
|
|
import torch._inductor.config
|
|
|
|
torch._inductor.config.coordinate_descent_tuning = True
|
|
torch._inductor.config.triton.unique_kernel_names = True
|
|
torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
|
|
|
|
# FIXME: tmp workaround
|
|
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
|
if hasattr(torch._dynamo.config, "cache_size_limit"):
|
|
torch._dynamo.config.cache_size_limit = 1024
|
|
|
|
monkey_patch_torch_compile()
|
|
|
|
|
|
def get_batch_sizes_to_capture(model_runner: ModelRunner):
|
|
server_args = model_runner.server_args
|
|
capture_bs = server_args.cuda_graph_bs
|
|
|
|
if capture_bs is None:
|
|
if server_args.speculative_algorithm is None:
|
|
if server_args.disable_cuda_graph_padding:
|
|
capture_bs = list(range(1, 33)) + list(range(48, 161, 16))
|
|
else:
|
|
capture_bs = [1, 2, 4, 8] + list(range(16, 161, 8))
|
|
else:
|
|
# Since speculative decoding requires more cuda graph memory, we
|
|
# capture less.
|
|
capture_bs = (
|
|
list(range(1, 9))
|
|
+ list(range(10, 33, 2))
|
|
+ list(range(40, 64, 8))
|
|
+ list(range(80, 161, 16))
|
|
)
|
|
|
|
gpu_mem = get_device_memory_capacity()
|
|
if gpu_mem is not None:
|
|
if gpu_mem > 90 * 1024: # H200, H20
|
|
capture_bs += list(range(160, 257, 8))
|
|
if gpu_mem > 160 * 1000: # B200, MI300
|
|
capture_bs += list(range(256, 513, 16))
|
|
|
|
if max(capture_bs) > model_runner.req_to_token_pool.size:
|
|
# In some cases (e.g., with a small GPU or --max-running-requests), the #max-running-requests
|
|
# is very small. We add more values here to make sure we capture the maximum bs.
|
|
capture_bs += [model_runner.req_to_token_pool.size]
|
|
|
|
mul_base = 1
|
|
|
|
if server_args.enable_two_batch_overlap:
|
|
mul_base *= 2
|
|
|
|
if require_gathered_buffer(server_args):
|
|
mul_base *= get_attention_tp_size()
|
|
|
|
capture_bs = [bs for bs in capture_bs if bs % mul_base == 0]
|
|
|
|
if server_args.cuda_graph_max_bs:
|
|
capture_bs = [bs for bs in capture_bs if bs <= server_args.cuda_graph_max_bs]
|
|
if max(capture_bs) < server_args.cuda_graph_max_bs:
|
|
capture_bs += list(
|
|
range(max(capture_bs), server_args.cuda_graph_max_bs + 1, 16)
|
|
)
|
|
capture_bs = [bs for bs in capture_bs if bs <= model_runner.req_to_token_pool.size]
|
|
capture_bs = list(sorted(set(capture_bs)))
|
|
assert len(capture_bs) > 0 and capture_bs[0] > 0, f"{capture_bs=}"
|
|
compile_bs = (
|
|
[bs for bs in capture_bs if bs <= server_args.torch_compile_max_bs]
|
|
if server_args.enable_torch_compile
|
|
else []
|
|
)
|
|
return capture_bs, compile_bs
|
|
|
|
|
|
# Reuse this memory pool across all cuda graph runners.
|
|
global_graph_memory_pool = None
|
|
|
|
|
|
def get_global_graph_memory_pool():
|
|
return global_graph_memory_pool
|
|
|
|
|
|
def set_global_graph_memory_pool(val):
|
|
global global_graph_memory_pool
|
|
global_graph_memory_pool = val
|
|
|
|
|
|
class CudaGraphRunner:
|
|
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
|
|
|
def __init__(self, model_runner: ModelRunner):
|
|
# Parse args
|
|
self.model_runner = model_runner
|
|
self.device = model_runner.device
|
|
self.device_module = torch.get_device_module(self.device)
|
|
self.graphs = {}
|
|
self.output_buffers = {}
|
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
|
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
|
|
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
|
|
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
|
|
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
|
|
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
|
|
self.enable_two_batch_overlap = (
|
|
model_runner.server_args.enable_two_batch_overlap
|
|
)
|
|
self.speculative_algorithm = model_runner.server_args.speculative_algorithm
|
|
self.enable_profile_cuda_graph = (
|
|
model_runner.server_args.enable_profile_cuda_graph
|
|
)
|
|
self.tp_size = model_runner.server_args.tp_size
|
|
self.dp_size = model_runner.server_args.dp_size
|
|
self.pp_size = model_runner.server_args.pp_size
|
|
|
|
self.attn_tp_size = get_attention_tp_size()
|
|
self.attn_tp_rank = get_attention_tp_rank()
|
|
|
|
# Batch sizes to capture
|
|
self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner)
|
|
log_info_on_rank0(logger, f"Capture cuda graph bs {self.capture_bs}")
|
|
self.capture_forward_mode = ForwardMode.DECODE
|
|
self.capture_hidden_mode = CaptureHiddenMode.NULL
|
|
self.num_tokens_per_bs = 1
|
|
if model_runner.spec_algorithm.is_eagle():
|
|
if self.model_runner.is_draft_worker:
|
|
raise RuntimeError("This should not happen")
|
|
else:
|
|
self.capture_forward_mode = ForwardMode.TARGET_VERIFY
|
|
self.num_tokens_per_bs = (
|
|
self.model_runner.server_args.speculative_num_draft_tokens
|
|
)
|
|
|
|
# If returning hidden states is enabled, set initial capture hidden mode to full to avoid double-capture on startup
|
|
if model_runner.server_args.enable_return_hidden_states:
|
|
self.capture_hidden_mode = CaptureHiddenMode.FULL
|
|
|
|
# Attention backend
|
|
self.max_bs = max(self.capture_bs)
|
|
self.max_num_token = self.max_bs * self.num_tokens_per_bs
|
|
self.model_runner.attn_backend.init_cuda_graph_state(
|
|
self.max_bs, self.max_num_token
|
|
)
|
|
self.seq_len_fill_value = (
|
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
|
)
|
|
|
|
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
|
self.encoder_len_fill_value = 0
|
|
self.seq_lens_cpu = torch.full(
|
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
|
)
|
|
|
|
if self.enable_torch_compile:
|
|
set_torch_compile_config()
|
|
|
|
if self.model_runner.server_args.enable_lora:
|
|
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
|
|
|
|
# Graph inputs
|
|
with torch.device(self.device):
|
|
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
|
self.seq_lens = torch.full(
|
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
|
)
|
|
self.out_cache_loc = torch.zeros(
|
|
(self.max_num_token,), dtype=self._cache_loc_dtype()
|
|
)
|
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
|
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
|
self.tbo_plugin = TboCudaGraphRunnerPlugin()
|
|
|
|
# pipeline parallelism
|
|
if self.pp_size > 1:
|
|
self.pp_proxy_tensors = {
|
|
"hidden_states": torch.zeros(
|
|
(self.max_bs, self.model_runner.model_config.hidden_size),
|
|
dtype=torch.bfloat16,
|
|
),
|
|
"residual": torch.zeros(
|
|
(self.max_bs, self.model_runner.model_config.hidden_size),
|
|
dtype=torch.bfloat16,
|
|
),
|
|
}
|
|
|
|
# Speculative_inference
|
|
if model_runner.spec_algorithm.is_eagle3():
|
|
self.model_runner.model.set_eagle3_layers_to_capture()
|
|
|
|
if self.is_encoder_decoder:
|
|
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
|
self.encoder_lens = torch.full(
|
|
(self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32
|
|
)
|
|
else:
|
|
self.encoder_lens = None
|
|
|
|
if self.require_gathered_buffer:
|
|
if self.require_mlp_tp_gather:
|
|
self.global_num_tokens_gpu = torch.zeros(
|
|
(self.dp_size,), dtype=torch.int32
|
|
)
|
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
|
(self.dp_size,), dtype=torch.int32
|
|
)
|
|
else:
|
|
assert self.require_attn_tp_gather
|
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
|
(1,), dtype=torch.int32
|
|
)
|
|
else:
|
|
self.global_num_tokens_gpu = None
|
|
self.global_num_tokens_for_logprob_gpu = None
|
|
|
|
self.custom_mask = torch.ones(
|
|
(
|
|
(self.seq_lens.sum().item() + self.max_num_token)
|
|
* self.num_tokens_per_bs
|
|
),
|
|
dtype=torch.bool,
|
|
device=self.device,
|
|
)
|
|
self.next_token_logits_buffer = torch.zeros(
|
|
(self.max_num_token, self.model_runner.model_config.vocab_size),
|
|
dtype=torch.float,
|
|
device=self.device,
|
|
)
|
|
|
|
# Capture
|
|
try:
|
|
with model_capture_mode():
|
|
self.capture()
|
|
except RuntimeError as e:
|
|
raise Exception(
|
|
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
|
)
|
|
|
|
def _cache_loc_dtype(self):
|
|
return torch.int64
|
|
|
|
def can_run(self, forward_batch: ForwardBatch):
|
|
if self.require_mlp_tp_gather:
|
|
cuda_graph_bs = (
|
|
max(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
|
|
if self.model_runner.spec_algorithm.is_eagle()
|
|
else max(forward_batch.global_num_tokens_cpu)
|
|
)
|
|
else:
|
|
cuda_graph_bs = forward_batch.batch_size
|
|
|
|
is_bs_supported = (
|
|
cuda_graph_bs in self.graphs
|
|
if self.disable_padding
|
|
else cuda_graph_bs <= self.max_bs
|
|
)
|
|
|
|
if self.require_mlp_sync:
|
|
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
|
|
|
|
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
|
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
|
# because the full_text_row_masked_out_mask tensor will always be ones
|
|
is_encoder_lens_supported = (
|
|
torch.all(forward_batch.encoder_lens > 0)
|
|
if self.is_encoder_decoder
|
|
else True
|
|
)
|
|
|
|
requested_capture_hidden_mode = max(
|
|
forward_batch.capture_hidden_mode,
|
|
(
|
|
forward_batch.spec_info.capture_hidden_mode
|
|
if getattr(forward_batch.spec_info, "capture_hidden_mode", None)
|
|
is not None
|
|
else CaptureHiddenMode.NULL
|
|
),
|
|
)
|
|
capture_hidden_mode_matches = (
|
|
requested_capture_hidden_mode == CaptureHiddenMode.NULL
|
|
or requested_capture_hidden_mode == self.capture_hidden_mode
|
|
)
|
|
is_tbo_supported = (
|
|
forward_batch.can_run_tbo if self.enable_two_batch_overlap else True
|
|
)
|
|
|
|
return (
|
|
is_bs_supported
|
|
and is_encoder_lens_supported
|
|
and is_tbo_supported
|
|
and capture_hidden_mode_matches
|
|
)
|
|
|
|
def capture(self) -> None:
|
|
profile_context = empty_context()
|
|
if self.enable_profile_cuda_graph:
|
|
profile_context = profile(
|
|
activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
|
|
record_shapes=True,
|
|
)
|
|
|
|
# Trigger CUDA graph capture for specific shapes.
|
|
# Capture the large shapes first so that the smaller shapes
|
|
# can reuse the memory pool allocated for the large shapes.
|
|
with freeze_gc(
|
|
self.model_runner.server_args.enable_cudagraph_gc
|
|
), graph_capture() as graph_capture_context:
|
|
with profile_context as prof:
|
|
self.stream = graph_capture_context.stream
|
|
avail_mem = get_available_gpu_memory(
|
|
self.model_runner.device,
|
|
self.model_runner.gpu_id,
|
|
empty_cache=False,
|
|
)
|
|
# Reverse the order to enable better memory sharing across cuda graphs.
|
|
capture_range = (
|
|
tqdm.tqdm(list(reversed(self.capture_bs)))
|
|
if get_tensor_model_parallel_rank() == 0
|
|
else reversed(self.capture_bs)
|
|
)
|
|
for i, bs in enumerate(capture_range):
|
|
if get_tensor_model_parallel_rank() == 0:
|
|
avail_mem = get_available_gpu_memory(
|
|
self.model_runner.device,
|
|
self.model_runner.gpu_id,
|
|
empty_cache=False,
|
|
)
|
|
capture_range.set_description(
|
|
f"Capturing batches ({bs=} {avail_mem=:.2f} GB)"
|
|
)
|
|
|
|
with patch_model(
|
|
self.model_runner.model,
|
|
bs in self.compile_bs,
|
|
num_tokens=bs * self.num_tokens_per_bs,
|
|
tp_group=self.model_runner.tp_group,
|
|
) as forward:
|
|
(
|
|
graph,
|
|
output_buffers,
|
|
) = self.capture_one_batch_size(bs, forward)
|
|
self.graphs[bs] = graph
|
|
self.output_buffers[bs] = output_buffers
|
|
|
|
# Save gemlite cache after each capture
|
|
save_gemlite_cache()
|
|
|
|
if self.enable_profile_cuda_graph:
|
|
log_message = (
|
|
"Sorted by CUDA Time:\n"
|
|
+ prof.key_averages(group_by_input_shape=True).table(
|
|
sort_by="cuda_time_total", row_limit=10
|
|
)
|
|
+ "\n\nSorted by CPU Time:\n"
|
|
+ prof.key_averages(group_by_input_shape=True).table(
|
|
sort_by="cpu_time_total", row_limit=10
|
|
)
|
|
)
|
|
logger.info(log_message)
|
|
|
|
def _capture_graph(self, graph, pool, stream, run_once_fn):
|
|
with self.device_module.graph(graph, pool=pool, stream=stream):
|
|
out = run_once_fn()
|
|
return out
|
|
|
|
def _create_device_graph(self):
|
|
return torch.cuda.CUDAGraph()
|
|
|
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
|
graph = self._create_device_graph()
|
|
stream = self.stream
|
|
num_tokens = bs * self.num_tokens_per_bs
|
|
|
|
# Graph inputs
|
|
input_ids = self.input_ids[:num_tokens]
|
|
req_pool_indices = self.req_pool_indices[:bs]
|
|
seq_lens = self.seq_lens[:bs]
|
|
out_cache_loc = self.out_cache_loc[:num_tokens]
|
|
positions = self.positions[:num_tokens]
|
|
if self.is_encoder_decoder:
|
|
encoder_lens = self.encoder_lens[:bs]
|
|
else:
|
|
encoder_lens = None
|
|
mrope_positions = self.mrope_positions[:, :bs]
|
|
next_token_logits_buffer = self.next_token_logits_buffer[:num_tokens]
|
|
self.num_token_non_padded[...] = num_tokens
|
|
|
|
# pipeline parallelism
|
|
if self.pp_size > 1:
|
|
pp_proxy_tensors = PPProxyTensors(
|
|
{k: v[:num_tokens] for k, v in self.pp_proxy_tensors.items()}
|
|
)
|
|
|
|
if self.require_mlp_tp_gather:
|
|
self.global_num_tokens_gpu.copy_(
|
|
torch.tensor(
|
|
[num_tokens] * self.dp_size,
|
|
dtype=torch.int32,
|
|
device=input_ids.device,
|
|
)
|
|
)
|
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
|
torch.tensor(
|
|
[num_tokens] * self.dp_size,
|
|
dtype=torch.int32,
|
|
device=input_ids.device,
|
|
)
|
|
)
|
|
global_dp_buffer_len = num_tokens * self.dp_size
|
|
elif self.require_attn_tp_gather:
|
|
self.global_num_tokens_gpu.copy_(
|
|
torch.tensor(
|
|
[num_tokens],
|
|
dtype=torch.int32,
|
|
device=input_ids.device,
|
|
)
|
|
)
|
|
self.global_num_tokens_for_logprob_gpu.copy_(
|
|
torch.tensor(
|
|
[num_tokens],
|
|
dtype=torch.int32,
|
|
device=input_ids.device,
|
|
)
|
|
)
|
|
global_dp_buffer_len = num_tokens
|
|
else:
|
|
global_dp_buffer_len = None
|
|
|
|
spec_info = self.get_spec_info(num_tokens)
|
|
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
|
|
self.capture_hidden_mode = (
|
|
spec_info.capture_hidden_mode if spec_info else CaptureHiddenMode.NULL
|
|
)
|
|
|
|
if self.model_runner.server_args.enable_lora:
|
|
# It is safe to capture CUDA graph using empty LoRA id, as the LoRA kernels will always be launched whenever
|
|
# `--enable-lora` is set to True (and return immediately if the LoRA id is empty for perf optimization).
|
|
lora_ids = [None] * bs
|
|
else:
|
|
lora_ids = None
|
|
|
|
forward_batch = ForwardBatch(
|
|
forward_mode=self.capture_forward_mode,
|
|
batch_size=bs,
|
|
input_ids=input_ids,
|
|
req_pool_indices=req_pool_indices,
|
|
seq_lens=seq_lens,
|
|
next_token_logits_buffer=next_token_logits_buffer,
|
|
orig_seq_lens=seq_lens,
|
|
req_to_token_pool=self.model_runner.req_to_token_pool,
|
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
|
attn_backend=self.model_runner.attn_backend,
|
|
out_cache_loc=out_cache_loc,
|
|
seq_lens_sum=seq_lens.sum().item(),
|
|
encoder_lens=encoder_lens,
|
|
return_logprob=False,
|
|
positions=positions,
|
|
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
|
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
|
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
|
global_dp_buffer_len=global_dp_buffer_len,
|
|
mrope_positions=mrope_positions,
|
|
spec_algorithm=self.model_runner.spec_algorithm,
|
|
spec_info=spec_info,
|
|
capture_hidden_mode=self.capture_hidden_mode,
|
|
num_token_non_padded=self.num_token_non_padded,
|
|
global_forward_mode=self.capture_forward_mode,
|
|
lora_ids=lora_ids,
|
|
)
|
|
self.tbo_plugin.capture_one_batch_size(forward_batch, num_tokens=num_tokens)
|
|
|
|
if lora_ids is not None:
|
|
self.model_runner.lora_manager.prepare_lora_batch(forward_batch)
|
|
|
|
# Attention backend
|
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
|
bs,
|
|
num_tokens,
|
|
req_pool_indices,
|
|
seq_lens,
|
|
encoder_lens,
|
|
forward_batch.forward_mode,
|
|
forward_batch.spec_info,
|
|
)
|
|
|
|
# Run and capture
|
|
def run_once():
|
|
# Clean intermediate result cache for DP attention
|
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
|
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
|
|
|
kwargs = {}
|
|
if (
|
|
self.pp_size > 1
|
|
and "pp_proxy_tensors" in inspect.signature(forward).parameters
|
|
):
|
|
kwargs["pp_proxy_tensors"] = PPProxyTensors(
|
|
{k: v.clone() for k, v in pp_proxy_tensors.tensors.items()}
|
|
)
|
|
|
|
logits_output_or_pp_proxy_tensors = forward(
|
|
input_ids,
|
|
forward_batch.positions,
|
|
forward_batch,
|
|
**kwargs,
|
|
)
|
|
return logits_output_or_pp_proxy_tensors
|
|
|
|
for _ in range(2):
|
|
self.device_module.synchronize()
|
|
self.model_runner.tp_group.barrier()
|
|
run_once()
|
|
|
|
if get_global_graph_memory_pool() is None:
|
|
set_global_graph_memory_pool(self.device_module.graph_pool_handle())
|
|
# Set graph pool id globally to be able to use symmetric memory
|
|
set_graph_pool_id(get_global_graph_memory_pool())
|
|
out = self._capture_graph(
|
|
graph, get_global_graph_memory_pool(), stream, run_once
|
|
)
|
|
|
|
return graph, out
|
|
|
|
def recapture_if_needed(self, forward_batch: ForwardBatch):
|
|
|
|
# If the required capture_hidden_mode changes, we need to recapture the graph
|
|
|
|
# These are the different factors that can influence the capture_hidden_mode
|
|
capture_hidden_mode_required_by_forward_batch = (
|
|
forward_batch.capture_hidden_mode
|
|
)
|
|
capture_hidden_mode_required_by_spec_info = getattr(
|
|
forward_batch.spec_info, "capture_hidden_mode", CaptureHiddenMode.NULL
|
|
)
|
|
capture_hidden_mode_required_for_returning_hidden_states = (
|
|
CaptureHiddenMode.FULL
|
|
if self.model_runner.server_args.enable_return_hidden_states
|
|
else CaptureHiddenMode.NULL
|
|
)
|
|
|
|
# Determine the highest capture_hidden_mode required
|
|
# (If we have FULL, we can emulate LAST or NULL)
|
|
# (If we have LAST, we can emulate NULL)
|
|
required_capture_hidden_mode = max(
|
|
capture_hidden_mode_required_by_forward_batch,
|
|
capture_hidden_mode_required_by_spec_info,
|
|
capture_hidden_mode_required_for_returning_hidden_states,
|
|
)
|
|
|
|
# If the current hidden mode is no longer aligned with the required hidden mode, we need to set it to what is required and re-capture
|
|
if self.capture_hidden_mode != required_capture_hidden_mode:
|
|
self.capture_hidden_mode = required_capture_hidden_mode
|
|
self.capture()
|
|
|
|
def replay_prepare(
|
|
self,
|
|
forward_batch: ForwardBatch,
|
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
):
|
|
self.recapture_if_needed(forward_batch)
|
|
|
|
raw_bs = forward_batch.batch_size
|
|
raw_num_token = raw_bs * self.num_tokens_per_bs
|
|
|
|
# Pad
|
|
if self.require_mlp_tp_gather:
|
|
max_num_tokens = max(forward_batch.global_num_tokens_cpu)
|
|
max_batch_size = (
|
|
max_num_tokens / self.num_tokens_per_bs
|
|
if self.model_runner.spec_algorithm.is_eagle()
|
|
else max_num_tokens
|
|
)
|
|
index = bisect.bisect_left(self.capture_bs, max_batch_size)
|
|
else:
|
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
|
bs = self.capture_bs[index]
|
|
if bs != raw_bs:
|
|
self.seq_lens.fill_(self.seq_len_fill_value)
|
|
self.out_cache_loc.zero_()
|
|
|
|
# Common inputs
|
|
self.input_ids[:raw_num_token].copy_(forward_batch.input_ids)
|
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
|
self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc)
|
|
self.positions[:raw_num_token].copy_(forward_batch.positions)
|
|
|
|
seq_lens_cpu = None
|
|
if forward_batch.seq_lens_cpu is not None:
|
|
if bs != raw_bs:
|
|
self.seq_lens_cpu.fill_(self.seq_len_fill_value)
|
|
self.seq_lens_cpu[:raw_bs].copy_(forward_batch.seq_lens_cpu)
|
|
seq_lens_cpu = self.seq_lens_cpu[:bs]
|
|
|
|
if pp_proxy_tensors:
|
|
for key in self.pp_proxy_tensors.keys():
|
|
dim = pp_proxy_tensors[key].shape[0]
|
|
self.pp_proxy_tensors[key][:dim].copy_(pp_proxy_tensors[key])
|
|
|
|
if self.is_encoder_decoder:
|
|
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
|
if forward_batch.mrope_positions is not None:
|
|
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
|
if self.require_gathered_buffer:
|
|
self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs)
|
|
self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs)
|
|
if enable_num_token_non_padded(self.model_runner.server_args):
|
|
num_token_non_padded = forward_batch.num_token_non_padded
|
|
if self.require_gathered_buffer:
|
|
tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs
|
|
num_local_token_non_padded = torch.clamp(
|
|
num_token_non_padded - tokens_per_rank * self.attn_tp_rank,
|
|
min=0,
|
|
max=tokens_per_rank,
|
|
)
|
|
self.num_token_non_padded.copy_(num_local_token_non_padded)
|
|
else:
|
|
self.num_token_non_padded.copy_(num_token_non_padded)
|
|
if self.enable_two_batch_overlap:
|
|
self.tbo_plugin.replay_prepare(
|
|
forward_mode=self.capture_forward_mode,
|
|
bs=bs,
|
|
num_token_non_padded=len(forward_batch.input_ids),
|
|
spec_info=forward_batch.spec_info,
|
|
)
|
|
if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None:
|
|
forward_batch.spec_info.custom_mask = self.custom_mask
|
|
# Attention backend
|
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
|
bs,
|
|
self.req_pool_indices[:bs],
|
|
self.seq_lens[:bs],
|
|
forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value,
|
|
self.encoder_lens[:bs] if self.is_encoder_decoder else None,
|
|
self.capture_forward_mode,
|
|
forward_batch.spec_info,
|
|
seq_lens_cpu=seq_lens_cpu,
|
|
)
|
|
|
|
# Store fields
|
|
self.raw_bs = raw_bs
|
|
self.raw_num_token = raw_num_token
|
|
self.bs = bs
|
|
|
|
def replay(
|
|
self,
|
|
forward_batch: ForwardBatch,
|
|
skip_attn_backend_init: bool = False,
|
|
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
) -> Union[LogitsProcessorOutput, PPProxyTensors]:
|
|
if not skip_attn_backend_init:
|
|
self.replay_prepare(forward_batch, pp_proxy_tensors)
|
|
else:
|
|
# In speculative decoding, these two fields are still needed.
|
|
self.input_ids[: self.raw_num_token].copy_(forward_batch.input_ids)
|
|
self.positions[: self.raw_num_token].copy_(forward_batch.positions)
|
|
|
|
# Replay
|
|
self.graphs[self.bs].replay()
|
|
|
|
output = self.output_buffers[self.bs]
|
|
if isinstance(output, LogitsProcessorOutput):
|
|
return LogitsProcessorOutput(
|
|
next_token_logits=output.next_token_logits[: self.raw_num_token],
|
|
hidden_states=(
|
|
output.hidden_states[: self.raw_num_token]
|
|
if output.hidden_states is not None
|
|
else None
|
|
),
|
|
)
|
|
else:
|
|
assert isinstance(output, PPProxyTensors)
|
|
return PPProxyTensors({k: v[: self.bs] for k, v in output.tensors.items()})
|
|
|
|
def get_spec_info(self, num_tokens: int):
|
|
spec_info = None
|
|
if self.model_runner.spec_algorithm.is_eagle():
|
|
from sglang.srt.speculative.eagle_utils import EagleVerifyInput
|
|
|
|
if self.model_runner.is_draft_worker:
|
|
raise RuntimeError("This should not happen.")
|
|
else:
|
|
spec_info = EagleVerifyInput(
|
|
draft_token=None,
|
|
custom_mask=self.custom_mask,
|
|
positions=None,
|
|
retrive_index=None,
|
|
retrive_next_token=None,
|
|
retrive_next_sibling=None,
|
|
retrive_cum_len=None,
|
|
spec_steps=self.model_runner.server_args.speculative_num_steps,
|
|
topk=self.model_runner.server_args.speculative_eagle_topk,
|
|
draft_token_num=self.model_runner.server_args.speculative_num_draft_tokens,
|
|
capture_hidden_mode=CaptureHiddenMode.FULL,
|
|
seq_lens_sum=None,
|
|
seq_lens_cpu=None,
|
|
)
|
|
|
|
return spec_info
|
|
|
|
|
|
CUDA_GRAPH_CAPTURE_FAILED_MSG = (
|
|
"Possible solutions:\n"
|
|
"1. set --mem-fraction-static to a smaller value (e.g., 0.8 or 0.7)\n"
|
|
"2. set --cuda-graph-max-bs to a smaller value (e.g., 16)\n"
|
|
"3. disable torch compile by not using --enable-torch-compile\n"
|
|
"4. disable CUDA graph by --disable-cuda-graph. (Not recommended. Huge performance loss)\n"
|
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
|
)
|