From de5533341ee3c1b7667b1eb1f209b6825335d136 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Wed, 5 Feb 2025 18:12:22 +0800 Subject: [PATCH] Update Triton extend backend interface (#3309) --- .../attention/double_sparsity_backend.py | 4 +- .../srt/layers/attention/triton_backend.py | 68 +++- .../triton_ops/double_sparsity_attention.py | 340 +++++++++++++++++- .../attention/triton_ops/extend_attention.py | 57 ++- test/srt/test_triton_attention_kernels.py | 27 +- 5 files changed, 427 insertions(+), 69 deletions(-) diff --git a/python/sglang/srt/layers/attention/double_sparsity_backend.py b/python/sglang/srt/layers/attention/double_sparsity_backend.py index a5e54f32d..c807e8753 100644 --- a/python/sglang/srt/layers/attention/double_sparsity_backend.py +++ b/python/sglang/srt/layers/attention/double_sparsity_backend.py @@ -17,12 +17,10 @@ class DoubleSparseAttnBackend(AttentionBackend): def __init__(self, model_runner: ModelRunner): # Lazy import to avoid the initialization of cuda context from sglang.srt.layers.attention.triton_ops.double_sparsity_attention import ( + extend_attention_fwd, flash_decode_attention_fwd, flash_decode_sparse_attention_fwd, ) - from sglang.srt.layers.attention.triton_ops.extend_attention import ( - extend_attention_fwd, - ) super().__init__() diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index c0f3bdb83..3475df721 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -37,6 +37,9 @@ class TritonAttnBackend(AttentionBackend): (max_bs + 1,), dtype=torch.int32, device=model_runner.device ) self.req_to_token = model_runner.req_to_token_pool.req_to_token + self.qo_indptr = torch.zeros( + (max_bs + 1,), dtype=torch.int32, device=model_runner.device + ) self.num_head = ( model_runner.model_config.num_attention_heads // get_attention_tp_size() @@ -54,6 +57,9 @@ class TritonAttnBackend(AttentionBackend): def init_forward_metadata(self, forward_batch: ForwardBatch): """Init auxiliary variables for triton attention backend.""" + bs = forward_batch.batch_size + kv_indptr = self.kv_indptr + if forward_batch.forward_mode.is_decode(): attn_logits = torch.empty( ( @@ -68,31 +74,59 @@ class TritonAttnBackend(AttentionBackend): max_extend_len = None - kv_indptr = self.kv_indptr - bs = len(forward_batch.req_pool_indices) kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0) kv_indptr = kv_indptr[: bs + 1] kv_indices = torch.empty( - forward_batch.seq_lens_sum, dtype=torch.int32, device="cuda" + forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device ) create_flashinfer_kv_indices_triton[(bs,)]( - forward_batch.req_to_token_pool.req_to_token, + self.req_to_token, forward_batch.req_pool_indices, forward_batch.seq_lens, kv_indptr, None, kv_indices, - forward_batch.req_to_token_pool.req_to_token.stride(0), + self.req_to_token.stride(0), ) + qo_indptr = None + custom_mask = None else: + kv_indptr[1 : bs + 1] = torch.cumsum( + forward_batch.extend_prefix_lens, dim=0 + ) + kv_indptr = kv_indptr[: bs + 1] + kv_indices = torch.empty( + forward_batch.extend_prefix_lens.sum().item(), + dtype=torch.int32, + device=self.device, + ) + create_flashinfer_kv_indices_triton[(bs,)]( + self.req_to_token, + forward_batch.req_pool_indices, + forward_batch.extend_prefix_lens, + kv_indptr, + None, + kv_indices, + self.req_to_token.stride(0), + ) + + qo_indptr = self.qo_indptr + qo_indptr[1 : bs + 1] = torch.cumsum(forward_batch.extend_seq_lens, dim=0) + qo_indptr = qo_indptr[: bs + 1] + custom_mask = None + attn_logits = None max_extend_len = torch.max(forward_batch.extend_seq_lens).item() - kv_indptr = None - kv_indices = None - - self.forward_metadata = attn_logits, max_extend_len, kv_indptr, kv_indices + self.forward_metadata = ( + attn_logits, + max_extend_len, + kv_indptr, + kv_indices, + qo_indptr, + custom_mask, + ) def init_cuda_graph_state(self, max_bs: int): self.cuda_graph_max_total_num_tokens = max_bs * self.cuda_graph_max_seq_len @@ -144,6 +178,8 @@ class TritonAttnBackend(AttentionBackend): None, kv_indptr, kv_indices, + None, + None, ) def init_forward_metadata_replay_cuda_graph( @@ -197,7 +233,9 @@ class TritonAttnBackend(AttentionBackend): layer, forward_batch.out_cache_loc, k, v ) - _, max_extend_len, _, _ = self.forward_metadata + _, max_extend_len, kv_indptr, kv_indices, qo_indptr, custom_mask = ( + self.forward_metadata + ) self.extend_attention_fwd( q.view(-1, layer.tp_q_head_num, layer.qk_head_dim), k.contiguous(), @@ -205,11 +243,9 @@ class TritonAttnBackend(AttentionBackend): o.view(-1, layer.tp_q_head_num, layer.v_head_dim), forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id), forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id), - forward_batch.req_to_token_pool.req_to_token, - forward_batch.req_pool_indices, - forward_batch.seq_lens, - forward_batch.extend_seq_lens, - forward_batch.extend_start_loc, + qo_indptr, + kv_indptr, + kv_indices, max_extend_len, layer.scaling, layer.logit_cap, @@ -235,7 +271,7 @@ class TritonAttnBackend(AttentionBackend): else: o = torch.empty_like(q) - attn_logits, _, kv_indptr, kv_indices = self.forward_metadata + attn_logits, _, kv_indptr, kv_indices, _, _ = self.forward_metadata if save_kv_cache: forward_batch.token_to_kv_pool.set_kv_buffer( diff --git a/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py b/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py index 79e148e9c..db0fb6b4d 100644 --- a/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py @@ -3,6 +3,13 @@ import triton import triton.language as tl from sglang.srt.managers.schedule_batch import global_server_args_dict +from sglang.srt.utils import is_hip + +is_cuda_available = torch.cuda.is_available() +if is_cuda_available: + CUDA_CAPABILITY = torch.cuda.get_device_capability() + +is_hip_ = is_hip() if global_server_args_dict.get("attention_reduce_in_fp32", False): REDUCE_TRITON_TYPE = tl.float32 @@ -274,9 +281,6 @@ def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, O, block_seq): return -import torch - - def flash_decode_attention_fwd( q, k_buffer, @@ -770,3 +774,333 @@ def flash_decode_sparse_attention_fwd( ) sparse_flash_decode_stage3(heavy_token_num, mid_out, mid_o_logexpsum, o, BLOCK_SEQ) + + +# Extend attention kernel for Double Sparsity +# Moved from https://github.com/sgl-project/sglang/blob/v0.4.2.post1/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +@triton.jit +def _fwd_kernel( + Q_Extend, + K_Extend, + V_Extend, + O_Extend, + K_Buffer, + V_Buffer, + Req_to_tokens, + B_req_idx, + B_Seq_Len, + B_Start_Loc_Extend, + B_Seq_Len_Extend, + sm_scale, + kv_group_num, + stride_qbs, + stride_qh, + stride_kbs, + stride_kh, + stride_vbs, + stride_vh, + stride_obs, + stride_oh, + stride_buf_kbs, + stride_buf_kh, + stride_buf_vbs, + stride_buf_vh, + stride_req_to_tokens_b, + logit_cap: tl.constexpr, + Lq: tl.constexpr, + Lv: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_seq = tl.program_id(0) + cur_head = tl.program_id(1) + cur_block_m = tl.program_id(2) + cur_kv_head = cur_head // kv_group_num + + cur_seq_len = tl.load(B_Seq_Len + cur_seq) + cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq) + cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend + + cur_seq_prefix_start_in_loc = 0 + cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq) + cur_batch_req_idx = tl.load(B_req_idx + cur_seq) + + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + offs_m = tl.arange(0, BLOCK_M) + mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend + + mask_d = offs_d < Lq + mask_dv = offs_dv < Lv + + offs_q = ( + (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] + ) + q = tl.load( + Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0 + ) + + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + offs_qpe = ( + (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_qbs + + cur_head * stride_qh + + offs_dpe[None, :] + ) + qpe = tl.load(Q_Extend + offs_qpe, mask=mask_m[:, None], other=0.0) + + # stage 1: compute scores with prefix + offs_n = tl.arange(0, BLOCK_N) + + acc = tl.zeros([BLOCK_M, BLOCK_DV], dtype=tl.float32) + deno = tl.zeros([BLOCK_M], dtype=tl.float32) + e_max = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + + for start_n in range(0, cur_seq_len_prefix, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + mask_n = (start_n + offs_n) < cur_seq_len_prefix + offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + ( + cur_seq_prefix_start_in_loc + start_n + offs_n + ) + offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0) + + # load k in transposed way + offs_buf_k = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) + k = tl.load( + K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) + + qk = tl.dot(q.to(k.dtype), k) + if BLOCK_DPE > 0: + offs_kpe = ( + offs_kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe.to(kpe.dtype), kpe) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(mask_m[:, None] & mask_n[None, :], qk, float("-inf")) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + deno = deno * re_scale + tl.sum(p, 1) + + offs_buf_v = ( + offs_kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) + p = p.to(v.dtype) + acc = acc * re_scale[:, None] + tl.dot(p, v) + + e_max = n_e_max + + # stage 2: compute the trianlge part + + cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M) + for start_n in range(0, cur_block_m_end, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + mask_n = (start_n + offs_n) < cur_block_m_end + + # load k in transposed way + offs_k = ( + (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs + + cur_kv_head * stride_kh + + offs_d[:, None] + ) + k = tl.load( + K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) + + qk = tl.dot(q, k, out_dtype=tl.float32) + if BLOCK_DPE > 0: + offs_kpe = ( + (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) + * stride_kbs + + cur_kv_head * stride_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Extend + offs_kpe, + mask=mask_n[None, :], + other=0.0, + ) + qk += tl.dot(qpe, kpe) + + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= ( + start_n + offs_n[None, :] + ) + mask_causual &= mask_m[:, None] & mask_n[None, :] + qk = tl.where(mask_causual, qk, float("-inf")) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + deno = deno * re_scale + tl.sum(p, 1) + + offs_v = ( + (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs + + cur_kv_head * stride_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) + p = p.to(v.dtype) + acc = acc * re_scale[:, None] + tl.dot(p, v) + + e_max = n_e_max + + offs_o = ( + (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + * stride_obs + + cur_head * stride_oh + + offs_dv[None, :] + ) + tl.store( + O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :] + ) + + +def extend_attention_fwd( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_seq_len, + b_seq_len_extend, + b_start_loc_extend, + max_len_extend, + sm_scale=None, + logit_cap=0.0, +): + """ + q_extend, k_extend, v_extend, o_extend: contiguous tensors + + k_buffer, v_buffer: (prefix + extend) tensors in mem_manager + """ + Lq, Lk, Lv = ( + q_extend.shape[-1], + k_extend.shape[-1], + v_extend.shape[-1], + ) + + if Lq == 576: + BLOCK_DMODEL = 512 + BLOCK_DPE = 64 + elif Lq == 288: + BLOCK_DMODEL = 256 + BLOCK_DPE = 32 + elif Lq == 192: + BLOCK_DMODEL = 128 + BLOCK_DPE = 64 + else: + BLOCK_DMODEL = triton.next_power_of_2(Lq) + BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) + + if is_hip_: + BLOCK_M, BLOCK_N = (64, 64) + num_warps = 4 + + else: + if is_cuda_available and CUDA_CAPABILITY[0] >= 9: + if Lq <= 256: + BLOCK_M, BLOCK_N = (128, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) + elif is_cuda_available and CUDA_CAPABILITY[0] >= 8: + if Lq <= 128: + BLOCK_M, BLOCK_N = (128, 128) + elif Lq <= 256: + BLOCK_M, BLOCK_N = (64, 64) + else: + BLOCK_M, BLOCK_N = (32, 64) + else: + BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32) + + num_warps = 4 if Lk <= 64 else 8 + + sm_scale = sm_scale or 1.0 / (Lq**0.5) + batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] + kv_group_num = q_extend.shape[1] // k_extend.shape[1] + + grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) + num_stages = 1 + + extra_kargs = {} + if is_hip_: + extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} + + _fwd_kernel[grid]( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_seq_len, + b_start_loc_extend, + b_seq_len_extend, + sm_scale, + kv_group_num, + q_extend.stride(0), + q_extend.stride(1), + k_extend.stride(0), + k_extend.stride(1), + v_extend.stride(0), + v_extend.stride(1), + o_extend.stride(0), + o_extend.stride(1), + k_buffer.stride(0), + k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), + req_to_tokens.stride(0), + logit_cap=logit_cap, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, + BLOCK_M=BLOCK_M, + BLOCK_N=BLOCK_N, + Lq=Lq, + Lv=Lv, + num_warps=num_warps, + num_stages=num_stages, + **extra_kargs, + ) diff --git a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py index b2654f1f7..6c9976931 100644 --- a/python/sglang/srt/layers/attention/triton_ops/extend_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/extend_attention.py @@ -46,11 +46,9 @@ def _fwd_kernel( O_Extend, K_Buffer, V_Buffer, - Req_to_tokens, - B_req_idx, - B_Seq_Len, - B_Start_Loc_Extend, - B_Seq_Len_Extend, + qo_indptr, + kv_indptr, + kv_indices, sm_scale, kv_group_num, stride_qbs, @@ -65,7 +63,6 @@ def _fwd_kernel( stride_buf_kh, stride_buf_vbs, stride_buf_vh, - stride_req_to_tokens_b, logit_cap: tl.constexpr, Lq: tl.constexpr, Lv: tl.constexpr, @@ -80,13 +77,10 @@ def _fwd_kernel( cur_block_m = tl.program_id(2) cur_kv_head = cur_head // kv_group_num - cur_seq_len = tl.load(B_Seq_Len + cur_seq) - cur_seq_len_extend = tl.load(B_Seq_Len_Extend + cur_seq) - cur_seq_len_prefix = cur_seq_len - cur_seq_len_extend - - cur_seq_prefix_start_in_loc = 0 - cur_seq_extend_start_contiguous = tl.load(B_Start_Loc_Extend + cur_seq) - cur_batch_req_idx = tl.load(B_req_idx + cur_seq) + cur_seq_extend_start_idx = tl.load(qo_indptr + cur_seq) + cur_seq_len_extend = tl.load(qo_indptr + cur_seq + 1) - cur_seq_extend_start_idx + cur_seq_kv_start_idx = tl.load(kv_indptr + cur_seq) + cur_seq_len_prefix = tl.load(kv_indptr + cur_seq + 1) - cur_seq_kv_start_idx offs_d = tl.arange(0, BLOCK_DMODEL) offs_dv = tl.arange(0, BLOCK_DV) @@ -97,7 +91,7 @@ def _fwd_kernel( mask_dv = offs_dv < Lv offs_q = ( - (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] @@ -109,7 +103,7 @@ def _fwd_kernel( if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) offs_qpe = ( - (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_dpe[None, :] @@ -126,10 +120,9 @@ def _fwd_kernel( for start_n in range(0, cur_seq_len_prefix, BLOCK_N): start_n = tl.multiple_of(start_n, BLOCK_N) mask_n = (start_n + offs_n) < cur_seq_len_prefix - offs_b_loc_prefix = cur_batch_req_idx * stride_req_to_tokens_b + ( - cur_seq_prefix_start_in_loc + start_n + offs_n + offs_kv_loc = tl.load( + kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0 ) - offs_kv_loc = tl.load(Req_to_tokens + offs_b_loc_prefix, mask=mask_n, other=0) # load k in transposed way offs_buf_k = ( @@ -188,7 +181,7 @@ def _fwd_kernel( # load k in transposed way offs_k = ( - (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) * stride_kbs + (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] ) @@ -199,8 +192,7 @@ def _fwd_kernel( qk = tl.dot(q, k, out_dtype=tl.float32) if BLOCK_DPE > 0: offs_kpe = ( - (cur_seq_extend_start_contiguous + start_n + offs_n[None, :]) - * stride_kbs + (cur_seq_extend_start_idx + start_n + offs_n[None, :]) * stride_kbs + cur_kv_head * stride_kh + offs_dpe[:, None] ) @@ -228,7 +220,7 @@ def _fwd_kernel( deno = deno * re_scale + tl.sum(p, 1) offs_v = ( - (cur_seq_extend_start_contiguous + start_n + offs_n[:, None]) * stride_vbs + (cur_seq_extend_start_idx + start_n + offs_n[:, None]) * stride_vbs + cur_kv_head * stride_vh + offs_dv[None, :] ) @@ -241,7 +233,7 @@ def _fwd_kernel( e_max = n_e_max offs_o = ( - (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) + (cur_seq_extend_start_idx + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_obs + cur_head * stride_oh + offs_dv[None, :] @@ -258,11 +250,9 @@ def extend_attention_fwd( o_extend, k_buffer, v_buffer, - req_to_tokens, - b_req_idx, - b_seq_len, - b_seq_len_extend, - b_start_loc_extend, + qo_indptr, + kv_indptr, + kv_indices, max_len_extend, sm_scale=None, logit_cap=0.0, @@ -315,7 +305,7 @@ def extend_attention_fwd( num_warps = 4 if Lk <= 64 else 8 sm_scale = sm_scale or 1.0 / (Lq**0.5) - batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1] + batch_size, head_num = qo_indptr.shape[0] - 1, q_extend.shape[1] kv_group_num = q_extend.shape[1] // k_extend.shape[1] grid = (batch_size, head_num, triton.cdiv(max_len_extend, BLOCK_M)) @@ -332,11 +322,9 @@ def extend_attention_fwd( o_extend, k_buffer, v_buffer, - req_to_tokens, - b_req_idx, - b_seq_len, - b_start_loc_extend, - b_seq_len_extend, + qo_indptr, + kv_indptr, + kv_indices, sm_scale, kv_group_num, q_extend.stride(0), @@ -351,7 +339,6 @@ def extend_attention_fwd( k_buffer.stride(1), v_buffer.stride(0), v_buffer.stride(1), - req_to_tokens.stride(0), logit_cap=logit_cap, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 52a20771b..3617e17be 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -45,16 +45,20 @@ class TestTritonAttention(unittest.TestCase): max_len_in_batch = torch.max(b_seq_len, 0)[0].item() b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda") - req_to_tokens = torch.empty( - (B, max_len_in_batch), dtype=torch.int32, device="cuda" - ) b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) + kv_indices = torch.zeros( + (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda" + ) + for i in range(B): - req_to_tokens[i, : b_seq_len[i]] = torch.arange( - b_start_loc[i], b_start_loc[i] + b_seq_len[i] + kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange( + b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i] ) total_token_num = torch.sum(b_seq_len).item() @@ -90,9 +94,10 @@ class TestTritonAttention(unittest.TestCase): ) b_seq_len_extend = b_seq_len - b_seq_len_prefix - b_start_loc_extend = torch.zeros_like(b_seq_len) - b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() + qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) + extend_attention_fwd( q_extend, k_extend, @@ -100,11 +105,9 @@ class TestTritonAttention(unittest.TestCase): o_extend, k_buffer, v_buffer, - req_to_tokens, - b_req_idx, - b_seq_len, - b_seq_len_extend, - b_start_loc_extend, + qo_indptr, + kv_indptr, + kv_indices, max_len_extend, )