* tests: add mul_mat perf/functional tests for p021/nc vulkan shaders * vulkan: Optimize mul_mat_vec p021 and nc shaders. These shaders are used in attention calculations, and when the KV cache grows large they start to dominate the run time. For the nc shader (which is called with large 'k' dimension), use unrolling and vector loads. For the p021 shader (which is called with large 'm' and small 'k' dimensions), take advantage of grouped query attention to reuse loads from the A matrix for the whole group, and reduce the number of workgroups (too much overhead from tiny dispatches). Using subgroupAdd in the p021 shader also helps, use that conditionally.
116 lines
3.3 KiB
Plaintext
116 lines
3.3 KiB
Plaintext
#version 450
|
|
|
|
#extension GL_EXT_control_flow_attributes : enable
|
|
#extension GL_EXT_shader_16bit_storage : require
|
|
|
|
#define BLOCK_SIZE 32
|
|
#define FLOAT_TYPE float
|
|
|
|
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
|
|
|
|
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
|
|
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
|
|
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
|
|
|
|
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
|
|
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
|
|
|
|
layout (push_constant) uniform parameter
|
|
{
|
|
uint ncols_x;
|
|
uint nrows_x;
|
|
uint row_stride_x;
|
|
uint channel_stride_x;
|
|
uint channel_x_divisor;
|
|
uint b_offset;
|
|
uint d_offset;
|
|
} p;
|
|
|
|
shared FLOAT_TYPE tmp[BLOCK_SIZE];
|
|
|
|
void main() {
|
|
const uint tid = gl_LocalInvocationID.x;
|
|
const uint row_x = gl_GlobalInvocationID.y;
|
|
const uint channel = gl_GlobalInvocationID.z;
|
|
const uint channel_x = channel / p.channel_x_divisor;
|
|
|
|
const uint nrows_y = p.ncols_x;
|
|
const uint nrows_dst = p.nrows_x;
|
|
const uint row_dst = row_x;
|
|
|
|
const uint idst = channel*nrows_dst + row_dst;
|
|
|
|
FLOAT_TYPE temp = 0.0f;
|
|
|
|
// Detect alignment for vector loads
|
|
bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0;
|
|
|
|
for (uint col_x0 = 0; col_x0 < p.ncols_x;) {
|
|
|
|
// Unroll 2x and do vec4 loads if aligned
|
|
const uint unroll_count = 2;
|
|
if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) {
|
|
[[unroll]] for (uint i = 0; i < unroll_count; ++i) {
|
|
const uint col_x = col_x0 + 4*tid;
|
|
|
|
const uint row_y = col_x;
|
|
|
|
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
|
const uint iy = channel*nrows_y + row_y;
|
|
|
|
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
|
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
|
|
|
|
temp += dot(av4, bv4);
|
|
|
|
col_x0 += 4*BLOCK_SIZE;
|
|
}
|
|
// do vec4 loads if aligned
|
|
} else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
|
|
const uint col_x = col_x0 + 4*tid;
|
|
|
|
const uint row_y = col_x;
|
|
|
|
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
|
const uint iy = channel*nrows_y + row_y;
|
|
|
|
const vec4 av4 = vec4(data_a_v4[ix / 4]);
|
|
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
|
|
|
|
temp += dot(av4, bv4);
|
|
|
|
col_x0 += 4*BLOCK_SIZE;
|
|
} else {
|
|
const uint col_x = col_x0 + tid;
|
|
if (col_x >= p.ncols_x) {
|
|
break;
|
|
}
|
|
|
|
const uint row_y = col_x;
|
|
|
|
const uint ix = channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
|
|
const uint iy = channel*nrows_y + row_y;
|
|
|
|
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
|
|
|
|
temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp);
|
|
col_x0 += BLOCK_SIZE;
|
|
}
|
|
}
|
|
|
|
tmp[tid] = temp;
|
|
|
|
// sum up partial sums and write back result
|
|
barrier();
|
|
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
|
|
if (tid < s) {
|
|
tmp[tid] += tmp[tid + s];
|
|
}
|
|
barrier();
|
|
}
|
|
|
|
if (tid == 0) {
|
|
dst[idst] = tmp[0];
|
|
}
|
|
}
|