From 5f5b3b2449f9bb55c545741fa8593bd9680c01c9 Mon Sep 17 00:00:00 2001 From: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Date: Tue, 12 Aug 2025 12:23:46 -0700 Subject: [PATCH] [5/n] DP Enhancement: Correct `num_token_non_padded` (#9107) --- .../srt/model_executor/cuda_graph_runner.py | 21 +++++++++++++++++-- 1 file changed, 19 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 303919505..537dab9eb 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -33,7 +33,11 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( set_graph_pool_id, ) from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture -from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size +from sglang.srt.layers.dp_attention import ( + DPPaddingMode, + get_attention_tp_rank, + get_attention_tp_size, +) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.model_executor.forward_batch_info import ( @@ -255,6 +259,9 @@ class CudaGraphRunner: self.dp_size = model_runner.server_args.dp_size self.pp_size = model_runner.server_args.pp_size + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() + # Batch sizes to capture self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) rank0_log(f"Capture cuda graph bs {self.capture_bs}") @@ -749,7 +756,17 @@ class CudaGraphRunner: self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) if enable_num_token_non_padded(self.model_runner.server_args): - self.num_token_non_padded.copy_(forward_batch.num_token_non_padded) + num_token_non_padded = forward_batch.num_token_non_padded + if self.require_gathered_buffer: + tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs + num_local_token_non_padded = torch.clamp( + num_token_non_padded - tokens_per_rank * self.attn_tp_rank, + min=0, + max=tokens_per_rank, + ) + self.num_token_non_padded.copy_(num_local_token_non_padded) + else: + self.num_token_non_padded.copy_(num_token_non_padded) if self.enable_two_batch_overlap: self.tbo_plugin.replay_prepare( forward_mode=self.capture_forward_mode,