metal : improve MUL_MAT_ID (#15541)
* metal : mul_mm_id remove hdst * metal : remove mul_mm_id hsrc1 * metal : mul_mm_id simplify + add test * metal : opt mul_mm_id map0 * metal : optimize mul_mm_id id gathering * metal : mul/div opt * metal : optimize mul_mm_id_map0 ggml-ci
This commit is contained in:
@@ -398,8 +398,12 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16,
|
||||
GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16,
|
||||
@@ -1428,8 +1432,12 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ1_M_F32, mul_mm_iq1_m_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_NL_F32, mul_mm_iq4_nl_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_IQ4_XS_F32, mul_mm_iq4_xs_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16, mul_mm_id_map0_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32, mul_mm_id_map1_f32, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1, mul_mm_id_map0_f16_ne20_1, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2, mul_mm_id_map0_f16_ne20_2, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4, mul_mm_id_map0_f16_ne20_4, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6, mul_mm_id_map0_f16_ne20_6, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8, mul_mm_id_map0_f16_ne20_8, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16, mul_mm_id_map0_f16_ne20_16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F32_F16, mul_mm_id_f32_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_F16_F16, mul_mm_id_f16_f16, has_simdgroup_mm);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_MUL_MM_ID_BF16_F16, mul_mm_id_bf16_f16, has_simdgroup_mm && use_bfloat);
|
||||
@@ -3908,38 +3916,6 @@ static int ggml_metal_encode_node(
|
||||
default: break;
|
||||
}
|
||||
|
||||
const int64_t neh10 = ne10; // n_embd
|
||||
const int64_t neh11 = ne21; // n_tokens
|
||||
const int64_t neh12 = ne02; // n_expert
|
||||
|
||||
const uint64_t nbh10 = ggml_type_size(GGML_TYPE_F16);
|
||||
const uint64_t nbh11 = nbh10*neh10;
|
||||
const uint64_t nbh12 = nbh11*neh11;
|
||||
const uint64_t nbh13 = nbh12*neh12;
|
||||
|
||||
const size_t s_src1 = ggml_type_size(GGML_TYPE_F16)*neh10*neh11*neh12;
|
||||
id<MTLBuffer> h_src1 = ggml_metal_mem_pool_alloc(mem_pool, s_src1);
|
||||
if (!h_src1) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_src1);
|
||||
return 0;
|
||||
}
|
||||
|
||||
const int64_t neh0 = ne0;
|
||||
const int64_t neh1 = ne21;
|
||||
const int64_t neh2 = ne02;
|
||||
|
||||
const uint64_t nbh0 = ggml_type_size(GGML_TYPE_F32);
|
||||
const uint64_t nbh1 = nbh0*neh0;
|
||||
const uint64_t nbh2 = nbh1*neh1;
|
||||
//const uint64_t nbh3 = nbh2*neh2;
|
||||
|
||||
const size_t s_dst = ggml_type_size(GGML_TYPE_F32)*neh0*neh1*neh2;
|
||||
id<MTLBuffer> h_dst = ggml_metal_mem_pool_alloc(mem_pool, s_dst);
|
||||
if (!h_dst) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_dst);
|
||||
return 0;
|
||||
}
|
||||
|
||||
// tokens per expert
|
||||
const size_t s_tpe = ggml_type_size(GGML_TYPE_I32)*ne02;
|
||||
id<MTLBuffer> h_tpe = ggml_metal_mem_pool_alloc(mem_pool, s_tpe);
|
||||
@@ -3949,8 +3925,8 @@ static int ggml_metal_encode_node(
|
||||
}
|
||||
|
||||
// id map
|
||||
// [n_expert_used, n_tokens]
|
||||
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne20*ne21;
|
||||
// [n_tokens, n_expert]
|
||||
const size_t s_ids = ggml_type_size(GGML_TYPE_I32)*ne21*ne02;
|
||||
id<MTLBuffer> h_ids = ggml_metal_mem_pool_alloc(mem_pool, s_ids);
|
||||
if (!h_ids) {
|
||||
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, s_ids);
|
||||
@@ -3958,32 +3934,45 @@ static int ggml_metal_encode_node(
|
||||
}
|
||||
|
||||
{
|
||||
const int nth = MIN(1024, ne10/4);
|
||||
|
||||
ggml_metal_kargs_mul_mm_id_map0 args = {
|
||||
ne02,
|
||||
ne10,
|
||||
ne11, // n_expert_used (bcast)
|
||||
ne11, // n_expert_used (bcast)
|
||||
nb11,
|
||||
nb12,
|
||||
neh11, // n_tokens
|
||||
nbh11,
|
||||
ne20, // n_expert_used
|
||||
ne21, // n_tokens
|
||||
ne20, // n_expert_used
|
||||
nb21,
|
||||
};
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16].pipeline;
|
||||
pipeline = nil;
|
||||
|
||||
switch (ne20) {
|
||||
case 1: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_1 ].pipeline; break;
|
||||
case 2: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_2 ].pipeline; break;
|
||||
case 4: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_4 ].pipeline; break;
|
||||
case 6: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_6 ].pipeline; break;
|
||||
case 8: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_8 ].pipeline; break;
|
||||
case 16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP0_F16_NE20_16].pipeline; break;
|
||||
default: GGML_ABORT("missing specialization for ne20 = %d", (int) ne20);
|
||||
}
|
||||
|
||||
GGML_ASSERT(ne02 <= (int) pipeline.maxTotalThreadsPerThreadgroup);
|
||||
|
||||
const size_t smem = ne02*ne20*sizeof(uint16_t);
|
||||
|
||||
GGML_ASSERT(smem <= device.maxThreadgroupMemoryLength);
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:2];
|
||||
[encoder setBuffer: h_src1 offset:0 atIndex:3];
|
||||
[encoder setBuffer: h_tpe offset:0 atIndex:4];
|
||||
[encoder setBuffer: h_ids offset:0 atIndex:5];
|
||||
[encoder setBuffer:id_src2 offset:offs_src2 atIndex:1];
|
||||
[encoder setBuffer: h_tpe offset:0 atIndex:2];
|
||||
[encoder setBuffer: h_ids offset:0 atIndex:3];
|
||||
[encoder setThreadgroupMemoryLength:smem atIndex:0];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne02, 1, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(1, 1, 1) threadsPerThreadgroup:MTLSizeMake(ne02, 1, 1)];
|
||||
}
|
||||
|
||||
{
|
||||
@@ -4022,13 +4011,15 @@ static int ggml_metal_encode_node(
|
||||
/*.nb01 =*/ nb01,
|
||||
/*.nb02 =*/ nb02,
|
||||
/*.nb03 =*/ nb03,
|
||||
/*.neh12 =*/ neh12,
|
||||
/*.nbh10 =*/ nbh10,
|
||||
/*.nbh11 =*/ nbh11,
|
||||
/*.nbh12 =*/ nbh12,
|
||||
/*.nbh13 =*/ nbh13,
|
||||
/*.neh0 =*/ neh0,
|
||||
/*.neh1 =*/ neh1,
|
||||
/*.ne11 =*/ ne11, // n_expert_used (bcast)
|
||||
/*.nb10 =*/ nb10,
|
||||
/*.nb11 =*/ nb11,
|
||||
/*.nb12 =*/ nb12,
|
||||
/*.nb13 =*/ nb13,
|
||||
/*.ne20 =*/ ne20, // n_expert_used
|
||||
/*.ne21 =*/ ne21, // n_tokens
|
||||
/*.ne0 =*/ ne0,
|
||||
/*.ne1 =*/ ne1,
|
||||
/*.r2 =*/ r2,
|
||||
/*.r3 =*/ r3,
|
||||
};
|
||||
@@ -4036,42 +4027,14 @@ static int ggml_metal_encode_node(
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer: h_src1 offset:0 atIndex:2];
|
||||
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:2];
|
||||
[encoder setBuffer: h_tpe offset:0 atIndex:3];
|
||||
[encoder setBuffer: h_dst offset:0 atIndex:4];
|
||||
[encoder setBuffer: h_ids offset:0 atIndex:4];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:5];
|
||||
|
||||
[encoder setThreadgroupMemoryLength:8192 atIndex:0];
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne21 + 31)/32, (ne01 + 63)/64, ne02) threadsPerThreadgroup:MTLSizeMake(128, 1, 1)];
|
||||
}
|
||||
|
||||
{
|
||||
GGML_ASSERT(ne0 % 4 == 0);
|
||||
|
||||
const int nth = MIN(1024, ne0/4);
|
||||
|
||||
ggml_metal_kargs_mul_mm_id_map1 args = {
|
||||
ne20, // n_expert_used
|
||||
neh0,
|
||||
neh1,
|
||||
nbh1,
|
||||
nbh2,
|
||||
ne0,
|
||||
nb1,
|
||||
nb2,
|
||||
};
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_MUL_MM_ID_MAP1_F32].pipeline;
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
[encoder setBytes:&args length:sizeof(args) atIndex:0];
|
||||
[encoder setBuffer: h_dst offset:0 atIndex:1];
|
||||
[encoder setBuffer: h_ids offset:0 atIndex:2];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:3];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne20, ne21, 1) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
}
|
||||
} else {
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
|
||||
Reference in New Issue
Block a user