Files
enginex-ascend-910-llama.cpp/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp

168 lines
7.8 KiB
Plaintext
Raw Normal View History

#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
vulkan: further optimize mul_mat_vec using larger loads (#10387) * vulkan: Use pipeline_robustness to disable robustness in mul_mat_vec. Add some early returns for nonexistent rows in mul_mat_vec shaders. These can only be hit when dispatching a 2D grid of workgroups. Fix the logic for the 2D grid of workgroups to round up. Enable the pipeline robustness extension if it's available, and use it to disable robustness for these pipelines. The instructions to do the bounds checking contend for the same ALU resources as the bit twiddling dequant instructions. * vulkan: Add GLSL structure aliases for quant types to allow larger loads In Vulkan it's not possible to cast pointer types, so instead you have to declare an aliased binding for the memory with a different type. This commit adds aliases for the quant formats using 16b ints, and in a few places where the struct size is a multiple of 4 also using 32b ints. Currently only q4_k's aliases are used, but others will be used in subsequent commits. * vulkan: use larger loads in q5_k and q6_k shaders. Similar to the optimization I did in q4_k recently, this vectorizes some loads and reduces the number of bit twiddling instructions. * vulkan: use larger K step per iteration in mul_mat_vec. Add vec4 dequantization functions, and use them to do K=8 per iteration in mul_mat_vec. This uses 16b loads for the quant values and 128b loads for B which helps reduce the load on the memory system. The K_PER_ITER==2 logic is still there, just for F16/F32, and really only because they support unaligned sizes. Tweak the num_iters/unrolling logic to be simpler and catch a couple missed unrolling opportunities.
2024-11-20 01:11:00 -06:00
#include "mul_mat_vec_base.comp"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y1_idx = i * QUANT_K + y_offset;
const uint y2_idx = y1_idx + 128;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ];
const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2];
const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4];
const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32;
const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2;
const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F));
const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h));
const FLOAT_TYPE sc0 = scale_0_4_l_f.x;
const FLOAT_TYPE sc1 = scale_0_4_l_f.y;
const FLOAT_TYPE sc2 = scale_0_4_l_f.z;
const FLOAT_TYPE sc3 = scale_0_4_l_f.w;
const FLOAT_TYPE sc4 = scale8_f.x;
const FLOAT_TYPE sc5 = scale8_f.y;
const FLOAT_TYPE sc6 = scale8_f.z;
const FLOAT_TYPE sc7 = scale8_f.w;
const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16);
uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F;
uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F;
uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F;
uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F;
const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8]));
const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4;
const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3;
const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010);
const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1;
qs0_16_u32_lo4 += qs0_16_lo4_offset16;
qs0_16_u32_hi4 += qs0_16_hi4_offset16;
qs64_80_u32_lo4 += qs64_80_lo4_offset16;
qs64_80_u32_hi4 += qs64_80_hi4_offset16;
const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4));
const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4));
const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4));
const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4));
const FLOAT_TYPE q4_0 = qs0_16_lo4.x;
const FLOAT_TYPE q4_1 = qs0_16_lo4.y;
const FLOAT_TYPE q4_2 = qs0_16_lo4.z;
const FLOAT_TYPE q4_3 = qs0_16_lo4.w;
const FLOAT_TYPE q4_4 = qs0_16_hi4.x;
const FLOAT_TYPE q4_5 = qs0_16_hi4.y;
const FLOAT_TYPE q4_6 = qs0_16_hi4.z;
const FLOAT_TYPE q4_7 = qs0_16_hi4.w;
const FLOAT_TYPE q4_8 = qs64_80_lo4.x;
const FLOAT_TYPE q4_9 = qs64_80_lo4.y;
const FLOAT_TYPE q4_10 = qs64_80_lo4.z;
const FLOAT_TYPE q4_11 = qs64_80_lo4.w;
const FLOAT_TYPE q4_12 = qs64_80_hi4.x;
const FLOAT_TYPE q4_13 = qs64_80_hi4.y;
const FLOAT_TYPE q4_14 = qs64_80_hi4.z;
const FLOAT_TYPE q4_15 = qs64_80_hi4.w;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]);
vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]);
vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]);
vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]);
vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]);
vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]);
vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]);
vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]);
const FLOAT_TYPE sx =
fma(FLOAT_TYPE(by10.x), q4_0,
fma(FLOAT_TYPE(by10.y), q4_1,
fma(FLOAT_TYPE(by116.x), q4_2,
FLOAT_TYPE(by116.y) * q4_3)));
const FLOAT_TYPE sy =
fma(FLOAT_TYPE(by132.x), q4_4,
fma(FLOAT_TYPE(by132.y), q4_5,
fma(FLOAT_TYPE(by148.x), q4_6,
FLOAT_TYPE(by148.y) * q4_7)));
const FLOAT_TYPE sz =
fma(FLOAT_TYPE(by20.x), q4_8,
fma(FLOAT_TYPE(by20.y), q4_9,
fma(FLOAT_TYPE(by216.x), q4_10,
FLOAT_TYPE(by216.y) * q4_11)));
const FLOAT_TYPE sw =
fma(FLOAT_TYPE(by232.x), q4_12,
fma(FLOAT_TYPE(by232.y), q4_13,
fma(FLOAT_TYPE(by248.x), q4_14,
FLOAT_TYPE(by248.y) * q4_15)));
const FLOAT_TYPE smin =
fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2,
fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3,
fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6,
(FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7)));
temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n]));
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...15
const uint ix = tid/16;
const uint il = itid/4; // 0...3
const uint ir = itid - 4*il; // 0...3
const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224
const uint v_in = il % 2;
const uint l0 = 4*ir + 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0;
const uint y_offset = 64*v_im + l0;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size)
calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}