From 2fc1299562075a1ca0f6fe6a7ddd4203181e3fdb Mon Sep 17 00:00:00 2001 From: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com> Date: Sun, 8 Jun 2025 20:09:17 +0800 Subject: [PATCH] Remove unnecessary kernels of num_token_non_padded (#6965) --- .../srt/model_executor/cuda_graph_runner.py | 24 ++++++++------ .../srt/model_executor/forward_batch_info.py | 33 +++++++++++-------- python/sglang/srt/two_batch_overlap.py | 3 -- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index d443baa1c..36d3a1b25 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -35,6 +35,7 @@ from sglang.srt.model_executor.forward_batch_info import ( ForwardBatch, ForwardMode, PPProxyTensors, + enable_num_token_non_padded, ) from sglang.srt.patch_torch import monkey_patch_torch_compile from sglang.srt.two_batch_overlap import TboCudaGraphRunnerPlugin @@ -190,6 +191,9 @@ class CudaGraphRunner: self.is_encoder_decoder = model_runner.model_config.is_encoder_decoder self.enable_dp_attention = model_runner.server_args.enable_dp_attention self.enable_sp_layernorm = model_runner.server_args.enable_sp_layernorm + self.enable_two_batch_overlap = ( + model_runner.server_args.enable_two_batch_overlap + ) self.speculative_algorithm = model_runner.server_args.speculative_algorithm self.tp_size = model_runner.server_args.tp_size self.dp_size = model_runner.server_args.dp_size @@ -327,9 +331,7 @@ class CudaGraphRunner: ) is_tbo_supported = ( - forward_batch.can_run_tbo - if self.model_runner.server_args.enable_two_batch_overlap - else True + forward_batch.can_run_tbo if self.enable_two_batch_overlap else True ) return is_bs_supported and is_encoder_lens_supported and is_tbo_supported @@ -549,13 +551,7 @@ class CudaGraphRunner: self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens) self.out_cache_loc[:raw_num_token].copy_(forward_batch.out_cache_loc) self.positions[:raw_num_token].copy_(forward_batch.positions) - num_token_non_padded = len(forward_batch.input_ids) - self.num_token_non_padded[...] = num_token_non_padded - self.tbo_plugin.replay_prepare( - forward_mode=forward_batch.forward_mode, - bs=bs, - num_token_non_padded=num_token_non_padded, - ) + if forward_batch.seq_lens_cpu is not None: if bs != raw_bs: self.seq_lens_cpu.fill_(1) @@ -572,6 +568,14 @@ class CudaGraphRunner: self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions) if self.enable_dp_attention or self.enable_sp_layernorm: self.global_num_tokens_gpu.copy_(forward_batch.global_num_tokens_gpu) + if enable_num_token_non_padded(self.model_runner.server_args): + self.num_token_non_padded.copy_(forward_batch.num_token_non_padded) + if self.enable_two_batch_overlap: + self.tbo_plugin.replay_prepare( + forward_mode=forward_batch.forward_mode, + bs=bs, + num_token_non_padded=len(forward_batch.input_ids), + ) # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( diff --git a/python/sglang/srt/model_executor/forward_batch_info.py b/python/sglang/srt/model_executor/forward_batch_info.py index d068b44d2..2d0328e29 100644 --- a/python/sglang/srt/model_executor/forward_batch_info.py +++ b/python/sglang/srt/model_executor/forward_batch_info.py @@ -118,6 +118,7 @@ class ForwardMode(IntEnum): class CaptureHiddenMode(IntEnum): + # Do not capture anything. NULL = auto() # Capture hidden states of all tokens. FULL = auto() @@ -253,6 +254,7 @@ class ForwardBatch: # For Qwen2-VL mrope_positions: torch.Tensor = None + # For two-batch overlap tbo_split_seq_index: Optional[int] = None tbo_parent_token_range: Optional[Tuple[int, int]] = None tbo_children: Optional[List["ForwardBatch"]] = None @@ -265,12 +267,6 @@ class ForwardBatch: ): from sglang.srt.two_batch_overlap import TboForwardBatchPreparer - device = model_runner.device - extend_input_logprob_token_ids_gpu = None - if batch.extend_input_logprob_token_ids is not None: - extend_input_logprob_token_ids_gpu = ( - batch.extend_input_logprob_token_ids.to(device, non_blocking=True) - ) ret = cls( forward_mode=batch.forward_mode, batch_size=len(batch.seq_lens), @@ -284,6 +280,7 @@ class ForwardBatch: encoder_lens_cpu=batch.encoder_lens_cpu, encoder_out_cache_loc=batch.encoder_out_cache_loc, seq_lens_sum=batch.seq_lens_sum, + seq_lens_cpu=batch.seq_lens_cpu, return_logprob=batch.return_logprob, top_logprobs_nums=batch.top_logprobs_nums, token_ids_logprobs=batch.token_ids_logprobs, @@ -298,12 +295,19 @@ class ForwardBatch: spec_info=batch.spec_info, capture_hidden_mode=batch.capture_hidden_mode, input_embeds=batch.input_embeds, - extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu, - num_token_non_padded=torch.tensor( - len(batch.input_ids), dtype=torch.int32 - ).to(device, non_blocking=True), tbo_split_seq_index=batch.tbo_split_seq_index, ) + device = model_runner.device + + if batch.extend_input_logprob_token_ids is not None: + ret.extend_input_logprob_token_ids_gpu = ( + batch.extend_input_logprob_token_ids.to(device, non_blocking=True) + ) + + if enable_num_token_non_padded(model_runner.server_args): + ret.num_token_non_padded = torch.tensor( + len(batch.input_ids), dtype=torch.int32 + ).to(device, non_blocking=True) # For DP attention if batch.global_num_tokens is not None: @@ -323,6 +327,7 @@ class ForwardBatch: dtype=model_runner.dtype, device=device, ) + if ret.forward_mode.is_idle(): ret.positions = torch.empty((0,), device=device) TboForwardBatchPreparer.prepare(ret) @@ -335,10 +340,6 @@ class ForwardBatch: ): ret.positions = ret.spec_info.positions - # Get seq_lens_cpu if needed - if ret.seq_lens_cpu is None: - ret.seq_lens_cpu = batch.seq_lens_cpu - # Init position information if ret.forward_mode.is_decode(): if ret.positions is None: @@ -605,6 +606,10 @@ class ForwardBatch: return self.tbo_split_seq_index is not None +def enable_num_token_non_padded(server_args): + return server_args.enable_ep_moe or server_args.enable_deepep_moe + + class PPProxyTensors: # adapted from https://github.com/vllm-project/vllm/blob/d14e98d924724b284dc5eaf8070d935e214e50c0/vllm/sequence.py#L1103 tensors: Dict[str, torch.Tensor] diff --git a/python/sglang/srt/two_batch_overlap.py b/python/sglang/srt/two_batch_overlap.py index 814a7f95d..9e83e0ba5 100644 --- a/python/sglang/srt/two_batch_overlap.py +++ b/python/sglang/srt/two_batch_overlap.py @@ -131,9 +131,6 @@ class TboCudaGraphRunnerPlugin: def replay_prepare( self, forward_mode: ForwardMode, bs: int, num_token_non_padded: int ): - if not global_server_args_dict["enable_two_batch_overlap"]: - return - tbo_split_seq_index, tbo_split_token_index = ( compute_split_indices_for_cuda_graph_replay( forward_mode=forward_mode,