vulkan: In coopmat2 mmq, load q4_k/q5_k scales through shared memory (#12833)
q4_k and q5_k had a lot of redundant global loads where the same 16B of scale information is repeatedly loaded and decoded during each loop iteration. This change restructures the loops to more explicitly iterate over whole blocks in the outer loop (with unrolled inner loop) and to copy/decode the scale data into shared memory once at the start of each outer loop. The copy is pipelined so the scale load from global memory is relatively cheap. This improves q4_k/q5_k model prompt processing performance by around 5-7%. I briefly tried applying this to q6_k and q4_0, and it didn't help for q6_k and hurt for q4_0. The big "else" path in mul_mm_cm2.comp that had all the clamped/unclamped variants isn't used as often as it originally was (e.g. due to the padded_N change), so I trimmed it down to offset some of the new complexity of the semi-manual loop unrolling.
This commit is contained in:
@@ -167,6 +167,101 @@ layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4
|
||||
block_q4_K_packed128 block;
|
||||
};
|
||||
|
||||
#if defined(IS_MUL_MM2)
|
||||
|
||||
// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales
|
||||
// into shared memory and then process the whole tile using those scales.
|
||||
// There is a fetch function that loads into private variables and then a store
|
||||
// function that stores into shared memory.
|
||||
// Q4_K and Q5_K have the same encoding of scales, so everything is shared except
|
||||
// the part that fetches from the structure (which has a different block layout).
|
||||
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
|
||||
const uint shAscales_stride = (BM + 2);
|
||||
// 1 scale per 32 elements -> 8 scales per block, per row
|
||||
shared vec2 shAscales[8 * shAscales_stride];
|
||||
uvec4 row_v;
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_K)
|
||||
layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];};
|
||||
|
||||
void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
|
||||
{
|
||||
uint tids_per_row = BLOCK_SIZE / BM;
|
||||
uint is_per_tid = 8 / tids_per_row;
|
||||
uint is_start = is_per_tid * (tid % tids_per_row);
|
||||
uint tid_row = tid / tids_per_row;
|
||||
|
||||
uint row = ir_BM + tid_row;
|
||||
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
|
||||
if (in_bounds || row < p.M) {
|
||||
row_v = data_a_q4_k_packed128[block_index].q4k[0];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#if defined(DATA_A_Q5_K)
|
||||
layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];};
|
||||
|
||||
void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds)
|
||||
{
|
||||
uint tids_per_row = BLOCK_SIZE / BM;
|
||||
uint is_per_tid = 8 / tids_per_row;
|
||||
uint is_start = is_per_tid * (tid % tids_per_row);
|
||||
uint tid_row = tid / tids_per_row;
|
||||
|
||||
uint row = ir_BM + tid_row;
|
||||
uint block_index = pos_a + row * stride_a + (block_k / QUANT_K);
|
||||
if (in_bounds || row < p.M) {
|
||||
row_v = data_a_q5_k_packed128[block_index].q5k[0];
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K)
|
||||
void store_scalesQ4_K(uint tid)
|
||||
{
|
||||
barrier();
|
||||
|
||||
uint tids_per_row = BLOCK_SIZE / BM;
|
||||
uint is_per_tid = 8 / tids_per_row;
|
||||
uint is_start = is_per_tid * (tid % tids_per_row);
|
||||
uint tid_row = tid / tids_per_row;
|
||||
|
||||
[[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) {
|
||||
uint is = idx + is_start;
|
||||
uvec4 v = row_v;
|
||||
const vec2 loadd = vec2(unpackFloat2x16(v.x));
|
||||
|
||||
uint32_t sc;
|
||||
uint32_t mbyte;
|
||||
|
||||
uint32_t scale0 = v.y;
|
||||
uint32_t scale4 = v.z;
|
||||
uint32_t scale8 = v.w;
|
||||
|
||||
uint32_t sc_lo = scale0;
|
||||
uint32_t mb_lo = scale4;
|
||||
uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2);
|
||||
uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2);
|
||||
|
||||
sc = is < 4 ? sc_lo : sc_hi;
|
||||
mbyte = is < 4 ? mb_lo : mb_hi;
|
||||
sc = sc >> (8 * (is & 3));
|
||||
mbyte = mbyte >> (8 * (is & 3));
|
||||
sc &= 0x3F;
|
||||
mbyte &= 0x3F;
|
||||
|
||||
const float d = loadd.x * float(sc);
|
||||
const float m = loadd.y * float(mbyte);
|
||||
shAscales[is * shAscales_stride + tid_row] = vec2(d,m);
|
||||
}
|
||||
|
||||
barrier();
|
||||
}
|
||||
#endif
|
||||
|
||||
#endif
|
||||
|
||||
float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2])
|
||||
{
|
||||
decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl);
|
||||
@@ -176,8 +271,12 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
|
||||
const uint b = (idx & 0x20) >> 5; // 0,1
|
||||
const uint is = (idx & 0xE0) >> 5; // 0..7
|
||||
|
||||
#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K)
|
||||
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
|
||||
float d = v.x;
|
||||
float m = v.y;
|
||||
#else
|
||||
uvec4 v = bl128.block.q4k[0];
|
||||
|
||||
const vec2 loadd = vec2(unpackFloat2x16(v.x));
|
||||
|
||||
uint32_t sc;
|
||||
@@ -201,6 +300,7 @@ float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2
|
||||
|
||||
const float d = loadd.x * float(sc);
|
||||
const float m = loadd.y * float(mbyte);
|
||||
#endif
|
||||
|
||||
uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]);
|
||||
qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF;
|
||||
@@ -231,6 +331,11 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
|
||||
const uint b = (idx & 0x20) >> 5; // 0,1
|
||||
const uint is = (idx & 0xE0) >> 5; // 0..7
|
||||
|
||||
#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K)
|
||||
vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)];
|
||||
float d = v.x;
|
||||
float m = v.y;
|
||||
#else
|
||||
uvec4 v = bl128.block.q5k[0];
|
||||
|
||||
const f16vec2 loadd = unpackFloat2x16(v.x);
|
||||
@@ -256,6 +361,7 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
|
||||
|
||||
const float16_t d = loadd.x * float16_t(sc);
|
||||
const float16_t m = loadd.y * float16_t(mbyte);
|
||||
#endif
|
||||
|
||||
uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]);
|
||||
qh = ((qh >> is) & 0x101) << 4;
|
||||
@@ -264,9 +370,9 @@ float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2
|
||||
qs = (qs >> (b * 4)) & 0x0F0F;
|
||||
qs = unpack8(qs | qh)[idx & 1];
|
||||
|
||||
float16_t ret = d * (float16_t(qs)) - m;
|
||||
float ret = d * float(qs) - m;
|
||||
|
||||
return ret;
|
||||
return float16_t(ret);
|
||||
}
|
||||
|
||||
layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K {
|
||||
@@ -564,8 +670,12 @@ float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoor
|
||||
#define dequantFuncA dequantFuncQ3_K
|
||||
#elif defined(DATA_A_Q4_K)
|
||||
#define dequantFuncA dequantFuncQ4_K
|
||||
#define fetch_scales fetch_scalesQ4_K
|
||||
#define store_scales store_scalesQ4_K
|
||||
#elif defined(DATA_A_Q5_K)
|
||||
#define dequantFuncA dequantFuncQ5_K
|
||||
#define fetch_scales fetch_scalesQ5_K
|
||||
#define store_scales store_scalesQ4_K
|
||||
#elif defined(DATA_A_Q6_K)
|
||||
#define dequantFuncA dequantFuncQ6_K
|
||||
#elif defined(DATA_A_IQ1_S)
|
||||
|
||||
Reference in New Issue
Block a user