From 7e4129008208d76a69431b2b3d73b0fdb3823aab Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Thu, 29 May 2025 15:13:07 +0800 Subject: [PATCH] Add draft extend CUDA graph for Triton backend (#6705) --- .../srt/layers/attention/triton_backend.py | 46 +++++++++++++++++++ python/sglang/srt/speculative/eagle_worker.py | 6 ++- 2 files changed, 51 insertions(+), 1 deletion(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 2bedcf077..175d723b2 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -128,6 +128,7 @@ class TritonAttnBackend(AttentionBackend): ) self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens + self.speculative_num_steps = model_runner.server_args.speculative_num_steps self.num_head = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() @@ -424,6 +425,34 @@ class TritonAttnBackend(AttentionBackend): num_kv_splits = None attn_logits = None attn_lse = None + elif forward_mode.is_draft_extend(): + num_tokens_per_bs = self.speculative_num_steps + 1 + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[: bs + 1] = torch.arange( + 0, + bs * num_tokens_per_bs + 1, + step=num_tokens_per_bs, + dtype=torch.int32, + device=self.device, + ) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + custom_mask = None + mask_indptr = None + max_extend_len = num_tokens_per_bs + num_kv_splits = None + attn_logits = None + attn_lse = None else: raise ValueError( f"Invalid forward mode: {forward_mode=} for CUDA Graph capture." @@ -504,6 +533,23 @@ class TritonAttnBackend(AttentionBackend): seq_mask_len = self.num_draft_tokens * (seq_lens + self.num_draft_tokens) mask_indptr = self.mask_indptr[: bs + 1] mask_indptr[1 : bs + 1] = torch.cumsum(seq_mask_len, dim=0) + elif forward_mode.is_draft_extend(): + seq_lens = seq_lens[:bs] + accept_lens = spec_info.accept_length[:bs] + qo_indptr = self.qo_indptr[: bs + 1] + qo_indptr[1 : bs + 1] = torch.cumsum(accept_lens, dim=0) + kv_indptr = self.kv_indptr[: bs + 1] + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) else: raise ValueError( f"Invalid forward mode: {forward_mode=} for CUDA Graph replay." diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index 86a8df534..1c78714b7 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -179,6 +179,7 @@ class EAGLEWorker(TpModelWorker): self.has_prefill_wrapper_verify = True elif self.server_args.attention_backend == "triton": from sglang.srt.layers.attention.triton_backend import ( + TritonAttnBackend, TritonMultiStepDraftBackend, ) @@ -187,7 +188,10 @@ class EAGLEWorker(TpModelWorker): self.topk, self.speculative_num_steps, ) - self.draft_extend_attn_backend = None + self.draft_extend_attn_backend = TritonAttnBackend( + self.draft_model_runner, + skip_prefill=False, + ) self.padded_static_len = self.speculative_num_steps + 1 self.has_prefill_wrapper_verify = False elif self.server_args.attention_backend == "fa3":