metal : extend mat-mat multiplication support (#16225)
* metal : support mul_mm with src1->type == GGML_TYPE_F16 * metal : support mul_mm_id with src1->type == GGML_TYPE_F16 [no ci] * metal : mul_mm support ne00 % 32 != 0 * metal : support mul_mm_id with ne00 % 32 != 0 * cont : remove unnecessary unrolls * cont : simplify data loading * metal : optimize mul_mm when output bounds checks are not needed
This commit is contained in:
@@ -438,21 +438,35 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mv_ext(ggml_metal_libr
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) {
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const ggml_type tsrc0 = op->src[0]->type;
|
||||
const ggml_type tsrc1 = op->src[1]->type;
|
||||
|
||||
const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
|
||||
const bool bc_out = op->ne[0] % 64 != 0 || op->ne[1] % 32 != 0;
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mm_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_bci=%d_bco=%d", base, bc_inp, bc_out);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_pipeline_set_smem(res, 8192);
|
||||
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
|
||||
ggml_metal_cv_set_bool(cv, bc_out, FC_MUL_MM + 1);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
// when the output size is not multiple of 64x32, we need extra smem to prevent out-of-bounds writes
|
||||
ggml_metal_pipeline_set_smem(res, bc_out ? 8192 : 4096 + 2048);
|
||||
|
||||
return res;
|
||||
}
|
||||
@@ -659,19 +673,30 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id_map0(ggml_metal_
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, ggml_type tsrc0, ggml_type tsrc1) {
|
||||
ggml_metal_pipeline_t ggml_metal_library_get_pipeline_mul_mm_id(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
const ggml_type tsrc0 = op->src[0]->type;
|
||||
const ggml_type tsrc1 = op->src[1]->type;
|
||||
|
||||
const bool bc_inp = op->src[0]->ne[0] % 32 != 0;
|
||||
|
||||
snprintf(base, 256, "kernel_mul_mm_id_%s_%s", ggml_type_name(tsrc0), ggml_type_name(tsrc1));
|
||||
snprintf(name, 256, "%s", base);
|
||||
snprintf(name, 256, "%s_bci=%d", base, bc_inp);
|
||||
|
||||
ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (res) {
|
||||
return res;
|
||||
}
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_bool(cv, bc_inp, FC_MUL_MM + 0);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
|
||||
ggml_metal_pipeline_set_smem(res, 8192);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user