vulkan: support softmax/FA batch and broadcast (#14449)

This commit is contained in:
Jeff Bolz
2025-07-01 03:32:56 -05:00
committed by Georgi Gerganov
parent ec68e84c32
commit 8875523eb3
7 changed files with 80 additions and 44 deletions

View File

@@ -130,6 +130,11 @@ void main() {
coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2);
}
uint32_t m_offset = 0;
if (p.nem2 != 1) {
m_offset = (iq3 % p.nem2) * p.nem1 * KV * 2 /*sizeof(float16_t)*/;
}
[[dont_unroll]]
for (uint32_t j = start_j; j < end_j; ++j) {
@@ -155,7 +160,7 @@ void main() {
coopmat<float16_t, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator> mv;
coopMatLoadTensorNV(mv, data_m, 0, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc));
S += slopeMat*coopmat<ACC_TYPE, gl_ScopeWorkgroup, Br, Bc, gl_MatrixUseAccumulator>(mv);
}
@@ -229,10 +234,10 @@ void main() {
if (p.k_num > 1) {
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
uint32_t o_offset = D * p.ne1 * split_k_index;
uint32_t o_offset = D * p.ne1 * (split_k_index + iq3 * p.k_num);
coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N);
o_offset = D * p.ne1 * p.k_num + p.ne1 * split_k_index * 2;
o_offset = D * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2;
coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N);
coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N);
return;
@@ -250,7 +255,7 @@ void main() {
O = Ldiag*O;
uint32_t o_offset = iq3*p.ne2*p.ne1;
uint32_t o_offset = iq3*p.ne2*p.ne1*D;
coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator> O_D = coopmat<D_TYPE, gl_ScopeWorkgroup, Br, D, gl_MatrixUseAccumulator>(O);
if (p.gqa_ratio > 1) {