Support non-contiguous query input for extend/decode attention (#7462)
This commit is contained in:
@@ -874,6 +874,8 @@ void decode_attention_kernel_impl(
|
|||||||
int64_t head_size,
|
int64_t head_size,
|
||||||
int64_t head_size_v,
|
int64_t head_size_v,
|
||||||
int64_t num_kv_splits,
|
int64_t num_kv_splits,
|
||||||
|
int64_t q_strideM,
|
||||||
|
int64_t q_strideH,
|
||||||
int64_t k_strideN,
|
int64_t k_strideN,
|
||||||
int64_t k_strideH,
|
int64_t k_strideH,
|
||||||
int64_t v_strideN,
|
int64_t v_strideN,
|
||||||
@@ -886,8 +888,6 @@ void decode_attention_kernel_impl(
|
|||||||
using Vec = at::vec::Vectorized<float>;
|
using Vec = at::vec::Vectorized<float>;
|
||||||
|
|
||||||
// strides
|
// strides
|
||||||
const int64_t q_strideM = num_heads * head_size;
|
|
||||||
const int64_t q_strideH = head_size;
|
|
||||||
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
|
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
|
||||||
const int64_t l_stride2 = head_size_v + 1;
|
const int64_t l_stride2 = head_size_v + 1;
|
||||||
|
|
||||||
@@ -1017,6 +1017,8 @@ void decode_attention_mla_kernel_impl(
|
|||||||
int64_t head_size,
|
int64_t head_size,
|
||||||
int64_t head_size_v,
|
int64_t head_size_v,
|
||||||
int64_t num_kv_splits,
|
int64_t num_kv_splits,
|
||||||
|
int64_t q_strideM,
|
||||||
|
int64_t q_strideH,
|
||||||
int64_t k_strideN,
|
int64_t k_strideN,
|
||||||
int64_t k_strideH,
|
int64_t k_strideH,
|
||||||
int64_t v_strideN,
|
int64_t v_strideN,
|
||||||
@@ -1033,8 +1035,6 @@ void decode_attention_mla_kernel_impl(
|
|||||||
const int64_t BLOCK_H = batches == 1 ? 6 : (batches > 16 ? 22 : 11);
|
const int64_t BLOCK_H = batches == 1 ? 6 : (batches > 16 ? 22 : 11);
|
||||||
|
|
||||||
// strides
|
// strides
|
||||||
const int64_t q_strideM = num_heads * head_size;
|
|
||||||
const int64_t q_strideH = head_size;
|
|
||||||
const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1);
|
const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1);
|
||||||
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
|
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
|
||||||
const int64_t l_stride2 = head_size_v + 1;
|
const int64_t l_stride2 = head_size_v + 1;
|
||||||
@@ -1209,6 +1209,8 @@ void decode_attention_grouped_kernel_impl(
|
|||||||
int64_t head_size,
|
int64_t head_size,
|
||||||
int64_t head_size_v,
|
int64_t head_size_v,
|
||||||
int64_t num_kv_splits,
|
int64_t num_kv_splits,
|
||||||
|
int64_t q_strideM,
|
||||||
|
int64_t q_strideH,
|
||||||
int64_t k_strideN,
|
int64_t k_strideN,
|
||||||
int64_t k_strideH,
|
int64_t k_strideH,
|
||||||
int64_t v_strideN,
|
int64_t v_strideN,
|
||||||
@@ -1227,8 +1229,6 @@ void decode_attention_grouped_kernel_impl(
|
|||||||
const int64_t BLOCK_H = std::min(4 * batches, kBLOCK_H);
|
const int64_t BLOCK_H = std::min(4 * batches, kBLOCK_H);
|
||||||
|
|
||||||
// strides
|
// strides
|
||||||
const int64_t q_strideM = num_heads * head_size;
|
|
||||||
const int64_t q_strideH = head_size;
|
|
||||||
const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1);
|
const int64_t l_stride0 = num_heads * num_kv_splits * (head_size_v + 1);
|
||||||
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
|
const int64_t l_stride1 = num_kv_splits * (head_size_v + 1);
|
||||||
const int64_t l_stride2 = head_size_v + 1;
|
const int64_t l_stride2 = head_size_v + 1;
|
||||||
@@ -1391,7 +1391,7 @@ void decode_attention_cpu(
|
|||||||
std::vector<c10::IValue>(
|
std::vector<c10::IValue>(
|
||||||
{query, output, k_buffer, v_buffer, attn_logits, req_to_token, req_pool_indices, seq_lens}));
|
{query, output, k_buffer, v_buffer, attn_logits, req_to_token, req_pool_indices, seq_lens}));
|
||||||
|
|
||||||
CHECK_INPUT(query);
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(query);
|
||||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_buffer);
|
||||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_buffer);
|
||||||
// for MLA, key and value shares the same storage and value could be non-contiguous
|
// for MLA, key and value shares the same storage and value could be non-contiguous
|
||||||
@@ -1422,6 +1422,10 @@ void decode_attention_cpu(
|
|||||||
CHECK_EQ(attn_logits.size(3), head_size_v + 1);
|
CHECK_EQ(attn_logits.size(3), head_size_v + 1);
|
||||||
CHECK_EQ(attn_logits.scalar_type(), at::kFloat);
|
CHECK_EQ(attn_logits.scalar_type(), at::kFloat);
|
||||||
|
|
||||||
|
// strides for query
|
||||||
|
int64_t q_strideM = query.stride(0);
|
||||||
|
int64_t q_strideH = query.stride(1);
|
||||||
|
|
||||||
// strides for k_buffer and v_buffer
|
// strides for k_buffer and v_buffer
|
||||||
int64_t k_strideN = k_buffer.stride(0);
|
int64_t k_strideN = k_buffer.stride(0);
|
||||||
int64_t k_strideH = k_buffer.stride(1);
|
int64_t k_strideH = k_buffer.stride(1);
|
||||||
@@ -1497,6 +1501,8 @@ void decode_attention_cpu(
|
|||||||
head_size,
|
head_size,
|
||||||
head_size_v,
|
head_size_v,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
q_strideM,
|
||||||
|
q_strideH,
|
||||||
k_strideN,
|
k_strideN,
|
||||||
k_strideH,
|
k_strideH,
|
||||||
v_strideN,
|
v_strideN,
|
||||||
@@ -1523,6 +1529,8 @@ void decode_attention_cpu(
|
|||||||
head_size,
|
head_size,
|
||||||
head_size_v,
|
head_size_v,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
q_strideM,
|
||||||
|
q_strideH,
|
||||||
k_strideN,
|
k_strideN,
|
||||||
k_strideH,
|
k_strideH,
|
||||||
v_strideN,
|
v_strideN,
|
||||||
@@ -1550,6 +1558,8 @@ void decode_attention_cpu(
|
|||||||
head_size,
|
head_size,
|
||||||
head_size_v,
|
head_size_v,
|
||||||
num_kv_splits,
|
num_kv_splits,
|
||||||
|
q_strideM,
|
||||||
|
q_strideH,
|
||||||
k_strideN,
|
k_strideN,
|
||||||
k_strideH,
|
k_strideH,
|
||||||
v_strideN,
|
v_strideN,
|
||||||
|
|||||||
@@ -240,6 +240,8 @@ void extend_attention_kernel_impl(
|
|||||||
int num_heads_kv,
|
int num_heads_kv,
|
||||||
int head_size,
|
int head_size,
|
||||||
int head_size_v,
|
int head_size_v,
|
||||||
|
int q_strideM,
|
||||||
|
int q_strideH,
|
||||||
int ke_strideN,
|
int ke_strideN,
|
||||||
int ke_strideH,
|
int ke_strideH,
|
||||||
int ve_strideN,
|
int ve_strideN,
|
||||||
@@ -259,8 +261,6 @@ void extend_attention_kernel_impl(
|
|||||||
using Vec = at::vec::Vectorized<float>;
|
using Vec = at::vec::Vectorized<float>;
|
||||||
|
|
||||||
// strides
|
// strides
|
||||||
const int q_strideM = num_heads * head_size;
|
|
||||||
const int q_strideH = head_size;
|
|
||||||
const int o_strideM = num_heads * head_size_v;
|
const int o_strideM = num_heads * head_size_v;
|
||||||
const int o_strideH = head_size_v;
|
const int o_strideH = head_size_v;
|
||||||
|
|
||||||
@@ -606,7 +606,7 @@ void extend_attention_cpu(
|
|||||||
extend_seq_lens,
|
extend_seq_lens,
|
||||||
extend_start_loc}));
|
extend_start_loc}));
|
||||||
|
|
||||||
CHECK_INPUT(q_extend);
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_extend);
|
||||||
CHECK_INPUT(o_extend);
|
CHECK_INPUT(o_extend);
|
||||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend);
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend);
|
||||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend);
|
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend);
|
||||||
@@ -623,7 +623,9 @@ void extend_attention_cpu(
|
|||||||
int head_size = q_extend.size(2);
|
int head_size = q_extend.size(2);
|
||||||
int head_size_v = v_extend.size(2);
|
int head_size_v = v_extend.size(2);
|
||||||
|
|
||||||
// strides for k_extend and v_extend
|
// strides for q_extend, k_extend and v_extend
|
||||||
|
int q_strideM = q_extend.stride(0);
|
||||||
|
int q_strideH = q_extend.stride(1);
|
||||||
int ke_strideN = k_extend.stride(0);
|
int ke_strideN = k_extend.stride(0);
|
||||||
int ke_strideH = k_extend.stride(1);
|
int ke_strideH = k_extend.stride(1);
|
||||||
int ve_strideN = v_extend.stride(0);
|
int ve_strideN = v_extend.stride(0);
|
||||||
@@ -698,6 +700,8 @@ void extend_attention_cpu(
|
|||||||
num_heads_kv,
|
num_heads_kv,
|
||||||
head_size,
|
head_size,
|
||||||
head_size_v,
|
head_size_v,
|
||||||
|
q_strideM,
|
||||||
|
q_strideH,
|
||||||
ke_strideN,
|
ke_strideN,
|
||||||
ke_strideH,
|
ke_strideH,
|
||||||
ve_strideN,
|
ve_strideN,
|
||||||
|
|||||||
@@ -102,9 +102,10 @@ class TestDecodeAttention(CustomTestCase):
|
|||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
# k_buffer, v_buffer, key and value supports non-contiguous tensors
|
# k_buffer, v_buffer, query, key and value supports non-contiguous tensors
|
||||||
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
v_buffer = v_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
|
q = q.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
key = key.transpose(0, 1).contiguous().transpose(0, 1)
|
key = key.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
value = value.transpose(0, 1).contiguous().transpose(0, 1)
|
value = value.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
torch.ops.sgl_kernel.decode_attention_cpu(
|
torch.ops.sgl_kernel.decode_attention_cpu(
|
||||||
|
|||||||
@@ -123,7 +123,8 @@ class TestExtendAttention(CustomTestCase):
|
|||||||
(b_seq_len_extend[i], H_Q, D), dtype=dtype
|
(b_seq_len_extend[i], H_Q, D), dtype=dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
# k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
|
# q_extend, k_extend, v_extend, k_buffer and v_buffer supports non-contiguous tensors
|
||||||
|
q_extend = q_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
k_extend = k_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
v_extend = v_extend.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
k_buffer = k_buffer.transpose(0, 1).contiguous().transpose(0, 1)
|
||||||
|
|||||||
Reference in New Issue
Block a user