diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 820459458..a51a06f09 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -48,6 +48,7 @@ from sglang.srt.utils import ( rank0_log, require_attn_tp_gather, require_gathered_buffer, + require_mlp_sync, require_mlp_tp_gather, ) @@ -212,6 +213,7 @@ class CudaGraphRunner: self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) + self.require_mlp_sync = require_mlp_sync(model_runner.server_args) self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) self.enable_two_batch_overlap = ( model_runner.server_args.enable_two_batch_overlap @@ -337,22 +339,22 @@ class CudaGraphRunner: def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: - total_batch_size = ( + cuda_graph_bs = ( sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() else sum(forward_batch.global_num_tokens_cpu) ) - is_bs_supported = forward_batch.can_run_dp_cuda_graph and ( - total_batch_size in self.graphs - if self.disable_padding - else total_batch_size <= self.max_bs - ) else: - is_bs_supported = ( - forward_batch.batch_size in self.graphs - if self.disable_padding - else forward_batch.batch_size <= self.max_bs - ) + cuda_graph_bs = forward_batch.batch_size + + is_bs_supported = ( + cuda_graph_bs in self.graphs + if self.disable_padding + else cuda_graph_bs <= self.max_bs + ) + + if self.require_mlp_sync: + is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph # NOTE: cuda graph cannot handle mixed batch (encoder_len = 0) # If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph diff --git a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py index 7120b10ea..6b6c1a777 100644 --- a/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_cuda_graph_runner.py @@ -23,6 +23,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput from sglang.srt.utils import ( require_attn_tp_gather, require_gathered_buffer, + require_mlp_sync, require_mlp_tp_gather, ) @@ -46,6 +47,7 @@ class EAGLEDraftCudaGraphRunner: self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) + self.require_mlp_sync = require_mlp_sync(model_runner.server_args) self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) self.dp_size = self.model_runner.dp_size self.tp_size = self.model_runner.tp_size @@ -127,24 +129,23 @@ class EAGLEDraftCudaGraphRunner: def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: - if not forward_batch.can_run_dp_cuda_graph: - return False - total_batch_size = ( + cuda_graph_bs = ( sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() else sum(forward_batch.global_num_tokens_cpu) ) - is_bs_supported = ( - total_batch_size in self.graphs - if self.disable_padding - else total_batch_size <= self.max_bs - ) else: - is_bs_supported = ( - forward_batch.batch_size in self.graphs - if self.disable_padding - else forward_batch.batch_size <= self.max_bs - ) + cuda_graph_bs = forward_batch.batch_size + + is_bs_supported = ( + cuda_graph_bs in self.graphs + if self.disable_padding + else cuda_graph_bs <= self.max_bs + ) + + if self.require_mlp_sync: + is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph + return is_bs_supported def capture(self): diff --git a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py index bd331a17a..b4ffde60e 100644 --- a/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +++ b/python/sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py @@ -24,6 +24,7 @@ from sglang.srt.speculative.eagle_utils import EagleDraftInput, fast_topk from sglang.srt.utils import ( require_attn_tp_gather, require_gathered_buffer, + require_mlp_sync, require_mlp_tp_gather, ) @@ -42,6 +43,7 @@ class EAGLEDraftExtendCudaGraphRunner: self.disable_padding = model_runner.server_args.disable_cuda_graph_padding self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args) self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args) + self.require_mlp_sync = require_mlp_sync(model_runner.server_args) self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args) self.tp_size = self.model_runner.tp_size self.dp_size = model_runner.server_args.dp_size @@ -130,29 +132,24 @@ class EAGLEDraftExtendCudaGraphRunner: def can_run(self, forward_batch: ForwardBatch): if self.require_mlp_tp_gather: - if not forward_batch.can_run_dp_cuda_graph: - return False - total_batch_size = ( + cuda_graph_bs = ( sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs if self.model_runner.spec_algorithm.is_eagle() else sum(forward_batch.global_num_tokens_cpu) ) - is_bs_supported = ( - total_batch_size in self.graphs - if self.disable_padding - else total_batch_size <= self.max_bs - ) - return is_bs_supported else: - batch_size = forward_batch.seq_lens.numel() + cuda_graph_bs = forward_batch.seq_lens.numel() - is_bs_supported = ( - batch_size in self.graphs - if self.disable_padding - else batch_size <= self.max_bs - ) + is_bs_supported = ( + cuda_graph_bs in self.graphs + if self.disable_padding + else cuda_graph_bs <= self.max_bs + ) - return is_bs_supported + if self.require_mlp_sync: + is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph + + return is_bs_supported def capture(self): CudaGraphRunner.capture(self)