metal : use function constants for mul_mv_ext kernels (#16074)
* metal : use function constants for mul_mv_ext kernels ggml-ci * metal : remove NW template argument ggml-ci * metal : adjust constants ggml-ci
This commit is contained in:
@@ -414,19 +414,26 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_rwkv(ggml_metal_library_t
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int r1ptg) {
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1, int nsg, int nxpsg, int r1ptg) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mv_ext_%s_%s_r1_%d", ggml_type_name(tsrc0), ggml_type_name(tsrc1), r1ptg);
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_nsg=%d_nxpsg=%d", base, nsg, nxpsg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, nsg, FC_MUL_MV + 0);
|
||||
ggml_metal_cv_set_int16(cv, nxpsg, FC_MUL_MV + 1);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -608,7 +615,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv(ggml_metal_library_
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mv_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
@@ -824,7 +831,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_id(ggml_metal_libra
|
||||
};
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mv_id_%s_%s%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1), suffix);
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
@@ -923,11 +930,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext(
|
||||
dk,
|
||||
dv);
|
||||
|
||||
snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||
"flash_attn_ext",
|
||||
ggml_type_name(op->src[1]->type),
|
||||
dk,
|
||||
dv,
|
||||
snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
has_sinks,
|
||||
has_bias,
|
||||
@@ -985,11 +989,8 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec(
|
||||
dk,
|
||||
dv);
|
||||
|
||||
snprintf(name, 256, "kernel_%s_%s_dk%d_dv%d_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
||||
"flash_attn_ext_vec",
|
||||
ggml_type_name(op->src[1]->type),
|
||||
dk,
|
||||
dv,
|
||||
snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d",
|
||||
base,
|
||||
has_mask,
|
||||
has_sinks,
|
||||
has_bias,
|
||||
@@ -1033,7 +1034,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec_reduce(
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_flash_attn_ext_vec_reduce");
|
||||
snprintf(name, 256, "kernel_flash_attn_ext_vec_reduce_dv=%d_nwg=%d", dv, nwg);
|
||||
snprintf(name, 256, "%s_dv=%d_nwg=%d", base, dv, nwg);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
|
||||
Reference in New Issue
Block a user