metal : make the backend async (#15906)
* metal : make the backend async ggml-ci * cont : add comments, extend op offload, clean up ggml-ci * metal : fix batch size for MUL_MAT_ID * metal : remove deprecated ggml_backend_metal_buffer_from_ptr * metal : create only metal buffers, no wrapping of host memory ggml-ci * metal : restore .alloc_buffer for buffer_from_ptr_type ggml-ci * metal : remove broken implementation of GGML_OP_SET ggml-ci * metal : clean-up loose ends, ready for tests ggml-ci * metal : support both private and shared buffers ggml-ci * metal : enable private buffers + add global device queue * metal : disable host buffer to prevent races ggml-ci * metal : avoid extra copy during set_tensor ggml-ci * metal : use separate buffer types for shread and private Metal buffers ggml-ci * metal : simplify synchronization logic ggml-ci * metal : fix build ggml-ci * metal : do not implement cpy_tensor ggml-ci * metal : separate implementations for shared and private buffers ggml-ci
This commit is contained in:
File diff suppressed because it is too large
Load Diff
@@ -5571,38 +5571,6 @@ kernel void kernel_flash_attn_ext_vec_reduce(
|
||||
#undef DV
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_set(
|
||||
constant ggml_metal_kargs_set & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device char * dst,
|
||||
uint3 tgpig[[threadgroup_position_in_grid]],
|
||||
ushort3 tpitg[[thread_position_in_threadgroup]],
|
||||
ushort3 ntg[[threads_per_threadgroup]]) {
|
||||
const int i13 = tgpig[2];
|
||||
const int i12 = tgpig[1];
|
||||
const int i11 = tgpig[0];
|
||||
|
||||
const int64_t n = i13*args.ne12*args.ne11*args.ne10 + i12*args.ne11*args.ne10 + i11*args.ne10;
|
||||
|
||||
const int64_t i3 = n / (args.ne12*args.ne11*args.ne10);
|
||||
const int64_t i2 = (n - i3*args.ne12*args.ne11*args.ne10) / (args.ne11*args.ne10);
|
||||
const int64_t i1 = (n - i3*args.ne12*args.ne11*args.ne10 - i2*args.ne11*args.ne10) / args.ne10;
|
||||
|
||||
device T * dst_data = (device T *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + args.offs);
|
||||
|
||||
for (int64_t i10 = tpitg.x; i10 < args.ne10; i10 += ntg.x) {
|
||||
device const T * src = (device T *) (src1 + i13*args.nb13 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10);
|
||||
dst_data[i10] = (T) src[0];
|
||||
}
|
||||
}
|
||||
|
||||
typedef decltype(kernel_set<float>) kernel_set_t;
|
||||
|
||||
template [[host_name("kernel_set_f32")]] kernel kernel_set_t kernel_set<float>;
|
||||
template [[host_name("kernel_set_i32")]] kernel kernel_set_t kernel_set<int32_t>;
|
||||
|
||||
template<typename T0, typename T1>
|
||||
kernel void kernel_cpy(
|
||||
constant ggml_metal_kargs_cpy & args,
|
||||
|
||||
Reference in New Issue
Block a user