From c77762d57f4161efae8222ad828b818d95f8d268 Mon Sep 17 00:00:00 2001 From: Ke Bao Date: Mon, 28 Oct 2024 01:54:38 +0800 Subject: [PATCH] Fix Triton decode kernel & ut (#1819) --- .../attention/triton_ops/decode_attention.py | 135 +++++++++++++----- .../attention/triton_ops/prefill_attention.py | 2 +- test/srt/run_suite.py | 3 +- ...nd.py => test_triton_attention_backend.py} | 0 test/srt/test_triton_attention_kernels.py | 120 ++++++++++++++-- 5 files changed, 218 insertions(+), 42 deletions(-) rename test/srt/{test_triton_attn_backend.py => test_triton_attention_backend.py} (100%) diff --git a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py index 9e06b068c..9dafbb513 100644 --- a/python/sglang/srt/layers/attention/triton_ops/decode_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/decode_attention.py @@ -296,12 +296,18 @@ def _fwd_grouped_kernel_stage1( Lk: tl.constexpr, ): cur_batch = tl.program_id(0) - cur_kv_head = tl.program_id(1) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) start_n = tl.program_id(2) reduce_dtype = Att_Out.dtype.element_ty - cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H) - mask_h = cur_head < (cur_kv_head + 1) * kv_group_num + + if BLOCK_H < kv_group_num: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H mask_h = mask_h & (cur_head < q_head_num) offs_d = tl.arange(0, BLOCK_DMODEL) @@ -400,10 +406,15 @@ def _fwd_grouped_kernel_stage2( Lv: tl.constexpr, ): cur_batch = tl.program_id(0) - cur_kv_head = tl.program_id(1) + cur_head_id = tl.program_id(1) + cur_kv_head = cur_head_id // tl.cdiv(kv_group_num, BLOCK_H) - cur_head = cur_kv_head * kv_group_num + tl.arange(0, BLOCK_H) - mask_h = cur_head < (cur_kv_head + 1) * kv_group_num + if BLOCK_H < kv_group_num: + VALID_BLOCK_H: tl.constexpr = BLOCK_H + else: + VALID_BLOCK_H: tl.constexpr = kv_group_num + cur_head = cur_head_id * VALID_BLOCK_H + tl.arange(0, BLOCK_H) + mask_h = cur_head < (cur_head_id + 1) * VALID_BLOCK_H mask_h = mask_h & (cur_head < q_head_num) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) @@ -485,7 +496,7 @@ def _decode_grouped_att_m_fwd( batch, head_num = B_req_idx.shape[0], q.shape[1] kv_group_num = q.shape[1] // k_buffer.shape[1] - BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) + BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) grid = ( batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), @@ -534,7 +545,7 @@ def _decode_grouped_softmax_reducev_fwd( BLOCK = 128 batch, head_num = b_seq_len.shape[0], logits.shape[0] kv_group_num = logits.shape[0] // v_buffer.shape[1] - BLOCK_H = max(16, triton.next_power_of_2(kv_group_num)) + BLOCK_H = max(16, min(64, triton.next_power_of_2(kv_group_num))) grid = (batch, triton.cdiv(head_num, min(BLOCK_H, kv_group_num)), 1) num_warps = 8 @@ -567,6 +578,80 @@ def _decode_grouped_softmax_reducev_fwd( ) +def decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + attn_logits, + max_len_in_batch, + sm_scale, + logit_cap=0.0, +): + _decode_att_m_fwd( + q, + k_buffer, + attn_logits, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + max_len_in_batch, + sm_scale, + logit_cap, + ) + _decode_softmax_reducev_fwd( + attn_logits, + v_buffer, + o, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + ) + + +def decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + attn_logits, + max_len_in_batch, + sm_scale, + logit_cap=0.0, +): + _decode_grouped_att_m_fwd( + q, + k_buffer, + attn_logits, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + max_len_in_batch, + sm_scale, + logit_cap, + ) + _decode_grouped_softmax_reducev_fwd( + attn_logits, + v_buffer, + o, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + ) + + def decode_attention_fwd( q, k_buffer, @@ -585,47 +670,33 @@ def decode_attention_fwd( if kv_group_num == 1: # MHA - _decode_att_m_fwd( + decode_attention_fwd_normal( q, k_buffer, - attn_logits, - req_to_token, - b_req_idx, - b_start_loc, - b_seq_len, - max_len_in_batch, - sm_scale, - logit_cap, - ) - _decode_softmax_reducev_fwd( - attn_logits, v_buffer, o, req_to_token, b_req_idx, b_start_loc, b_seq_len, + attn_logits, + max_len_in_batch, + sm_scale, + logit_cap, ) else: # GQA/MQA/MLA - _decode_grouped_att_m_fwd( + decode_attention_fwd_grouped( q, k_buffer, - attn_logits, - req_to_token, - b_req_idx, - b_start_loc, - b_seq_len, - max_len_in_batch, - sm_scale, - logit_cap, - ) - _decode_grouped_softmax_reducev_fwd( - attn_logits, v_buffer, o, req_to_token, b_req_idx, b_start_loc, b_seq_len, + attn_logits, + max_len_in_batch, + sm_scale, + logit_cap, ) diff --git a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py index c90aac1cc..7906aca1c 100644 --- a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py @@ -168,7 +168,7 @@ def _fwd_kernel( def context_attention_fwd( q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True ): - if is_cuda_available and CUDA_CAPABILITY[0] >= 8: + if is_cuda_available and CUDA_CAPABILITY[0] > 8: BLOCK = 128 else: BLOCK = 64 diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index 3f8a1fecb..1237df709 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -26,7 +26,8 @@ suites = { "test_srt_endpoint.py", "test_torch_compile.py", "test_torchao.py", - "test_triton_attn_backend.py", + "test_triton_attention_kernels.py", + "test_triton_attention_backend.py", "test_update_weights.py", "test_vision_openai_server.py", ], diff --git a/test/srt/test_triton_attn_backend.py b/test/srt/test_triton_attention_backend.py similarity index 100% rename from test/srt/test_triton_attn_backend.py rename to test/srt/test_triton_attention_backend.py diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 539b4d4e0..44abfd61b 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -3,7 +3,11 @@ import unittest import torch -from sglang.srt.layers.attention.triton_ops.decode_attention import decode_attention_fwd +from sglang.srt.layers.attention.triton_ops.decode_attention import ( + decode_attention_fwd, + decode_attention_fwd_grouped, + decode_attention_fwd_normal, +) from sglang.srt.layers.attention.triton_ops.extend_attention import ( extend_attention_fwd, redundant_attention, @@ -13,7 +17,7 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import ( ) -class TestExtendAttention(unittest.TestCase): +class TestTritonAttention(unittest.TestCase): def _set_all_seeds(self, seed): """Set all random seeds for reproducibility.""" @@ -127,7 +131,7 @@ class TestExtendAttention(unittest.TestCase): for value in attention_values: self._test_extend_attention_once(19, 12331, 12, 4, value) - def _test_context_attention_once(self, head_dim): + def _test_context_attention_once(self, head_dim, is_causal): # Set up a simple test case num_heads = 4 seq_lens = [8, 12] @@ -143,15 +147,35 @@ class TestExtendAttention(unittest.TestCase): b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda") b_seq_len = torch.tensor(seq_lens, device="cuda") - context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_seq_len) + context_attention_fwd( + q, k, v, o, b_start_loc, b_seq_len, max_seq_len, is_causal=is_causal + ) + + cu_seq_lens = [0] * (len(seq_lens) + 1) + for i, seq_len in enumerate(seq_lens): + cu_seq_lens[i + 1] = cu_seq_lens[i] + seq_len + + for i in range(len(seq_lens)): + start, end = cu_seq_lens[i], cu_seq_lens[i + 1] + o_torch = torch.nn.functional.scaled_dot_product_attention( + q[start:end].permute(1, 0, 2), + k[start:end].permute(1, 0, 2), + v[start:end].permute(1, 0, 2), + is_causal=is_causal, + ).permute(1, 0, 2) + + cos_sim = torch.nn.functional.cosine_similarity( + o[start:end].flatten(), o_torch.flatten(), dim=0 + ) + self.assertTrue(cos_sim.item() > 1 - (1e-5)) + self.assertTrue(torch.allclose(o[start:end], o_torch, atol=1e-2)) def test_context_attention(self): - # Here we just to ensure there is no error - # TODO: correctnesss test head_dim = [128, 96, 80, 13] for dim in head_dim: - self._test_context_attention_once(dim) + for is_causal in [True, False]: + self._test_context_attention_once(dim, is_causal) def _test_decode_attention_once(self, B, H_Q, H_KV, D): dtype = torch.bfloat16 @@ -174,6 +198,12 @@ class TestExtendAttention(unittest.TestCase): b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda") b_seq_len = torch.full((B,), seq_len, device="cuda") + attn_logits = torch.empty( + (H_Q, total_tokens), + dtype=dtype, + device="cuda", + ) + decode_attention_fwd( q, k_buffer, @@ -183,8 +213,8 @@ class TestExtendAttention(unittest.TestCase): b_req_idx, b_start_loc, b_seq_len, + attn_logits, seq_len, - total_tokens, sm_scale, ) @@ -203,6 +233,80 @@ class TestExtendAttention(unittest.TestCase): for B, H_Q, H_KV, D in configs: self._test_decode_attention_once(B, H_Q, H_KV, D) + def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V): + dtype = torch.bfloat16 + seq_len = 10 # This represents the number of tokens already in the sequence + total_tokens = B * seq_len + sm_scale = 1.0 / (D**0.5) + + # q represents the new token being generated, one per batch + q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") + + # k_buffer and v_buffer represent all previous tokens + k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") + v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device="cuda") + + # o will have the same shape as q + o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") + o_grouped = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") + + req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) + b_req_idx = torch.arange(B, device="cuda") + b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda") + b_seq_len = torch.full((B,), seq_len, device="cuda") + + attn_logits = torch.empty( + (H_Q, total_tokens), + dtype=dtype, + device="cuda", + ) + + decode_attention_fwd_normal( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + attn_logits, + seq_len, + sm_scale, + ) + + decode_attention_fwd_grouped( + q, + k_buffer, + v_buffer, + o_grouped, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + attn_logits, + seq_len, + sm_scale, + ) + + cos_sim = torch.nn.functional.cosine_similarity( + o.flatten(), o_grouped.flatten(), dim=0 + ) + self.assertTrue(cos_sim.item() > 0.99) + self.assertTrue(torch.allclose(o, o_grouped, atol=3e-2)) + + def test_grouped_decode_attention(self): + configs = [ + (2, 16, 1, 64, 64), + (2, 64, 1, 13, 13), + (2, 128, 1, 80, 80), + (2, 128, 2, 512, 512), + (2, 128, 1, 576, 512), + ] + + for B, H_Q, H_KV, D, D_V in configs: + self._test_grouped_decode_attention_once(B, H_Q, H_KV, D, D_V) + if __name__ == "__main__": unittest.main()