vulkan: move common FA code to flash_attn_base.comp (#13556)
* vulkan: move common FA code to flash_attn_base.comp * vulkan: move common FA index/stride setup code to flash_attn_base.comp * build fix
This commit is contained in:
@@ -18,62 +18,12 @@
|
||||
|
||||
#include "types.comp"
|
||||
#include "dequant_funcs_cm2.comp"
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
layout (constant_id = 1) const uint32_t Br = 32;
|
||||
layout (constant_id = 2) const uint32_t Bc = 32;
|
||||
layout (constant_id = 3) const uint32_t D = 32;
|
||||
layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
|
||||
|
||||
layout (push_constant) uniform parameter {
|
||||
uint32_t N;
|
||||
uint32_t KV;
|
||||
|
||||
uint32_t ne1;
|
||||
uint32_t ne2;
|
||||
uint32_t ne3;
|
||||
|
||||
uint32_t neq2;
|
||||
uint32_t neq3;
|
||||
uint32_t nek2;
|
||||
uint32_t nek3;
|
||||
uint32_t nev2;
|
||||
uint32_t nev3;
|
||||
uint32_t nem1;
|
||||
|
||||
uint32_t nb01;
|
||||
uint32_t nb02;
|
||||
uint32_t nb03;
|
||||
uint32_t nb11;
|
||||
uint32_t nb12;
|
||||
uint32_t nb13;
|
||||
uint32_t nb21;
|
||||
uint32_t nb22;
|
||||
uint32_t nb23;
|
||||
uint32_t nb31;
|
||||
|
||||
float scale;
|
||||
float max_bias;
|
||||
float logit_softcap;
|
||||
|
||||
uint32_t mask;
|
||||
uint32_t n_head_log2;
|
||||
float m0;
|
||||
float m1;
|
||||
|
||||
uint32_t gqa_ratio;
|
||||
uint32_t split_kv;
|
||||
uint32_t k_num;
|
||||
} p;
|
||||
#include "flash_attn_base.comp"
|
||||
|
||||
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
||||
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
|
||||
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
|
||||
layout (binding = 3) readonly buffer M {uint8_t data_m[];};
|
||||
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
||||
|
||||
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
||||
|
||||
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
||||
return max(x, y);
|
||||
@@ -118,67 +68,12 @@ D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TY
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Store column zero. This is used to save per-row m and L values for split_k.
|
||||
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
||||
{
|
||||
if (r < N && c == 0) {
|
||||
uint32_t offset = iq2 + r;
|
||||
data_o[o_offset + offset] = D_TYPE(elem);
|
||||
}
|
||||
return elem;
|
||||
}
|
||||
|
||||
// Load the slope matrix, indexed by Q's dimension 2.
|
||||
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
||||
{
|
||||
const uint32_t h = iq2 + (r % p.gqa_ratio);
|
||||
|
||||
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
||||
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
||||
|
||||
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
||||
}
|
||||
|
||||
void main() {
|
||||
#ifdef NEEDS_INIT_IQ_SHMEM
|
||||
init_iq_shmem(gl_WorkGroupSize);
|
||||
#endif
|
||||
|
||||
const uint32_t N = p.N;
|
||||
const uint32_t KV = p.KV;
|
||||
|
||||
uint32_t i = gl_WorkGroupID.x;
|
||||
uint32_t split_k_index = 0;
|
||||
|
||||
if (p.k_num > 1) {
|
||||
i = 0;
|
||||
split_k_index = gl_WorkGroupID.x;
|
||||
}
|
||||
|
||||
const uint32_t Tr = CEIL_DIV(N, Br);
|
||||
|
||||
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
||||
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
||||
|
||||
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
||||
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
||||
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
||||
const uint32_t iq3 = gl_WorkGroupID.z;
|
||||
|
||||
// broadcast factors
|
||||
const uint32_t rk2 = p.neq2/p.nek2;
|
||||
const uint32_t rk3 = p.neq3/p.nek3;
|
||||
|
||||
const uint32_t rv2 = p.neq2/p.nev2;
|
||||
const uint32_t rv3 = p.neq3/p.nev3;
|
||||
|
||||
// k indices
|
||||
const uint32_t ik3 = iq3 / rk3;
|
||||
const uint32_t ik2 = iq2 / rk2;
|
||||
|
||||
// v indices
|
||||
const uint32_t iv3 = iq3 / rv3;
|
||||
const uint32_t iv2 = iq2 / rv2;
|
||||
init_indices();
|
||||
|
||||
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
||||
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
|
||||
@@ -195,17 +90,6 @@ void main() {
|
||||
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
||||
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
||||
|
||||
// nb?1 are already divided by the type size and are in units of elements.
|
||||
// When using grouped query attention, Q is indexed by iq2, so the stride
|
||||
// should be nb02 (which is in bytes).
|
||||
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
||||
uint32_t k_stride = p.nb11;
|
||||
uint32_t v_stride = p.nb21;
|
||||
// When using grouped query attention, all rows use the same mask (stride 0).
|
||||
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
||||
// that prevents the compiler from folding the "&" through the select
|
||||
// and breaking the alignment detection.
|
||||
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
||||
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
||||
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user