Support non-contiguous query input for extend/decode attention (#7462)
This commit is contained in:
@@ -240,6 +240,8 @@ void extend_attention_kernel_impl(
|
||||
int num_heads_kv,
|
||||
int head_size,
|
||||
int head_size_v,
|
||||
int q_strideM,
|
||||
int q_strideH,
|
||||
int ke_strideN,
|
||||
int ke_strideH,
|
||||
int ve_strideN,
|
||||
@@ -259,8 +261,6 @@ void extend_attention_kernel_impl(
|
||||
using Vec = at::vec::Vectorized<float>;
|
||||
|
||||
// strides
|
||||
const int q_strideM = num_heads * head_size;
|
||||
const int q_strideH = head_size;
|
||||
const int o_strideM = num_heads * head_size_v;
|
||||
const int o_strideH = head_size_v;
|
||||
|
||||
@@ -606,7 +606,7 @@ void extend_attention_cpu(
|
||||
extend_seq_lens,
|
||||
extend_start_loc}));
|
||||
|
||||
CHECK_INPUT(q_extend);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(q_extend);
|
||||
CHECK_INPUT(o_extend);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(k_extend);
|
||||
CHECK_LAST_DIM_CONTIGUOUS_INPUT(v_extend);
|
||||
@@ -623,7 +623,9 @@ void extend_attention_cpu(
|
||||
int head_size = q_extend.size(2);
|
||||
int head_size_v = v_extend.size(2);
|
||||
|
||||
// strides for k_extend and v_extend
|
||||
// strides for q_extend, k_extend and v_extend
|
||||
int q_strideM = q_extend.stride(0);
|
||||
int q_strideH = q_extend.stride(1);
|
||||
int ke_strideN = k_extend.stride(0);
|
||||
int ke_strideH = k_extend.stride(1);
|
||||
int ve_strideN = v_extend.stride(0);
|
||||
@@ -698,6 +700,8 @@ void extend_attention_cpu(
|
||||
num_heads_kv,
|
||||
head_size,
|
||||
head_size_v,
|
||||
q_strideM,
|
||||
q_strideH,
|
||||
ke_strideN,
|
||||
ke_strideH,
|
||||
ve_strideN,
|
||||
|
||||
Reference in New Issue
Block a user