diff --git a/python/sglang/srt/layers/attention/triton_backend.py b/python/sglang/srt/layers/attention/triton_backend.py index 1b7c4c46d..1ea193ae7 100644 --- a/python/sglang/srt/layers/attention/triton_backend.py +++ b/python/sglang/srt/layers/attention/triton_backend.py @@ -40,6 +40,9 @@ class TritonAttnBackend(AttentionBackend): else: self.reduce_dtype = torch.float16 + self.num_kv_splits = model_runner.server_args.triton_attention_num_kv_splits + self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1] + self.forward_metadata = None self.cuda_graph_max_seq_len = model_runner.model_config.context_len @@ -53,10 +56,14 @@ class TritonAttnBackend(AttentionBackend): start_loc = torch.zeros_like(forward_batch.seq_lens, dtype=torch.int32) start_loc[1:] = torch.cumsum(forward_batch.seq_lens[:-1], dim=0) - total_num_tokens = forward_batch.seq_lens_sum attn_logits = torch.empty( - (self.num_head, total_num_tokens), - dtype=self.reduce_dtype, + ( + forward_batch.batch_size, + self.num_head, + self.num_kv_splits, + self.v_head_dim + 1, + ), + dtype=torch.float32, device=self.device, ) @@ -75,11 +82,8 @@ class TritonAttnBackend(AttentionBackend): (max_bs,), dtype=torch.int32, device=self.device ) self.cuda_graph_attn_logits = torch.empty( - ( - self.num_head, - self.cuda_graph_max_total_num_tokens, - ), - dtype=self.reduce_dtype, + (max_bs, self.num_head, self.num_kv_splits, self.v_head_dim + 1), + dtype=torch.float32, device="cuda", ) @@ -189,6 +193,7 @@ class TritonAttnBackend(AttentionBackend): forward_batch.seq_lens, attn_logits, max_seq_len, + self.num_kv_splits, layer.scaling, layer.logit_cap, ) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 56d38693f..9eeb98a29 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -17,8 +17,8 @@ It supports page size = 1. """ # Adapted from -# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_nopad_att1.py -# https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/token_attention_softmax_and_reducev.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage1.py +# https://github.com/ModelTC/lightllm/blob/96353e868a840db4d103138caf15ed9dbea8c186/lightllm/models/deepseek2/triton_kernel/gqa_flash_decoding_stage2.py import triton import triton.language as tl @@ -37,10 +37,10 @@ def tanh(x): def _fwd_kernel_stage1( Q, K_Buffer, + V_Buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, Att_Out, stride_req_to_tokens_b, @@ -48,152 +48,137 @@ def _fwd_kernel_stage1( stride_qh, stride_buf_kbs, stride_buf_kh, - att_stride_h, - kv_group_num: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - SPLIT_K: tl.constexpr, - logit_cap: tl.constexpr, - Lk: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head = tl.program_id(1) - split_k_id = tl.program_id(2) - - reduce_dtype = Att_Out.dtype.element_ty - cur_kv_head = cur_head // kv_group_num - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d - q = tl.load(Q + off_q).to(reduce_dtype) - - kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K) - split_k_start = kv_len_per_split * split_k_id - split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len) - - for start_n in range(split_k_start, split_k_end, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, - mask=offs_n < split_k_end, - other=0, - ) - offs_buf_k = ( - k_loc[:, None] * stride_buf_kbs - + cur_kv_head * stride_buf_kh - + offs_d[None, :] - ) - k = tl.load( - K_Buffer + offs_buf_k, - mask=(offs_n[:, None] < split_k_end) & (offs_d[None, :] < Lk), - other=0.0, - ).to(reduce_dtype) - att_value = tl.sum(q[None, :] * k, 1) - att_value *= sm_scale - - if logit_cap > 0: - att_value = logit_cap * tanh(att_value / logit_cap) - - off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) - tl.store(Att_Out + off_o, att_value, mask=offs_n < split_k_end) - - -@triton.jit -def _fwd_kernel_stage2( - logits, - V_Buffer, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - stride_logic_h, stride_buf_vbs, stride_buf_vh, - stride_obs, - stride_oh, - stride_req_to_token_b, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, kv_group_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, + logit_cap: tl.constexpr, + Lk: tl.constexpr, Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) + split_kv_id = tl.program_id(2) cur_kv_head = cur_head // kv_group_num + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + q = tl.load(Q + off_q, mask=mask_d, other=0.0) - offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :] - v_ptrs = V_Buffer + offs_buf_v + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) - e_max = float("-inf") + e_max = -float("inf") e_sum = 0.0 - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - v_index = tl.load( - Req_to_tokens - + cur_batch_req_idx * stride_req_to_token_b - + (start_n + offs_n), - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0, + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + offs_buf_k = ( + kv_loc[:, None] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[None, :] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[:, None] < split_kv_end) & (mask_d[None, :]), + other=0.0, + ) + qk = tl.sum(q[None, :] * k, 1) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where(offs_n < split_kv_end, qk, float("-inf")) + + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max) + acc *= re_scale + acc += tl.sum(p[:, None] * v, 0) + + e_sum = e_sum * re_scale + tl.sum(p, 0) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv ) - qk = tl.load( - logits - + cur_head * stride_logic_h - + (cur_batch_start_loc + start_n + offs_n), - mask=start_n + offs_n < cur_batch_seq_len, - other=float("-inf"), + tl.store( + Att_Out + offs_mid_o, + acc / e_sum, + mask=(mask_dv), ) - n_e_max = tl.maximum(tl.max(qk, 0), e_max) - old_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max) - e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load( - v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv ) - acc = acc * old_scale + tl.sum(p[:, None] * v, 0) - e_max = n_e_max - acc = acc / e_sum - off_o = cur_batch * stride_obs + cur_head * stride_oh + offs_d - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=(offs_d < Lv)) + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + ) def _decode_att_m_fwd( q, k_buffer, + v_buffer, att_out, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ): - BLOCK = 32 - SPLIT_K = 8 + BLOCK = 64 + NUM_KV_SPLITS = num_kv_splits Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] batch, head_num = B_req_idx.shape[0], q.shape[1] - grid = (batch, head_num, SPLIT_K) + grid = (batch, head_num, NUM_KV_SPLITS) kv_group_num = q.shape[1] // k_buffer.shape[1] if kv_group_num == 1: @@ -202,14 +187,15 @@ def _decode_att_m_fwd( num_warps = 2 BLOCK_DMODEL = triton.next_power_of_2(Lk) + BLOCK_DV = triton.next_power_of_2(Lv) _fwd_kernel_stage1[grid]( q, k_buffer, + v_buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, att_out, Req_to_tokens.stride(0), @@ -217,56 +203,20 @@ def _decode_att_m_fwd( q.stride(1), k_buffer.stride(0), k_buffer.stride(1), - att_out.stride(0), - kv_group_num=kv_group_num, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK, - SPLIT_K=SPLIT_K, - logit_cap=logit_cap, - num_warps=num_warps, - num_stages=1, - Lk=Lk, - ) - - -def _decode_softmax_reducev_fwd( - logits, - v_buffer, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, -): - BLOCK = 64 - batch, head = b_seq_len.shape[0], logits.shape[0] - grid = (batch, head, 1) - kv_group_num = logits.shape[0] // v_buffer.shape[1] - - num_warps = 1 - - Lv = v_buffer.shape[-1] - BLOCK_DMODEL = triton.next_power_of_2(Lv) - - _fwd_kernel_stage2[grid]( - logits, - v_buffer, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - logits.stride(0), v_buffer.stride(0), v_buffer.stride(1), - o.stride(0), - o.stride(1), - req_to_tokens.stride(0), + att_out.stride(0), + att_out.stride(1), + att_out.stride(2), kv_group_num=kv_group_num, BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_DV=BLOCK_DV, BLOCK_N=BLOCK, + NUM_KV_SPLITS=NUM_KV_SPLITS, + logit_cap=logit_cap, num_warps=num_warps, - num_stages=3, + num_stages=2, + Lk=Lk, Lv=Lv, ) @@ -275,10 +225,10 @@ def _decode_softmax_reducev_fwd( def _fwd_grouped_kernel_stage1( Q, K_Buffer, + V_Buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, Att_Out, stride_req_to_tokens_b, @@ -286,124 +236,27 @@ def _fwd_grouped_kernel_stage1( stride_qh, stride_buf_kbs, stride_buf_kh, - att_stride_h, + stride_buf_vbs, + stride_buf_vh, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, kv_group_num: tl.constexpr, q_head_num: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_DPE: tl.constexpr, + BLOCK_DV: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_H: tl.constexpr, - SPLIT_K: tl.constexpr, + NUM_KV_SPLITS: tl.constexpr, logit_cap: tl.constexpr, Lk: tl.constexpr, -): - cur_batch = tl.program_id(0) - cur_head_id = tl.program_id(1) - cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) - split_k_id = tl.program_id(2) - - reduce_dtype = Att_Out.dtype.element_ty - - if BLOCK_H < kv_group_num: - VALID_BLOCK_H: tl.constexpr = BLOCK_H - else: - VALID_BLOCK_H: tl.constexpr = kv_group_num - cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) - mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H - mask_h = mask_h & (cur_head < q_head_num) - - offs_d = tl.arange(0, BLOCK_DMODEL) - cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) - cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - - offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] - q = tl.load( - Q + offs_q, mask=(mask_h[:, None]) & (offs_d[None, :] < Lk), other=0.0 - ).to(reduce_dtype) - - if BLOCK_DPE > 0: - offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) - off_qpe = ( - cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] - ) - qpe = tl.load(Q + off_qpe, mask=mask_h[:, None], other=0.0).to(reduce_dtype) - - kv_len_per_split = tl.cdiv(cur_batch_seq_len, SPLIT_K) - split_k_start = kv_len_per_split * split_k_id - split_k_end = tl.minimum(split_k_start + kv_len_per_split, cur_batch_seq_len) - - for start_n in range(split_k_start, split_k_end, BLOCK_N): - offs_n = start_n + tl.arange(0, BLOCK_N) - k_loc = tl.load( - Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, - mask=offs_n < split_k_end, - other=0, - ) - offs_buf_k = ( - k_loc[None, :] * stride_buf_kbs - + cur_kv_head * stride_buf_kh - + offs_d[:, None] - ) - k = tl.load( - K_Buffer + offs_buf_k, - mask=(offs_n[None, :] < split_k_end) & (offs_d[:, None] < Lk), - other=0.0, - ).to(reduce_dtype) - qk = tl.dot(q, k) - if BLOCK_DPE > 0: - offs_buf_kpe = ( - k_loc[None, :] * stride_buf_kbs - + cur_kv_head * stride_buf_kh - + offs_dpe[:, None] - ) - kpe = tl.load( - K_Buffer + offs_buf_kpe, - mask=offs_n[None, :] < split_k_end, - other=0.0, - ).to(reduce_dtype) - qk += tl.dot(qpe, kpe) - qk *= sm_scale - - if logit_cap > 0: - qk = logit_cap * tanh(qk / logit_cap) - - offs_o = cur_head[:, None] * att_stride_h + ( - cur_batch_in_all_start_index + offs_n[None, :] - ) - - tl.store( - Att_Out + offs_o, - qk, - mask=mask_h[:, None] & (offs_n[None, :] < split_k_end), - ) - - -@triton.jit -def _fwd_grouped_kernel_stage2( - logits, - V_Buffer, - Out, - Req_to_tokens, - B_req_idx, - B_Start_Loc, - B_Seqlen, - stride_logic_h, - stride_buf_vbs, - stride_buf_vh, - stride_obs, - stride_oh, - stride_req_to_token_b, - kv_group_num: tl.constexpr, - q_head_num: tl.constexpr, - BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr, - BLOCK_H: tl.constexpr, Lv: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head_id = tl.program_id(1) cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) + split_kv_id = tl.program_id(2) if BLOCK_H < kv_group_num: VALID_BLOCK_H: tl.constexpr = BLOCK_H @@ -413,71 +266,137 @@ def _fwd_grouped_kernel_stage2( mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H mask_h = mask_h & (cur_head < q_head_num) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_dv = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lk + mask_dv = offs_dv < Lv cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Start_Loc + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - offs_n = tl.arange(0, BLOCK_N) - offs_d = tl.arange(0, BLOCK_DMODEL) + offs_q = cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_d[None, :] + q = tl.load(Q + offs_q, mask=(mask_h[:, None]) & (mask_d[None, :]), other=0.0) - offs_buf_v = cur_kv_head * stride_buf_vh + offs_d[None, :] - v_ptrs = V_Buffer + offs_buf_v + if BLOCK_DPE > 0: + offs_dpe = BLOCK_DMODEL + tl.arange(0, BLOCK_DPE) + mask_dpe = offs_dpe < Lk + off_qpe = ( + cur_batch * stride_qbs + cur_head[:, None] * stride_qh + offs_dpe[None, :] + ) + qpe = tl.load( + Q + off_qpe, mask=(mask_h[:, None]) & (mask_dpe[None, :]), other=0.0 + ) + + kv_len_per_split = tl.cdiv(cur_batch_seq_len, NUM_KV_SPLITS) + split_kv_start = kv_len_per_split * split_kv_id + split_kv_end = tl.minimum(split_kv_start + kv_len_per_split, cur_batch_seq_len) e_max = tl.zeros([BLOCK_H], dtype=tl.float32) - float("inf") e_sum = tl.zeros([BLOCK_H], dtype=tl.float32) - acc = tl.zeros([BLOCK_H, BLOCK_DMODEL], dtype=tl.float32) + acc = tl.zeros([BLOCK_H, BLOCK_DV], dtype=tl.float32) - for start_n in range(0, cur_batch_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) - v_index = tl.load( - Req_to_tokens - + cur_batch_req_idx * stride_req_to_token_b - + (start_n + offs_n), - mask=(start_n + offs_n) < cur_batch_seq_len, - other=0, + if split_kv_end > split_kv_start: + for start_n in range(split_kv_start, split_kv_end, BLOCK_N): + offs_n = start_n + tl.arange(0, BLOCK_N) + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n, + mask=offs_n < split_kv_end, + other=0, + ) + offs_buf_k = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_d[:, None] + ) + k = tl.load( + K_Buffer + offs_buf_k, + mask=(offs_n[None, :] < split_kv_end) & (mask_d[:, None]), + other=0.0, + ) + qk = tl.dot(q, k.to(q.dtype)) + if BLOCK_DPE > 0: + offs_buf_kpe = ( + kv_loc[None, :] * stride_buf_kbs + + cur_kv_head * stride_buf_kh + + offs_dpe[:, None] + ) + kpe = tl.load( + K_Buffer + offs_buf_kpe, + mask=(offs_n[None, :] < split_kv_end) & (mask_dpe[:, None]), + other=0.0, + ) + qk += tl.dot(qpe, kpe.to(qpe.dtype)) + qk *= sm_scale + + if logit_cap > 0: + qk = logit_cap * tanh(qk / logit_cap) + + qk = tl.where( + mask_h[:, None] & (offs_n[None, :] < split_kv_end), qk, float("-inf") + ) + + offs_buf_v = ( + kv_loc[:, None] * stride_buf_vbs + + cur_kv_head * stride_buf_vh + + offs_dv[None, :] + ) + v = tl.load( + V_Buffer + offs_buf_v, + mask=(offs_n[:, None] < split_kv_end) & (mask_dv[None, :]), + other=0.0, + ) + + n_e_max = tl.maximum(tl.max(qk, 1), e_max) + re_scale = tl.exp(e_max - n_e_max) + p = tl.exp(qk - n_e_max[:, None]) + acc *= re_scale[:, None] + acc += tl.dot(p.to(v.dtype), v) + + e_sum = e_sum * re_scale + tl.sum(p, 1) + e_max = n_e_max + + offs_mid_o = ( + cur_batch * stride_mid_ob + + cur_head[:, None] * stride_mid_oh + + split_kv_id * stride_mid_os + + offs_dv[None, :] ) - offs_qk = cur_head[:, None] * stride_logic_h + ( - cur_batch_start_loc + start_n + offs_n[None, :] + tl.store( + Att_Out + offs_mid_o, + acc / e_sum[:, None], + mask=(mask_h[:, None]) & (mask_dv[None, :]), ) - qk = tl.load( - logits + offs_qk, - mask=mask_h[:, None] & (start_n + offs_n[None, :] < cur_batch_seq_len), - other=float("-inf"), + offs_mid_o_1 = ( + cur_batch * stride_mid_ob + + cur_head * stride_mid_oh + + split_kv_id * stride_mid_os + + Lv ) - n_e_max = tl.maximum(tl.max(qk, 1), e_max) - old_scale = tl.exp(e_max - n_e_max) - p = tl.exp(qk - n_e_max[:, None]) - e_sum = e_sum * old_scale + tl.sum(p, 1) - v = tl.load( - v_ptrs + v_index[:, None] * stride_buf_vbs, mask=(offs_d[None, :] < Lv) + tl.store( + Att_Out + offs_mid_o_1, + e_max + tl.log(e_sum), + mask=mask_h, ) - p = p.to(v.dtype) - acc = acc * old_scale[:, None] + tl.dot(p, v) - e_max = n_e_max - - acc = acc / e_sum[:, None] - off_o = cur_batch * stride_obs + cur_head[:, None] * stride_oh + offs_d[None, :] - out_ptrs = Out + off_o - tl.store(out_ptrs, acc, mask=(mask_h[:, None]) & (offs_d[None, :] < Lv)) def _decode_grouped_att_m_fwd( q, k_buffer, + v_buffer, att_out, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ): - BLOCK = 64 + BLOCK = 32 Lk = k_buffer.shape[-1] + Lv = v_buffer.shape[-1] if Lk == 576: BLOCK_DMODEL = 512 @@ -488,20 +407,19 @@ def _decode_grouped_att_m_fwd( else: BLOCK_DMODEL = triton.next_power_of_2(Lk) BLOCK_DPE = 0 + BLOCK_DV = triton.next_power_of_2(Lv) batch, head_num = B_req_idx.shape[0], q.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[1] - BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) - SPLIT_K = 8 + BLOCK_H = 16 + NUM_KV_SPLITS = num_kv_splits grid = ( batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), - SPLIT_K, + NUM_KV_SPLITS, ) - num_warps = 4 - extra_kargs = {} if is_hip_: # https://rocm.docs.amd.com/en/docs-6.2.0/how-to/llm-fine-tuning-optimization/optimizing-triton-kernel.html @@ -511,10 +429,10 @@ def _decode_grouped_att_m_fwd( _fwd_grouped_kernel_stage1[grid]( q, k_buffer, + v_buffer, sm_scale, Req_to_tokens, B_req_idx, - B_Start_Loc, B_Seqlen, att_out, Req_to_tokens.stride(0), @@ -522,41 +440,88 @@ def _decode_grouped_att_m_fwd( q.stride(1), k_buffer.stride(0), k_buffer.stride(1), + v_buffer.stride(0), + v_buffer.stride(1), att_out.stride(0), + att_out.stride(1), + att_out.stride(2), kv_group_num=kv_group_num, q_head_num=head_num, BLOCK_DMODEL=BLOCK_DMODEL, BLOCK_DPE=BLOCK_DPE, + BLOCK_DV=BLOCK_DV, BLOCK_N=BLOCK, BLOCK_H=BLOCK_H, - SPLIT_K=SPLIT_K, + NUM_KV_SPLITS=NUM_KV_SPLITS, logit_cap=logit_cap, - num_warps=num_warps, - num_stages=1, + num_warps=4, + num_stages=2, Lk=Lk, + Lv=Lv, **extra_kargs, ) -def _decode_grouped_softmax_reducev_fwd( - logits, - v_buffer, - o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, +@triton.jit +def _fwd_kernel_stage2( + Mid_O, + O, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_obs, + stride_oh, + NUM_KV_SPLITS: tl.constexpr, + BLOCK_DV: tl.constexpr, + Lv: tl.constexpr, ): - BLOCK = 128 - batch, head_num = b_seq_len.shape[0], logits.shape[0] - kv_group_num = logits.shape[0] // v_buffer.shape[1] - BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) - grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) - num_warps = 8 + offs_d = tl.arange(0, BLOCK_DV) + mask_d = offs_d < Lv + e_sum = 0.0 + e_max = -float("inf") + acc = tl.zeros([BLOCK_DV], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + Lv + + for split_kv_id in range(0, NUM_KV_SPLITS): + tv = tl.load( + Mid_O + offs_v + split_kv_id * stride_mid_os, mask=mask_d, other=0.0 + ) + tlogic = tl.load(Mid_O + offs_logic + split_kv_id * stride_mid_os) + n_e_max = tl.maximum(tlogic, e_max) + + old_scale = tl.exp(e_max - n_e_max) + acc *= old_scale + exp_logic = tl.exp(tlogic - n_e_max) + acc += exp_logic * tv + + e_sum = e_sum * old_scale + exp_logic + e_max = n_e_max + + tl.store( + O + cur_batch * stride_obs + cur_head * stride_oh + offs_d, + acc / e_sum, + mask=mask_d, + ) + + +def _decode_softmax_reducev_fwd( + logits, + q, + o, + v_buffer, + num_kv_splits, +): + batch, head_num = q.shape[0], q.shape[1] Lv = v_buffer.shape[-1] - BLOCK_DMODEL = triton.next_power_of_2(Lv) + BLOCK_DV = triton.next_power_of_2(Lv) + + NUM_KV_SPLITS = num_kv_splits extra_kargs = {} if is_hip_: @@ -564,28 +529,20 @@ def _decode_grouped_softmax_reducev_fwd( # https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/compiler.py extra_kargs = {"waves_per_eu": 4, "matrix_instr_nonkdim": 16, "kpack": 2} - _fwd_grouped_kernel_stage2[grid]( + grid = (batch, head_num) + _fwd_kernel_stage2[grid]( logits, - v_buffer, o, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, logits.stride(0), - v_buffer.stride(0), - v_buffer.stride(1), + logits.stride(1), + logits.stride(2), o.stride(0), o.stride(1), - req_to_tokens.stride(0), - kv_group_num=kv_group_num, - q_head_num=head_num, - BLOCK_DMODEL=BLOCK_DMODEL, - BLOCK_N=BLOCK, - BLOCK_H=BLOCK_H, + NUM_KV_SPLITS=NUM_KV_SPLITS, + BLOCK_DV=BLOCK_DV, Lv=Lv, - num_warps=num_warps, - num_stages=1, + num_warps=4, + num_stages=2, **extra_kargs, ) @@ -597,34 +554,27 @@ def decode_attention_fwd_normal( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap=0.0, ): _decode_att_m_fwd( q, k_buffer, + v_buffer, attn_logits, req_to_token, b_req_idx, - b_start_loc, b_seq_len, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) - _decode_softmax_reducev_fwd( - attn_logits, - v_buffer, - o, - req_to_token, - b_req_idx, - b_start_loc, - b_seq_len, - ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, num_kv_splits) def decode_attention_fwd_grouped( @@ -634,34 +584,27 @@ def decode_attention_fwd_grouped( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap=0.0, ): _decode_grouped_att_m_fwd( q, k_buffer, + v_buffer, attn_logits, req_to_token, b_req_idx, - b_start_loc, b_seq_len, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) - _decode_grouped_softmax_reducev_fwd( - attn_logits, - v_buffer, - o, - req_to_token, - b_req_idx, - b_start_loc, - b_seq_len, - ) + _decode_softmax_reducev_fwd(attn_logits, q, o, v_buffer, num_kv_splits) def decode_attention_fwd( @@ -675,9 +618,11 @@ def decode_attention_fwd( b_seq_len, attn_logits, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap=0.0, ): + assert num_kv_splits == attn_logits.shape[2] kv_group_num = q.shape[1] // v_buffer.shape[1] if kv_group_num == 1: @@ -689,10 +634,10 @@ def decode_attention_fwd( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) @@ -705,10 +650,10 @@ def decode_attention_fwd( o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, max_len_in_batch, + num_kv_splits, sm_scale, logit_cap, ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index c2e75a642..fe12d961d 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -141,6 +141,7 @@ class ServerArgs: enable_nan_detection: bool = False enable_p2p_check: bool = False triton_attention_reduce_in_fp32: bool = False + triton_attention_num_kv_splits: int = 8 num_continuous_decode_steps: int = 1 delete_ckpt_after_loading: bool = False @@ -753,6 +754,12 @@ class ServerArgs: help="Cast the intermidiate attention results to fp32 to avoid possible crashes related to fp16." "This only affects Triton attention kernels.", ) + parser.add_argument( + "--triton-attention-num-kv-splits", + type=int, + default=ServerArgs.triton_attention_num_kv_splits, + help="The number of KV splits in flash decoding Triton kernel. Larger value is better in longer context scenarios. The default value is 8.", + ) parser.add_argument( "--num-continuous-decode-steps", type=int, diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 44abfd61b..b7917345b 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -182,6 +182,7 @@ class TestTritonAttention(unittest.TestCase): seq_len = 10 # This represents the number of tokens already in the sequence total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) + num_kv_splits = 8 # q represents the new token being generated, one per batch q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") @@ -199,8 +200,8 @@ class TestTritonAttention(unittest.TestCase): b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( - (H_Q, total_tokens), - dtype=dtype, + (B, H_Q, num_kv_splits, D + 1), + dtype=torch.float32, device="cuda", ) @@ -215,6 +216,7 @@ class TestTritonAttention(unittest.TestCase): b_seq_len, attn_logits, seq_len, + num_kv_splits, sm_scale, ) @@ -235,9 +237,10 @@ class TestTritonAttention(unittest.TestCase): def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): dtype = torch.bfloat16 - seq_len = 10 # This represents the number of tokens already in the sequence + seq_len = 128 # This represents the number of tokens already in the sequence total_tokens = B * seq_len sm_scale = 1.0 / (D**0.5) + num_kv_splits = 8 # q represents the new token being generated, one per batch q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") @@ -247,8 +250,8 @@ class TestTritonAttention(unittest.TestCase): v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda") # o will have the same shape as q - o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") - o_grouped = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") + o = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") + o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device="cuda") req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) b_req_idx = torch.arange(B, device="cuda") @@ -256,8 +259,8 @@ class TestTritonAttention(unittest.TestCase): b_seq_len = torch.full((B,), seq_len, device="cuda") attn_logits = torch.empty( - (H_Q, total_tokens), - dtype=dtype, + (B, H_Q, num_kv_splits, D_V + 1), + dtype=torch.float32, device="cuda", ) @@ -268,13 +271,19 @@ class TestTritonAttention(unittest.TestCase): o, req_to_token, b_req_idx, - b_start_loc, b_seq_len, attn_logits, seq_len, + num_kv_splits, sm_scale, ) + attn_logits1 = torch.empty( + (B, H_Q, num_kv_splits, D_V + 1), + dtype=torch.float32, + device="cuda", + ) + decode_attention_fwd_grouped( q, k_buffer, @@ -282,21 +291,23 @@ class TestTritonAttention(unittest.TestCase): o_grouped, req_to_token, b_req_idx, - b_start_loc, b_seq_len, - attn_logits, + attn_logits1, seq_len, + num_kv_splits, sm_scale, ) cos_sim = torch.nn.functional.cosine_similarity( o.flatten(), o_grouped.flatten(), dim=0 ) + print(cos_sim.item()) self.assertTrue(cos_sim.item() > 0.99) self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2)) def test_grouped_decode_attention(self): configs = [ + (2, 16, 16, 64, 64), (2, 16, 1, 64, 64), (2, 64, 1, 13, 13), (2, 128, 1, 80, 80),