From 32cc66efa586af348b7c51f21a68e4771db7219c Mon Sep 17 00:00:00 2001 From: YanbingJiang Date: Tue, 20 May 2025 12:23:17 +0800 Subject: [PATCH] Update extend/decode attention kernel for CPU in sgl-kernel and add UTs (#6405) Co-authored-by: mingfeima --- sgl-kernel/csrc/cpu/decode.cpp | 122 +++++++++++-- sgl-kernel/csrc/cpu/torch_extension_cpu.cpp | 7 +- test/srt/cpu/test_decode.py | 167 +++++++++++++++++ test/srt/cpu/test_extend.py | 187 ++++++++++++++++++++ 4 files changed, 464 insertions(+), 19 deletions(-) create mode 100644 test/srt/cpu/test_decode.py create mode 100644 test/srt/cpu/test_extend.py diff --git a/sgl-kernel/csrc/cpu/decode.cpp b/sgl-kernel/csrc/cpu/decode.cpp index d1305f351..899987677 100644 --- a/sgl-kernel/csrc/cpu/decode.cpp +++ b/sgl-kernel/csrc/cpu/decode.cpp @@ -34,6 +34,19 @@ inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc, } } +template +inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ src, int64_t size) { + using bVec = at::vec::Vectorized; + int64_t d = 0; + for (; d <= size - bVec::size(); d += bVec::size()) { + bVec out_bvec = bVec::loadu(src + d); + out_bvec.store(out + d); + } + for (; d < size; ++d) { + out[d] = src[d]; + } +} + // GEMM handles query @ key (indexed) x scale // A : [M, K] // B : [N, K] indexed @@ -611,8 +624,11 @@ void decode_attention_kernel_impl( scalar_t* __restrict__ output, float* __restrict__ attn_logits, const scalar_t* __restrict__ query, - const scalar_t* __restrict__ k_buffer, - const scalar_t* __restrict__ v_buffer, + scalar_t* __restrict__ k_buffer, + scalar_t* __restrict__ v_buffer, + const scalar_t* __restrict__ key, + const scalar_t* __restrict__ value, + const int64_t* __restrict__ loc, const index_t* __restrict__ req_to_token, const int64_t* __restrict__ req_pool_indices, const int64_t* __restrict__ seq_lens, @@ -625,11 +641,33 @@ void decode_attention_kernel_impl( int64_t k_strideH, int64_t v_strideN, int64_t v_strideH, + int64_t nk_strideN, + int64_t nk_strideH, + int64_t nv_strideN, + int64_t nv_strideH, float scaling, float logit_cap, int64_t max_num_reqs, int64_t max_context_len, int64_t max_total_num_tokens) { + at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) { + int64_t bs{0}, head_id{0}; + data_index_init(begin, bs, batches, head_id, num_heads); + + for (int64_t i = begin; i < end; i++) { + int64_t loc_val = loc[bs]; + scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_id * k_strideH; + scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_id * v_strideH; + const scalar_t* new_key_ptr = key + bs * nk_strideN + head_id * nk_strideH; + const scalar_t* new_value_ptr = value + bs * nv_strideN + head_id * nv_strideH; + copy_stub(k_buffer_ptr, new_key_ptr, head_size); + copy_stub(v_buffer_ptr, new_value_ptr, head_size_v); + + // move to the next index + data_index_step(bs, batches, head_id, num_heads); + } + }); + using Vec = at::vec::Vectorized; // block length for k_buffer and v_buffer @@ -791,8 +829,11 @@ void decode_attention_grouped_kernel_impl( scalar_t* __restrict__ output, float* __restrict__ attn_logits, const scalar_t* __restrict__ query, - const scalar_t* __restrict__ k_buffer, - const scalar_t* __restrict__ v_buffer, + scalar_t* __restrict__ k_buffer, + scalar_t* __restrict__ v_buffer, + const scalar_t* __restrict__ key, + const scalar_t* __restrict__ value, + const int64_t* __restrict__ loc, const index_t* __restrict__ req_to_token, const int64_t* __restrict__ req_pool_indices, const int64_t* __restrict__ seq_lens, @@ -806,11 +847,33 @@ void decode_attention_grouped_kernel_impl( int64_t k_strideH, int64_t v_strideN, int64_t v_strideH, + int64_t nk_strideN, + int64_t nk_strideH, + int64_t nv_strideN, + int64_t nv_strideH, float scaling, float logit_cap, int64_t max_num_reqs, int64_t max_context_len, int64_t max_total_num_tokens) { + at::parallel_for(0, batches * num_heads_kv, 0, [&](int64_t begin, int64_t end) { + int64_t bs{0}, head_kv_id{0}; + data_index_init(begin, bs, batches, head_kv_id, num_heads_kv); + + for (int64_t i = begin; i < end; i++) { + int64_t loc_val = loc[bs]; + scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_kv_id * k_strideH; + scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_kv_id * v_strideH; + const scalar_t* new_key_ptr = key + bs * nk_strideN + head_kv_id * nk_strideH; + const scalar_t* new_value_ptr = value + bs * nv_strideN + head_kv_id * nv_strideH; + copy_stub(k_buffer_ptr, new_key_ptr, head_size); + copy_stub(v_buffer_ptr, new_value_ptr, head_size_v); + + // move to the next index + data_index_step(bs, batches, head_kv_id, num_heads_kv); + } + }); + using Vec = at::vec::Vectorized; // block length for k_buffer and v_buffer @@ -833,14 +896,12 @@ void decode_attention_grouped_kernel_impl( // partition the heads into blocks for parallel const int64_t num_groups = num_heads / num_heads_kv; - const int64_t num_blocks = div_up(num_heads, std::min(BLOCK_H, num_groups)); - const int64_t num_groups_per_block = div_up(num_groups, BLOCK_H); - const int64_t num_heads_per_block = std::min(num_groups, BLOCK_H); + const int64_t num_blocks = div_up(num_groups, BLOCK_H); - // parallel on [batches, num_blocks, num_kv_splits] - at::parallel_for(0, batches * num_blocks * num_kv_splits, 0, [&](int64_t begin, int64_t end) { - int64_t bs{0}, head_id{0}, kv_id{0}; - data_index_init(begin, bs, batches, head_id, num_blocks, kv_id, num_kv_splits); + // parallel on [batches, num_heads_kv, num_blocks, num_kv_splits] + at::parallel_for(0, batches * num_heads_kv * num_blocks * num_kv_splits, 0, [&](int64_t begin, int64_t end) { + int64_t bs{0}, head_kv_id{0}, block_id{0}, kv_id{0}; + data_index_init(begin, bs, batches, head_kv_id, num_heads_kv, block_id, num_blocks, kv_id, num_kv_splits); alignas(64) float s_i[BLOCK_H * BLOCK_N]; float* __restrict__ s_delta = s_i; @@ -850,15 +911,13 @@ void decode_attention_grouped_kernel_impl( alignas(64) float m_delta[BLOCK_H]; for (int64_t i = begin; i < end; ++i) { - const int64_t h_start = head_id * num_heads_per_block; - const int64_t h_end = std::min(h_start + num_heads_per_block, num_heads); + const int64_t h_start = head_kv_id * num_groups + block_id * BLOCK_H; + const int64_t h_end = head_kv_id * num_groups + std::min(block_id * BLOCK_H + BLOCK_H, num_groups); const int64_t h_size = h_end - h_start; // get query const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + h_start * q_strideH; - // kv head id and valid block head size - int64_t head_kv_id = head_id / num_groups_per_block; int64_t seq_len_kv = seq_lens[bs]; int64_t req_pool_id = req_pool_indices[bs]; TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!"); @@ -952,7 +1011,7 @@ void decode_attention_grouped_kernel_impl( } // move to the next index - data_index_step(bs, batches, head_id, num_blocks, kv_id, num_kv_splits); + data_index_step(bs, batches, head_kv_id, num_heads_kv, block_id, num_blocks, kv_id, num_kv_splits); } }); @@ -1004,9 +1063,12 @@ void decode_attention_grouped_kernel_impl( // void decode_attention_cpu( at::Tensor& query, - at::Tensor& output, at::Tensor& k_buffer, at::Tensor& v_buffer, + at::Tensor& output, + at::Tensor& key, + at::Tensor& value, + at::Tensor& loc, at::Tensor& attn_logits, at::Tensor& req_to_token, at::Tensor& req_pool_indices, @@ -1021,9 +1083,15 @@ void decode_attention_cpu( CHECK_INPUT(query); CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer); CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer); + // for MLA, key and value shares the same storage and value could be non-contiguous + CHECK_LAST_DIM_CONTIGUOUS_INPUT(key); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(value); CHECK_DIM(3, query); CHECK_DIM(3, k_buffer); CHECK_DIM(3, v_buffer); + CHECK_DIM(3, key); + CHECK_DIM(3, value); + CHECK_DIM(1, loc); int64_t num_seqs = seq_lens.size(0); int64_t max_num_reqs = req_to_token.size(0); @@ -1037,6 +1105,7 @@ void decode_attention_cpu( int64_t num_kv_splits = attn_logits.size(2); + CHECK_EQ(loc.numel(), num_seqs); CHECK_EQ(attn_logits.size(0), num_seqs); CHECK_EQ(attn_logits.size(1), num_heads); CHECK_EQ(attn_logits.size(3), head_size_v + 1); @@ -1047,6 +1116,11 @@ void decode_attention_cpu( int64_t k_strideH = k_buffer.stride(1); int64_t v_strideN = v_buffer.stride(0); int64_t v_strideH = v_buffer.stride(1); + // strides for new key and value + int64_t nk_strideN = key.stride(0); + int64_t nk_strideH = key.stride(1); + int64_t nv_strideN = value.stride(0); + int64_t nv_strideH = value.stride(1); // check index data types const auto index_dtype = req_to_token.scalar_type(); @@ -1070,6 +1144,9 @@ void decode_attention_cpu( query.data_ptr(), k_buffer.data_ptr(), v_buffer.data_ptr(), + key.data_ptr(), + value.data_ptr(), + loc.data_ptr(), req_to_token.data_ptr(), req_pool_indices.data_ptr(), seq_lens.data_ptr(), @@ -1082,6 +1159,10 @@ void decode_attention_cpu( k_strideH, v_strideN, v_strideH, + nk_strideN, + nv_strideH, + nv_strideN, + nv_strideH, sm_scale, logit_cap, max_num_reqs, @@ -1095,6 +1176,9 @@ void decode_attention_cpu( query.data_ptr(), k_buffer.data_ptr(), v_buffer.data_ptr(), + key.data_ptr(), + value.data_ptr(), + loc.data_ptr(), req_to_token.data_ptr(), req_pool_indices.data_ptr(), seq_lens.data_ptr(), @@ -1108,6 +1192,10 @@ void decode_attention_cpu( k_strideH, v_strideN, v_strideH, + nk_strideN, + nk_strideH, + nv_strideN, + nv_strideH, sm_scale, logit_cap, max_num_reqs, diff --git a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp index efaa12aca..aa28c7ed8 100644 --- a/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp +++ b/sgl-kernel/csrc/cpu/torch_extension_cpu.cpp @@ -49,9 +49,12 @@ std::tuple biased_grouped_topk_cpu( // attention void decode_attention_cpu( at::Tensor& query, - at::Tensor& output, at::Tensor& k_cache, - at::Tensor& v_cahce, + at::Tensor& v_cache, + at::Tensor& output, + at::Tensor& key, + at::Tensor& value, + at::Tensor& loc, at::Tensor& attn_logits, at::Tensor& req_to_token, at::Tensor& req_pool_indices, diff --git a/test/srt/cpu/test_decode.py b/test/srt/cpu/test_decode.py new file mode 100644 index 000000000..1ab1bfae8 --- /dev/null +++ b/test/srt/cpu/test_decode.py @@ -0,0 +1,167 @@ +import unittest + +import torch +from sgl_kernel.common_ops import decode_attention_cpu as decode_attention +from torch.nn.functional import scaled_dot_product_attention + +from sglang.test.test_utils import CustomTestCase + + +class TestDecodeAttention(CustomTestCase): + def _run_sdpa_forward_decode( + self, + query: torch.Tensor, + output: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + scaling=None, + enable_gqa=False, + causal=False, + ): + # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size] + query = query.movedim(0, query.dim() - 2) + + start_q, start_kv = 0, 0 + for seq_idx in range(seq_lens.shape[0]): + seq_len_q = 1 + seq_len_kv = seq_lens[seq_idx] + end_q = start_q + seq_len_q + end_kv = start_kv + seq_len_kv + + per_req_query = query[:, start_q:end_q, :] + + # get key and value from cache. per_req_tokens contains the kv cache + # index for each token in the sequence. + req_pool_idx = req_pool_indices[seq_idx] + per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] + per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2) + per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2) + + per_req_out = ( + scaled_dot_product_attention( + per_req_query.unsqueeze(0), + per_req_key.unsqueeze(0), + per_req_value.unsqueeze(0), + enable_gqa=enable_gqa, + scale=scaling, + is_causal=causal, + ) + .squeeze(0) + .movedim(query.dim() - 2, 0) + ) + output[start_q:end_q, :, :] = per_req_out + start_q, start_kv = end_q, end_kv + + return output + + def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, device): + dtype = torch.bfloat16 + # This represents the number of tokens already in the sequence + seq_len = 1024 + total_tokens = B * seq_len + sm_scale = 1.0 / (D**0.5) + logit_cap = 0.0 + num_kv_splits = 8 + enable_gqa = H_Q != H_KV + + # q represents the new token being generated, one per batch + q = torch.randn(B, H_Q, D, dtype=dtype, device=device) + + # k_buffer and v_buffer represent all previous tokens + k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device) + v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device=device) + + key = torch.randn(B, H_KV, D, dtype=dtype) + value = torch.randn(B, H_KV, D_V, dtype=dtype) + loc = torch.randint(0, 10, (B,)).to(torch.int64) + + # set kv cache + k_buffer[loc] = key + v_buffer[loc] = value + + # o will have the same shape as q + o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device) + o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device) + + req_to_token = ( + torch.arange(total_tokens, device=device) + .reshape(B, seq_len) + .to(torch.int32) + ) + b_req_idx = torch.arange(B, device=device).to(torch.int64) + b_seq_len = torch.full((B,), seq_len, device=device).to(torch.int64) + + attn_logits = torch.empty( + (B, H_Q, num_kv_splits, D_V + 1), + dtype=torch.float32, + device=device, + ) + + # k_buffer, v_buffer, key and value supports non-contiguous tensors + k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1) + v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1) + key = key.transpose(0, 1).contiguous().transpose(0, 1) + value = value.transpose(0, 1).contiguous().transpose(0, 1) + decode_attention( + q, + k_buffer, + v_buffer, + o, + key, + value, + loc, + attn_logits, + req_to_token, + b_req_idx, + b_seq_len, + sm_scale, + logit_cap, + ) + + self._run_sdpa_forward_decode( + q, + o_grouped, + k_buffer, + v_buffer, + req_to_token, + b_req_idx, + b_seq_len, + scaling=sm_scale, + enable_gqa=enable_gqa, + ) + + cos_sim = torch.nn.functional.cosine_similarity( + o.flatten(), o_grouped.flatten(), dim=0 + ) + self.assertGreater(cos_sim.item(), 0.99) + torch.testing.assert_close(o, o_grouped, atol=3e-2, rtol=1e-6) + + def _test_grouped_decode_attention(self, device="cuda"): + configs = [ + (2, 16, 16, 64, 64), + (2, 16, 1, 16, 16), + (2, 32, 8, 33, 55), + (2, 16, 1, 64, 64), + (2, 64, 1, 13, 13), + (2, 128, 1, 80, 80), + (2, 128, 2, 512, 512), + (1, 16, 1, 576, 512), + (1, 16, 16, 576, 512), + (1, 22, 1, 576, 512), + (1, 40, 8, 128, 128), + ] + + for B, H_Q, H_KV, D, D_V in configs: + self._test_grouped_decode_attention_once( + B, H_Q, H_KV, D, D_V, device=device + ) + + def test_grouped_decode_attention(self): + self._test_grouped_decode_attention("cpu") + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/cpu/test_extend.py b/test/srt/cpu/test_extend.py new file mode 100644 index 000000000..35fbfc184 --- /dev/null +++ b/test/srt/cpu/test_extend.py @@ -0,0 +1,187 @@ +import unittest + +import torch +from sgl_kernel.common_ops import extend_attention_cpu as extend_attention +from torch.nn.functional import scaled_dot_product_attention + +from sglang.test.test_utils import CustomTestCase + + +class TestExtendAttention(CustomTestCase): + + def _run_sdpa_forward_extend( + self, + query: torch.Tensor, + output: torch.Tensor, + k_cache: torch.Tensor, + v_cache: torch.Tensor, + req_to_token: torch.Tensor, + req_pool_indices: torch.Tensor, + seq_lens: torch.Tensor, + extend_prefix_lens: torch.Tensor, + extend_seq_lens: torch.Tensor, + scaling=None, + enable_gqa=False, + causal=False, + ): + + assert seq_lens.shape[0] == extend_prefix_lens.shape[0] + assert seq_lens.shape[0] == extend_seq_lens.shape[0] + + # [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size] + query = query.movedim(0, query.dim() - 2) + + start_q, start_kv = 0, 0 + for seq_idx in range(seq_lens.shape[0]): + + extend_seq_len_q = extend_seq_lens[seq_idx] + prefill_seq_len_q = extend_prefix_lens[seq_idx] + + seq_len_kv = seq_lens[seq_idx] + end_q = start_q + extend_seq_len_q + end_kv = start_kv + seq_len_kv + + per_req_query = query[:, start_q:end_q, :] + per_req_query_redudant = torch.empty( + (per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]), + dtype=per_req_query.dtype, + device=per_req_query.device, + ) + + per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query + + # get key and value from cache. per_req_tokens contains the kv cache + # index for each token in the sequence. + req_pool_idx = req_pool_indices[seq_idx] + per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv] + per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2) + per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2) + + per_req_out_redudant = ( + scaled_dot_product_attention( + per_req_query_redudant.unsqueeze(0), + per_req_key.unsqueeze(0), + per_req_value.unsqueeze(0), + enable_gqa=enable_gqa, + scale=scaling, + is_causal=causal, + ) + .squeeze(0) + .movedim(query.dim() - 2, 0) + ) + output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :] + start_q, start_kv = end_q, end_kv + return output + + def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D, DV, mla=False): + dtype = torch.bfloat16 + + b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32) + if mla: + b_seq_len_prefix.zero_() + b_seq_len_extend = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32) + b_seq_len = b_seq_len_prefix + b_seq_len_extend + max_len_in_batch = torch.max(b_seq_len, 0)[0].item() + + b_req_idx = torch.arange(B, dtype=torch.int32) + req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32) + b_start_loc = torch.zeros((B,), dtype=torch.int32) + b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0) + b_start_loc_extend = torch.zeros((B,), dtype=torch.int32) + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + + for i in range(B): + req_to_tokens[i, : b_seq_len[i]] = torch.arange( + b_start_loc[i], b_start_loc[i] + b_seq_len[i] + ) + + total_token_num = torch.sum(b_seq_len).item() + extend_token_num = torch.sum(b_seq_len_extend).item() + + H_BUF = 1 if mla else H_KV + k_buffer = torch.randn((total_token_num, H_BUF, D), dtype=dtype) + v_buffer = torch.randn((total_token_num, H_BUF, DV), dtype=dtype) + + k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype) + v_extend = torch.empty((extend_token_num, H_KV, DV), dtype=dtype) + q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype) + + for i in range(B): + extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i] + extend_end_in_buffer = b_start_loc[i] + b_seq_len[i] + extend_start = b_start_loc_extend[i] + extend_end = b_start_loc_extend[i] + b_seq_len_extend[i] + k_extend[extend_start:extend_end] = k_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + v_extend[extend_start:extend_end] = v_buffer[ + extend_start_in_buffer:extend_end_in_buffer + ] + q_extend[extend_start:extend_end] = torch.randn( + (b_seq_len_extend[i], H_Q, D), dtype=dtype + ) + + # k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors + k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1) + v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1) + k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1) + v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1) + + b_seq_len_extend = b_seq_len - b_seq_len_prefix + b_start_loc_extend = torch.zeros_like(b_seq_len) + b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0) + max_len_extend = torch.max(b_seq_len_extend, 0)[0].item() + + sm_scale = 1.0 / (D**0.5) + logit_cap = 0.0 + + # handle index type + b_req_idx = b_req_idx.to(torch.int64) + b_seq_len = b_seq_len.to(torch.int64) + + enable_gqa = H_Q != H_KV + o_ref = torch.empty((extend_token_num, H_Q, DV), dtype=dtype) + self._run_sdpa_forward_extend( + q_extend, + o_ref, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_seq_len, + b_seq_len_prefix, + b_seq_len_extend, + scaling=sm_scale, + enable_gqa=enable_gqa, + causal=True, + ) + + o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype) + extend_attention( + q_extend, + k_extend, + v_extend, + o_extend, + k_buffer, + v_buffer, + req_to_tokens, + b_req_idx, + b_seq_len, + b_seq_len_extend, + b_start_loc_extend, + max_len_extend, + sm_scale, + logit_cap, + ) + + torch.testing.assert_close(o_ref, o_extend, atol=1e-2, rtol=1e-2) + + def test_extend_attention(self): + for is_mla in [True, False]: + self._test_extend_attention_once(1, 123, 1, 1, 128, 96, is_mla) + self._test_extend_attention_once(1, 123, 16, 1, 128, 96, is_mla) + self._test_extend_attention_once(4, 1230, 16, 4, 128, 96, is_mla) + + +if __name__ == "__main__": + unittest.main()