Rewrite the stride logic for the mask tensor in the FA shader to force the stride to be aligned, to allow using more efficient loads.
384 lines
14 KiB
Plaintext
384 lines
14 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : enable
|
|
#extension GL_EXT_shader_16bit_storage : require
|
|
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
|
|
#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
|
|
|
|
#extension GL_KHR_memory_scope_semantics : enable
|
|
#extension GL_KHR_cooperative_matrix : enable
|
|
#extension GL_NV_cooperative_matrix2 : enable
|
|
#extension GL_EXT_buffer_reference : enable
|
|
#extension GL_KHR_shader_subgroup_ballot : enable
|
|
#extension GL_KHR_shader_subgroup_vote : enable
|
|
#extension GL_EXT_null_initializer : enable
|
|
|
|
#include "types.comp"
|
|
#include "dequant_funcs_cm2.comp"
|
|
|
|
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout (constant_id = 1) const uint32_t Br = 32;
|
|
layout (constant_id = 2) const uint32_t Bc = 32;
|
|
layout (constant_id = 3) const uint32_t D = 32;
|
|
layout (constant_id = 4) const uint32_t Clamp = gl_CooperativeMatrixClampModeConstantNV;
|
|
|
|
layout (push_constant) uniform parameter {
|
|
uint32_t N;
|
|
uint32_t KV;
|
|
|
|
uint32_t ne1;
|
|
uint32_t ne2;
|
|
uint32_t ne3;
|
|
|
|
uint32_t neq2;
|
|
uint32_t neq3;
|
|
uint32_t nek2;
|
|
uint32_t nek3;
|
|
uint32_t nev2;
|
|
uint32_t nev3;
|
|
uint32_t nem1;
|
|
|
|
uint32_t nb01;
|
|
uint32_t nb02;
|
|
uint32_t nb03;
|
|
uint32_t nb11;
|
|
uint32_t nb12;
|
|
uint32_t nb13;
|
|
uint32_t nb21;
|
|
uint32_t nb22;
|
|
uint32_t nb23;
|
|
uint32_t nb31;
|
|
|
|
float scale;
|
|
float max_bias;
|
|
float logit_softcap;
|
|
|
|
uint32_t mask;
|
|
uint32_t n_head_log2;
|
|
float m0;
|
|
float m1;
|
|
|
|
uint32_t gqa_ratio;
|
|
uint32_t split_kv;
|
|
uint32_t k_num;
|
|
} p;
|
|
|
|
layout (binding = 0) readonly buffer Q {uint8_t data_q[];};
|
|
layout (binding = 1) readonly buffer K {uint8_t data_k[];};
|
|
layout (binding = 2) readonly buffer V {uint8_t data_v[];};
|
|
layout (binding = 3) readonly buffer M {uint8_t data_m[];};
|
|
layout (binding = 4) writeonly buffer O {D_TYPE data_o[];};
|
|
|
|
#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b))
|
|
|
|
ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
|
return max(x, y);
|
|
}
|
|
|
|
ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) {
|
|
return x;
|
|
}
|
|
|
|
// Replace matrix elements >= numRows or numCols with 'replace'
|
|
ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) {
|
|
if (row >= numRows || col >= numCols) {
|
|
return replace;
|
|
}
|
|
return elem;
|
|
}
|
|
|
|
ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem)
|
|
{
|
|
return exp(elem);
|
|
}
|
|
|
|
ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1)
|
|
{
|
|
return max(elem0, elem1);
|
|
}
|
|
|
|
#if defined(BLOCK_SIZE)
|
|
#define DECODEFUNC , DEQUANTFUNC
|
|
#else
|
|
#define DECODEFUNC
|
|
#endif
|
|
|
|
// Store the output when doing grouped query attention.
|
|
// Rows index by Q's dimension 2, and the first N rows are valid.
|
|
D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
{
|
|
if (r < N && c < D) {
|
|
uint32_t offset = (iq2 + r) * D + c;
|
|
data_o[o_offset + offset] = D_TYPE(elem);
|
|
}
|
|
return elem;
|
|
}
|
|
|
|
// Store column zero. This is used to save per-row m and L values for split_k.
|
|
ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N)
|
|
{
|
|
if (r < N && c == 0) {
|
|
uint32_t offset = iq2 + r;
|
|
data_o[o_offset + offset] = D_TYPE(elem);
|
|
}
|
|
return elem;
|
|
}
|
|
|
|
// Load the slope matrix, indexed by Q's dimension 2.
|
|
ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2)
|
|
{
|
|
const uint32_t h = iq2 + (r & (p.gqa_ratio - 1));
|
|
|
|
const ACC_TYPE base = ACC_TYPE(h < p.n_head_log2 ? p.m0 : p.m1);
|
|
const int exph = int(h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1);
|
|
|
|
return ACC_TYPE(pow(base, ACC_TYPE(exph)));
|
|
}
|
|
|
|
void main() {
|
|
#ifdef NEEDS_INIT_IQ_SHMEM
|
|
init_iq_shmem(gl_WorkGroupSize);
|
|
#endif
|
|
|
|
const uint32_t N = p.N;
|
|
const uint32_t KV = p.KV;
|
|
|
|
uint32_t i = gl_WorkGroupID.x;
|
|
uint32_t split_k_index = 0;
|
|
|
|
if (p.k_num > 1) {
|
|
i = 0;
|
|
split_k_index = gl_WorkGroupID.x;
|
|
}
|
|
|
|
const uint32_t Tr = CEIL_DIV(N, Br);
|
|
|
|
const uint32_t start_j = split_k_index * p.split_kv / Bc;
|
|
const uint32_t end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc);
|
|
|
|
// When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y.
|
|
// When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2.
|
|
const uint32_t iq2 = gl_WorkGroupID.y * p.gqa_ratio;
|
|
const uint32_t iq3 = gl_WorkGroupID.z;
|
|
|
|
// broadcast factors
|
|
const uint32_t rk2 = p.neq2/p.nek2;
|
|
const uint32_t rk3 = p.neq3/p.nek3;
|
|
|
|
const uint32_t rv2 = p.neq2/p.nev2;
|
|
const uint32_t rv3 = p.neq3/p.nev3;
|
|
|
|
// k indices
|
|
const uint32_t ik3 = iq3 / rk3;
|
|
const uint32_t ik2 = iq2 / rk2;
|
|
|
|
// v indices
|
|
const uint32_t iv3 = iq3 / rv3;
|
|
const uint32_t iv2 = iq2 / rv2;
|
|
|
|
tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV);
|
|
tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp);
|
|
tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp);
|
|
|
|
tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0);
|
|
|
|
#if defined(BLOCK_SIZE)
|
|
tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE);
|
|
tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE);
|
|
#endif
|
|
|
|
tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, D);
|
|
tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, D);
|
|
tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, D);
|
|
|
|
// nb?1 are already divided by the type size and are in units of elements.
|
|
// When using grouped query attention, Q is indexed by iq2, so the stride
|
|
// should be nb02 (which is in bytes).
|
|
uint32_t q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01;
|
|
uint32_t k_stride = p.nb11;
|
|
uint32_t v_stride = p.nb21;
|
|
// When using grouped query attention, all rows use the same mask (stride 0).
|
|
// "p.gqa_ratio >> 16" is just a roundabout way of writing zero
|
|
// that prevents the compiler from folding the "&" through the select
|
|
// and breaking the alignment detection.
|
|
uint32_t m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV;
|
|
// hint to the compiler that strides are aligned for the aligned variant of the shader
|
|
if (Clamp != gl_CooperativeMatrixClampModeConstantNV)
|
|
{
|
|
q_stride &= ~7;
|
|
#if !defined(BLOCK_SIZE)
|
|
k_stride &= ~7;
|
|
v_stride &= ~7;
|
|
#endif
|
|
m_stride &= ~7;
|
|
}
|
|
tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1);
|
|
tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1);
|
|
tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1);
|
|
|
|
coopmat<Q_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Q;
|
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA> Qf16;
|
|
|
|
uint32_t q_offset = iq2*p.nb02+iq3*p.nb03;
|
|
coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, D));
|
|
|
|
Qf16 = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseA>(Q);
|
|
Qf16 *= float16_t(p.scale);
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> L, M;
|
|
|
|
// Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M.
|
|
const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF);
|
|
|
|
L = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
|
M = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(NEG_FLT_MAX_OVER_2);
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> slopeMat = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(1.0);
|
|
|
|
// ALiBi
|
|
if (p.max_bias > 0.0f) {
|
|
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
|
|
}
|
|
|
|
[[dont_unroll]]
|
|
for (uint32_t j = start_j; j < end_j; ++j) {
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> S = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0);
|
|
|
|
coopmat<float16_t, gl_ScopeWorkgroup, D, Bc, gl_MatrixUseB> K_T;
|
|
|
|
uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13;
|
|
coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, D), tensorViewTranspose DECODEFUNC);
|
|
S = coopMatMulAdd(Qf16, K_T, S);
|
|
|
|
if (p.logit_softcap != 0.0f) {
|
|
[[unroll]]
|
|
for (int k = 0; k < S.length(); ++k) {
|
|
S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]);
|
|
}
|
|
}
|
|
|
|
if (p.mask != 0) {
|
|
tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp);
|
|
tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV);
|
|
tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1);
|
|
|
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
|
|
|
|
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
|
|
|
|
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
|
|
}
|
|
|
|
// Clear padding elements to -inf, so they don't contribute to rowmax
|
|
if (Clamp != 0 &&
|
|
((j + 1) * Bc > KV ||
|
|
(i + 1) * Br > N)) {
|
|
|
|
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
|
|
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
|
|
|
|
coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C);
|
|
}
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> rowmax, P, rowsum, eM;
|
|
|
|
coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce);
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> Mold = M;
|
|
|
|
// M = max(rowmax, Mold)
|
|
// P = e^(S - M)
|
|
// eM = e^(Mold - M)
|
|
coopMatPerElementNV(M, rowmax, Max, Mold);
|
|
coopMatPerElementNV(P, S - M, Exp);
|
|
coopMatPerElementNV(eM, Mold - M, Exp);
|
|
|
|
// Clear padding elements to 0, so they don't contribute to rowsum
|
|
if (Clamp != 0 &&
|
|
((j + 1) * Bc > KV ||
|
|
(i + 1) * Br > N)) {
|
|
|
|
uint R = ((i + 1) * Br > N) ? (N % Br) : Br;
|
|
uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc;
|
|
|
|
coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C);
|
|
}
|
|
|
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA> P_A = coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseA>(P);
|
|
|
|
// compute rowsum by multiplying by matrix of all ones.
|
|
coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB> One = coopmat<float16_t, gl_ScopeWorkgroup, Bc, Bc, gl_MatrixUseB>(1.0);
|
|
|
|
rowsum = coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(0.0);
|
|
rowsum = coopMatMulAdd(P_A, One, rowsum);
|
|
|
|
coopmat<float16_t, gl_ScopeWorkgroup, Bc, D, gl_MatrixUseB> V;
|
|
uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23;
|
|
coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, D) DECODEFUNC);
|
|
|
|
L = eM*L + rowsum;
|
|
|
|
// This is the "diagonal" matrix in the paper, but since we do componentwise
|
|
// multiply rather than matrix multiply it has the diagonal element smeared
|
|
// across the row
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> eMdiag;
|
|
|
|
// resize eM by using smear/reduce
|
|
coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
|
|
|
// multiply with fp16 accumulation, then add to O.
|
|
coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> PV = coopmat<float16_t, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(0);
|
|
PV = coopMatMulAdd(P_A, V, PV);
|
|
|
|
O = eMdiag * O + coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(PV);
|
|
}
|
|
|
|
// If there is split_k, then the split_k resolve shader does the final
|
|
// division by L. Store the intermediate O value and per-row m and L values.
|
|
if (p.k_num > 1) {
|
|
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
|
|
|
uint32_t o_offset = D * p.ne1 * split_k_index;
|
|
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
|
|
|
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
|
|
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
|
|
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
|
|
return;
|
|
}
|
|
|
|
coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> Ldiag;
|
|
|
|
// resize L by using smear/reduce
|
|
coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce);
|
|
|
|
[[unroll]]
|
|
for (int k = 0; k < Ldiag.length(); ++k) {
|
|
Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k];
|
|
}
|
|
|
|
O = Ldiag*O;
|
|
|
|
uint32_t o_offset = iq3*p.ne2*p.ne1;
|
|
|
|
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
|
|
if (p.gqa_ratio > 1) {
|
|
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
|
|
} else {
|
|
tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV);
|
|
tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, D);
|
|
|
|
// permute dimensions
|
|
tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2);
|
|
|
|
coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, D), tensorViewPermute);
|
|
}
|
|
}
|