Support non-contiguous query input for extend/decode attention (#7462)
This commit is contained in:
@@ -874,6 +874,8 @@ void decode_attention_kernel_impl(
|
||||
int64_t head_size,
|
||||
int64_t head_size_v,
|
||||
int64_t num_kv_splits,
|
||||
int64_t q_strideM,
|
||||
int64_t q_strideH,
|
||||
int64_t k_strideN,
|
||||
int64_t k_strideH,
|
||||
int64_t v_strideN,
|
||||
@@ -886,8 +888,6 @@ void decode_attention_kernel_impl(
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
|
||||
// strides
|
||||
const int64_t q_strideM = num_heads * head_size;
|
||||
const int64_t q_strideH = head_size;
|
||||
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
|
||||
const int64_t l_stride2 = head_size_v + 1;
|
||||
|
||||
@@ -1017,6 +1017,8 @@ void decode_attention_mla_kernel_impl(
|
||||
int64_t head_size,
|
||||
int64_t head_size_v,
|
||||
int64_t num_kv_splits,
|
||||
int64_t q_strideM,
|
||||
int64_t q_strideH,
|
||||
int64_t k_strideN,
|
||||
int64_t k_strideH,
|
||||
int64_t v_strideN,
|
||||
@@ -1033,8 +1035,6 @@ void decode_attention_mla_kernel_impl(
|
||||
const int64_t BLOCK_H = batches == 1 ? 6 : (batches > 16 ? 22 : 11);
|
||||
|
||||
// strides
|
||||
const int64_t q_strideM = num_heads * head_size;
|
||||
const int64_t q_strideH = head_size;
|
||||
const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1);
|
||||
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
|
||||
const int64_t l_stride2 = head_size_v + 1;
|
||||
@@ -1209,6 +1209,8 @@ void decode_attention_grouped_kernel_impl(
|
||||
int64_t head_size,
|
||||
int64_t head_size_v,
|
||||
int64_t num_kv_splits,
|
||||
int64_t q_strideM,
|
||||
int64_t q_strideH,
|
||||
int64_t k_strideN,
|
||||
int64_t k_strideH,
|
||||
int64_t v_strideN,
|
||||
@@ -1227,8 +1229,6 @@ void decode_attention_grouped_kernel_impl(
|
||||
const int64_t BLOCK_H = std::min(4 * batches, kBLOCK_H);
|
||||
|
||||
// strides
|
||||
const int64_t q_strideM = num_heads * head_size;
|
||||
const int64_t q_strideH = head_size;
|
||||
const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1);
|
||||
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
|
||||
const int64_t l_stride2 = head_size_v + 1;
|
||||
@@ -1391,7 +1391,7 @@ void decode_attention_cpu(
|
||||
std::vector<c10::IValue>(
|
||||
{query, output, k_buffer, v_buffer, attn_logits, req_to_token, req_pool_indices, seq_lens}));
|
||||
|
||||
CHECK_INPUT(query);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(query);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
|
||||
// for MLA, key and value shares the same storage and value could be non-contiguous
|
||||
@@ -1422,6 +1422,10 @@ void decode_attention_cpu(
|
||||
CHECK_EQ(attn_logits.size(3), head_size_v + 1);
|
||||
CHECK_EQ(attn_logits.scalar_type(), at::kFloat);
|
||||
|
||||
// strides for query
|
||||
int64_t q_strideM = query.stride(0);
|
||||
int64_t q_strideH = query.stride(1);
|
||||
|
||||
// strides for k_buffer and v_buffer
|
||||
int64_t k_strideN = k_buffer.stride(0);
|
||||
int64_t k_strideH = k_buffer.stride(1);
|
||||
@@ -1497,6 +1501,8 @@ void decode_attention_cpu(
|
||||
head_size,
|
||||
head_size_v,
|
||||
num_kv_splits,
|
||||
q_strideM,
|
||||
q_strideH,
|
||||
k_strideN,
|
||||
k_strideH,
|
||||
v_strideN,
|
||||
@@ -1523,6 +1529,8 @@ void decode_attention_cpu(
|
||||
head_size,
|
||||
head_size_v,
|
||||
num_kv_splits,
|
||||
q_strideM,
|
||||
q_strideH,
|
||||
k_strideN,
|
||||
k_strideH,
|
||||
v_strideN,
|
||||
@@ -1550,6 +1558,8 @@ void decode_attention_cpu(
|
||||
head_size,
|
||||
head_size_v,
|
||||
num_kv_splits,
|
||||
q_strideM,
|
||||
q_strideH,
|
||||
k_strideN,
|
||||
k_strideH,
|
||||
v_strideN,
|
||||
|
||||
Reference in New Issue
Block a user