vulkan: In coopmat2 mmq, load q4_k/q5_k scales through shared memory (#12833)

q4_k and q5_k had a lot of redundant global loads where the same 16B of
scale information is repeatedly loaded and decoded during each loop iteration.
This change restructures the loops to more explicitly iterate over whole
blocks in the outer loop (with unrolled inner loop) and to copy/decode the
scale data into shared memory once at the start of each outer loop. The copy
is pipelined so the scale load from global memory is relatively cheap.

This improves q4_k/q5_k model prompt processing performance by around 5-7%.
I briefly tried applying this to q6_k and q4_0, and it didn't help for q6_k
and hurt for q4_0.

The big "else" path in mul_mm_cm2.comp that had all the clamped/unclamped
variants isn't used as often as it originally was (e.g. due to the padded_N
change), so I trimmed it down to offset some of the new complexity of the
semi-manual loop unrolling.
This commit is contained in:
Jeff Bolz
2025-04-09 00:25:08 -05:00
committed by GitHub
parent 7ecd780b1a
commit 0090950f67
3 changed files with 235 additions and 42 deletions

View File

@@ -167,6 +167,101 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
block_q4_K_packed128 block;
};
#if defined(IS_MUL_MM2)
// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
// into shared memory and then process the whole tile using those scales.
// There is a fetch function that loads into private variables and then a store
// function that stores into shared memory.
// Q4_K and Q5_K have the same encoding of scales, so everything is shared except
// the part that fetches from the structure (which has a different block layout).
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
const uint shAscales_stride = (BM + 2);
// 1 scale per 32 elements -> 8 scales per block, per row
shared vec2 shAscales[8 * shAscales_stride];
uvec4 row_v;
#endif
#if defined(DATA_A_Q4_K)
layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
{
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
uint row = ir_BM + tid_row;
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
if (in_bounds || row < p.M) {
row_v = data_a_q4_k_packed128[block_index].q4k[0];
}
}
#endif
#if defined(DATA_A_Q5_K)
layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
{
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
uint row = ir_BM + tid_row;
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
if (in_bounds || row < p.M) {
row_v = data_a_q5_k_packed128[block_index].q5k[0];
}
}
#endif
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
void store_scalesQ4_K(uint tid)
{
barrier();
uint tids_per_row = BLOCK_SIZE / BM;
uint is_per_tid = 8 / tids_per_row;
uint is_start = is_per_tid * (tid % tids_per_row);
uint tid_row = tid / tids_per_row;
[[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
uint is = idx + is_start;
uvec4 v = row_v;
const vec2 loadd = vec2(unpackFloat2x16(v.x));
uint32_t sc;
uint32_t mbyte;
uint32_t scale0 = v.y;
uint32_t scale4 = v.z;
uint32_t scale8 = v.w;
uint32_t sc_lo = scale0;
uint32_t mb_lo = scale4;
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
sc = is < 4 ? sc_lo : sc_hi;
mbyte = is < 4 ? mb_lo : mb_hi;
sc = sc >> (8 * (is & 3));
mbyte = mbyte >> (8 * (is & 3));
sc &= 0x3F;
mbyte &= 0x3F;
const float d = loadd.x * float(sc);
const float m = loadd.y * float(mbyte);
shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
}
barrier();
}
#endif
#endif
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
{
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
@@ -176,8 +271,12 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
float d = v.x;
float m = v.y;
#else
uvec4 v = bl128.block.q4k[0];
const vec2 loadd = vec2(unpackFloat2x16(v.x));
uint32_t sc;
@@ -201,6 +300,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
const float d = loadd.x * float(sc);
const float m = loadd.y * float(mbyte);
#endif
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
@@ -231,6 +331,11 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
const uint b = (idx & 0x20) >> 5; // 0,1
const uint is = (idx & 0xE0) >> 5; // 0..7
#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
float d = v.x;
float m = v.y;
#else
uvec4 v = bl128.block.q5k[0];
const f16vec2 loadd = unpackFloat2x16(v.x);
@@ -256,6 +361,7 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
const float16_t d = loadd.x * float16_t(sc);
const float16_t m = loadd.y * float16_t(mbyte);
#endif
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
qh = ((qh >> is) & 0x101) << 4;
@@ -264,9 +370,9 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
qs = (qs >> (b * 4)) & 0x0F0F;
qs = unpack8(qs | qh)[idx & 1];
float16_t ret = d * (float16_t(qs)) - m;
float ret = d * float(qs) - m;
return ret;
return float16_t(ret);
}
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
@@ -564,8 +670,12 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
#define dequantFuncA dequantFuncQ3_K
#elif defined(DATA_A_Q4_K)
#define dequantFuncA dequantFuncQ4_K
#define fetch_scales fetch_scalesQ4_K
#define store_scales store_scalesQ4_K
#elif defined(DATA_A_Q5_K)
#define dequantFuncA dequantFuncQ5_K
#define fetch_scales fetch_scalesQ5_K
#define store_scales store_scalesQ4_K
#elif defined(DATA_A_Q6_K)
#define dequantFuncA dequantFuncQ6_K
#elif defined(DATA_A_IQ1_S)