metal : fuse NORM + MUL + ADD, support non-multiples of 4 (#16220)

* metal : fuse NORM + MUL + ADD

* metal : support norms of non-multiple of 4

* cont : fix comment [no ci]
This commit is contained in:
Georgi Gerganov
2025-09-25 11:30:16 +03:00
committed by GitHub
parent 4ea00794b8
commit dfcd53f7ec
9 changed files with 206 additions and 232 deletions

View File

@@ -661,13 +661,13 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te
case GGML_OP_SOFT_MAX:
case GGML_OP_GROUP_NORM:
return has_simdgroup_reduction && ggml_is_contiguous_rows(op->src[0]);
case GGML_OP_RMS_NORM:
case GGML_OP_L2_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_ARGMAX:
return has_simdgroup_reduction;
case GGML_OP_NORM:
return has_simdgroup_reduction && (op->ne[0] % 4 == 0 && ggml_is_contiguous_1(op->src[0]));
case GGML_OP_RMS_NORM:
return has_simdgroup_reduction && (ggml_is_contiguous_rows(op->src[0]));
case GGML_OP_ROPE:
return true;
case GGML_OP_IM2COL: