From 5a33c3aae72ae0c356bc113b328aa3a904a3365d Mon Sep 17 00:00:00 2001 From: Liangsheng Yin Date: Tue, 14 Oct 2025 20:08:32 +0800 Subject: [PATCH] Optimize Triton Draft Backend (#11556) --- .../srt/layers/attention/triton_backend.py | 77 ++++++++++++------- 1 file changed, 50 insertions(+), 27 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 71c034dd7..4fab75700 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING, Optional, Union +from typing import TYPE_CHECKING, List, Optional import torch import triton @@ -12,6 +12,7 @@ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_trito from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.layers.radix_attention import AttentionType from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode +from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices from sglang.srt.utils import ( get_bool_env_var, get_device_core_count, @@ -423,6 +424,7 @@ class TritonAttnBackend(AttentionBackend): max_bs: int, max_num_tokens: int, kv_indices_buf: Optional[torch.Tensor] = None, + cuda_graph_num_kv_splits_buf: Optional[torch.Tensor] = None, ): self.cuda_graph_attn_logits = torch.zeros( (max_num_tokens, self.num_head, self.max_kv_splits, self.v_head_dim), @@ -434,9 +436,17 @@ class TritonAttnBackend(AttentionBackend): dtype=torch.float32, device=self.device, ) - self.cuda_graph_num_kv_splits = torch.full( - (max_num_tokens,), self.max_kv_splits, dtype=torch.int32, device=self.device - ) + + if cuda_graph_num_kv_splits_buf is None: + self.cuda_graph_num_kv_splits = torch.full( + (max_num_tokens,), + self.max_kv_splits, + dtype=torch.int32, + device=self.device, + ) + else: + self.cuda_graph_num_kv_splits = cuda_graph_num_kv_splits_buf + if kv_indices_buf is None: self.cuda_graph_kv_indices = torch.zeros( (max_num_tokens * self.max_context_len), @@ -683,9 +693,7 @@ class TritonAttnBackend(AttentionBackend): ) else: - kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr - kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices - num_token = spec_info.kv_indptr.shape[0] - 1 + assert False, "Multi-step cuda graph init is not done here." self.get_num_kv_splits(num_kv_splits[:num_token], seq_lens[:bs]) elif forward_mode.is_target_verify(): @@ -898,11 +906,8 @@ class TritonMultiStepDraftBackend: topk: int, speculative_num_steps: int, ): - from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices - self.topk = topk self.speculative_num_steps = speculative_num_steps - self.generate_draft_decode_kv_indices = generate_draft_decode_kv_indices max_bs = model_runner.req_to_token_pool.size * self.topk self.kv_indptr = torch.zeros( ( @@ -912,7 +917,7 @@ class TritonMultiStepDraftBackend: dtype=torch.int32, device=model_runner.device, ) - self.attn_backends = [] + self.attn_backends: List[TritonAttnBackend] = [] for i in range(self.speculative_num_steps): self.attn_backends.append( TritonAttnBackend( @@ -931,13 +936,19 @@ class TritonMultiStepDraftBackend: self.page_size = model_runner.server_args.page_size def common_template( - self, forward_batch: ForwardBatch, kv_indices_buffer: torch.Tensor, call_fn: int + self, + forward_batch: ForwardBatch, + kv_indices_buffer: Optional[torch.Tensor], + call_fn: int, ): + if kv_indices_buffer is None: + kv_indices_buffer = self.cuda_graph_kv_indices + num_seqs = forward_batch.batch_size bs = self.topk * num_seqs seq_lens_sum = forward_batch.seq_lens_sum - self.generate_draft_decode_kv_indices[ + generate_draft_decode_kv_indices[ (self.speculative_num_steps, num_seqs, self.topk) ]( forward_batch.req_pool_indices, @@ -955,6 +966,9 @@ class TritonMultiStepDraftBackend: self.page_size, ) + if call_fn is None: + return + for i in range(self.speculative_num_steps): forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1] forward_batch.spec_info.kv_indices = kv_indices_buffer[i][ @@ -989,9 +1003,18 @@ class TritonMultiStepDraftBackend: dtype=torch.int64, device=self.device, ) + self.cuda_graph_num_kv_splits = torch.full( + (max_num_tokens,), + self.attn_backends[0].max_kv_splits, + dtype=torch.int32, + device=self.device, + ) for i in range(self.speculative_num_steps): self.attn_backends[i].init_cuda_graph_state( - max_bs, max_num_tokens, kv_indices_buf=self.cuda_graph_kv_indices[i] + max_bs, + max_num_tokens, + kv_indices_buf=self.cuda_graph_kv_indices[i], + cuda_graph_num_kv_splits_buf=self.cuda_graph_num_kv_splits, ) def init_forward_metadata_capture_cuda_graph(self, forward_batch: ForwardBatch): @@ -1006,24 +1029,24 @@ class TritonMultiStepDraftBackend: spec_info=forward_batch.spec_info, ) - self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) + self.common_template(forward_batch, None, call_fn) def init_forward_metadata_replay_cuda_graph( self, forward_batch: ForwardBatch, bs: int ): - def call_fn(i, forward_batch): - self.attn_backends[i].init_forward_metadata_replay_cuda_graph( - bs, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - seq_lens_sum=-1, - encoder_lens=None, - forward_mode=ForwardMode.DECODE, - spec_info=forward_batch.spec_info, - seq_lens_cpu=None, - ) + self.common_template(forward_batch, None, None) - self.common_template(forward_batch, self.cuda_graph_kv_indices, call_fn) + # NOTE: Multi-step's attention backends use the slice of + # - kv_indptr buffer (cuda graph and non-cuda graph) + # - kv_indices buffer (cuda graph only) + # So we don't need to assign the KV indices inside the attention backend. + + # Compute num_kv_splits only once + num_token = forward_batch.batch_size * self.topk + self.attn_backends[-1].get_num_kv_splits( + self.attn_backends[-1].cuda_graph_num_kv_splits[:num_token], + forward_batch.seq_lens[:bs], + ) @triton.jit