* Begin work on set_rows * Work on set rows * Add error buffers for reporting unsupported SET_ROWS indices * Remove extra comments * Work on templating for different types in shaders * Work on shader type generation * Working q4_0 mul_mat and some templating for different types * Add q4_0_f16 matmul and fix device init * Add matmul support for basic quantization types * Add q2_k and q3_k quantization * Add rest of k-quants * Get firt i-quant working * Closer to supporting all i-quants * Support rest of i-quants * Cleanup code * Fix python formatting * debug * Bugfix for memset * Add padding to end of buffers on creation * Simplify bit-shifting * Update usage of StringView
41 lines
1.3 KiB
WebGPU Shading Language
41 lines
1.3 KiB
WebGPU Shading Language
@group(0) @binding(0)
|
|
var<storage, read_write> output_buffer: array<u32>;
|
|
|
|
struct Params {
|
|
offset: u32, // in bytes
|
|
size: u32, // in bytes
|
|
value: u32, // 4 8-bit values, which are either repeating (memset_tensor) or may be separate (cleaning up unaligned set_tensor operations)
|
|
};
|
|
|
|
@group(0) @binding(1)
|
|
var<uniform> params: Params;
|
|
|
|
override wg_size: u32;
|
|
override bytes_per_thread: u32;
|
|
|
|
@compute @workgroup_size(wg_size)
|
|
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
|
let i = gid.x * bytes_per_thread;
|
|
let start = params.offset;
|
|
let end = params.offset + params.size;
|
|
|
|
for (var j: u32 = 0u; j < bytes_per_thread; j += 4) {
|
|
let byte_index = start + i + j;
|
|
if (byte_index + 4 <= end) {
|
|
output_buffer[byte_index >> 2] = params.value;
|
|
} else {
|
|
// Handle tail (unaligned)
|
|
for (var k: u32 = 0; k < 4; k++) {
|
|
let idx = byte_index + k;
|
|
if (idx < end) {
|
|
let word_idx = idx >> 2;
|
|
let bit_offset = (idx & 3) * 8u;
|
|
let mask = ~(0xffu << bit_offset);
|
|
let existing = output_buffer[word_idx];
|
|
output_buffer[word_idx] = (existing & mask) | (params.value & (0xffu << bit_offset));
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|