metal : batch rows copy in a single threadgroup (#14384)
* metal : batch rows copy in a single threadgroup ggml-ci * metal : handle some edge cases when threadgroup size is not a power of 2 ggml-ci
This commit is contained in:
@@ -2450,6 +2450,7 @@ static bool ggml_metal_encode_node(
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
||||
nth = MIN(nth, ne00);
|
||||
|
||||
ggml_metal_kargs_sum_rows args = {
|
||||
@@ -3780,6 +3781,7 @@ static bool ggml_metal_encode_node(
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
||||
nth = MIN(nth, ne00/4);
|
||||
|
||||
ggml_metal_kargs_rms_norm args = {
|
||||
@@ -3816,6 +3818,7 @@ static bool ggml_metal_encode_node(
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
||||
nth = MIN(nth, ne00/4);
|
||||
|
||||
ggml_metal_kargs_l2_norm args = {
|
||||
@@ -3888,6 +3891,7 @@ static bool ggml_metal_encode_node(
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
||||
nth = MIN(nth, ne00/4);
|
||||
|
||||
ggml_metal_kargs_norm args = {
|
||||
@@ -4974,8 +4978,39 @@ static bool ggml_metal_encode_node(
|
||||
default: GGML_ABORT("not implemented");
|
||||
}
|
||||
|
||||
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
||||
|
||||
// TODO: support
|
||||
//const int32_t nk00 = ne00/ggml_blck_size(dst->type);
|
||||
const int32_t nk00 = ne00;
|
||||
|
||||
int nth = 32; // SIMD width
|
||||
|
||||
while (nth < nk00 && nth < (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nth *= 2;
|
||||
}
|
||||
|
||||
nth = MIN(nth, (int) pipeline.maxTotalThreadsPerThreadgroup);
|
||||
|
||||
// when rows are small, we can batch them together in a single threadgroup
|
||||
int nrptg = 1;
|
||||
|
||||
// TODO: relax this constraint in the future
|
||||
if (ggml_blck_size(src0->type) == 1 && ggml_blck_size(dst->type) == 1) {
|
||||
if (nth > nk00) {
|
||||
nrptg = (nth + nk00 - 1)/nk00;
|
||||
nth = nk00;
|
||||
|
||||
if (nrptg*nth > (int) pipeline.maxTotalThreadsPerThreadgroup) {
|
||||
nrptg--;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
nth = MIN(nth, nk00);
|
||||
|
||||
ggml_metal_kargs_cpy args = {
|
||||
/*.ne00 =*/ ne00,
|
||||
/*.ne00 =*/ nk00,
|
||||
/*.ne01 =*/ ne01,
|
||||
/*.ne02 =*/ ne02,
|
||||
/*.ne03 =*/ ne03,
|
||||
@@ -4998,11 +5033,7 @@ static bool ggml_metal_encode_node(
|
||||
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
|
||||
[encoder setBuffer:id_dst offset:offs_dst atIndex:2];
|
||||
|
||||
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
|
||||
int nth = MIN(1024, ne00/ggml_blck_size(src0->type));
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, 1, 1)];
|
||||
|
||||
[encoder dispatchThreadgroups:MTLSizeMake((ne01 + nrptg - 1)/nrptg, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth, nrptg, 1)];
|
||||
} break;
|
||||
case GGML_OP_SET:
|
||||
{
|
||||
|
||||
Reference in New Issue
Block a user