metal : extend mat-mat multiplication support (#16225)
* metal : support mul_mm with src1->type == GGML_TYPE_F16 * metal : support mul_mm_id with src1->type == GGML_TYPE_F16 [no ci] * metal : mul_mm support ne00 % 32 != 0 * metal : support mul_mm_id with ne00 % 32 != 0 * cont : remove unnecessary unrolls * cont : simplify data loading * metal : optimize mul_mm when output bounds checks are not needed
This commit is contained in:
@@ -1520,22 +1520,20 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) {
|
||||
!ggml_is_transposed(op->src[1]) &&
|
||||
// for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs
|
||||
// AMD GPU and older A-chips will reuse matrix-vector multiplication kernel
|
||||
props_dev->has_simdgroup_mm &&
|
||||
op->src[1]->type == GGML_TYPE_F32 &&
|
||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||
props_dev->has_simdgroup_mm && ne00 >= 64 &&
|
||||
(ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) {
|
||||
//printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12);
|
||||
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
default: break;
|
||||
}
|
||||
//switch (op->src[0]->type) {
|
||||
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
// default: break;
|
||||
//}
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op->src[0]->type, op->src[1]->type);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm(lib, op);
|
||||
|
||||
ggml_metal_kargs_mul_mm args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
@@ -1655,8 +1653,6 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
GGML_ASSERT(!ggml_is_transposed(op->src[0]));
|
||||
GGML_ASSERT(!ggml_is_transposed(op->src[1]));
|
||||
|
||||
GGML_ASSERT(op->src[1]->type == GGML_TYPE_F32);
|
||||
|
||||
GGML_ASSERT(ne03 == 1);
|
||||
GGML_ASSERT(ne13 == 1);
|
||||
|
||||
@@ -1674,19 +1670,15 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
// ne21 = n_rows (batch size)
|
||||
const int ne21_mm_id_min = 32;
|
||||
|
||||
if (props_dev->has_simdgroup_mm &&
|
||||
ne00 % 32 == 0 && ne00 >= 64 &&
|
||||
(ne21 >= ne21_mm_id_min)) {
|
||||
GGML_ASSERT(ne00 % 4 == 0);
|
||||
|
||||
if (props_dev->has_simdgroup_mm && ne00 >= 64 && (ne21 >= ne21_mm_id_min)) {
|
||||
// some Metal matrix data types require aligned pointers
|
||||
// ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5)
|
||||
switch (op->src[0]->type) {
|
||||
case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
default: break;
|
||||
}
|
||||
//switch (op->src[0]->type) {
|
||||
// case GGML_TYPE_F32: GGML_ASSERT(nb01 % 16 == 0); break;
|
||||
// case GGML_TYPE_F16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
// case GGML_TYPE_BF16: GGML_ASSERT(nb01 % 8 == 0); break;
|
||||
// default: break;
|
||||
//}
|
||||
|
||||
// extra buffers for intermediate id mapping
|
||||
ggml_metal_buffer_id bid_tpe = bid_dst;
|
||||
@@ -1730,7 +1722,7 @@ int ggml_metal_op_mul_mat_id(ggml_metal_op_t ctx, int idx) {
|
||||
ggml_metal_op_concurrency_reset(ctx);
|
||||
|
||||
{
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op->src[0]->type, GGML_TYPE_F16);
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_mul_mm_id(lib, op);
|
||||
|
||||
ggml_metal_kargs_mul_mm_id args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
|
||||
Reference in New Issue
Block a user