Fix CUDA Graph Check under Deepep with DP FFN (#7451)

This commit is contained in:
Cheng Wan
2025-06-22 20:35:58 -07:00
committed by GitHub
parent 3cee035e99
commit ac5010e0ba
3 changed files with 40 additions and 40 deletions

View File

@@ -48,6 +48,7 @@ from sglang.srt.utils import (
rank0_log,
require_attn_tp_gather,
require_gathered_buffer,
require_mlp_sync,
require_mlp_tp_gather,
)
@@ -212,6 +213,7 @@ class CudaGraphRunner:
self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder
self.require_gathered_buffer = require_gathered_buffer(model_runner.server_args)
self.require_mlp_tp_gather = require_mlp_tp_gather(model_runner.server_args)
self.require_mlp_sync = require_mlp_sync(model_runner.server_args)
self.require_attn_tp_gather = require_attn_tp_gather(model_runner.server_args)
self.enable_two_batch_overlap = (
model_runner.server_args.enable_two_batch_overlap
@@ -337,22 +339,22 @@ class CudaGraphRunner:
def can_run(self, forward_batch: ForwardBatch):
if self.require_mlp_tp_gather:
total_batch_size = (
cuda_graph_bs = (
sum(forward_batch.global_num_tokens_cpu) // self.num_tokens_per_bs
if self.model_runner.spec_algorithm.is_eagle()
else sum(forward_batch.global_num_tokens_cpu)
)
is_bs_supported = forward_batch.can_run_dp_cuda_graph and (
total_batch_size in self.graphs
if self.disable_padding
else total_batch_size <= self.max_bs
)
else:
is_bs_supported = (
forward_batch.batch_size in self.graphs
if self.disable_padding
else forward_batch.batch_size <= self.max_bs
)
cuda_graph_bs = forward_batch.batch_size
is_bs_supported = (
cuda_graph_bs in self.graphs
if self.disable_padding
else cuda_graph_bs <= self.max_bs
)
if self.require_mlp_sync:
is_bs_supported = is_bs_supported and forward_batch.can_run_dp_cuda_graph
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph