Piecewise CUDA Graph Support & Torch Compile Backend (#10062)

Signed-off-by: Oasis-Git <ayw.sirius19@gmail.com>
This commit is contained in:
Yuwei An
2025-10-11 20:55:57 -07:00
committed by GitHub
parent 20a6c0a63d
commit 4ac8e09df0
21 changed files with 2706 additions and 19 deletions

View File

@@ -108,8 +108,15 @@ from sglang.srt.mem_cache.memory_pool import (
)
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
from sglang.srt.model_executor.forward_batch_info import (
ForwardBatch,
ForwardMode,
PPProxyTensors,
)
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
PiecewiseCudaGraphRunner,
)
from sglang.srt.model_loader import get_model
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
@@ -307,6 +314,26 @@ class ModelRunner:
self._model_update_group = {}
self._weights_send_group = {}
if (
self.server_args.enable_piecewise_cuda_graph
and self.can_run_piecewise_cuda_graph()
):
self.attention_layers = []
for layer in self.model.model.layers:
if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "attn"):
self.attention_layers.append(layer.self_attn.attn)
if len(self.attention_layers) < self.model_config.num_hidden_layers:
# TODO(yuwei): support Non-Standard GQA
log_info_on_rank0(
logger,
"Disable piecewise CUDA graph because some layers do not apply Standard GQA",
)
self.piecewise_cuda_graph_runner = None
else:
self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
else:
self.piecewise_cuda_graph_runner = None
def initialize(self, min_per_gpu_memory: float):
server_args = self.server_args
@@ -692,6 +719,7 @@ class ModelRunner:
pipeline_model_parallel_size=self.pp_size,
expert_model_parallel_size=self.moe_ep_size,
duplicate_tp_group=self.server_args.enable_pdmux,
torch_compile=self.server_args.enable_piecewise_cuda_graph,
)
initialize_dp_attention(
server_args=self.server_args,
@@ -1411,6 +1439,27 @@ class ModelRunner:
f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
)
def can_run_piecewise_cuda_graph(self):
if self.server_args.disable_cuda_graph:
log_info_on_rank0(
logger, "Disable piecewise CUDA graph because disable_cuda_graph is set"
)
return False
if self.server_args.enable_torch_compile:
log_info_on_rank0(
logger,
"Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile",
)
return False
if self.pp_size > 1:
# TODO(yuwei): support PP
log_info_on_rank0(
logger,
"Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP",
)
return False
return True
def init_memory_pool(
self,
total_gpu_memory: int,
@@ -1932,6 +1981,11 @@ class ModelRunner:
kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
if not self.is_generation:
kwargs["get_embedding"] = True
if self.piecewise_cuda_graph_runner is not None:
if self.piecewise_cuda_graph_runner.can_run(forward_batch):
return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs)
return self.model.forward(
forward_batch.input_ids,
forward_batch.positions,