From ac5010e0ba1491892ef5f5aeb603df98237ff7f9 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Sun, 22 Jun 2025 20:35:58 -0700 Subject: [PATCH] Fix CUDA Graph Check under Deepep with DP FFN (#7451) --- .../srt/model_executor/cuda_graph_runner.py | 24 ++++++++------- .../eagle_draft_cuda_graph_runner.py | 27 ++++++++--------- .../eagle_draft_extend_cuda_graph_runner.py | 29 +++++++++---------- 3 files changed, 40 insertions(+), 40 deletions(-) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 820459458..a51a06f09 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -48,6 +48,7 @@ from sglang.srt.utils import ( rank0_log, require_attn_tp_gather, require_gathered_buffer, + require_mlp_sync, require_mlp_tp_gather, ) @@ -212,6 +213,7 @@ class CudaGraphRunner: self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) + self.require_mlp_sync = require_mlp_sync(model_runner.server_args) self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) self.enable_two_batch_overlap = ( model_runner.server_args.enable_two_batch_overlap @@ -337,22 +339,22 @@ class CudaGraphRunner: def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: - total_batch_size = ( + cuda_graph_bs = ( sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() else sum(forward_batch.global_num_tokens_cpu) ) - is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( - total_batch_size in self.graphs - if self.disable_padding - else total_batch_size <= self.max_bs - ) else: - is_bs_supported = ( - forward_batch.batch_size in self.graphs - if self.disable_padding - else forward_batch.batch_size <= self.max_bs - ) + cuda_graph_bs = forward_batch.batch_size + + is_bs_supported = ( + cuda_graph_bs in self.graphs + if self.disable_padding + else cuda_graph_bs <= self.max_bs + ) + + if self.require_mlp_sync: + is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0) # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 7120b10ea..6b6c1a777 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -23,6 +23,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput from sglang.srt.utils import ( require_attn_tp_gather, require_gathered_buffer, + require_mlp_sync, require_mlp_tp_gather, ) @@ -46,6 +47,7 @@ class EAGLEDraftCudaGraphRunner: self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) + self.require_mlp_sync = require_mlp_sync(model_runner.server_args) self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) self.dp_size = self.model_runner.dp_size self.tp_size = self.model_runner.tp_size @@ -127,24 +129,23 @@ class EAGLEDraftCudaGraphRunner: def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: - if not forward_batch.can_run_dp_cuda_graph: - return False - total_batch_size = ( + cuda_graph_bs = ( sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() else sum(forward_batch.global_num_tokens_cpu) ) - is_bs_supported = ( - total_batch_size in self.graphs - if self.disable_padding - else total_batch_size <= self.max_bs - ) else: - is_bs_supported = ( - forward_batch.batch_size in self.graphs - if self.disable_padding - else forward_batch.batch_size <= self.max_bs - ) + cuda_graph_bs = forward_batch.batch_size + + is_bs_supported = ( + cuda_graph_bs in self.graphs + if self.disable_padding + else cuda_graph_bs <= self.max_bs + ) + + if self.require_mlp_sync: + is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph + return is_bs_supported def capture(self): diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index bd331a17a..b4ffde60e 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk from sglang.srt.utils import ( require_attn_tp_gather, require_gathered_buffer, + require_mlp_sync, require_mlp_tp_gather, ) @@ -42,6 +43,7 @@ class EAGLEDraftExtendCudaGraphRunner: self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) + self.require_mlp_sync = require_mlp_sync(model_runner.server_args) self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) self.tp_size = self.model_runner.tp_size self.dp_size = model_runner.server_args.dp_size @@ -130,29 +132,24 @@ class EAGLEDraftExtendCudaGraphRunner: def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: - if not forward_batch.can_run_dp_cuda_graph: - return False - total_batch_size = ( + cuda_graph_bs = ( sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() else sum(forward_batch.global_num_tokens_cpu) ) - is_bs_supported = ( - total_batch_size in self.graphs - if self.disable_padding - else total_batch_size <= self.max_bs - ) - return is_bs_supported else: - batch_size = forward_batch.seq_lens.numel() + cuda_graph_bs = forward_batch.seq_lens.numel() - is_bs_supported = ( - batch_size in self.graphs - if self.disable_padding - else batch_size <= self.max_bs - ) + is_bs_supported = ( + cuda_graph_bs in self.graphs + if self.disable_padding + else cuda_graph_bs <= self.max_bs + ) - return is_bs_supported + if self.require_mlp_sync: + is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph + + return is_bs_supported def capture(self): CudaGraphRunner.capture(self)