diff --git a/test/srt/test_triton_attention_kernels.py b/test/srt/test_triton_attention_kernels.py index 47eb16a9b..b15684f9a 100644 --- a/test/srt/test_triton_attention_kernels.py +++ b/test/srt/test_triton_attention_kernels.py @@ -2,6 +2,7 @@ import random import unittest import torch +import torch.nn.functional as F from sglang.srt.layers.attention.triton_ops.decode_attention import ( decode_attention_fwd, @@ -18,6 +19,80 @@ from sglang.srt.layers.attention.triton_ops.prefill_attention import ( from sglang.test.test_utils import CustomTestCase +def extend_attention_fwd_torch( + q: torch.Tensor, # [extend_tokens, H_Q, D] + k: torch.Tensor, # [extend_tokens, H_KV, D] + v: torch.Tensor, # [extend_tokens, H_KV, D] + o: torch.Tensor, # [extend_tokens, H_Q, D] + k_cache: torch.Tensor, # [total_tokens, H_KV, D] + v_cache: torch.Tensor, # [total_tokens, H_KV, D] + qo_indptr: torch.Tensor, # [B+1] + kv_indptr: torch.Tensor, # [B+1] + kv_indices: torch.Tensor, # [prefix_tokens] + sliding_window_size: int, +): + B = qo_indptr.size(0) - 1 + _, H_Q, D = q.shape + _, H_KV, _ = k.shape + + group_size = H_Q // H_KV + scale = 1.0 / D**0.5 + + for i in range(B): + q_start = int(qo_indptr[i].item()) + q_end = int(qo_indptr[i + 1].item()) + kv_start = int(kv_indptr[i].item()) + kv_end = int(kv_indptr[i + 1].item()) + + prefix_indices = kv_indices[kv_start:kv_end] + k_prefix = k_cache[prefix_indices] # [prefix_len, H_KV, D] + v_prefix = v_cache[prefix_indices] # [prefix_len, H_KV, D] + + k_extend = k[q_start:q_end] # [extend_len, H_KV, D] + v_extend = v[q_start:q_end] # [extend_len, H_KV, D] + q_extend = q[q_start:q_end] # [extend_len, H_Q, D] + + k_full = torch.cat([k_prefix, k_extend], dim=0) # [total_len, H_KV, D] + v_full = torch.cat([v_prefix, v_extend], dim=0) # [total_len, H_KV, D] + + if group_size != 1: + k_full_hq = k_full.repeat_interleave( + group_size, dim=1 + ) # [total_len, H_Q, D] + v_full_hq = v_full.repeat_interleave( + group_size, dim=1 + ) # [total_len, H_Q, D] + else: + k_full_hq = k_full + v_full_hq = v_full + + prefix_len = k_prefix.size(0) + extend_len = k_extend.size(0) + total_len = prefix_len + extend_len + + # causal + pos_keys = torch.arange(total_len, device=q.device) + t = prefix_len + torch.arange(extend_len, device=q.device) # [extend_len] + causal_mask = pos_keys.unsqueeze(0) <= t.unsqueeze(1) + + # sliding window + if sliding_window_size is not None and sliding_window_size > 0: + start = (t - (sliding_window_size)).clamp_min(0) # [extend_len] + else: + start = torch.zeros_like(t) + window_mask = pos_keys.unsqueeze(0) >= start.unsqueeze(1) + + final_mask = causal_mask & window_mask + + attn_scores = ( + torch.einsum("qhd,khd->qhk", q_extend, k_full_hq) * scale + ) # [extend_len, H_Q, total_len] + attn_scores = attn_scores.masked_fill(~final_mask.unsqueeze(1), float("-inf")) + + attn_weights = F.softmax(attn_scores, dim=-1) + o[q_start:q_end] = torch.einsum("qhk,khd->qhd", attn_weights, v_full_hq) + + class TestTritonAttention(CustomTestCase): def _set_all_seeds(self, seed): @@ -180,6 +255,115 @@ class TestTritonAttention(CustomTestCase): for value in attention_values: self._test_extend_attention_once(19, 12331, 12, 4, value) + def _test_extend_attention_sliding_window_once( + self, B, N_CTX, H_Q, H_KV, D, WINDOW_SIZE + ): + dtype = torch.bfloat16 + + b_seq_len_prefix = torch.randint( + 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" + ) + b_seq_len_extend = torch.randint( + 1, N_CTX // 2, (B,), dtype=torch.int32, device="cuda" + ) + b_seq_len = b_seq_len_prefix + b_seq_len_extend + + b_start_loc = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) + b_start_loc_extend = torch.zeros((B,), dtype=torch.int32, device="cuda") + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + + kv_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + kv_indptr[1 : B + 1] = torch.cumsum(b_seq_len_prefix[:B], dim=0) + kv_indices = torch.zeros( + (b_seq_len_prefix.sum().item(),), dtype=torch.int32, device="cuda" + ) + + for i in range(B): + kv_indices[kv_indptr[i] : kv_indptr[i + 1]] = torch.arange( + b_start_loc[i], b_start_loc[i] + b_seq_len_prefix[i] + ) + + total_token_num = torch.sum(b_seq_len).item() + extend_token_num = torch.sum(b_seq_len_extend).item() + k_buffer = torch.empty( + (total_token_num, H_KV, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + v_buffer = torch.empty( + (total_token_num, H_KV, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + + k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") + v_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype, device="cuda") + q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype, device="cuda") + for i in range(B): + extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] + extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] + extend_start = b_start_loc_extend[i] + extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] + k_extend[extend_start:extend_end] = k_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + v_extend[extend_start:extend_end] = v_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + q_extend[extend_start:extend_end] = torch.empty( + (b_seq_len_extend[i], H_Q, D), dtype=dtype, device="cuda" + ).normal_(mean=0.1, std=0.2) + + o_extend_triton = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device="cuda" + ) + o_extend_torch = torch.empty( + (extend_token_num, H_Q, D), dtype=dtype, device="cuda" + ) + + b_seq_len_extend = b_seq_len - b_seq_len_prefix + max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() + qo_indptr = torch.zeros((B + 1,), dtype=torch.int32, device="cuda") + qo_indptr[1 : B + 1] = torch.cumsum(b_seq_len_extend[:B], dim=0) + + extend_attention_fwd( + q_extend, + k_extend, + v_extend, + o_extend_triton, + k_buffer, + v_buffer, + qo_indptr, + kv_indptr, + kv_indices, + custom_mask=None, + is_causal=True, + mask_indptr=None, + max_len_extend=max_len_extend, + sliding_window_size=WINDOW_SIZE, + ) + + extend_attention_fwd_torch( + q_extend, + k_extend, + v_extend, + o_extend_torch, + k_buffer, + v_buffer, + qo_indptr, + kv_indptr, + kv_indices, + WINDOW_SIZE, + ) + + self.assertTrue( + torch.allclose(o_extend_triton, o_extend_torch, rtol=1e-3, atol=1e-3) + ) + + def test_extend_attention_sliding_window(self): + window_sizes = [-1, 127] + for window_size in window_sizes: + self._test_extend_attention_sliding_window_once( + 19, 12331, 64, 8, 128, window_size + ) + def _test_context_attention_once(self, head_dim, is_causal): # Set up a simple test case num_heads = 4