Support non-contiguous query input for extend/decode attention (#7462)

This commit is contained in:
YanbingJiang
2025-07-03 10:59:45 +08:00
committed by GitHub
parent 40e5cb7a9c
commit b044400dd3
4 changed files with 29 additions and 13 deletions

View File

@@ -240,6 +240,8 @@ void extend_attention_kernel_impl(
int num_heads_kv,
int head_size,
int head_size_v,
int q_strideM,
int q_strideH,
int ke_strideN,
int ke_strideH,
int ve_strideN,
@@ -259,8 +261,6 @@ void extend_attention_kernel_impl(
using Vec = at::vec::Vectorized<float>;
// strides
const int q_strideM = num_heads * head_size;
const int q_strideH = head_size;
const int o_strideM = num_heads * head_size_v;
const int o_strideH = head_size_v;
@@ -606,7 +606,7 @@ void extend_attention_cpu(
extend_seq_lens,
extend_start_loc}));
CHECK_INPUT(q_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_extend);
CHECK_INPUT(o_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend);
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend);
@@ -623,7 +623,9 @@ void extend_attention_cpu(
int head_size = q_extend.size(2);
int head_size_v = v_extend.size(2);
// strides for k_extend and v_extend
// strides for q_extend, k_extend and v_extend
int q_strideM = q_extend.stride(0);
int q_strideH = q_extend.stride(1);
int ke_strideN = k_extend.stride(0);
int ke_strideH = k_extend.stride(1);
int ve_strideN = v_extend.stride(0);
@@ -698,6 +700,8 @@ void extend_attention_cpu(
num_heads_kv,
head_size,
head_size_v,
q_strideM,
q_strideH,
ke_strideN,
ke_strideH,
ve_strideN,