From ed0a0b692cf6bb91e97f623d3317b2d8f1c4792c Mon Sep 17 00:00:00 2001 From: u4lr451 Date: Tue, 24 Jun 2025 08:34:13 +0800 Subject: [PATCH] Perormance: Enable cuda graph for dp idle batch (#7269) Co-authored-by: austindeng Co-authored-by: Cheng Wan <54331508+ch-wan@users.noreply.github.com> Co-authored-by: ch-wan --- .../attention/flashattention_backend.py | 20 ++++++--- python/sglang/srt/managers/scheduler.py | 5 --- .../srt/model_executor/cuda_graph_runner.py | 43 +++++++++---------- python/sglang/srt/speculative/eagle_utils.py | 2 + python/sglang/srt/speculative/eagle_worker.py | 31 ++++++------- 5 files changed, 51 insertions(+), 50 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index 9871ca3a8..9cbdce35c 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -1704,14 +1704,15 @@ class FlashAttentionBackend(AttentionBackend): # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk) metadata_expand = self.target_verify_metadata_topk_expand[bs] + # metadata_expand.max_seq_len_q = 1, already set in capture # metadata_expand.cu_seqlens_q already set in capture - offsets = torch.arange( self.speculative_num_draft_tokens, device=device ).unsqueeze( 0 ) # shape: (1, self.speculative_num_draft_tokens) + cols = offsets.expand(seq_lens.numel(), -1) + seq_lens.unsqueeze(1) cum_len = torch.nn.functional.pad( torch.cumsum( @@ -1728,17 +1729,20 @@ class FlashAttentionBackend(AttentionBackend): ).view(1, -1) # avoid extracting padded seq indices which will be out of boundary mask_extraction_indices[ - :, spec_info.positions.numel() * self.speculative_num_draft_tokens : + :, + spec_info.positions.numel() * self.speculative_num_draft_tokens :, ].fill_(0) - mask = spec_info.custom_mask[mask_extraction_indices].view( -1, self.speculative_num_draft_tokens ) # (bsz * draft_num, draft_num) + col_indices = offsets.expand( mask.shape[0], self.speculative_num_draft_tokens ) keys = torch.where( - mask, col_indices, col_indices + self.speculative_num_draft_tokens + mask, + col_indices, + col_indices + self.speculative_num_draft_tokens, ) _, sort_order = torch.sort(keys, dim=1) @@ -1747,6 +1751,7 @@ class FlashAttentionBackend(AttentionBackend): .gather(1, cols) .repeat_interleave(self.speculative_num_draft_tokens, dim=0) ) # (bsz, draft_num) + metadata_expand.page_table.copy_( non_masked_page_table.gather(1, sort_order) ) @@ -1758,6 +1763,7 @@ class FlashAttentionBackend(AttentionBackend): dtype=torch.int32, ) ) + elif forward_mode.is_draft_extend(): metadata = self.draft_extend_metadata[bs] metadata.cache_seqlens_int32.copy_(seq_lens) @@ -1767,7 +1773,11 @@ class FlashAttentionBackend(AttentionBackend): torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32) ) accept_length = spec_info.accept_length[:bs] - metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1 + if spec_info.accept_length_cpu: + metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1 + else: + metadata.max_seq_len_q = 1 + metadata.cu_seqlens_q[1:].copy_( torch.cumsum(accept_length, dim=0, dtype=torch.int32) ) diff --git a/python/sglang/srt/managers/scheduler.py b/python/sglang/srt/managers/scheduler.py index 5079d27dd..50f029cd3 100644 --- a/python/sglang/srt/managers/scheduler.py +++ b/python/sglang/srt/managers/scheduler.py @@ -1821,11 +1821,6 @@ class Scheduler( else: can_cuda_graph = 0 - if not spec_algorithm.is_none(): - # TODO(sang): Support cuda graph when idle batch is there. - if local_batch is None or local_batch.forward_mode.is_idle(): - can_cuda_graph = 0 - is_extend_in_batch = ( local_batch.forward_mode.is_extend() if local_batch else False ) diff --git a/python/sglang/srt/model_executor/cuda_graph_runner.py b/python/sglang/srt/model_executor/cuda_graph_runner.py index a51a06f09..1c6720847 100644 --- a/python/sglang/srt/model_executor/cuda_graph_runner.py +++ b/python/sglang/srt/model_executor/cuda_graph_runner.py @@ -306,28 +306,30 @@ class CudaGraphRunner: self.encoder_lens = None if self.require_gathered_buffer: + self.gathered_buffer = torch.zeros( + ( + self.max_num_token, + self.model_runner.model_config.hidden_size, + ), + dtype=self.model_runner.dtype, + ) if self.require_mlp_tp_gather: - self.gathered_buffer = torch.zeros( - ( - self.max_bs * self.dp_size * self.num_tokens_per_bs, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) self.global_num_tokens_gpu = torch.zeros( (self.dp_size,), dtype=torch.int32 ) else: assert self.require_attn_tp_gather - self.gathered_buffer = torch.zeros( - ( - self.max_bs * self.num_tokens_per_bs, - self.model_runner.model_config.hidden_size, - ), - dtype=self.model_runner.dtype, - ) self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32) + self.custom_mask = torch.ones( + ( + (self.seq_lens.sum().item() + self.max_num_token) + * self.num_tokens_per_bs + ), + dtype=torch.bool, + device="cuda", + ) + # Capture try: with model_capture_mode(): @@ -674,11 +676,12 @@ class CudaGraphRunner: self.num_token_non_padded.copy_(forward_batch.num_token_non_padded) if self.enable_two_batch_overlap: self.tbo_plugin.replay_prepare( - forward_mode=forward_batch.forward_mode, + forward_mode=self.capture_forward_mode, bs=bs, num_token_non_padded=len(forward_batch.input_ids), ) - + if forward_batch.forward_mode.is_idle() and forward_batch.spec_info is not None: + forward_batch.spec_info.custom_mask = self.custom_mask # Attention backend self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph( bs, @@ -686,7 +689,7 @@ class CudaGraphRunner: self.seq_lens[:bs], forward_batch.seq_lens_sum + (bs - raw_bs) * self.seq_len_fill_value, self.encoder_lens[:bs] if self.is_encoder_decoder else None, - forward_batch.forward_mode, + self.capture_forward_mode, forward_batch.spec_info, seq_lens_cpu=self.seq_lens_cpu[:bs], ) @@ -736,11 +739,7 @@ class CudaGraphRunner: else: spec_info = EagleVerifyInput( draft_token=None, - custom_mask=torch.ones( - (num_tokens * self.model_runner.model_config.context_len), - dtype=torch.bool, - device="cuda", - ), + custom_mask=self.custom_mask, positions=None, retrive_index=None, retrive_next_token=None, diff --git a/python/sglang/srt/speculative/eagle_utils.py b/python/sglang/srt/speculative/eagle_utils.py index 1db3448a1..83724b385 100644 --- a/python/sglang/srt/speculative/eagle_utils.py +++ b/python/sglang/srt/speculative/eagle_utils.py @@ -99,6 +99,8 @@ class EagleDraftInput: topk_p=torch.empty((0, topk), device=device, dtype=torch.float32), topk_index=torch.empty((0, topk), device=device, dtype=torch.int64), capture_hidden_mode=capture_hidden_mode, + accept_length=torch.empty((0,), device=device, dtype=torch.int32), + accept_length_cpu=[], ) def prepare_extend_after_decode( diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index c9e57702d..effcbae4a 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -322,13 +322,11 @@ class EAGLEWorker(TpModelWorker): logits_output, verify_output, model_worker_batch, can_run_cuda_graph = ( self.verify(batch, spec_info) ) - need_forward, can_run_draft_extend_cuda_graph = ( - self.check_forward_draft_extend_after_decode(batch) - ) - if need_forward: + + if self.check_forward_draft_extend_after_decode(batch): with self.draft_tp_context(self.draft_model_runner.tp_group): self.forward_draft_extend_after_decode( - batch, can_run_draft_extend_cuda_graph + batch, ) return ( logits_output, @@ -344,7 +342,7 @@ class EAGLEWorker(TpModelWorker): and batch.spec_info.verified_id.shape[0] > 0 ) if not self.server_args.enable_dp_attention: - return local_need_forward, True + return local_need_forward global_need_forward = torch.tensor( [ @@ -357,10 +355,7 @@ class EAGLEWorker(TpModelWorker): ) global_need_forward_cnt = global_need_forward[0].item() need_forward = global_need_forward_cnt > 0 - can_run_draft_extend_cuda_graph = ( - global_need_forward_cnt == get_tensor_model_parallel_world_size() - ) - return need_forward, can_run_draft_extend_cuda_graph + return need_forward def forward_target_extend( self, batch: ScheduleBatch @@ -816,15 +811,12 @@ class EAGLEWorker(TpModelWorker): assert forward_batch.spec_info is batch.spec_info self.capture_for_decode(logits_output, forward_batch.spec_info) - def forward_draft_extend_after_decode( - self, batch: ScheduleBatch, can_run_draft_extend_cuda_graph: bool - ): + def forward_draft_extend_after_decode(self, batch: ScheduleBatch): # Backup fields that will be modified in-place seq_lens_backup = batch.seq_lens.clone() req_pool_indices_backup = batch.req_pool_indices accept_length_backup = batch.spec_info.accept_length return_logprob_backup = batch.return_logprob - input_is_idle = batch.forward_mode.is_idle() if not input_is_idle: # Prepare metadata @@ -836,14 +828,18 @@ class EAGLEWorker(TpModelWorker): else: batch = batch.copy() batch.prepare_for_idle() + hidden_size = ( + self.model_config.hidden_size * 3 + if self.speculative_algorithm.is_eagle3() + else self.model_config.hidden_size + ) batch.spec_info = EagleDraftInput.create_idle_input( device=self.device, - hidden_size=self.model_config.hidden_size, + hidden_size=hidden_size, dtype=self.model_config.dtype, topk=self.topk, capture_hidden_mode=CaptureHiddenMode.LAST, ) - batch.return_hidden_states = False model_worker_batch = batch.get_model_worker_batch() model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens @@ -858,8 +854,7 @@ class EAGLEWorker(TpModelWorker): # Run can_cuda_graph = ( - can_run_draft_extend_cuda_graph - and self.cuda_graph_runner_for_draft_extend + self.cuda_graph_runner_for_draft_extend and self.cuda_graph_runner_for_draft_extend.can_run(forward_batch) ) if can_cuda_graph: