From b044400dd34ecbb7ff3a9dd629d0faf68abaf1b9 Mon Sep 17 00:00:00 2001 From: YanbingJiang Date: Thu, 3 Jul 2025 10:59:45 +0800 Subject: [PATCH] Support non-contiguous query input for extend/decode attention (#7462) --- sgl-kernel/csrc/cpu/decode.cpp | 24 +++++++++++++++++------- sgl-kernel/csrc/cpu/extend.cpp | 12 ++++++++---- test/srt/cpu/test_decode.py | 3 ++- test/srt/cpu/test_extend.py | 3 ++- 4 files changed, 29 insertions(+), 13 deletions(-) diff --git a/sgl-kernel/csrc/cpu/decode.cpp b/sgl-kernel/csrc/cpu/decode.cpp index 7f55232e8..ae5ac51c8 100644 --- a/sgl-kernel/csrc/cpu/decode.cpp +++ b/sgl-kernel/csrc/cpu/decode.cpp @@ -874,6 +874,8 @@ void decode_attention_kernel_impl( int64_t head_size, int64_t head_size_v, int64_t num_kv_splits, + int64_t q_strideM, + int64_t q_strideH, int64_t k_strideN, int64_t k_strideH, int64_t v_strideN, @@ -886,8 +888,6 @@ void decode_attention_kernel_impl( using Vec = at::vec::Vectorized; // strides - const int64_t q_strideM = num_heads * head_size; - const int64_t q_strideH = head_size; const int64_t l_stride1 = num_kv_splits * (head_size_v + 1); const int64_t l_stride2 = head_size_v + 1; @@ -1017,6 +1017,8 @@ void decode_attention_mla_kernel_impl( int64_t head_size, int64_t head_size_v, int64_t num_kv_splits, + int64_t q_strideM, + int64_t q_strideH, int64_t k_strideN, int64_t k_strideH, int64_t v_strideN, @@ -1033,8 +1035,6 @@ void decode_attention_mla_kernel_impl( const int64_t BLOCK_H = batches == 1 ? 6 : (batches > 16 ? 22 : 11); // strides - const int64_t q_strideM = num_heads * head_size; - const int64_t q_strideH = head_size; const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1); const int64_t l_stride1 = num_kv_splits * (head_size_v + 1); const int64_t l_stride2 = head_size_v + 1; @@ -1209,6 +1209,8 @@ void decode_attention_grouped_kernel_impl( int64_t head_size, int64_t head_size_v, int64_t num_kv_splits, + int64_t q_strideM, + int64_t q_strideH, int64_t k_strideN, int64_t k_strideH, int64_t v_strideN, @@ -1227,8 +1229,6 @@ void decode_attention_grouped_kernel_impl( const int64_t BLOCK_H = std::min(4 * batches, kBLOCK_H); // strides - const int64_t q_strideM = num_heads * head_size; - const int64_t q_strideH = head_size; const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1); const int64_t l_stride1 = num_kv_splits * (head_size_v + 1); const int64_t l_stride2 = head_size_v + 1; @@ -1391,7 +1391,7 @@ void decode_attention_cpu( std::vector( {query, output, k_buffer, v_buffer, attn_logits, req_to_token, req_pool_indices, seq_lens})); - CHECK_INPUT(query); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(query); CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer); CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer); // for MLA, key and value shares the same storage and value could be non-contiguous @@ -1422,6 +1422,10 @@ void decode_attention_cpu( CHECK_EQ(attn_logits.size(3), head_size_v + 1); CHECK_EQ(attn_logits.scalar_type(), at::kFloat); + // strides for query + int64_t q_strideM = query.stride(0); + int64_t q_strideH = query.stride(1); + // strides for k_buffer and v_buffer int64_t k_strideN = k_buffer.stride(0); int64_t k_strideH = k_buffer.stride(1); @@ -1497,6 +1501,8 @@ void decode_attention_cpu( head_size, head_size_v, num_kv_splits, + q_strideM, + q_strideH, k_strideN, k_strideH, v_strideN, @@ -1523,6 +1529,8 @@ void decode_attention_cpu( head_size, head_size_v, num_kv_splits, + q_strideM, + q_strideH, k_strideN, k_strideH, v_strideN, @@ -1550,6 +1558,8 @@ void decode_attention_cpu( head_size, head_size_v, num_kv_splits, + q_strideM, + q_strideH, k_strideN, k_strideH, v_strideN, diff --git a/sgl-kernel/csrc/cpu/extend.cpp b/sgl-kernel/csrc/cpu/extend.cpp index c9f424634..3162ccea8 100644 --- a/sgl-kernel/csrc/cpu/extend.cpp +++ b/sgl-kernel/csrc/cpu/extend.cpp @@ -240,6 +240,8 @@ void extend_attention_kernel_impl( int num_heads_kv, int head_size, int head_size_v, + int q_strideM, + int q_strideH, int ke_strideN, int ke_strideH, int ve_strideN, @@ -259,8 +261,6 @@ void extend_attention_kernel_impl( using Vec = at::vec::Vectorized; // strides - const int q_strideM = num_heads * head_size; - const int q_strideH = head_size; const int o_strideM = num_heads * head_size_v; const int o_strideH = head_size_v; @@ -606,7 +606,7 @@ void extend_attention_cpu( extend_seq_lens, extend_start_loc})); - CHECK_INPUT(q_extend); + CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_extend); CHECK_INPUT(o_extend); CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend); CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend); @@ -623,7 +623,9 @@ void extend_attention_cpu( int head_size = q_extend.size(2); int head_size_v = v_extend.size(2); - // strides for k_extend and v_extend + // strides for q_extend, k_extend and v_extend + int q_strideM = q_extend.stride(0); + int q_strideH = q_extend.stride(1); int ke_strideN = k_extend.stride(0); int ke_strideH = k_extend.stride(1); int ve_strideN = v_extend.stride(0); @@ -698,6 +700,8 @@ void extend_attention_cpu( num_heads_kv, head_size, head_size_v, + q_strideM, + q_strideH, ke_strideN, ke_strideH, ve_strideN, diff --git a/test/srt/cpu/test_decode.py b/test/srt/cpu/test_decode.py index 94160b9db..c77378e1a 100644 --- a/test/srt/cpu/test_decode.py +++ b/test/srt/cpu/test_decode.py @@ -102,9 +102,10 @@ class TestDecodeAttention(CustomTestCase): device=device, ) - # k_buffer, v_buffer, key and value supports non-contiguous tensors + # k_buffer, v_buffer, query, key and value supports non-contiguous tensors k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1) v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1) + q = q.transpose(0, 1).contiguous().transpose(0, 1) key = key.transpose(0, 1).contiguous().transpose(0, 1) value = value.transpose(0, 1).contiguous().transpose(0, 1) torch.ops.sgl_kernel.decode_attention_cpu( diff --git a/test/srt/cpu/test_extend.py b/test/srt/cpu/test_extend.py index 6b8429aa5..9c6f5b394 100644 --- a/test/srt/cpu/test_extend.py +++ b/test/srt/cpu/test_extend.py @@ -123,7 +123,8 @@ class TestExtendAttention(CustomTestCase): (b_seq_len_extend[i], H_Q, D), dtype=dtype ) - # k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors + # q_extend, k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors + q_extend = q_extend.transpose(0, 1).contiguous().transpose(0, 1) k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1) v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1) k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)