sync from b7516
This commit is contained in:
@@ -1684,60 +1684,3 @@ ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_opt_step_sgd(ggm
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_memset(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
GGML_ASSERT(op->type == GGML_TYPE_I64);
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
snprintf(base, 256, "kernel_memset_%s", ggml_type_name(op->type));
|
||||
snprintf(name, 256, "%s", base);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr);
|
||||
}
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
ggml_metal_pipeline_with_params ggml_metal_library_get_pipeline_count_equal(ggml_metal_library_t lib, const ggml_tensor * op) {
|
||||
assert(op->op == GGML_OP_COUNT_EQUAL);
|
||||
|
||||
GGML_TENSOR_LOCALS(int64_t, ne0, op->src[0], ne);
|
||||
|
||||
GGML_ASSERT(op->src[0]->type == op->src[1]->type);
|
||||
GGML_ASSERT(op->src[0]->type == GGML_TYPE_I32);
|
||||
GGML_ASSERT(op->type == GGML_TYPE_I64);
|
||||
|
||||
// note: the kernel only supports i32 output due to metal atomic add only supporting atomic_int
|
||||
GGML_ASSERT(ggml_nelements(op->src[0]) < (1LL << 31));
|
||||
|
||||
char base[256];
|
||||
char name[256];
|
||||
|
||||
int nsg = 1;
|
||||
while (32*nsg < ne00 && nsg < 32) {
|
||||
nsg *= 2;
|
||||
}
|
||||
|
||||
snprintf(base, 256, "kernel_count_equal_%s", ggml_type_name(op->src[0]->type));
|
||||
snprintf(name, 256, "%s_nsg=%d", base, nsg);
|
||||
|
||||
ggml_metal_pipeline_with_params res = ggml_metal_library_get_pipeline(lib, name);
|
||||
if (!res.pipeline) {
|
||||
ggml_metal_cv_t cv = ggml_metal_cv_init();
|
||||
|
||||
ggml_metal_cv_set_int16(cv, nsg, FC_COUNT_EQUAL + 0);
|
||||
|
||||
res = ggml_metal_library_compile_pipeline(lib, base, name, cv);
|
||||
|
||||
ggml_metal_cv_free(cv);
|
||||
}
|
||||
|
||||
res.smem = 32 * sizeof(int32_t);
|
||||
res.nsg = nsg;
|
||||
|
||||
return res;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user