From 8e6bdf851c4aa6619baa584fc450af748720319d Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Mon, 9 Sep 2024 01:30:24 -0700 Subject: [PATCH] [triton] Support head_dim not 2^n in triton extend and decode attention (#1281) --- python/sglang/srt/layers/decode_attention.py | 50 ++++++++++++------ python/sglang/srt/layers/extend_attention.py | 51 +++++++++++++------ python/sglang/srt/layers/prefill_attention.py | 20 +++++--- 3 files changed, 84 insertions(+), 37 deletions(-) diff --git a/python/sglang/srt/layers/decode_attention.py b/python/sglang/srt/layers/decode_attention.py index dc92a6548..9c9822b85 100644 --- a/python/sglang/srt/layers/decode_attention.py +++ b/python/sglang/srt/layers/decode_attention.py @@ -60,6 +60,7 @@ def _fwd_kernel_stage1( BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, logit_cap: tl.constexpr, + Lk: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -97,7 +98,7 @@ def _fwd_kernel_stage1( ) k = tl.load( K_Buffer + offs_buf_k, - mask=offs_n_new[:, None] < cur_batch_end_index, + mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < Lk), other=0.0, ).to(REDUCE_TRITON_TYPE) att_value = tl.sum(q[None, :] * k, 1) @@ -128,6 +129,7 @@ def _fwd_kernel_stage2( kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -170,14 +172,16 @@ def _fwd_kernel_stage2( old_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max) e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs) + v = tl.load( + v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + ) acc = acc * old_scale + tl.sum(p[:, None] * v, 0) e_max = n_e_max acc = acc / e_sum off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d out_ptrs = Out + off_o - tl.store(out_ptrs, acc) + tl.store(out_ptrs, acc, mask=(offs_d < Lv)) def _decode_att_m_fwd( @@ -196,7 +200,7 @@ def _decode_att_m_fwd( # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 96, 128, 256} batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -208,6 +212,8 @@ def _decode_att_m_fwd( else: num_warps = 2 + BLOCK_DMODEL = triton.next_power_of_2(Lk) + _fwd_kernel_stage1[grid]( q, k_buffer, @@ -224,11 +230,12 @@ def _decode_att_m_fwd( k_buffer.stride(1), att_out.stride(0), kv_group_num=kv_group_num, - BLOCK_DMODEL=Lk, + BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK, logit_cap=logit_cap, num_warps=num_warps, num_stages=1, + Lk=Lk, ) @@ -248,6 +255,9 @@ def _decode_softmax_reducev_fwd( num_warps = 1 + Lv = v_buffer.shape[-1] + BLOCK_DMODEL = triton.next_power_of_2(Lv) + _fwd_kernel_stage2[grid]( logics, v_buffer, @@ -263,10 +273,11 @@ def _decode_softmax_reducev_fwd( o.stride(1), req_to_tokens.stride(0), kv_group_num=kv_group_num, - BLOCK_DMODEL=v_buffer.shape[-1], + BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK, num_warps=num_warps, num_stages=3, + Lv=Lv, ) @@ -293,6 +304,7 @@ def _fwd_grouped_kernel_stage1( BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, logit_cap: tl.constexpr, + Lk: tl.constexpr, ): cur_batch = tl.program_id(0) cur_kv_head = tl.program_id(1) @@ -324,9 +336,9 @@ def _fwd_grouped_kernel_stage1( block_mask = tl.where(block_stard_index < cur_batch_seq_len, 1, 0) for start_mark in range(0, block_mask, 1): - q = tl.load(Q + offs_q + start_mark, mask=mask_h[:, None]).to( - REDUCE_TRITON_TYPE - ) + q = tl.load( + Q + offs_q + start_mark, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk) + ).to(REDUCE_TRITON_TYPE) offs_n_new = cur_batch_start_index + offs_n k_loc = tl.load( Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, @@ -340,7 +352,7 @@ def _fwd_grouped_kernel_stage1( ) k = tl.load( K_Buffer + offs_buf_k, - mask=offs_n_new[None, :] < cur_batch_end_index, + mask=(offs_n_new[None, :] < cur_batch_end_index) & (offs_d[:, None] < Lk), other=0.0, ).to(REDUCE_TRITON_TYPE) qk = tl.dot(q, k) @@ -395,6 +407,7 @@ def _fwd_grouped_kernel_stage2( BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, + Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_kv_head = tl.program_id(1) @@ -441,7 +454,9 @@ def _fwd_grouped_kernel_stage2( old_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max[:, None]) e_sum = e_sum * old_scale + tl.sum(p, 1) - v = tl.load(v_ptrs + v_index[:, None] * stride_buf_vbs) + v = tl.load( + v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + ) p = p.to(v.dtype) acc = acc * old_scale[:, None] + tl.dot(p, v) e_max = n_e_max @@ -449,7 +464,7 @@ def _fwd_grouped_kernel_stage2( acc = acc / e_sum[:, None] off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :] out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=mask_h[:, None]) + tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv)) def _decode_grouped_att_m_fwd( @@ -468,13 +483,13 @@ def _decode_grouped_att_m_fwd( # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] assert Lq == Lk - assert Lk in {16, 32, 64, 128, 256, 576} + assert Lk in {16, 32, 64, 96, 128, 256, 576} if Lk == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 else: - BLOCK_DMODEL = Lk + BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DPE = 0 batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -513,6 +528,7 @@ def _decode_grouped_att_m_fwd( logit_cap=logit_cap, num_warps=num_warps, num_stages=1, + Lk=Lk, ) @@ -533,6 +549,9 @@ def _decode_grouped_softmax_reducev_fwd( num_warps = 8 + Lv = v_buffer.shape[-1] + BLOCK_DMODEL = triton.next_power_of_2(Lv) + _fwd_grouped_kernel_stage2[grid]( logics, v_buffer, @@ -549,11 +568,12 @@ def _decode_grouped_softmax_reducev_fwd( req_to_tokens.stride(0), kv_group_num=kv_group_num, q_head_num=head_num, - BLOCK_DMODEL=v_buffer.shape[-1], + BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_N=BLOCK, BLOCK_H=BLOCK_H, num_warps=num_warps, num_stages=1, + Lv=Lv, ) diff --git a/python/sglang/srt/layers/extend_attention.py b/python/sglang/srt/layers/extend_attention.py index 6c7686971..888062285 100644 --- a/python/sglang/srt/layers/extend_attention.py +++ b/python/sglang/srt/layers/extend_attention.py @@ -15,7 +15,7 @@ limitations under the License. """ Memory-efficient attention for prefill. -It supporst page size = 1 and prefill with KV cache (i.e. extend). +It supports page size = 1 and prefill with KV cache (i.e. extend). """ import torch @@ -67,6 +67,8 @@ def _fwd_kernel( BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, logit_cap: tl.constexpr, + Lq: tl.constexpr, + Lv: tl.constexpr, ): cur_seq = tl.program_id(0) cur_head = tl.program_id(1) @@ -86,13 +88,18 @@ def _fwd_kernel( offs_m = tl.arange(0, BLOCK_M) mask_m = (cur_block_m * BLOCK_M + offs_m) < cur_seq_len_extend + mask_d = offs_d < Lq + mask_dv = offs_dv < Lv + offs_q = ( (cur_seq_extend_start_contiguous + cur_block_m * BLOCK_M + offs_m[:, None]) * stride_qbs + cur_head * stride_qh + offs_d[None, :] ) - q = tl.load(Q_Extend + offs_q, mask=mask_m[:, None], other=0.0) + q = tl.load( + Q_Extend + offs_q, mask=(mask_m[:, None]) & (mask_d[None, :]), other=0.0 + ) if BLOCK_DPE > 0: offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) @@ -125,7 +132,9 @@ def _fwd_kernel( + cur_kv_head * stride_buf_kh + offs_d[:, None] ) - k = tl.load(K_Buffer + offs_buf_k, mask=mask_n[None, :], other=0.0) + k = tl.load( + K_Buffer + offs_buf_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) qk = tl.dot(q.to(k.dtype), k) if BLOCK_DPE > 0: @@ -157,7 +166,9 @@ def _fwd_kernel( + cur_kv_head * stride_buf_vh + offs_dv[None, :] ) - v = tl.load(V_Buffer + offs_buf_v, mask=mask_n[:, None], other=0.0) + v = tl.load( + V_Buffer + offs_buf_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) p = p.to(v.dtype) acc = acc * re_scale[:, None] + tl.dot(p, v) @@ -176,7 +187,9 @@ def _fwd_kernel( + cur_kv_head * stride_kh + offs_d[:, None] ) - k = tl.load(K_Extend + offs_k, mask=mask_n[None, :], other=0.0) + k = tl.load( + K_Extend + offs_k, mask=(mask_n[None, :]) & (mask_d[:, None]), other=0.0 + ) qk = tl.dot(q, k, out_dtype=tl.float32) if BLOCK_DPE > 0: @@ -214,7 +227,9 @@ def _fwd_kernel( + cur_kv_head * stride_vh + offs_dv[None, :] ) - v = tl.load(V_Extend + offs_v, mask=mask_n[:, None], other=0.0) + v = tl.load( + V_Extend + offs_v, mask=mask_n[:, None] & mask_dv[None, :], other=0.0 + ) p = p.to(v.dtype) acc = acc * re_scale[:, None] + tl.dot(p, v) @@ -226,7 +241,9 @@ def _fwd_kernel( + cur_head * stride_oh + offs_dv[None, :] ) - tl.store(O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None]) + tl.store( + O_Extend + offs_o, acc / deno[:, None], mask=mask_m[:, None] & mask_dv[None, :] + ) def extend_attention_fwd( @@ -261,16 +278,18 @@ def extend_attention_fwd( ) assert Lq == Lk and Lv == Lo - assert Lq in {16, 32, 64, 128, 256, 576} - assert Lv in {16, 32, 64, 128, 256, 512} + + # TODO: is the assertion necessary? + assert Lq in {16, 32, 64, 96, 128, 256, 576} + assert Lv in {16, 32, 64, 96, 128, 256, 512} if Lq == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 else: - BLOCK_DMODEL = Lq + BLOCK_DMODEL = triton.next_power_of_2(Lq) BLOCK_DPE = 0 - BLOCK_DV = Lv + BLOCK_DV = triton.next_power_of_2(Lv) if CUDA_CAPABILITY[0] >= 9: if Lq <= 256: @@ -330,6 +349,8 @@ def extend_attention_fwd( num_warps=num_warps, num_stages=num_stages, logit_cap=logit_cap, + Lq=Lq, + Lv=Lv, ) @@ -373,10 +394,7 @@ def redundant_attention( pt += cur_seq_len_extend -def test(): - torch.manual_seed(0) - - B, N_CTX, H_Q, H_KV, D = 19, 12331, 12, 4, 128 +def test_once(B, N_CTX, H_Q, H_KV, D): dtype = torch.float16 b_seq_len_prefix = torch.randint( @@ -473,4 +491,5 @@ def test(): if __name__ == "__main__": - test() + test_once(19, 12331, 12, 4, 128) + test_once(19, 12331, 12, 4, 96) diff --git a/python/sglang/srt/layers/prefill_attention.py b/python/sglang/srt/layers/prefill_attention.py index 99343a4df..fbf9976fb 100644 --- a/python/sglang/srt/layers/prefill_attention.py +++ b/python/sglang/srt/layers/prefill_attention.py @@ -48,6 +48,7 @@ def _fwd_kernel( BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, + Lk: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) @@ -72,7 +73,11 @@ def _fwd_kernel( off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] - q = tl.load(Q + off_q, mask=offs_m[:, None] < cur_batch_seq_len, other=0.0) + mask_d = offs_d < Lk + + q = tl.load( + Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d), other=0.0 + ) k_ptrs = K + off_k v_ptrs = V + off_v @@ -89,7 +94,7 @@ def _fwd_kernel( # -- compute qk ---- k = tl.load( k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, - mask=(start_n + offs_n[None, :]) < cur_batch_seq_len, + mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (mask_d[:, None]), other=0.0, ) # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) @@ -118,7 +123,7 @@ def _fwd_kernel( # update acc v = tl.load( v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, - mask=(start_n + offs_n[:, None]) < cur_batch_seq_len, + mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (mask_d[None, :]), other=0.0, ) @@ -134,7 +139,9 @@ def _fwd_kernel( + offs_d[None, :] ) out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=offs_m[:, None] < cur_batch_seq_len) + tl.store( + out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (mask_d[None, :]) + ) def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): @@ -145,7 +152,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 128, 256} + assert Lk in {16, 32, 64, 96, 128, 256} sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] @@ -172,8 +179,9 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): o.stride(1), kv_group_num=kv_group_num, BLOCK_M=BLOCK, - BLOCK_DMODEL=Lk, + BLOCK_DMODEL=triton.next_power_of_2(Lk), BLOCK_N=BLOCK, num_warps=num_warps, num_stages=1, + Lk=Lk, )