Update extend/decode attention kernel for CPU in sgl-kernel and add UTs (#6405)
Co-authored-by: mingfeima <mingfei.ma@intel.com>
This commit is contained in:
@@ -34,6 +34,19 @@ inline void copy_stub(scalar_t* __restrict__ out, const float* __restrict__ acc,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
template <typename scalar_t>
|
||||||
|
inline void copy_stub(scalar_t* __restrict__ out, const scalar_t* __restrict__ src, int64_t size) {
|
||||||
|
using bVec = at::vec::Vectorized<scalar_t>;
|
||||||
|
int64_t d = 0;
|
||||||
|
for (; d <= size - bVec::size(); d += bVec::size()) {
|
||||||
|
bVec out_bvec = bVec::loadu(src + d);
|
||||||
|
out_bvec.store(out + d);
|
||||||
|
}
|
||||||
|
for (; d < size; ++d) {
|
||||||
|
out[d] = src[d];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GEMM handles query @ key (indexed) x scale
|
// GEMM handles query @ key (indexed) x scale
|
||||||
// A : [M, K]
|
// A : [M, K]
|
||||||
// B : [N, K] indexed
|
// B : [N, K] indexed
|
||||||
@@ -611,8 +624,11 @@ void decode_attention_kernel_impl(
|
|||||||
scalar_t* __restrict__ output,
|
scalar_t* __restrict__ output,
|
||||||
float* __restrict__ attn_logits,
|
float* __restrict__ attn_logits,
|
||||||
const scalar_t* __restrict__ query,
|
const scalar_t* __restrict__ query,
|
||||||
const scalar_t* __restrict__ k_buffer,
|
scalar_t* __restrict__ k_buffer,
|
||||||
const scalar_t* __restrict__ v_buffer,
|
scalar_t* __restrict__ v_buffer,
|
||||||
|
const scalar_t* __restrict__ key,
|
||||||
|
const scalar_t* __restrict__ value,
|
||||||
|
const int64_t* __restrict__ loc,
|
||||||
const index_t* __restrict__ req_to_token,
|
const index_t* __restrict__ req_to_token,
|
||||||
const int64_t* __restrict__ req_pool_indices,
|
const int64_t* __restrict__ req_pool_indices,
|
||||||
const int64_t* __restrict__ seq_lens,
|
const int64_t* __restrict__ seq_lens,
|
||||||
@@ -625,11 +641,33 @@ void decode_attention_kernel_impl(
|
|||||||
int64_t k_strideH,
|
int64_t k_strideH,
|
||||||
int64_t v_strideN,
|
int64_t v_strideN,
|
||||||
int64_t v_strideH,
|
int64_t v_strideH,
|
||||||
|
int64_t nk_strideN,
|
||||||
|
int64_t nk_strideH,
|
||||||
|
int64_t nv_strideN,
|
||||||
|
int64_t nv_strideH,
|
||||||
float scaling,
|
float scaling,
|
||||||
float logit_cap,
|
float logit_cap,
|
||||||
int64_t max_num_reqs,
|
int64_t max_num_reqs,
|
||||||
int64_t max_context_len,
|
int64_t max_context_len,
|
||||||
int64_t max_total_num_tokens) {
|
int64_t max_total_num_tokens) {
|
||||||
|
at::parallel_for(0, batches * num_heads, 0, [&](int64_t begin, int64_t end) {
|
||||||
|
int64_t bs{0}, head_id{0};
|
||||||
|
data_index_init(begin, bs, batches, head_id, num_heads);
|
||||||
|
|
||||||
|
for (int64_t i = begin; i < end; i++) {
|
||||||
|
int64_t loc_val = loc[bs];
|
||||||
|
scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_id * k_strideH;
|
||||||
|
scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_id * v_strideH;
|
||||||
|
const scalar_t* new_key_ptr = key + bs * nk_strideN + head_id * nk_strideH;
|
||||||
|
const scalar_t* new_value_ptr = value + bs * nv_strideN + head_id * nv_strideH;
|
||||||
|
copy_stub<scalar_t>(k_buffer_ptr, new_key_ptr, head_size);
|
||||||
|
copy_stub<scalar_t>(v_buffer_ptr, new_value_ptr, head_size_v);
|
||||||
|
|
||||||
|
// move to the next index
|
||||||
|
data_index_step(bs, batches, head_id, num_heads);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
using Vec = at::vec::Vectorized<float>;
|
using Vec = at::vec::Vectorized<float>;
|
||||||
|
|
||||||
// block length for k_buffer and v_buffer
|
// block length for k_buffer and v_buffer
|
||||||
@@ -791,8 +829,11 @@ void decode_attention_grouped_kernel_impl(
|
|||||||
scalar_t* __restrict__ output,
|
scalar_t* __restrict__ output,
|
||||||
float* __restrict__ attn_logits,
|
float* __restrict__ attn_logits,
|
||||||
const scalar_t* __restrict__ query,
|
const scalar_t* __restrict__ query,
|
||||||
const scalar_t* __restrict__ k_buffer,
|
scalar_t* __restrict__ k_buffer,
|
||||||
const scalar_t* __restrict__ v_buffer,
|
scalar_t* __restrict__ v_buffer,
|
||||||
|
const scalar_t* __restrict__ key,
|
||||||
|
const scalar_t* __restrict__ value,
|
||||||
|
const int64_t* __restrict__ loc,
|
||||||
const index_t* __restrict__ req_to_token,
|
const index_t* __restrict__ req_to_token,
|
||||||
const int64_t* __restrict__ req_pool_indices,
|
const int64_t* __restrict__ req_pool_indices,
|
||||||
const int64_t* __restrict__ seq_lens,
|
const int64_t* __restrict__ seq_lens,
|
||||||
@@ -806,11 +847,33 @@ void decode_attention_grouped_kernel_impl(
|
|||||||
int64_t k_strideH,
|
int64_t k_strideH,
|
||||||
int64_t v_strideN,
|
int64_t v_strideN,
|
||||||
int64_t v_strideH,
|
int64_t v_strideH,
|
||||||
|
int64_t nk_strideN,
|
||||||
|
int64_t nk_strideH,
|
||||||
|
int64_t nv_strideN,
|
||||||
|
int64_t nv_strideH,
|
||||||
float scaling,
|
float scaling,
|
||||||
float logit_cap,
|
float logit_cap,
|
||||||
int64_t max_num_reqs,
|
int64_t max_num_reqs,
|
||||||
int64_t max_context_len,
|
int64_t max_context_len,
|
||||||
int64_t max_total_num_tokens) {
|
int64_t max_total_num_tokens) {
|
||||||
|
at::parallel_for(0, batches * num_heads_kv, 0, [&](int64_t begin, int64_t end) {
|
||||||
|
int64_t bs{0}, head_kv_id{0};
|
||||||
|
data_index_init(begin, bs, batches, head_kv_id, num_heads_kv);
|
||||||
|
|
||||||
|
for (int64_t i = begin; i < end; i++) {
|
||||||
|
int64_t loc_val = loc[bs];
|
||||||
|
scalar_t* k_buffer_ptr = k_buffer + loc_val * k_strideN + head_kv_id * k_strideH;
|
||||||
|
scalar_t* v_buffer_ptr = v_buffer + loc_val * v_strideN + head_kv_id * v_strideH;
|
||||||
|
const scalar_t* new_key_ptr = key + bs * nk_strideN + head_kv_id * nk_strideH;
|
||||||
|
const scalar_t* new_value_ptr = value + bs * nv_strideN + head_kv_id * nv_strideH;
|
||||||
|
copy_stub<scalar_t>(k_buffer_ptr, new_key_ptr, head_size);
|
||||||
|
copy_stub<scalar_t>(v_buffer_ptr, new_value_ptr, head_size_v);
|
||||||
|
|
||||||
|
// move to the next index
|
||||||
|
data_index_step(bs, batches, head_kv_id, num_heads_kv);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
using Vec = at::vec::Vectorized<float>;
|
using Vec = at::vec::Vectorized<float>;
|
||||||
|
|
||||||
// block length for k_buffer and v_buffer
|
// block length for k_buffer and v_buffer
|
||||||
@@ -833,14 +896,12 @@ void decode_attention_grouped_kernel_impl(
|
|||||||
|
|
||||||
// partition the heads into blocks for parallel
|
// partition the heads into blocks for parallel
|
||||||
const int64_t num_groups = num_heads / num_heads_kv;
|
const int64_t num_groups = num_heads / num_heads_kv;
|
||||||
const int64_t num_blocks = div_up(num_heads, std::min(BLOCK_H, num_groups));
|
const int64_t num_blocks = div_up(num_groups, BLOCK_H);
|
||||||
const int64_t num_groups_per_block = div_up(num_groups, BLOCK_H);
|
|
||||||
const int64_t num_heads_per_block = std::min(num_groups, BLOCK_H);
|
|
||||||
|
|
||||||
// parallel on [batches, num_blocks, num_kv_splits]
|
// parallel on [batches, num_heads_kv, num_blocks, num_kv_splits]
|
||||||
at::parallel_for(0, batches * num_blocks * num_kv_splits, 0, [&](int64_t begin, int64_t end) {
|
at::parallel_for(0, batches * num_heads_kv * num_blocks * num_kv_splits, 0, [&](int64_t begin, int64_t end) {
|
||||||
int64_t bs{0}, head_id{0}, kv_id{0};
|
int64_t bs{0}, head_kv_id{0}, block_id{0}, kv_id{0};
|
||||||
data_index_init(begin, bs, batches, head_id, num_blocks, kv_id, num_kv_splits);
|
data_index_init(begin, bs, batches, head_kv_id, num_heads_kv, block_id, num_blocks, kv_id, num_kv_splits);
|
||||||
|
|
||||||
alignas(64) float s_i[BLOCK_H * BLOCK_N];
|
alignas(64) float s_i[BLOCK_H * BLOCK_N];
|
||||||
float* __restrict__ s_delta = s_i;
|
float* __restrict__ s_delta = s_i;
|
||||||
@@ -850,15 +911,13 @@ void decode_attention_grouped_kernel_impl(
|
|||||||
alignas(64) float m_delta[BLOCK_H];
|
alignas(64) float m_delta[BLOCK_H];
|
||||||
|
|
||||||
for (int64_t i = begin; i < end; ++i) {
|
for (int64_t i = begin; i < end; ++i) {
|
||||||
const int64_t h_start = head_id * num_heads_per_block;
|
const int64_t h_start = head_kv_id * num_groups + block_id * BLOCK_H;
|
||||||
const int64_t h_end = std::min(h_start + num_heads_per_block, num_heads);
|
const int64_t h_end = head_kv_id * num_groups + std::min(block_id * BLOCK_H + BLOCK_H, num_groups);
|
||||||
const int64_t h_size = h_end - h_start;
|
const int64_t h_size = h_end - h_start;
|
||||||
|
|
||||||
// get query
|
// get query
|
||||||
const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + h_start * q_strideH;
|
const scalar_t* __restrict__ q_ptr = query + bs * q_strideM + h_start * q_strideH;
|
||||||
|
|
||||||
// kv head id and valid block head size
|
|
||||||
int64_t head_kv_id = head_id / num_groups_per_block;
|
|
||||||
int64_t seq_len_kv = seq_lens[bs];
|
int64_t seq_len_kv = seq_lens[bs];
|
||||||
int64_t req_pool_id = req_pool_indices[bs];
|
int64_t req_pool_id = req_pool_indices[bs];
|
||||||
TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!");
|
TORCH_CHECK(seq_len_kv <= max_context_len, "seq_len_kv out of scope!");
|
||||||
@@ -952,7 +1011,7 @@ void decode_attention_grouped_kernel_impl(
|
|||||||
}
|
}
|
||||||
|
|
||||||
// move to the next index
|
// move to the next index
|
||||||
data_index_step(bs, batches, head_id, num_blocks, kv_id, num_kv_splits);
|
data_index_step(bs, batches, head_kv_id, num_heads_kv, block_id, num_blocks, kv_id, num_kv_splits);
|
||||||
}
|
}
|
||||||
});
|
});
|
||||||
|
|
||||||
@@ -1004,9 +1063,12 @@ void decode_attention_grouped_kernel_impl(
|
|||||||
//
|
//
|
||||||
void decode_attention_cpu(
|
void decode_attention_cpu(
|
||||||
at::Tensor& query,
|
at::Tensor& query,
|
||||||
at::Tensor& output,
|
|
||||||
at::Tensor& k_buffer,
|
at::Tensor& k_buffer,
|
||||||
at::Tensor& v_buffer,
|
at::Tensor& v_buffer,
|
||||||
|
at::Tensor& output,
|
||||||
|
at::Tensor& key,
|
||||||
|
at::Tensor& value,
|
||||||
|
at::Tensor& loc,
|
||||||
at::Tensor& attn_logits,
|
at::Tensor& attn_logits,
|
||||||
at::Tensor& req_to_token,
|
at::Tensor& req_to_token,
|
||||||
at::Tensor& req_pool_indices,
|
at::Tensor& req_pool_indices,
|
||||||
@@ -1021,9 +1083,15 @@ void decode_attention_cpu(
|
|||||||
CHECK_INPUT(query);
|
CHECK_INPUT(query);
|
||||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
|
||||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
|
||||||
|
// for MLA, key and value shares the same storage and value could be non-contiguous
|
||||||
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(key);
|
||||||
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(value);
|
||||||
CHECK_DIM(3, query);
|
CHECK_DIM(3, query);
|
||||||
CHECK_DIM(3, k_buffer);
|
CHECK_DIM(3, k_buffer);
|
||||||
CHECK_DIM(3, v_buffer);
|
CHECK_DIM(3, v_buffer);
|
||||||
|
CHECK_DIM(3, key);
|
||||||
|
CHECK_DIM(3, value);
|
||||||
|
CHECK_DIM(1, loc);
|
||||||
|
|
||||||
int64_t num_seqs = seq_lens.size(0);
|
int64_t num_seqs = seq_lens.size(0);
|
||||||
int64_t max_num_reqs = req_to_token.size(0);
|
int64_t max_num_reqs = req_to_token.size(0);
|
||||||
@@ -1037,6 +1105,7 @@ void decode_attention_cpu(
|
|||||||
|
|
||||||
int64_t num_kv_splits = attn_logits.size(2);
|
int64_t num_kv_splits = attn_logits.size(2);
|
||||||
|
|
||||||
|
CHECK_EQ(loc.numel(), num_seqs);
|
||||||
CHECK_EQ(attn_logits.size(0), num_seqs);
|
CHECK_EQ(attn_logits.size(0), num_seqs);
|
||||||
CHECK_EQ(attn_logits.size(1), num_heads);
|
CHECK_EQ(attn_logits.size(1), num_heads);
|
||||||
CHECK_EQ(attn_logits.size(3), head_size_v + 1);
|
CHECK_EQ(attn_logits.size(3), head_size_v + 1);
|
||||||
@@ -1047,6 +1116,11 @@ void decode_attention_cpu(
|
|||||||
int64_t k_strideH = k_buffer.stride(1);
|
int64_t k_strideH = k_buffer.stride(1);
|
||||||
int64_t v_strideN = v_buffer.stride(0);
|
int64_t v_strideN = v_buffer.stride(0);
|
||||||
int64_t v_strideH = v_buffer.stride(1);
|
int64_t v_strideH = v_buffer.stride(1);
|
||||||
|
// strides for new key and value
|
||||||
|
int64_t nk_strideN = key.stride(0);
|
||||||
|
int64_t nk_strideH = key.stride(1);
|
||||||
|
int64_t nv_strideN = value.stride(0);
|
||||||
|
int64_t nv_strideH = value.stride(1);
|
||||||
|
|
||||||
// check index data types
|
// check index data types
|
||||||
const auto index_dtype = req_to_token.scalar_type();
|
const auto index_dtype = req_to_token.scalar_type();
|
||||||
@@ -1070,6 +1144,9 @@ void decode_attention_cpu(
|
|||||||
query.data_ptr<scalar_t>(),
|
query.data_ptr<scalar_t>(),
|
||||||
k_buffer.data_ptr<scalar_t>(),
|
k_buffer.data_ptr<scalar_t>(),
|
||||||
v_buffer.data_ptr<scalar_t>(),
|
v_buffer.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
value.data_ptr<scalar_t>(),
|
||||||
|
loc.data_ptr<int64_t>(),
|
||||||
req_to_token.data_ptr<index_t>(),
|
req_to_token.data_ptr<index_t>(),
|
||||||
req_pool_indices.data_ptr<int64_t>(),
|
req_pool_indices.data_ptr<int64_t>(),
|
||||||
seq_lens.data_ptr<int64_t>(),
|
seq_lens.data_ptr<int64_t>(),
|
||||||
@@ -1082,6 +1159,10 @@ void decode_attention_cpu(
|
|||||||
k_strideH,
|
k_strideH,
|
||||||
v_strideN,
|
v_strideN,
|
||||||
v_strideH,
|
v_strideH,
|
||||||
|
nk_strideN,
|
||||||
|
nv_strideH,
|
||||||
|
nv_strideN,
|
||||||
|
nv_strideH,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
max_num_reqs,
|
max_num_reqs,
|
||||||
@@ -1095,6 +1176,9 @@ void decode_attention_cpu(
|
|||||||
query.data_ptr<scalar_t>(),
|
query.data_ptr<scalar_t>(),
|
||||||
k_buffer.data_ptr<scalar_t>(),
|
k_buffer.data_ptr<scalar_t>(),
|
||||||
v_buffer.data_ptr<scalar_t>(),
|
v_buffer.data_ptr<scalar_t>(),
|
||||||
|
key.data_ptr<scalar_t>(),
|
||||||
|
value.data_ptr<scalar_t>(),
|
||||||
|
loc.data_ptr<int64_t>(),
|
||||||
req_to_token.data_ptr<index_t>(),
|
req_to_token.data_ptr<index_t>(),
|
||||||
req_pool_indices.data_ptr<int64_t>(),
|
req_pool_indices.data_ptr<int64_t>(),
|
||||||
seq_lens.data_ptr<int64_t>(),
|
seq_lens.data_ptr<int64_t>(),
|
||||||
@@ -1108,6 +1192,10 @@ void decode_attention_cpu(
|
|||||||
k_strideH,
|
k_strideH,
|
||||||
v_strideN,
|
v_strideN,
|
||||||
v_strideH,
|
v_strideH,
|
||||||
|
nk_strideN,
|
||||||
|
nk_strideH,
|
||||||
|
nv_strideN,
|
||||||
|
nv_strideH,
|
||||||
sm_scale,
|
sm_scale,
|
||||||
logit_cap,
|
logit_cap,
|
||||||
max_num_reqs,
|
max_num_reqs,
|
||||||
|
|||||||
@@ -49,9 +49,12 @@ std::tuple<at::Tensor, at::Tensor> biased_grouped_topk_cpu(
|
|||||||
// attention
|
// attention
|
||||||
void decode_attention_cpu(
|
void decode_attention_cpu(
|
||||||
at::Tensor& query,
|
at::Tensor& query,
|
||||||
at::Tensor& output,
|
|
||||||
at::Tensor& k_cache,
|
at::Tensor& k_cache,
|
||||||
at::Tensor& v_cahce,
|
at::Tensor& v_cache,
|
||||||
|
at::Tensor& output,
|
||||||
|
at::Tensor& key,
|
||||||
|
at::Tensor& value,
|
||||||
|
at::Tensor& loc,
|
||||||
at::Tensor& attn_logits,
|
at::Tensor& attn_logits,
|
||||||
at::Tensor& req_to_token,
|
at::Tensor& req_to_token,
|
||||||
at::Tensor& req_pool_indices,
|
at::Tensor& req_pool_indices,
|
||||||
|
|||||||
167
test/srt/cpu/test_decode.py
Normal file
167
test/srt/cpu/test_decode.py
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from sgl_kernel.common_ops import decode_attention_cpu as decode_attention
|
||||||
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestDecodeAttention(CustomTestCase):
|
||||||
|
def _run_sdpa_forward_decode(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
scaling=None,
|
||||||
|
enable_gqa=False,
|
||||||
|
causal=False,
|
||||||
|
):
|
||||||
|
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
||||||
|
query = query.movedim(0, query.dim() - 2)
|
||||||
|
|
||||||
|
start_q, start_kv = 0, 0
|
||||||
|
for seq_idx in range(seq_lens.shape[0]):
|
||||||
|
seq_len_q = 1
|
||||||
|
seq_len_kv = seq_lens[seq_idx]
|
||||||
|
end_q = start_q + seq_len_q
|
||||||
|
end_kv = start_kv + seq_len_kv
|
||||||
|
|
||||||
|
per_req_query = query[:, start_q:end_q, :]
|
||||||
|
|
||||||
|
# get key and value from cache. per_req_tokens contains the kv cache
|
||||||
|
# index for each token in the sequence.
|
||||||
|
req_pool_idx = req_pool_indices[seq_idx]
|
||||||
|
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
||||||
|
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||||
|
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||||
|
|
||||||
|
per_req_out = (
|
||||||
|
scaled_dot_product_attention(
|
||||||
|
per_req_query.unsqueeze(0),
|
||||||
|
per_req_key.unsqueeze(0),
|
||||||
|
per_req_value.unsqueeze(0),
|
||||||
|
enable_gqa=enable_gqa,
|
||||||
|
scale=scaling,
|
||||||
|
is_causal=causal,
|
||||||
|
)
|
||||||
|
.squeeze(0)
|
||||||
|
.movedim(query.dim() - 2, 0)
|
||||||
|
)
|
||||||
|
output[start_q:end_q, :, :] = per_req_out
|
||||||
|
start_q, start_kv = end_q, end_kv
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _test_grouped_decode_attention_once(self, B, H_Q, H_KV, D, D_V, device):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
# This represents the number of tokens already in the sequence
|
||||||
|
seq_len = 1024
|
||||||
|
total_tokens = B * seq_len
|
||||||
|
sm_scale = 1.0 / (D**0.5)
|
||||||
|
logit_cap = 0.0
|
||||||
|
num_kv_splits = 8
|
||||||
|
enable_gqa = H_Q != H_KV
|
||||||
|
|
||||||
|
# q represents the new token being generated, one per batch
|
||||||
|
q = torch.randn(B, H_Q, D, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
# k_buffer and v_buffer represent all previous tokens
|
||||||
|
k_buffer = torch.randn(total_tokens, H_KV, D, dtype=dtype, device=device)
|
||||||
|
v_buffer = torch.randn(total_tokens, H_KV, D_V, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
key = torch.randn(B, H_KV, D, dtype=dtype)
|
||||||
|
value = torch.randn(B, H_KV, D_V, dtype=dtype)
|
||||||
|
loc = torch.randint(0, 10, (B,)).to(torch.int64)
|
||||||
|
|
||||||
|
# set kv cache
|
||||||
|
k_buffer[loc] = key
|
||||||
|
v_buffer[loc] = value
|
||||||
|
|
||||||
|
# o will have the same shape as q
|
||||||
|
o = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device)
|
||||||
|
o_grouped = torch.zeros(B, H_Q, D_V, dtype=dtype, device=device)
|
||||||
|
|
||||||
|
req_to_token = (
|
||||||
|
torch.arange(total_tokens, device=device)
|
||||||
|
.reshape(B, seq_len)
|
||||||
|
.to(torch.int32)
|
||||||
|
)
|
||||||
|
b_req_idx = torch.arange(B, device=device).to(torch.int64)
|
||||||
|
b_seq_len = torch.full((B,), seq_len, device=device).to(torch.int64)
|
||||||
|
|
||||||
|
attn_logits = torch.empty(
|
||||||
|
(B, H_Q, num_kv_splits, D_V + 1),
|
||||||
|
dtype=torch.float32,
|
||||||
|
device=device,
|
||||||
|
)
|
||||||
|
|
||||||
|
# k_buffer, v_buffer, key and value supports non-contiguous tensors
|
||||||
|
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
|
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
|
key = key.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
|
value = value.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
|
decode_attention(
|
||||||
|
q,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
o,
|
||||||
|
key,
|
||||||
|
value,
|
||||||
|
loc,
|
||||||
|
attn_logits,
|
||||||
|
req_to_token,
|
||||||
|
b_req_idx,
|
||||||
|
b_seq_len,
|
||||||
|
sm_scale,
|
||||||
|
logit_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
self._run_sdpa_forward_decode(
|
||||||
|
q,
|
||||||
|
o_grouped,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
req_to_token,
|
||||||
|
b_req_idx,
|
||||||
|
b_seq_len,
|
||||||
|
scaling=sm_scale,
|
||||||
|
enable_gqa=enable_gqa,
|
||||||
|
)
|
||||||
|
|
||||||
|
cos_sim = torch.nn.functional.cosine_similarity(
|
||||||
|
o.flatten(), o_grouped.flatten(), dim=0
|
||||||
|
)
|
||||||
|
self.assertGreater(cos_sim.item(), 0.99)
|
||||||
|
torch.testing.assert_close(o, o_grouped, atol=3e-2, rtol=1e-6)
|
||||||
|
|
||||||
|
def _test_grouped_decode_attention(self, device="cuda"):
|
||||||
|
configs = [
|
||||||
|
(2, 16, 16, 64, 64),
|
||||||
|
(2, 16, 1, 16, 16),
|
||||||
|
(2, 32, 8, 33, 55),
|
||||||
|
(2, 16, 1, 64, 64),
|
||||||
|
(2, 64, 1, 13, 13),
|
||||||
|
(2, 128, 1, 80, 80),
|
||||||
|
(2, 128, 2, 512, 512),
|
||||||
|
(1, 16, 1, 576, 512),
|
||||||
|
(1, 16, 16, 576, 512),
|
||||||
|
(1, 22, 1, 576, 512),
|
||||||
|
(1, 40, 8, 128, 128),
|
||||||
|
]
|
||||||
|
|
||||||
|
for B, H_Q, H_KV, D, D_V in configs:
|
||||||
|
self._test_grouped_decode_attention_once(
|
||||||
|
B, H_Q, H_KV, D, D_V, device=device
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_grouped_decode_attention(self):
|
||||||
|
self._test_grouped_decode_attention("cpu")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
187
test/srt/cpu/test_extend.py
Normal file
187
test/srt/cpu/test_extend.py
Normal file
@@ -0,0 +1,187 @@
|
|||||||
|
import unittest
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from sgl_kernel.common_ops import extend_attention_cpu as extend_attention
|
||||||
|
from torch.nn.functional import scaled_dot_product_attention
|
||||||
|
|
||||||
|
from sglang.test.test_utils import CustomTestCase
|
||||||
|
|
||||||
|
|
||||||
|
class TestExtendAttention(CustomTestCase):
|
||||||
|
|
||||||
|
def _run_sdpa_forward_extend(
|
||||||
|
self,
|
||||||
|
query: torch.Tensor,
|
||||||
|
output: torch.Tensor,
|
||||||
|
k_cache: torch.Tensor,
|
||||||
|
v_cache: torch.Tensor,
|
||||||
|
req_to_token: torch.Tensor,
|
||||||
|
req_pool_indices: torch.Tensor,
|
||||||
|
seq_lens: torch.Tensor,
|
||||||
|
extend_prefix_lens: torch.Tensor,
|
||||||
|
extend_seq_lens: torch.Tensor,
|
||||||
|
scaling=None,
|
||||||
|
enable_gqa=False,
|
||||||
|
causal=False,
|
||||||
|
):
|
||||||
|
|
||||||
|
assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
|
||||||
|
assert seq_lens.shape[0] == extend_seq_lens.shape[0]
|
||||||
|
|
||||||
|
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
||||||
|
query = query.movedim(0, query.dim() - 2)
|
||||||
|
|
||||||
|
start_q, start_kv = 0, 0
|
||||||
|
for seq_idx in range(seq_lens.shape[0]):
|
||||||
|
|
||||||
|
extend_seq_len_q = extend_seq_lens[seq_idx]
|
||||||
|
prefill_seq_len_q = extend_prefix_lens[seq_idx]
|
||||||
|
|
||||||
|
seq_len_kv = seq_lens[seq_idx]
|
||||||
|
end_q = start_q + extend_seq_len_q
|
||||||
|
end_kv = start_kv + seq_len_kv
|
||||||
|
|
||||||
|
per_req_query = query[:, start_q:end_q, :]
|
||||||
|
per_req_query_redudant = torch.empty(
|
||||||
|
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
|
||||||
|
dtype=per_req_query.dtype,
|
||||||
|
device=per_req_query.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query
|
||||||
|
|
||||||
|
# get key and value from cache. per_req_tokens contains the kv cache
|
||||||
|
# index for each token in the sequence.
|
||||||
|
req_pool_idx = req_pool_indices[seq_idx]
|
||||||
|
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
||||||
|
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||||
|
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
||||||
|
|
||||||
|
per_req_out_redudant = (
|
||||||
|
scaled_dot_product_attention(
|
||||||
|
per_req_query_redudant.unsqueeze(0),
|
||||||
|
per_req_key.unsqueeze(0),
|
||||||
|
per_req_value.unsqueeze(0),
|
||||||
|
enable_gqa=enable_gqa,
|
||||||
|
scale=scaling,
|
||||||
|
is_causal=causal,
|
||||||
|
)
|
||||||
|
.squeeze(0)
|
||||||
|
.movedim(query.dim() - 2, 0)
|
||||||
|
)
|
||||||
|
output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]
|
||||||
|
start_q, start_kv = end_q, end_kv
|
||||||
|
return output
|
||||||
|
|
||||||
|
def _test_extend_attention_once(self, B, N_CTX, H_Q, H_KV, D, DV, mla=False):
|
||||||
|
dtype = torch.bfloat16
|
||||||
|
|
||||||
|
b_seq_len_prefix = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32)
|
||||||
|
if mla:
|
||||||
|
b_seq_len_prefix.zero_()
|
||||||
|
b_seq_len_extend = torch.randint(1, N_CTX // 2, (B,), dtype=torch.int32)
|
||||||
|
b_seq_len = b_seq_len_prefix + b_seq_len_extend
|
||||||
|
max_len_in_batch = torch.max(b_seq_len, 0)[0].item()
|
||||||
|
|
||||||
|
b_req_idx = torch.arange(B, dtype=torch.int32)
|
||||||
|
req_to_tokens = torch.empty((B, max_len_in_batch), dtype=torch.int32)
|
||||||
|
b_start_loc = torch.zeros((B,), dtype=torch.int32)
|
||||||
|
b_start_loc[1:] = torch.cumsum(b_seq_len[:-1], 0)
|
||||||
|
b_start_loc_extend = torch.zeros((B,), dtype=torch.int32)
|
||||||
|
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
req_to_tokens[i, : b_seq_len[i]] = torch.arange(
|
||||||
|
b_start_loc[i], b_start_loc[i] + b_seq_len[i]
|
||||||
|
)
|
||||||
|
|
||||||
|
total_token_num = torch.sum(b_seq_len).item()
|
||||||
|
extend_token_num = torch.sum(b_seq_len_extend).item()
|
||||||
|
|
||||||
|
H_BUF = 1 if mla else H_KV
|
||||||
|
k_buffer = torch.randn((total_token_num, H_BUF, D), dtype=dtype)
|
||||||
|
v_buffer = torch.randn((total_token_num, H_BUF, DV), dtype=dtype)
|
||||||
|
|
||||||
|
k_extend = torch.empty((extend_token_num, H_KV, D), dtype=dtype)
|
||||||
|
v_extend = torch.empty((extend_token_num, H_KV, DV), dtype=dtype)
|
||||||
|
q_extend = torch.empty((extend_token_num, H_Q, D), dtype=dtype)
|
||||||
|
|
||||||
|
for i in range(B):
|
||||||
|
extend_start_in_buffer = b_start_loc[i] + b_seq_len_prefix[i]
|
||||||
|
extend_end_in_buffer = b_start_loc[i] + b_seq_len[i]
|
||||||
|
extend_start = b_start_loc_extend[i]
|
||||||
|
extend_end = b_start_loc_extend[i] + b_seq_len_extend[i]
|
||||||
|
k_extend[extend_start:extend_end] = k_buffer[
|
||||||
|
extend_start_in_buffer:extend_end_in_buffer
|
||||||
|
]
|
||||||
|
v_extend[extend_start:extend_end] = v_buffer[
|
||||||
|
extend_start_in_buffer:extend_end_in_buffer
|
||||||
|
]
|
||||||
|
q_extend[extend_start:extend_end] = torch.randn(
|
||||||
|
(b_seq_len_extend[i], H_Q, D), dtype=dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
|
||||||
|
k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
|
v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
|
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
|
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
|
|
||||||
|
b_seq_len_extend = b_seq_len - b_seq_len_prefix
|
||||||
|
b_start_loc_extend = torch.zeros_like(b_seq_len)
|
||||||
|
b_start_loc_extend[1:] = torch.cumsum(b_seq_len_extend[:-1], 0)
|
||||||
|
max_len_extend = torch.max(b_seq_len_extend, 0)[0].item()
|
||||||
|
|
||||||
|
sm_scale = 1.0 / (D**0.5)
|
||||||
|
logit_cap = 0.0
|
||||||
|
|
||||||
|
# handle index type
|
||||||
|
b_req_idx = b_req_idx.to(torch.int64)
|
||||||
|
b_seq_len = b_seq_len.to(torch.int64)
|
||||||
|
|
||||||
|
enable_gqa = H_Q != H_KV
|
||||||
|
o_ref = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
|
||||||
|
self._run_sdpa_forward_extend(
|
||||||
|
q_extend,
|
||||||
|
o_ref,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
req_to_tokens,
|
||||||
|
b_req_idx,
|
||||||
|
b_seq_len,
|
||||||
|
b_seq_len_prefix,
|
||||||
|
b_seq_len_extend,
|
||||||
|
scaling=sm_scale,
|
||||||
|
enable_gqa=enable_gqa,
|
||||||
|
causal=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
o_extend = torch.empty((extend_token_num, H_Q, DV), dtype=dtype)
|
||||||
|
extend_attention(
|
||||||
|
q_extend,
|
||||||
|
k_extend,
|
||||||
|
v_extend,
|
||||||
|
o_extend,
|
||||||
|
k_buffer,
|
||||||
|
v_buffer,
|
||||||
|
req_to_tokens,
|
||||||
|
b_req_idx,
|
||||||
|
b_seq_len,
|
||||||
|
b_seq_len_extend,
|
||||||
|
b_start_loc_extend,
|
||||||
|
max_len_extend,
|
||||||
|
sm_scale,
|
||||||
|
logit_cap,
|
||||||
|
)
|
||||||
|
|
||||||
|
torch.testing.assert_close(o_ref, o_extend, atol=1e-2, rtol=1e-2)
|
||||||
|
|
||||||
|
def test_extend_attention(self):
|
||||||
|
for is_mla in [True, False]:
|
||||||
|
self._test_extend_attention_once(1, 123, 1, 1, 128, 96, is_mla)
|
||||||
|
self._test_extend_attention_once(1, 123, 16, 1, 128, 96, is_mla)
|
||||||
|
self._test_extend_attention_once(4, 1230, 16, 4, 128, 96, is_mla)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
unittest.main()
|
||||||
Reference in New Issue
Block a user