metal : extend mat-mat multiplication support (#16225)

* metal : support mul_mm with src1->type == GGML_TYPE_F16

* metal : support mul_mm_id with src1->type == GGML_TYPE_F16

[no ci]

* metal : mul_mm support ne00 % 32 != 0

* metal : support mul_mm_id with ne00 % 32 != 0

* cont : remove unnecessary unrolls

* cont : simplify data loading

* metal : optimize mul_mm when output bounds checks are not needed
This commit is contained in:
Georgi Gerganov
2025-09-28 09:34:44 +03:00
committed by GitHub
parent 3b53634fe3
commit 6a2c6145a0
7 changed files with 247 additions and 120 deletions

View File

@@ -1520,22 +1520,20 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
!ggml_is_transposed(op->src[1]) &&
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
props_dev->has_simdgroup_mm &&
op->src[1]->type == GGML_TYPE_F32 &&
ne00 % 32 == 0 && ne00 >= 64 &&
props_dev->has_simdgroup_mm && ne00 >= 64 &&
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
switch (op->src[0]->type) {
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
default: break;
}
//switch (op->src[0]->type) {
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
// default: break;
//}
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op->src[0]->type, op->src[1]->type);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
ggml_metal_kargs_mul_mm args = {
/*.ne00 =*/ ne00,
@@ -1655,8 +1653,6 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
GGML_ASSERT(!ggml_is_transposed(op->src[0]));
GGML_ASSERT(!ggml_is_transposed(op->src[1]));
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
GGML_ASSERT(ne03 == 1);
GGML_ASSERT(ne13 == 1);
@@ -1674,19 +1670,15 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
// ne21 = n_rows (batch size)
const int ne21_mm_id_min = 32;
if (props_dev->has_simdgroup_mm &&
ne00 % 32 == 0 && ne00 >= 64 &&
(ne21 >= ne21_mm_id_min)) {
GGML_ASSERT(ne00 % 4 == 0);
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
// some Metal matrix data types require aligned pointers
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
switch (op->src[0]->type) {
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
default: break;
}
//switch (op->src[0]->type) {
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
// default: break;
//}
// extra buffers for intermediate id mapping
ggml_metal_buffer_id bid_tpe = bid_dst;
@@ -1730,7 +1722,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
ggml_metal_op_concurrency_reset(ctx);
{
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op->src[0]->type, GGML_TYPE_F16);
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
ggml_metal_kargs_mul_mm_id args = {
/*.ne00 =*/ ne00,