diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index fade8ed29..c0f3bdb83 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -5,6 +5,9 @@ from typing import TYPE_CHECKING, Optional import torch from sglang.srt.layers.attention import AttentionBackend +from sglang.srt.layers.attention.flashinfer_backend import ( + create_flashinfer_kv_indices_triton, +) from sglang.srt.layers.dp_attention import get_attention_tp_size from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode @@ -29,6 +32,12 @@ class TritonAttnBackend(AttentionBackend): self.decode_attention_fwd = decode_attention_fwd self.extend_attention_fwd = extend_attention_fwd + max_bs = model_runner.req_to_token_pool.size + self.kv_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) + self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.num_head = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() ) @@ -58,11 +67,32 @@ class TritonAttnBackend(AttentionBackend): ) max_extend_len = None + + kv_indptr = self.kv_indptr + bs = len(forward_batch.req_pool_indices) + kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda" + ) + create_flashinfer_kv_indices_triton[(bs,)]( + forward_batch.req_to_token_pool.req_to_token, + forward_batch.req_pool_indices, + forward_batch.seq_lens, + kv_indptr, + None, + kv_indices, + forward_batch.req_to_token_pool.req_to_token.stride(0), + ) + else: attn_logits = None max_extend_len = torch.max(forward_batch.extend_seq_lens).item() - self.forward_metadata = attn_logits, max_extend_len + kv_indptr = None + kv_indices = None + + self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len @@ -73,7 +103,12 @@ class TritonAttnBackend(AttentionBackend): self.cuda_graph_attn_logits = torch.empty( (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1), dtype=torch.float32, - device="cuda", + device=self.device, + ) + self.cuda_graph_kv_indices = torch.zeros( + (max_bs * self.cuda_graph_max_seq_len), + dtype=torch.int32, + device=self.device, ) def init_forward_metadata_capture_cuda_graph( @@ -90,9 +125,25 @@ class TritonAttnBackend(AttentionBackend): assert forward_mode.is_decode(), "Not supported" assert spec_info is None, "Not supported" + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices, + seq_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + self.forward_metadata = ( self.cuda_graph_attn_logits, None, + kv_indptr, + kv_indices, ) def init_forward_metadata_replay_cuda_graph( @@ -109,6 +160,20 @@ class TritonAttnBackend(AttentionBackend): self.cuda_graph_start_loc.zero_() self.cuda_graph_start_loc[1:bs] = torch.cumsum(seq_lens[: bs - 1], dim=0) + kv_indptr = self.kv_indptr + kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = self.cuda_graph_kv_indices + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + req_pool_indices[:bs], + seq_lens[:bs], + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + def get_cuda_graph_seq_len_fill_value(self): return 1 @@ -132,7 +197,7 @@ class TritonAttnBackend(AttentionBackend): layer, forward_batch.out_cache_loc, k, v ) - _, max_extend_len = self.forward_metadata + _, max_extend_len, _, _ = self.forward_metadata self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -170,7 +235,7 @@ class TritonAttnBackend(AttentionBackend): else: o = torch.empty_like(q) - attn_logits, _ = self.forward_metadata + attn_logits, _, kv_indptr, kv_indices = self.forward_metadata if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( @@ -182,9 +247,8 @@ class TritonAttnBackend(AttentionBackend): forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), o.view(-1, layer.tp_q_head_num, layer.v_head_dim), - forward_batch.req_to_token_pool.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, + kv_indptr, + kv_indices, attn_logits, self.num_kv_splits, layer.scaling, diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 512900bd3..f2274322c 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -49,11 +49,9 @@ def _fwd_kernel_stage1( K_Buffer, V_Buffer, sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, Att_Out, - stride_req_to_tokens_b, stride_qbs, stride_qh, stride_buf_kbs, @@ -82,8 +80,9 @@ def _fwd_kernel_stage1( offs_dv = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lk mask_dv = offs_dv < Lv - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d q = tl.load(Q + off_q, mask=mask_d, other=0.0) @@ -100,7 +99,7 @@ def _fwd_kernel_stage1( for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + kv_indices + cur_batch_kv_start_idx + offs_n, mask=offs_n < split_kv_end, other=0, ) @@ -173,9 +172,8 @@ def _decode_att_m_fwd( k_buffer, v_buffer, att_out, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, num_kv_splits, sm_scale, logit_cap, @@ -188,7 +186,7 @@ def _decode_att_m_fwd( Lk = k_buffer.shape[-1] Lv = v_buffer.shape[-1] - batch, head_num = B_req_idx.shape[0], q.shape[1] + batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] grid = (batch, head_num, NUM_KV_SPLITS) kv_group_num = q.shape[1] // k_buffer.shape[1] @@ -208,11 +206,9 @@ def _decode_att_m_fwd( k_buffer, v_buffer, sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, att_out, - Req_to_tokens.stride(0), q.stride(0), q.stride(1), k_buffer.stride(0), @@ -241,11 +237,9 @@ def _fwd_grouped_kernel_stage1( K_Buffer, V_Buffer, sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, Att_Out, - stride_req_to_tokens_b, stride_qbs, stride_qh, stride_buf_kbs, @@ -284,8 +278,9 @@ def _fwd_grouped_kernel_stage1( offs_dv = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lk mask_dv = offs_dv < Lv - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + cur_batch_kv_start_idx = tl.load(kv_indptr + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - cur_batch_kv_start_idx offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) @@ -312,7 +307,7 @@ def _fwd_grouped_kernel_stage1( for start_n in range(split_kv_start, split_kv_end, BLOCK_N): offs_n = start_n + tl.arange(0, BLOCK_N) kv_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + kv_indices + cur_batch_kv_start_idx + offs_n, mask=offs_n < split_kv_end, other=0, ) @@ -400,9 +395,8 @@ def _decode_grouped_att_m_fwd( k_buffer, v_buffer, att_out, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, num_kv_splits, sm_scale, logit_cap, @@ -426,7 +420,7 @@ def _decode_grouped_att_m_fwd( BLOCK_DPE = 0 BLOCK_DV = triton.next_power_of_2(Lv) - batch, head_num = B_req_idx.shape[0], q.shape[1] + batch, head_num = kv_indptr.shape[0] - 1, q.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[1] BLOCK_H = 16 @@ -450,11 +444,9 @@ def _decode_grouped_att_m_fwd( k_buffer, v_buffer, sm_scale, - Req_to_tokens, - B_req_idx, - B_Seqlen, + kv_indptr, + kv_indices, att_out, - Req_to_tokens.stride(0), q.stride(0), q.stride(1), k_buffer.stride(0), @@ -485,7 +477,7 @@ def _decode_grouped_att_m_fwd( def _fwd_kernel_stage2( Mid_O, O, - B_Seqlen, + kv_indptr, stride_mid_ob, stride_mid_oh, stride_mid_os, @@ -498,7 +490,9 @@ def _fwd_kernel_stage2( cur_batch = tl.program_id(0) cur_head = tl.program_id(1) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_seq_len = tl.load(kv_indptr + cur_batch + 1) - tl.load( + kv_indptr + cur_batch + ) offs_d = tl.arange(0, BLOCK_DV) mask_d = offs_d < Lv @@ -542,7 +536,7 @@ def _decode_softmax_reducev_fwd( q, o, v_buffer, - b_seq_len, + kv_indptr, num_kv_splits, ): batch, head_num = q.shape[0], q.shape[1] @@ -561,7 +555,7 @@ def _decode_softmax_reducev_fwd( _fwd_kernel_stage2[grid]( logits, o, - b_seq_len, + kv_indptr, logits.stride(0), logits.stride(1), logits.stride(2), @@ -581,9 +575,8 @@ def decode_attention_fwd_normal( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -594,14 +587,13 @@ def decode_attention_fwd_normal( k_buffer, v_buffer, attn_logits, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, num_kv_splits, sm_scale, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits) def decode_attention_fwd_grouped( @@ -609,9 +601,8 @@ def decode_attention_fwd_grouped( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -622,14 +613,13 @@ def decode_attention_fwd_grouped( k_buffer, v_buffer, attn_logits, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, num_kv_splits, sm_scale, logit_cap, ) - _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, b_seq_len, num_kv_splits) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, kv_indptr, num_kv_splits) def decode_attention_fwd( @@ -637,9 +627,8 @@ def decode_attention_fwd( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -655,9 +644,8 @@ def decode_attention_fwd( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -670,9 +658,8 @@ def decode_attention_fwd( k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 2398af9b0..52a20771b 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -194,10 +194,12 @@ class TestTritonAttention(unittest.TestCase): # o will have the same shape as q o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") - req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) - b_req_idx = torch.arange(B, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0) + kv_indices = torch.arange(total_tokens, device="cuda") + attn_logits = torch.empty( (B, H_Q, num_kv_splits, D + 1), dtype=torch.float32, @@ -209,9 +211,8 @@ class TestTritonAttention(unittest.TestCase): k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -250,10 +251,12 @@ class TestTritonAttention(unittest.TestCase): o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") - req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) - b_req_idx = torch.arange(B, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len[:B], dim=0) + kv_indices = torch.arange(total_tokens, device="cuda") + attn_logits = torch.empty( (B, H_Q, num_kv_splits, D_V + 1), dtype=torch.float32, @@ -265,9 +268,8 @@ class TestTritonAttention(unittest.TestCase): k_buffer, v_buffer, o, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits, num_kv_splits, sm_scale, @@ -284,9 +286,8 @@ class TestTritonAttention(unittest.TestCase): k_buffer, v_buffer, o_grouped, - req_to_token, - b_req_idx, - b_seq_len, + kv_indptr, + kv_indices, attn_logits1, num_kv_splits, sm_scale,