From c5f865013e729a6449384c595492018041e9fb64 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Sat, 23 Nov 2024 16:51:46 +0800 Subject: [PATCH] Fix grid size in Triton decoding kernel (#2134) --- .../attention/triton_ops/decode_attention.py | 72 +++++++++---------- 1 file changed, 34 insertions(+), 38 deletions(-) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 5ce03e49c..8fa0cb4b0 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -50,12 +50,13 @@ def _fwd_kernel_stage1( kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + SPLIT_K: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) - start_n = tl.program_id(2) + split_k_id = tl.program_id(2) reduce_dtype = Att_Out.dtype.element_ty cur_kv_head = cur_head // kv_group_num @@ -65,22 +66,18 @@ def _fwd_kernel_stage1( cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_index = 0 - cur_batch_end_index = cur_batch_seq_len - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + q = tl.load(Q + off_q).to(reduce_dtype) - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K) + split_k_start = kv_len_per_split * split_k_id + split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len) - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark).to(reduce_dtype) - offs_n_new = cur_batch_start_index + offs_n + for start_n in range(split_k_start, split_k_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, - mask=offs_n_new < cur_batch_end_index, + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + mask=offs_n < split_k_end, other=0, ) offs_buf_k = ( @@ -90,7 +87,7 @@ def _fwd_kernel_stage1( ) k = tl.load( K_Buffer + offs_buf_k, - mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk), + mask=(offs_n[:, None] < split_k_end) & (offs_d[None, :] < Lk), other=0.0, ).to(reduce_dtype) att_value = tl.sum(q[None, :] * k, 1) @@ -100,7 +97,7 @@ def _fwd_kernel_stage1( att_value = logit_cap * tanh(att_value / logit_cap) off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) - tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) + tl.store(Att_Out + off_o, att_value, mask=offs_n < split_k_end) @triton.jit @@ -189,11 +186,12 @@ def _decode_att_m_fwd( logit_cap, ): BLOCK = 32 + SPLIT_K = 8 Lk = k_buffer.shape[-1] batch, head_num = B_req_idx.shape[0], q.shape[1] - grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK)) + grid = (batch, head_num, SPLIT_K) kv_group_num = q.shape[1] // k_buffer.shape[1] if kv_group_num == 1: @@ -221,6 +219,7 @@ def _decode_att_m_fwd( kv_group_num=kv_group_num, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK, + SPLIT_K=SPLIT_K, logit_cap=logit_cap, num_warps=num_warps, num_stages=1, @@ -292,13 +291,14 @@ def _fwd_grouped_kernel_stage1( BLOCK_DPE: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, + SPLIT_K: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head_id = tl.program_id(1) cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) - start_n = tl.program_id(2) + split_k_id = tl.program_id(2) reduce_dtype = Att_Out.dtype.element_ty @@ -315,30 +315,27 @@ def _fwd_grouped_kernel_stage1( cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_index = 0 - cur_batch_end_index = cur_batch_seq_len - offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] + q = tl.load( + Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0 + ).to(reduce_dtype) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) off_qpe = ( cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] ) + qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0).to(reduce_dtype) - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) + kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K) + split_k_start = kv_len_per_split * split_k_id + split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len) - block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) - - for start_mark in range(0, block_mask, 1): - q = tl.load( - Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk) - ).to(reduce_dtype) - offs_n_new = cur_batch_start_index + offs_n + for start_n in range(split_k_start, split_k_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, - mask=offs_n_new < cur_batch_end_index, + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + mask=offs_n < split_k_end, other=0, ) offs_buf_k = ( @@ -348,14 +345,11 @@ def _fwd_grouped_kernel_stage1( ) k = tl.load( K_Buffer + offs_buf_k, - mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk), + mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk), other=0.0, ).to(reduce_dtype) qk = tl.dot(q, k) if BLOCK_DPE > 0: - qpe = tl.load(Q + off_qpe + start_mark, mask=mask_h[:, None]).to( - reduce_dtype - ) offs_buf_kpe = ( k_loc[None, :] * stride_buf_kbs + cur_kv_head * stride_buf_kh @@ -363,7 +357,7 @@ def _fwd_grouped_kernel_stage1( ) kpe = tl.load( K_Buffer + offs_buf_kpe, - mask=offs_n_new[None, :] < cur_batch_end_index, + mask=offs_n[None, :] < split_k_end, other=0.0, ).to(reduce_dtype) qk += tl.dot(qpe, kpe) @@ -379,7 +373,7 @@ def _fwd_grouped_kernel_stage1( tl.store( Att_Out + offs_o, qk, - mask=mask_h[:, None] & (offs_n_new[None, :] < cur_batch_end_index), + mask=mask_h[:, None] & (offs_n[None, :] < split_k_end), ) @@ -497,10 +491,11 @@ def _decode_grouped_att_m_fwd( kv_group_num = q.shape[1] // k_buffer.shape[1] BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) + SPLIT_K = 8 grid = ( batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), - triton.cdiv(max_len_in_batch, BLOCK), + SPLIT_K, ) num_warps = 4 @@ -532,6 +527,7 @@ def _decode_grouped_att_m_fwd( BLOCK_DPE=BLOCK_DPE, BLOCK_N=BLOCK, BLOCK_H=BLOCK_H, + SPLIT_K=SPLIT_K, logit_cap=logit_cap, num_warps=num_warps, num_stages=1,