mat vec double buffer (#12188)
This commit is contained in:
@@ -6,20 +6,21 @@
|
||||
|
||||
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
|
||||
|
||||
shared FLOAT_TYPE sccache[BLOCK_SIZE/16][16];
|
||||
shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16];
|
||||
|
||||
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
|
||||
uint csel = 0;
|
||||
|
||||
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
|
||||
const uint y_idx = i * QUANT_K + y_offset;
|
||||
|
||||
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
|
||||
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
|
||||
csel ^= 1;
|
||||
|
||||
if (!all_threads) { // when we don't have enough blocks to use all threads
|
||||
barrier();
|
||||
if (i < num_blocks_per_row)
|
||||
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
|
||||
sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
|
||||
barrier();
|
||||
|
||||
if (i >= num_blocks_per_row)
|
||||
@@ -51,8 +52,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
const vec4 q3 = vec4(unpack8(q3_u32)) - 32;
|
||||
|
||||
if (all_threads) {
|
||||
barrier();
|
||||
sccache[ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
|
||||
sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]);
|
||||
barrier();
|
||||
}
|
||||
|
||||
@@ -71,7 +71,7 @@ void calc_superblock(const uint a_offset, const uint b_offset, const uint itid,
|
||||
sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]);
|
||||
sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]);
|
||||
}
|
||||
temp[j][n] = fma(fma(sum[0], sccache[ix][s_offset], fma(sum[1], sccache[ix][s_offset + 2], fma(sum[2], sccache[ix][s_offset + 4], sum[3] * sccache[ix][s_offset + 6]))), d, temp[j][n]);
|
||||
temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user