diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index 303919505..537dab9eb 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -33,7 +33,11 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import ( set_graph_pool_id, ) from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture -from sglang.srt.layers.dp_attention import DPPaddingMode, get_attention_tp_size +from sglang.srt.layers.dp_attention import ( + DPPaddingMode, + get_attention_tp_rank, + get_attention_tp_size, +) from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.torchao_utils import save_gemlite_cache from sglang.srt.model_executor.forward_batch_info import ( @@ -255,6 +259,9 @@ class CudaGraphRunner: self.dp_size = model_runner.server_args.dp_size self.pp_size = model_runner.server_args.pp_size + self.attn_tp_size = get_attention_tp_size() + self.attn_tp_rank = get_attention_tp_rank() + # Batch sizes to capture self.capture_bs, self.compile_bs = get_batch_sizes_to_capture(model_runner) rank0_log(f"Capture cuda graph bs {self.capture_bs}") @@ -749,7 +756,17 @@ class CudaGraphRunner: self.global_num_tokens_gpu.fill_(bs * self.num_tokens_per_bs) self.global_num_tokens_for_logprob_gpu.fill_(bs * self.num_tokens_per_bs) if enable_num_token_non_padded(self.model_runner.server_args): - self.num_token_non_padded.copy_(forward_batch.num_token_non_padded) + num_token_non_padded = forward_batch.num_token_non_padded + if self.require_gathered_buffer: + tokens_per_rank = bs // self.attn_tp_size * self.num_tokens_per_bs + num_local_token_non_padded = torch.clamp( + num_token_non_padded - tokens_per_rank * self.attn_tp_rank, + min=0, + max=tokens_per_rank, + ) + self.num_token_non_padded.copy_(num_local_token_non_padded) + else: + self.num_token_non_padded.copy_(num_token_non_padded) if self.enable_two_batch_overlap: self.tbo_plugin.replay_prepare( forward_mode=self.capture_forward_mode,