metal : dynamic simdgroups for MV kernels (#16340)

* metal : dynamic simdgroups for MV kernels

* cont : minor
This commit is contained in:
Georgi Gerganov
2025-09-30 11:03:23 +03:00
committed by GitHub
parent 3c62aed89f
commit 35fb82497e
4 changed files with 119 additions and 96 deletions

View File

@@ -1565,6 +1565,12 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
} else {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv(lib, op);
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
ggml_metal_kargs_mul_mv args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
@@ -1582,16 +1588,11 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
/*.nb13 =*/ nb13,
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.nr0 =*/ nr0,
/*.r2 =*/ r2,
/*.r3 =*/ r3,
};
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
ggml_metal_encoder_set_pipeline(enc, pipeline);
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1);
@@ -1758,6 +1759,14 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_metal_encoder_dispatch_threadgroups(enc, (ne21 + 31)/32, (ne01 + 63)/64, ne02, 128, 1, 1);
}
} else {
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
ggml_metal_kargs_mul_mv_id args = {
/*.nei0 =*/ ne20,
/*.nei1 =*/ ne21,
@@ -1778,16 +1787,9 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
/*.ne0 =*/ ne0,
/*.ne1 =*/ ne1,
/*.nb1 =*/ nb1,
/*.nr0 =*/ nr0,
};
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mv_id(lib, op);
const int nr0 = ggml_metal_pipeline_get_nr0(pipeline);
const int nr1 = ggml_metal_pipeline_get_nr1(pipeline);
const int nsg = ggml_metal_pipeline_get_nsg(pipeline);
const size_t smem = ggml_metal_pipeline_get_smem(pipeline);
if (ggml_is_quantized(op->src[0]->type)) {
GGML_ASSERT(ne00 >= nsg*nr0);
}