metal : restore im2col perf (#16219)
This commit is contained in:
@@ -2768,7 +2768,6 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
||||
const uint64_t ofs0 = op->src[1]->nb[is_2D ? 3 : 2] / 4;
|
||||
const uint64_t ofs1 = op->src[1]->nb[is_2D ? 2 : 1] / 4;
|
||||
|
||||
|
||||
ggml_metal_kargs_im2col args = {
|
||||
/*.ofs0 =*/ ofs0,
|
||||
/*.ofs1 =*/ ofs1,
|
||||
@@ -2789,15 +2788,16 @@ int ggml_metal_op_im2col(ggml_metal_op_t ctx, int idx) {
|
||||
|
||||
ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_im2col(lib, op);
|
||||
|
||||
const uint64_t n_threads = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), N);
|
||||
const int64_t quotient = N / n_threads + (N % n_threads > 0 ? 1 : 0);
|
||||
GGML_ASSERT(KH*KW <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline));
|
||||
|
||||
const uint64_t ntptg0 = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)/(KH*KW), N);
|
||||
|
||||
ggml_metal_encoder_set_pipeline(enc, pipeline);
|
||||
ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 1);
|
||||
ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2);
|
||||
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, quotient * CHW, OH, OW, n_threads, 1, 1);
|
||||
ggml_metal_encoder_dispatch_threadgroups(enc, IC, OH, OW, ntptg0, KH, KW);
|
||||
|
||||
return 1;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user