diff --git a/python/sglang/srt/layers/triton_attention/decode_attention.py b/python/sglang/srt/layers/triton_attention/decode_attention.py index ebf29cc59..82ce6efc5 100644 --- a/python/sglang/srt/layers/triton_attention/decode_attention.py +++ b/python/sglang/srt/layers/triton_attention/decode_attention.py @@ -199,8 +199,6 @@ def _decode_att_m_fwd( BLOCK = 32 # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] - assert Lq == Lk - assert Lk in {16, 32, 64, 96, 128, 256} batch, head_num = B_req_idx.shape[0], q.shape[1] @@ -482,8 +480,6 @@ def _decode_grouped_att_m_fwd( BLOCK = 32 # shape constraints Lq, Lk = q.shape[-1], k_buffer.shape[-1] - assert Lq == Lk - assert Lk in {16, 32, 64, 96, 128, 256, 576, 288} if Lk == 576: BLOCK_DMODEL = 512 diff --git a/python/sglang/srt/layers/triton_attention/extend_attention.py b/python/sglang/srt/layers/triton_attention/extend_attention.py index 81039e676..1193c4124 100644 --- a/python/sglang/srt/layers/triton_attention/extend_attention.py +++ b/python/sglang/srt/layers/triton_attention/extend_attention.py @@ -277,12 +277,6 @@ def extend_attention_fwd( o_extend.shape[-1], ) - assert Lq == Lk and Lv == Lo - - # TODO: is the assertion necessary? - assert Lq in {16, 32, 64, 96, 128, 256, 576, 288} - assert Lv in {16, 32, 64, 96, 128, 256, 512} - if Lq == 576: BLOCK_DMODEL = 512 BLOCK_DPE = 64 @@ -395,104 +389,3 @@ def redundant_attention( pl, pr = b_start_loc[i] + b_seq_len_prefix[i], b_start_loc[i] + b_seq_len[i] o_extend[pt : pt + cur_seq_len_extend] = o_buffer[pl:pr] pt += cur_seq_len_extend - - -def test_once(B, N_CTX, H_Q, H_KV, D): - dtype = torch.float16 - - b_seq_len_prefix = torch.randint( - 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" - ) - b_seq_len_extend = torch.randint( - 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" - ) - b_seq_len = b_seq_len_prefix + b_seq_len_extend - max_len_in_batch = torch.max(b_seq_len, 0)[0].item() - - b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda") - req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32, device="cuda") - b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") - b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) - b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") - b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) - for i in range(B): - req_to_tokens[i, : b_seq_len[i]] = torch.arange( - b_start_loc[i], b_start_loc[i] + b_seq_len[i] - ) - - total_token_num = torch.sum(b_seq_len).item() - extend_token_num = torch.sum(b_seq_len_extend).item() - k_buffer = torch.empty( - (total_token_num, H_KV, D), dtype=dtype, device="cuda" - ).normal_(mean=0.1, std=0.2) - v_buffer = torch.empty( - (total_token_num, H_KV, D), dtype=dtype, device="cuda" - ).normal_(mean=0.1, std=0.2) - - k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") - v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") - q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") - for i in range(B): - extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] - extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] - extend_start = b_start_loc_extend[i] - extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] - k_extend[extend_start:extend_end] = k_buffer[ - extend_start_in_buffer:extend_end_in_buffer - ] - v_extend[extend_start:extend_end] = v_buffer[ - extend_start_in_buffer:extend_end_in_buffer - ] - q_extend[extend_start:extend_end] = torch.empty( - (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda" - ).normal_(mean=0.1, std=0.2) - - o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") - o_redundant = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") - - b_seq_len_extend = b_seq_len - b_seq_len_prefix - b_start_loc_extend = torch.zeros_like(b_seq_len) - b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) - max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() - extend_attention_fwd( - q_extend, - k_extend, - v_extend, - o_extend, - k_buffer, - v_buffer, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - b_seq_len_prefix, - b_start_loc_extend, - b_seq_len_extend, - max_len_in_batch, - max_len_extend, - ) - - redundant_attention( - q_extend, - k_extend, - v_extend, - o_redundant, - k_buffer, - v_buffer, - req_to_tokens, - b_req_idx, - b_start_loc, - b_seq_len, - b_seq_len_prefix, - max_len_in_batch, - ) - - print("Mean: ", torch.mean(torch.abs(o_extend - o_redundant))) - print("Max: ", torch.max(torch.abs(o_extend - o_redundant))) - - assert torch.allclose(o_extend, o_redundant, rtol=1e-2) - - -if __name__ == "__main__": - test_once(19, 12331, 12, 4, 128) - test_once(19, 12331, 12, 4, 96) diff --git a/python/sglang/srt/layers/triton_attention/prefill_attention.py b/python/sglang/srt/layers/triton_attention/prefill_attention.py index fbf9976fb..e19e73ec1 100644 --- a/python/sglang/srt/layers/triton_attention/prefill_attention.py +++ b/python/sglang/srt/layers/triton_attention/prefill_attention.py @@ -151,8 +151,6 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len): BLOCK = 64 Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] - assert Lq == Lk and Lk == Lv - assert Lk in {16, 32, 64, 96, 128, 256} sm_scale = 1.0 / (Lq**0.5) batch, head = b_seq_len.shape[0], q.shape[1] diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py new file mode 100644 index 000000000..0d094b557 --- /dev/null +++ b/test/srt/test_triton_attention_kernels.py @@ -0,0 +1,213 @@ +import random +import unittest + +import torch + +from sglang.srt.layers.triton_attention.decode_attention import decode_attention_fwd +from sglang.srt.layers.triton_attention.extend_attention import ( + extend_attention_fwd, + redundant_attention, +) +from sglang.srt.layers.triton_attention.prefill_attention import context_attention_fwd + + +class TestExtendAttention(unittest.TestCase): + + def _set_all_seeds(self, seed): + """Set all random seeds for reproducibility.""" + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + def setUp(self): + # Set seeds before each test method + self._set_all_seeds(42) + + def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D): + dtype = torch.bfloat16 + + b_seq_len_prefix = torch.randint( + 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" + ) + b_seq_len_extend = torch.randint( + 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" + ) + b_seq_len = b_seq_len_prefix + b_seq_len_extend + max_len_in_batch = torch.max(b_seq_len, 0)[0].item() + + b_req_idx = torch.arange(B, dtype=torch.int32, device="cuda") + req_to_tokens = torch.empty( + (B, max_len_in_batch), dtype=torch.int32, device="cuda" + ) + b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) + b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + for i in range(B): + req_to_tokens[i, : b_seq_len[i]] = torch.arange( + b_start_loc[i], b_start_loc[i] + b_seq_len[i] + ) + + total_token_num = torch.sum(b_seq_len).item() + extend_token_num = torch.sum(b_seq_len_extend).item() + k_buffer = torch.empty( + (total_token_num, H_KV, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + v_buffer = torch.empty( + (total_token_num, H_KV, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + + k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") + v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") + q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + for i in range(B): + extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] + extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] + extend_start = b_start_loc_extend[i] + extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] + k_extend[extend_start:extend_end] = k_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + v_extend[extend_start:extend_end] = v_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + q_extend[extend_start:extend_end] = torch.empty( + (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + + o_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + o_redundant = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device="cuda" + ) + + b_seq_len_extend = b_seq_len - b_seq_len_prefix + b_start_loc_extend = torch.zeros_like(b_seq_len) + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() + extend_attention_fwd( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + b_seq_len_prefix, + b_start_loc_extend, + b_seq_len_extend, + max_len_in_batch, + max_len_extend, + ) + + redundant_attention( + q_extend, + k_extend, + v_extend, + o_redundant, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + b_seq_len_prefix, + max_len_in_batch, + ) + + self.assertTrue(torch.allclose(o_extend, o_redundant, rtol=1e-2)) + + def test_extend_attention(self): + + # Define the varying parameter values + attention_values = [128, 96, 80, 13] + + # Loop through the values and call the method + for value in attention_values: + self._test_extend_attention_once(19, 12331, 12, 4, value) + + def _test_context_attention_once(self, head_dim): + # Set up a simple test case + batch_size = 2 + num_heads = 4 + seq_lens = [8, 12] + max_seq_len = max(seq_lens) + + # Create random input tensors + q = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda") + k = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda") + v = torch.randn(sum(seq_lens), num_heads, head_dim, device="cuda") + o = torch.zeros(sum(seq_lens), num_heads, head_dim, device="cuda") + + # Create b_start_loc and b_seq_len tensors + b_start_loc = torch.tensor([0, seq_lens[0]], device="cuda") + b_seq_len = torch.tensor(seq_lens, device="cuda") + + context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_seq_len) + + def test_context_attention(self): + # Here we just to ensure there is no error + # TODO: correctnesss test + head_dim = [128, 96, 80, 13] + + for dim in head_dim: + self._test_context_attention_once(dim) + + def _test_decode_attention_once(self, B, H_Q, H_KV, D): + dtype = torch.bfloat16 + seq_len = 10 # This represents the number of tokens already in the sequence + total_tokens = B * seq_len + sm_scale = 1.0 / (D**0.5) + + # q represents the new token being generated, one per batch + q = torch.randn(B, H_Q, D, dtype=dtype, device="cuda") + + # k_buffer and v_buffer represent all previous tokens + k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") + v_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device="cuda") + + # o will have the same shape as q + o = torch.zeros(B, H_Q, D, dtype=dtype, device="cuda") + + req_to_token = torch.arange(total_tokens, device="cuda").reshape(B, seq_len) + b_req_idx = torch.arange(B, device="cuda") + b_start_loc = torch.arange(0, total_tokens, seq_len, device="cuda") + b_seq_len = torch.full((B,), seq_len, device="cuda") + + decode_attention_fwd( + q, + k_buffer, + v_buffer, + o, + req_to_token, + b_req_idx, + b_start_loc, + b_seq_len, + seq_len, + total_tokens, + sm_scale, + ) + + def test_decode_attention(self): + # Here we just to ensure there is no error + # TODO: correctnesss test + + # Test configurations + configs = [ + (2, 4, 4, 64), # MHA + (2, 4, 2, 64), # GQA + (2, 4, 4, 80), # Non-standard head dim + (2, 4, 4, 13), # Prime number head dim + ] + + for B, H_Q, H_KV, D in configs: + self._test_decode_attention_once(B, H_Q, H_KV, D) + + +if __name__ == "__main__": + unittest.main()