* vulkan: Use pipeline_robustness to disable robustness in mul_mat_vec. Add some early returns for nonexistent rows in mul_mat_vec shaders. These can only be hit when dispatching a 2D grid of workgroups. Fix the logic for the 2D grid of workgroups to round up. Enable the pipeline robustness extension if it's available, and use it to disable robustness for these pipelines. The instructions to do the bounds checking contend for the same ALU resources as the bit twiddling dequant instructions. * vulkan: Add GLSL structure aliases for quant types to allow larger loads In Vulkan it's not possible to cast pointer types, so instead you have to declare an aliased binding for the memory with a different type. This commit adds aliases for the quant formats using 16b ints, and in a few places where the struct size is a multiple of 4 also using 32b ints. Currently only q4_k's aliases are used, but others will be used in subsequent commits. * vulkan: use larger loads in q5_k and q6_k shaders. Similar to the optimization I did in q4_k recently, this vectorizes some loads and reduces the number of bit twiddling instructions. * vulkan: use larger K step per iteration in mul_mat_vec. Add vec4 dequantization functions, and use them to do K=8 per iteration in mul_mat_vec. This uses 16b loads for the quant values and 128b loads for B which helps reduce the load on the memory system. The K_PER_ITER==2 logic is still there, just for F16/F32, and really only because they support unaligned sizes. Tweak the num_iters/unrolling logic to be simpler and catch a couple missed unrolling opportunities.
152 lines
6.9 KiB
Plaintext
152 lines
6.9 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_shader_explicit_arithmetic_types : require
|
|
|
|
#include "mul_mat_vec_base.comp"
|
|
|
|
layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
shared FLOAT_TYPE tmp[32];
|
|
|
|
void main() {
|
|
const uint row = gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z;
|
|
|
|
if (row >= p.stride_d) {
|
|
return;
|
|
}
|
|
|
|
uint a_offset, b_offset, d_offset;
|
|
get_offsets(a_offset, b_offset, d_offset);
|
|
|
|
const uint num_blocks_per_row = p.ncols / QUANT_K;
|
|
const uint ib0 = a_offset / QUANT_K + row*num_blocks_per_row;
|
|
|
|
const uint tid = gl_LocalInvocationID.x/2; // 0...31 or 0...16
|
|
const uint ix = gl_LocalInvocationID.x%2; // 0 or 0, 1
|
|
|
|
const uint il = tid/4; // 0...3
|
|
const uint ir = tid - 4*il; // 0...7 or 0...3
|
|
|
|
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
|
|
const uint v_in = il % 2;
|
|
|
|
const uint l0 = 4*ir + 2*v_in; // 0...15
|
|
const uint q_offset = 32*v_im + l0;
|
|
const uint y_offset = 64*v_im + l0;
|
|
|
|
const uint8_t hm1 = uint8_t(1 << (2*v_im));
|
|
const uint8_t hm2 = uint8_t(hm1 << 4);
|
|
|
|
FLOAT_TYPE temp = FLOAT_TYPE(0.0); // partial sum for thread in warp
|
|
|
|
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += 2) {
|
|
const uint y1_idx = i * QUANT_K + y_offset;
|
|
const uint y2_idx = y1_idx + 128;
|
|
|
|
f16vec2 d = data_a[ib0 + i].d;
|
|
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
|
|
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
|
|
|
|
uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
|
|
uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
|
|
uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
|
|
uvec4 scale0 = uvec4(unpack8(scale0_u32));
|
|
uvec4 scale4 = uvec4(unpack8(scale4_u32));
|
|
uvec4 scale8 = uvec4(unpack8(scale8_u32));
|
|
|
|
const uint32_t sc0 = ( scale0.x & 0x3f);
|
|
const uint32_t sc1 = ( scale0.y & 0x3f);
|
|
const uint32_t sc2 = ( scale4.x & 0x3f);
|
|
const uint32_t sc3 = ( scale4.y & 0x3f);
|
|
const uint32_t sc4 = (( scale8.x & 0x0f) | ((scale0.x & 0xc0) >> 2));
|
|
const uint32_t sc5 = (( scale8.y & 0x0f) | ((scale0.y & 0xc0) >> 2));
|
|
const uint32_t sc6 = (((scale8.x >> 4) & 0x0f) | ((scale4.x & 0xc0) >> 2));
|
|
const uint32_t sc7 = (((scale8.y >> 4) & 0x0f) | ((scale4.y & 0xc0) >> 2));
|
|
|
|
uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
|
|
uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
|
|
|
|
uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
|
|
uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
|
|
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
|
|
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
|
|
|
|
uvec4 qs0_16_lo4 = uvec4(unpack8(qs0_16_u32_lo4));
|
|
uvec4 qs64_80_lo4 = uvec4(unpack8(qs64_80_u32_lo4));
|
|
uvec4 qs0_16_hi4 = uvec4(unpack8(qs0_16_u32_hi4));
|
|
uvec4 qs64_80_hi4 = uvec4(unpack8(qs64_80_u32_hi4));
|
|
|
|
const uint32_t q4_0 = qs0_16_lo4.x;
|
|
const uint32_t q4_1 = qs0_16_lo4.y;
|
|
const uint32_t q4_2 = qs0_16_lo4.z;
|
|
const uint32_t q4_3 = qs0_16_lo4.w;
|
|
const uint32_t q4_4 = qs0_16_hi4.x;
|
|
const uint32_t q4_5 = qs0_16_hi4.y;
|
|
const uint32_t q4_6 = qs0_16_hi4.z;
|
|
const uint32_t q4_7 = qs0_16_hi4.w;
|
|
const uint32_t q4_8 = qs64_80_lo4.x;
|
|
const uint32_t q4_9 = qs64_80_lo4.y;
|
|
const uint32_t q4_10 = qs64_80_lo4.z;
|
|
const uint32_t q4_11 = qs64_80_lo4.w;
|
|
const uint32_t q4_12 = qs64_80_hi4.x;
|
|
const uint32_t q4_13 = qs64_80_hi4.y;
|
|
const uint32_t q4_14 = qs64_80_hi4.z;
|
|
const uint32_t q4_15 = qs64_80_hi4.w;
|
|
|
|
B_TYPE_VEC2 by10 = data_b_v2[(b_offset + y1_idx) / 2];
|
|
B_TYPE_VEC2 by116 = data_b_v2[(b_offset + y1_idx) / 2 + 8];
|
|
B_TYPE_VEC2 by132 = data_b_v2[(b_offset + y1_idx) / 2 + 16];
|
|
B_TYPE_VEC2 by148 = data_b_v2[(b_offset + y1_idx) / 2 + 24];
|
|
B_TYPE_VEC2 by20 = data_b_v2[(b_offset + y2_idx) / 2];
|
|
B_TYPE_VEC2 by216 = data_b_v2[(b_offset + y2_idx) / 2 + 8];
|
|
B_TYPE_VEC2 by232 = data_b_v2[(b_offset + y2_idx) / 2 + 16];
|
|
B_TYPE_VEC2 by248 = data_b_v2[(b_offset + y2_idx) / 2 + 24];
|
|
|
|
uint32_t qh0 = data_a_packed16[ib0 + i].qh[l0 / 2];
|
|
uint32_t qh1 = qh0 >> 8;
|
|
uint32_t qh16 = data_a_packed16[ib0 + i].qh[l0 / 2 + 8];
|
|
uint32_t qh17 = qh16 >> 8;
|
|
|
|
const FLOAT_TYPE sx =
|
|
fma(FLOAT_TYPE(by10.x), (q4_0 + (((qh0 & hm1) != 0) ? 16 : 0)),
|
|
fma(FLOAT_TYPE(by10.y), (q4_1 + (((qh1 & hm1) != 0) ? 16 : 0)),
|
|
fma(FLOAT_TYPE(by116.x), (q4_2 + (((qh16 & hm1) != 0) ? 16 : 0)),
|
|
FLOAT_TYPE(by116.y) * (q4_3 + (((qh17 & hm1) != 0) ? 16 : 0)))));
|
|
const FLOAT_TYPE sy =
|
|
fma(FLOAT_TYPE(by132.x), (q4_4 + (((qh0 & (hm1 << 1)) != 0) ? 16 : 0)),
|
|
fma(FLOAT_TYPE(by132.y), (q4_5 + (((qh1 & (hm1 << 1)) != 0) ? 16 : 0)),
|
|
fma(FLOAT_TYPE(by148.x), (q4_6 + (((qh16 & (hm1 << 1)) != 0) ? 16 : 0)),
|
|
FLOAT_TYPE(by148.y) * (q4_7 + (((qh17 & (hm1 << 1)) != 0) ? 16 : 0)))));
|
|
const FLOAT_TYPE sz =
|
|
fma(FLOAT_TYPE(by20.x), (q4_8 + (((qh0 & hm2) != 0) ? 16 : 0)),
|
|
fma(FLOAT_TYPE(by20.y), (q4_9 + (((qh1 & hm2) != 0) ? 16 : 0)),
|
|
fma(FLOAT_TYPE(by216.x), (q4_10 + (((qh16 & hm2) != 0) ? 16 : 0)),
|
|
FLOAT_TYPE(by216.y) * (q4_11 + (((qh17 & hm2) != 0) ? 16 : 0)))));
|
|
const FLOAT_TYPE sw =
|
|
fma(FLOAT_TYPE(by232.x), (q4_12 + (((qh0 & (hm2 << 1)) != 0) ? 16 : 0)),
|
|
fma(FLOAT_TYPE(by232.y), (q4_13 + (((qh1 & (hm2 << 1)) != 0) ? 16 : 0)),
|
|
fma(FLOAT_TYPE(by248.x), (q4_14 + (((qh16 & (hm2 << 1)) != 0) ? 16 : 0)),
|
|
FLOAT_TYPE(by248.y) * (q4_15 + (((qh17 & (hm2 << 1)) != 0) ? 16 : 0)))));
|
|
const FLOAT_TYPE smin =
|
|
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
|
|
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
|
|
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
|
|
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
|
|
temp = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp));
|
|
}
|
|
|
|
tmp[gl_LocalInvocationID.x] = temp;
|
|
|
|
// sum up partial sums and write back result
|
|
barrier();
|
|
[[unroll]] for (uint s = 16; s > 0; s >>= 1) {
|
|
if (tid < s) {
|
|
tmp[tid] += tmp[tid + s];
|
|
}
|
|
barrier();
|
|
}
|
|
if (tid == 0) {
|
|
data_d[d_offset + row] = D_TYPE(tmp[0]);
|
|
}
|
|
}
|