vulkan: Remove splitting for mul_mat_id (#15568)

row_ids only needs to hold the BN rows for the current tile.
This commit is contained in:
Jeff Bolz
2025-08-25 23:42:44 -05:00
committed by GitHub
parent 74f52f77f2
commit 34bdbbd7c2
4 changed files with 35 additions and 56 deletions

View File

@@ -2090,10 +2090,11 @@ static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vec
const uint32_t warps = warptile[0] / warptile[10]; const uint32_t warps = warptile[0] / warptile[10];
const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size;
const uint32_t mmid_row_ids = mul_mat_id ? (4096 * sizeof(uint32_t) + 4/*_ne1*/) : 0; const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0;
const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0;
const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0;
const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size; const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh;
const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize;
VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), "
@@ -6288,7 +6289,6 @@ static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context&
const uint64_t nei0 = ids->ne[0]; const uint64_t nei0 = ids->ne[0];
const uint64_t nei1 = ids->ne[1]; const uint64_t nei1 = ids->ne[1];
GGML_ASSERT(nei0 * nei1 <= 4096);
const uint32_t nbi1 = ids->nb[1]; const uint32_t nbi1 = ids->nb[1];
const uint32_t nbi2 = ids->nb[2]; const uint32_t nbi2 = ids->nb[2];
@@ -6728,37 +6728,7 @@ static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx
if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) {
ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
} else { } else {
// Split based on number of ids, to fit in shared memory ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
const uint32_t nei0 = (uint32_t)src2->ne[0];
const uint32_t nei1 = (uint32_t)src2->ne[1];
GGML_ASSERT(nei0 <= 4096);
const uint32_t split_size = std::min(nei1, 4096u / nei0);
if (split_size == nei1) {
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun);
} else {
ggml_tensor src1_copy = *src1;
ggml_tensor src2_copy = *src2;
ggml_tensor dst_copy = *dst;
for (uint32_t token_start = 0; token_start < nei1; token_start += split_size) {
const uint32_t n_tokens = std::min(split_size, nei1 - token_start);
src1_copy.view_offs = src1->view_offs + token_start * src1_copy.nb[2];
src2_copy.view_offs = src2->view_offs + token_start * src2_copy.nb[1];
dst_copy.view_offs = dst->view_offs + token_start * dst_copy.nb[2];
src1_copy.ne[2] = n_tokens;
src2_copy.ne[1] = n_tokens;
dst_copy.ne[2] = n_tokens;
ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, &src1_copy, &src2_copy, &dst_copy, dryrun);
// invalidate cached prealloc_y, can't cache based on the copy of the ggml_tensor
ctx->prealloc_y_last_pipeline_used = {};
ctx->prealloc_y_last_tensor_used = nullptr;
}
}
} }
} }

View File

@@ -109,13 +109,13 @@ shared FLOAT_TYPE buf_b[BN * SHMEM_STRIDE];
#define NUM_WARPS (BLOCK_SIZE / WARP) #define NUM_WARPS (BLOCK_SIZE / WARP)
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
shared u16vec2 row_ids[4096]; shared u16vec2 row_ids[BN];
uint _ne1; uint _ne1;
#ifdef MUL_MAT_ID_USE_SUBGROUPS #ifdef MUL_MAT_ID_USE_SUBGROUPS
shared uvec4 ballots_sh[NUM_WARPS]; shared uvec4 ballots_sh[NUM_WARPS];
void load_row_ids(uint expert_idx, bool nei0_is_pow2) { void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0; _ne1 = 0;
uint num_elements = p.nei1 * p.nei0; uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0); uint nei0shift = findLSB(p.nei0);
@@ -165,11 +165,14 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
barrier(); barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) { if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
row_ids[_ne1 + idx] = u16vec2(ii0, ii1); row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1);
} }
_ne1 += total; _ne1 += total;
iter &= 15; iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
break;
}
} }
barrier(); barrier();
} }
@@ -242,16 +245,18 @@ void main() {
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
#ifdef MUL_MAT_ID_USE_SUBGROUPS #ifdef MUL_MAT_ID_USE_SUBGROUPS
if (bitCount(p.nei0) == 1) { if (bitCount(p.nei0) == 1) {
load_row_ids(expert_idx, true); load_row_ids(expert_idx, true, ic);
} else { } else {
load_row_ids(expert_idx, false); load_row_ids(expert_idx, false, ic);
} }
#else #else
_ne1 = 0; _ne1 = 0;
for (uint ii1 = 0; ii1 < p.nei1; ii1++) { for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) {
for (uint ii0 = 0; ii0 < p.nei0; ii0++) { for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) {
if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) {
row_ids[_ne1] = u16vec2(ii0, ii1); if (_ne1 >= ic * BN) {
row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1);
}
_ne1++; _ne1++;
} }
} }
@@ -797,7 +802,7 @@ void main() {
[[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) {
#if LOAD_VEC_B == 8 #if LOAD_VEC_B == 8
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; const u16vec2 row_idx = row_ids[loadc_b + l];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
#else #else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@@ -813,7 +818,7 @@ void main() {
buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w); buf_b[buf_idx + 7] = FLOAT_TYPE(data_b[idx][1].w);
#elif LOAD_VEC_B == 4 #elif LOAD_VEC_B == 4
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; const u16vec2 row_idx = row_ids[loadc_b + l];
const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b;
#else #else
const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b; const uint idx = pos_b + (loadc_b + l) * p.stride_b / LOAD_VEC_B + loadr_b;
@@ -832,7 +837,7 @@ void main() {
#else #else
const uint row_i = ic * BN + loadc_b + l; const uint row_i = ic * BN + loadc_b + l;
if (row_i < _ne1 && block + loadr_b < end_k) { if (row_i < _ne1 && block + loadr_b < end_k) {
const u16vec2 row_idx = row_ids[row_i]; const u16vec2 row_idx = row_ids[loadc_b + l];
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = TO_FLOAT_TYPE(data_b[pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + loadr_b]);
} else { } else {
buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f); buf_b[(loadc_b + l) * SHMEM_STRIDE + loadr_b] = FLOAT_TYPE(0.0f);
@@ -903,7 +908,7 @@ void main() {
const uint row_i = dc + cm_col * TN + col + store_c; const uint row_i = dc + cm_col * TN + col + store_c;
if (row_i >= _ne1) break; if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i]; const u16vec2 row_idx = row_ids[row_i - ic * BN];
if (dr + cm_row * TM + store_r < p.M) { if (dr + cm_row * TM + store_r < p.M) {
data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]);
@@ -953,7 +958,7 @@ void main() {
const uint row_i = dc_warp + cc; const uint row_i = dc_warp + cc;
if (row_i >= _ne1) break; if (row_i >= _ne1) break;
const u16vec2 row_idx = row_ids[row_i]; const u16vec2 row_idx = row_ids[row_i - ic * BN];
#endif // MUL_MAT_ID #endif // MUL_MAT_ID
[[unroll]] for (uint cr = 0; cr < TM; cr++) { [[unroll]] for (uint cr = 0; cr < TM; cr++) {
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID

View File

@@ -93,7 +93,7 @@ layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];}; layout (binding = 3) readonly buffer IDS {int data_ids[];};
shared u16vec4 row_ids[4096]; shared u16vec4 row_ids[BN];
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB {
B_TYPE b[]; B_TYPE b[];
@@ -111,7 +111,7 @@ B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const i
return B_TYPE(0.0); return B_TYPE(0.0);
} }
const u16vec4 row_idx = row_ids[row_i]; const u16vec4 row_idx = row_ids[row_i & (BN - 1)];
B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]];
return ret; return ret;
@@ -123,14 +123,14 @@ D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem
uint dc = ic * BN + c; uint dc = ic * BN + c;
if (dr < p.M && dc < _ne1) { if (dr < p.M && dc < _ne1) {
uint row_i = dc; uint row_i = c;
const u16vec4 row_idx = row_ids[row_i]; const u16vec4 row_idx = row_ids[row_i];
data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem; data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem;
} }
return elem; return elem;
} }
void load_row_ids(uint expert_idx, bool nei0_is_pow2) { void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) {
_ne1 = 0; _ne1 = 0;
uint num_elements = p.nei1 * p.nei0; uint num_elements = p.nei1 * p.nei0;
uint nei0shift = findLSB(p.nei0); uint nei0shift = findLSB(p.nei0);
@@ -180,11 +180,14 @@ void load_row_ids(uint expert_idx, bool nei0_is_pow2) {
barrier(); barrier();
uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot);
if (in_range && id == expert_idx) { if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) {
row_ids[_ne1 + idx] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0); row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0);
} }
_ne1 += total; _ne1 += total;
iter &= 15; iter &= 15;
if (_ne1 >= (ic + 1) * BN) {
break;
}
} }
barrier(); barrier();
} }
@@ -218,9 +221,9 @@ void main() {
#ifdef MUL_MAT_ID #ifdef MUL_MAT_ID
if (bitCount(p.nei0) == 1) { if (bitCount(p.nei0) == 1) {
load_row_ids(expert_idx, true); load_row_ids(expert_idx, true, ic);
} else { } else {
load_row_ids(expert_idx, false); load_row_ids(expert_idx, false, ic);
} }
// Workgroup has no work // Workgroup has no work

View File

@@ -6017,6 +6017,7 @@ static std::vector<std::unique_ptr<test_case>> make_test_cases_eval() {
// test large experts*tokens // test large experts*tokens
for (bool b : {false, true}) { for (bool b : {false, true}) {
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16)); test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 16, 16, b, 32, 1024, 16));
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 2, 2, b, 32, 8192, 64));
} }
test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1)); test_cases.emplace_back(new test_mul_mat_id(GGML_TYPE_F16, GGML_TYPE_F32, 1, 1, false, 8, 16, 1));