metal : fuse NORM + MUL + ADD, support non-multiples of 4 (#16220)
* metal : fuse NORM + MUL + ADD * metal : support norms of non-multiple of 4 * cont : fix comment [no ci]
This commit is contained in:
@@ -1090,36 +1090,6 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_bin(
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rms_norm(ggml_metal_library_t lib, const ggml_tensor * op, int32_t n_fuse) {
|
||||
assert(op->op == GGML_OP_RMS_NORM);
|
||||
|
||||
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
switch (n_fuse) {
|
||||
case 1: snprintf(base, 256, "kernel_rms_norm_f32"); break;
|
||||
case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32"); break;
|
||||
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32"); break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
|
||||
ggml_metal_pipeline_set_smem(res, 32*sizeof(float));
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_l2_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_L2_NORM);
|
||||
|
||||
@@ -1167,16 +1137,37 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_group_norm(ggml_metal_libr
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_NORM);
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_norm(ggml_metal_library_t lib, const ggml_tensor * op, int n_fuse) {
|
||||
assert(op->op == GGML_OP_NORM || op->op == GGML_OP_RMS_NORM);
|
||||
|
||||
GGML_ASSERT(op->src[0]->ne[0] % 4 == 0);
|
||||
GGML_ASSERT(ggml_is_contiguous_1(op->src[0]));
|
||||
GGML_ASSERT(ggml_is_contiguous_rows(op->src[0]));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_norm_f32");
|
||||
const char * suffix = "";
|
||||
if (op->ne[0] % 4 == 0) {
|
||||
suffix = "_4";
|
||||
}
|
||||
|
||||
switch (op->op) {
|
||||
case GGML_OP_NORM:
|
||||
switch (n_fuse) {
|
||||
case 1: snprintf(base, 256, "kernel_norm_f32%s", suffix); break;
|
||||
case 2: snprintf(base, 256, "kernel_norm_mul_f32%s", suffix); break;
|
||||
case 3: snprintf(base, 256, "kernel_norm_mul_add_f32%s", suffix); break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
} break;
|
||||
case GGML_OP_RMS_NORM:
|
||||
switch (n_fuse) {
|
||||
case 1: snprintf(base, 256, "kernel_rms_norm_f32%s", suffix); break;
|
||||
case 2: snprintf(base, 256, "kernel_rms_norm_mul_f32%s", suffix); break;
|
||||
case 3: snprintf(base, 256, "kernel_rms_norm_mul_add_f32%s", suffix); break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
} break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
}
|
||||
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
|
||||
Reference in New Issue
Block a user