vulkan: optimize mul_mat_id loading row ids into shared memory (#15427)

- Spread the work across the whole workgroup. Using more threads seems to
far outweigh the synchronization overhead.
- Specialize the code for when the division is by a power of two.
This commit is contained in:
Jeff Bolz
2025-08-23 01:31:54 -05:00
committed by GitHub
parent e92734d51b
commit 330c3d2d21
3 changed files with 133 additions and 81 deletions

View File

@@ -2168,9 +2168,9 @@ static void ggml_vk_load_shaders(vk_device& device) {
s_mmq_wg_denoms_k = { 32, 64, 1 };
// spec constants and tile sizes for quant matmul_id
l_warptile_mmqid = { 256, 128, 128, 16, 0 };
m_warptile_mmqid = { 256, 128, 64, 16, 0 };
s_warptile_mmqid = { 256, 128, 64, 16, 0 };
l_warptile_mmqid = { 256, 128, 128, 16, 0, device->subgroup_size };
m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size };
l_mmqid_wg_denoms = { 128, 128, 1 };
m_mmqid_wg_denoms = { 128, 64, 1 };
s_mmqid_wg_denoms = { 128, 64, 1 };