Update extend/decode attention kernel for CPU in sgl-kernel and add UTs (#6405)

Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
YanbingJiang
2025-05-20 12:23:17 +08:00
committed by GitHub
parent 83f2d9d4ed
commit 32cc66efa5
4 changed files with 464 additions and 19 deletions

View File

@@ -34,6 +34,19 @@ inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc,
}
}
template <typename scalar_t>
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ src, int64_t size) {
using bVec = at::vec::Vectorized<scalar_t>;
int64_t d = 0;
for (; d <= size - bVec::size(); d += bVec::size()) {
bVec out_bvec = bVec::loadu(src + d);
out_bvec.store(out + d);
}
for (; d < size; ++d) {
out[d] = src[d];
}
}
// GEMM handles query @ key (indexed) x scale
// A : [M, K]
// B : [N, K] indexed
@@ -611,8 +624,11 @@ void decode_attention_kernel_impl(
scalar_t* __restrict__ output,
float* __restrict__ attn_logits,
const scalar_t* __restrict__ query,
const scalar_t* __restrict__ k_buffer,
const scalar_t* __restrict__ v_buffer,
scalar_t* __restrict__ k_buffer,
scalar_t* __restrict__ v_buffer,
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
const int64_t* __restrict__ loc,
const index_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices,
const int64_t* __restrict__ seq_lens,
@@ -625,11 +641,33 @@ void decode_attention_kernel_impl(
int64_t k_strideH,
int64_t v_strideN,
int64_t v_strideH,
int64_t nk_strideN,
int64_t nk_strideH,
int64_t nv_strideN,
int64_t nv_strideH,
float scaling,
float logit_cap,
int64_t max_num_reqs,
int64_t max_context_len,
int64_t max_total_num_tokens) {
at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) {
int64_t bs{0}, head_id{0};
data_index_init(begin, bs, batches, head_id, num_heads);
for (int64_t i = begin; i < end; i++) {
int64_t loc_val = loc[bs];
scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_id * k_strideH;
scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_id * v_strideH;
const scalar_t* new_key_ptr = key + bs * nk_strideN + head_id * nk_strideH;
const scalar_t* new_value_ptr = value + bs * nv_strideN + head_id * nv_strideH;
copy_stub<scalar_t>(k_buffer_ptr, new_key_ptr, head_size);
copy_stub<scalar_t>(v_buffer_ptr, new_value_ptr, head_size_v);
// move to the next index
data_index_step(bs, batches, head_id, num_heads);
}
});
using Vec = at::vec::Vectorized<float>;
// block length for k_buffer and v_buffer
@@ -791,8 +829,11 @@ void decode_attention_grouped_kernel_impl(
scalar_t* __restrict__ output,
float* __restrict__ attn_logits,
const scalar_t* __restrict__ query,
const scalar_t* __restrict__ k_buffer,
const scalar_t* __restrict__ v_buffer,
scalar_t* __restrict__ k_buffer,
scalar_t* __restrict__ v_buffer,
const scalar_t* __restrict__ key,
const scalar_t* __restrict__ value,
const int64_t* __restrict__ loc,
const index_t* __restrict__ req_to_token,
const int64_t* __restrict__ req_pool_indices,
const int64_t* __restrict__ seq_lens,
@@ -806,11 +847,33 @@ void decode_attention_grouped_kernel_impl(
int64_t k_strideH,
int64_t v_strideN,
int64_t v_strideH,
int64_t nk_strideN,
int64_t nk_strideH,
int64_t nv_strideN,
int64_t nv_strideH,
float scaling,
float logit_cap,
int64_t max_num_reqs,
int64_t max_context_len,
int64_t max_total_num_tokens) {
at::parallel_for(0, batches * num_heads_kv, 0, [&](int64_t begin, int64_t end) {
int64_t bs{0}, head_kv_id{0};
data_index_init(begin, bs, batches, head_kv_id, num_heads_kv);
for (int64_t i = begin; i < end; i++) {
int64_t loc_val = loc[bs];
scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_kv_id * k_strideH;
scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_kv_id * v_strideH;
const scalar_t* new_key_ptr = key + bs * nk_strideN + head_kv_id * nk_strideH;
const scalar_t* new_value_ptr = value + bs * nv_strideN + head_kv_id * nv_strideH;
copy_stub<scalar_t>(k_buffer_ptr, new_key_ptr, head_size);
copy_stub<scalar_t>(v_buffer_ptr, new_value_ptr, head_size_v);
// move to the next index
data_index_step(bs, batches, head_kv_id, num_heads_kv);
}
});
using Vec = at::vec::Vectorized<float>;
// block length for k_buffer and v_buffer
@@ -833,14 +896,12 @@ void decode_attention_grouped_kernel_impl(
// partition the heads into blocks for parallel
const int64_t num_groups = num_heads / num_heads_kv;
const int64_t num_blocks = div_up(num_heads, std::min(BLOCK_H, num_groups));
const int64_t num_groups_per_block = div_up(num_groups, BLOCK_H);
const int64_t num_heads_per_block = std::min(num_groups, BLOCK_H);
const int64_t num_blocks = div_up(num_groups, BLOCK_H);
// parallel on [batches, num_blocks, num_kv_splits]
at::parallel_for(0, batches * num_blocks * num_kv_splits, 0, [&](int64_t begin, int64_t end) {
int64_t bs{0}, head_id{0}, kv_id{0};
data_index_init(begin, bs, batches, head_id, num_blocks, kv_id, num_kv_splits);
// parallel on [batches, num_heads_kv, num_blocks, num_kv_splits]
at::parallel_for(0, batches * num_heads_kv * num_blocks * num_kv_splits, 0, [&](int64_t begin, int64_t end) {
int64_t bs{0}, head_kv_id{0}, block_id{0}, kv_id{0};
data_index_init(begin, bs, batches, head_kv_id, num_heads_kv, block_id, num_blocks, kv_id, num_kv_splits);
alignas(64) float s_i[BLOCK_H * BLOCK_N];
float* __restrict__ s_delta = s_i;
@@ -850,15 +911,13 @@ void decode_attention_grouped_kernel_impl(
alignas(64) float m_delta[BLOCK_H];
for (int64_t i = begin; i < end; ++i) {
const int64_t h_start = head_id * num_heads_per_block;
const int64_t h_end = std::min(h_start + num_heads_per_block, num_heads);
const int64_t h_start = head_kv_id * num_groups + block_id * BLOCK_H;
const int64_t h_end = head_kv_id * num_groups + std::min(block_id * BLOCK_H + BLOCK_H, num_groups);
const int64_t h_size = h_end - h_start;
// get query
const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + h_start * q_strideH;
// kv head id and valid block head size
int64_t head_kv_id = head_id / num_groups_per_block;
int64_t seq_len_kv = seq_lens[bs];
int64_t req_pool_id = req_pool_indices[bs];
TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!");
@@ -952,7 +1011,7 @@ void decode_attention_grouped_kernel_impl(
}
// move to the next index
data_index_step(bs, batches, head_id, num_blocks, kv_id, num_kv_splits);
data_index_step(bs, batches, head_kv_id, num_heads_kv, block_id, num_blocks, kv_id, num_kv_splits);
}
});
@@ -1004,9 +1063,12 @@ void decode_attention_grouped_kernel_impl(
//
void decode_attention_cpu(
at::Tensor& query,
at::Tensor& output,
at::Tensor& k_buffer,
at::Tensor& v_buffer,
at::Tensor& output,
at::Tensor& key,
at::Tensor& value,
at::Tensor& loc,
at::Tensor& attn_logits,
at::Tensor& req_to_token,
at::Tensor& req_pool_indices,
@@ -1021,9 +1083,15 @@ void decode_attention_cpu(
CHECK_INPUT(query);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
// for MLA, key and value shares the same storage and value could be non-contiguous
CHECK_LAST_DIM_CONTIGUOUS_INPUT(key);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(value);
CHECK_DIM(3, query);
CHECK_DIM(3, k_buffer);
CHECK_DIM(3, v_buffer);
CHECK_DIM(3, key);
CHECK_DIM(3, value);
CHECK_DIM(1, loc);
int64_t num_seqs = seq_lens.size(0);
int64_t max_num_reqs = req_to_token.size(0);
@@ -1037,6 +1105,7 @@ void decode_attention_cpu(
int64_t num_kv_splits = attn_logits.size(2);
CHECK_EQ(loc.numel(), num_seqs);
CHECK_EQ(attn_logits.size(0), num_seqs);
CHECK_EQ(attn_logits.size(1), num_heads);
CHECK_EQ(attn_logits.size(3), head_size_v + 1);
@@ -1047,6 +1116,11 @@ void decode_attention_cpu(
int64_t k_strideH = k_buffer.stride(1);
int64_t v_strideN = v_buffer.stride(0);
int64_t v_strideH = v_buffer.stride(1);
// strides for new key and value
int64_t nk_strideN = key.stride(0);
int64_t nk_strideH = key.stride(1);
int64_t nv_strideN = value.stride(0);
int64_t nv_strideH = value.stride(1);
// check index data types
const auto index_dtype = req_to_token.scalar_type();
@@ -1070,6 +1144,9 @@ void decode_attention_cpu(
query.data_ptr<scalar_t>(),
k_buffer.data_ptr<scalar_t>(),
v_buffer.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
loc.data_ptr<int64_t>(),
req_to_token.data_ptr<index_t>(),
req_pool_indices.data_ptr<int64_t>(),
seq_lens.data_ptr<int64_t>(),
@@ -1082,6 +1159,10 @@ void decode_attention_cpu(
k_strideH,
v_strideN,
v_strideH,
nk_strideN,
nv_strideH,
nv_strideN,
nv_strideH,
sm_scale,
logit_cap,
max_num_reqs,
@@ -1095,6 +1176,9 @@ void decode_attention_cpu(
query.data_ptr<scalar_t>(),
k_buffer.data_ptr<scalar_t>(),
v_buffer.data_ptr<scalar_t>(),
key.data_ptr<scalar_t>(),
value.data_ptr<scalar_t>(),
loc.data_ptr<int64_t>(),
req_to_token.data_ptr<index_t>(),
req_pool_indices.data_ptr<int64_t>(),
seq_lens.data_ptr<int64_t>(),
@@ -1108,6 +1192,10 @@ void decode_attention_cpu(
k_strideH,
v_strideN,
v_strideH,
nk_strideN,
nk_strideH,
nv_strideN,
nv_strideH,
sm_scale,
logit_cap,
max_num_reqs,