GGML WebGPU: Support for ADD, MUL, RMS_NORM, GET_ROWS operators (#16018)
* Add paramater buffer pool, batching of submissions, refactor command building/submission * Add header for linux builds * Free staged parameter buffers at once * Format with clang-format * Fix thread-safe implementation * Use device implicit synchronization * Update workflow to use custom release * Remove testing branch workflow * some f32 tests passing * Disable set_rows until it's implemented * f32 add all tests passing * Begin work on set_rows * Work on set rows * Add error buffers for reporting unsupported SET_ROWS indices * Remove extra comments * Add templated addition, clean up code * Get addition and multiplication working * Implement rms_norm * Add get_rows implementation * Add new get_rows files * Refactor use of wg size entry * Fix compilation * Try manually unrolled q4_0 quant * Revert "Try manually unrolled q4_0 quant" This reverts commit 77f8b96515f7e640ae4b0e44f066321fbc4a6166. * Move to constant max wg size * Check for tensor size in supports_op * Vectorize f32 and change default workgroup size * Move f32 get_rows from < 4 to % 4 != 0 * fix linter errors * Add in-place tests --------- Co-authored-by: Neha Abbas <nehaabbas@ReeseLevines-MacBook-Pro.local>
This commit is contained in:
44
ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl
Normal file
44
ggml/src/ggml-webgpu/wgsl-shaders/add.tmpl.wgsl
Normal file
@@ -0,0 +1,44 @@
|
||||
#define(VARIANTS)
|
||||
|
||||
[
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f32",
|
||||
}
|
||||
},
|
||||
{
|
||||
"REPLS": {
|
||||
"TYPE" : "f16",
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
#end(VARIANTS)
|
||||
|
||||
#define(SHADER)
|
||||
|
||||
enable f16;
|
||||
|
||||
#include "binary_head.tmpl"
|
||||
|
||||
@group(0) @binding(0)
|
||||
var<storage, read_write> src0: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(1)
|
||||
var<storage, read_write> src1: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(2)
|
||||
var<storage, read_write> dst: array<{{TYPE}}>;
|
||||
|
||||
@group(0) @binding(3)
|
||||
var<uniform> params: Params;
|
||||
|
||||
override wg_size: u32;
|
||||
@compute @workgroup_size(wg_size)
|
||||
fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
|
||||
if (gid.x < params.ne) {
|
||||
dst[params.offset_dst + gid.x] = src0[params.offset_src0 + gid.x] + src1[params.offset_src1 + src1_index(gid.x)];
|
||||
}
|
||||
}
|
||||
|
||||
#end(SHADER)
|
||||
Reference in New Issue
Block a user