|
|
|
|
@@ -7743,7 +7743,7 @@ kernel void kernel_get_rows_i32(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename block_q, void (*quantize_func)(device const float *, device block_q &)>
|
|
|
|
|
template<typename TI, typename block_q, void (*quantize_func)(device const float *, device block_q &)>
|
|
|
|
|
kernel void kernel_set_rows_q32(
|
|
|
|
|
constant ggml_metal_kargs_set_rows & args,
|
|
|
|
|
device const void * src0,
|
|
|
|
|
@@ -7764,7 +7764,7 @@ kernel void kernel_set_rows_q32(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const int32_t i10 = i01;
|
|
|
|
|
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
|
|
|
|
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
|
|
|
|
|
|
|
|
|
device block_q * dst_row = ( device block_q *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
|
|
|
|
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
|
|
|
|
@@ -7774,7 +7774,7 @@ kernel void kernel_set_rows_q32(
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
template<typename T>
|
|
|
|
|
template<typename T, typename TI>
|
|
|
|
|
kernel void kernel_set_rows_f(
|
|
|
|
|
constant ggml_metal_kargs_set_rows & args,
|
|
|
|
|
device const void * src0,
|
|
|
|
|
@@ -7795,7 +7795,7 @@ kernel void kernel_set_rows_f(
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
const int32_t i10 = i01;
|
|
|
|
|
const int64_t i1 = ((const device int64_t *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
|
|
|
|
const TI i1 = ((const device TI *) ((const device char *) src1 + i10*args.nb10 + i11*args.nb11 + i12*args.nb12))[0];
|
|
|
|
|
|
|
|
|
|
device T * dst_row = ( device T *) (( device char *) dst + i1*args.nb1 + i02*args.nb2 + i03*args.nb3);
|
|
|
|
|
const device float * src_row = (const device float *) ((const device char *) src0 + i01*args.nb01 + i02*args.nb02 + i03*args.nb03);
|
|
|
|
|
@@ -8218,22 +8218,31 @@ template [[host_name("kernel_get_rows_iq4_xs")]] kernel get_rows_q_t kernel_get
|
|
|
|
|
// set rows
|
|
|
|
|
//
|
|
|
|
|
|
|
|
|
|
typedef decltype(kernel_set_rows_f<float>) set_rows_f_t;
|
|
|
|
|
typedef decltype(kernel_set_rows_f<float, int64_t>) set_rows_f_t;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_set_rows_f32")]] kernel set_rows_f_t kernel_set_rows_f<float>;
|
|
|
|
|
template [[host_name("kernel_set_rows_f16")]] kernel set_rows_f_t kernel_set_rows_f<half>;
|
|
|
|
|
template [[host_name("kernel_set_rows_f32_i64")]] kernel set_rows_f_t kernel_set_rows_f<float, int64_t>;
|
|
|
|
|
template [[host_name("kernel_set_rows_f32_i32")]] kernel set_rows_f_t kernel_set_rows_f<float, int32_t>;
|
|
|
|
|
template [[host_name("kernel_set_rows_f16_i64")]] kernel set_rows_f_t kernel_set_rows_f<half, int64_t>;
|
|
|
|
|
template [[host_name("kernel_set_rows_f16_i32")]] kernel set_rows_f_t kernel_set_rows_f<half, int32_t>;
|
|
|
|
|
#if defined(GGML_METAL_HAS_BF16)
|
|
|
|
|
template [[host_name("kernel_set_rows_bf16")]] kernel set_rows_f_t kernel_set_rows_f<bfloat>;
|
|
|
|
|
template [[host_name("kernel_set_rows_bf16_i64")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int64_t>;
|
|
|
|
|
template [[host_name("kernel_set_rows_bf16_i32")]] kernel set_rows_f_t kernel_set_rows_f<bfloat, int32_t>;
|
|
|
|
|
#endif
|
|
|
|
|
|
|
|
|
|
typedef decltype(kernel_set_rows_q32<block_q8_0, quantize_q8_0>) set_rows_q32_t;
|
|
|
|
|
typedef decltype(kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>) set_rows_q32_t;
|
|
|
|
|
|
|
|
|
|
template [[host_name("kernel_set_rows_q8_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q8_0, quantize_q8_0>;
|
|
|
|
|
template [[host_name("kernel_set_rows_q4_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_0, quantize_q4_0>;
|
|
|
|
|
template [[host_name("kernel_set_rows_q4_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q4_1, quantize_q4_1>;
|
|
|
|
|
template [[host_name("kernel_set_rows_q5_0")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_0, quantize_q5_0>;
|
|
|
|
|
template [[host_name("kernel_set_rows_q5_1")]] kernel set_rows_q32_t kernel_set_rows_q32<block_q5_1, quantize_q5_1>;
|
|
|
|
|
template [[host_name("kernel_set_rows_iq4_nl")]] kernel set_rows_q32_t kernel_set_rows_q32<block_iq4_nl, quantize_iq4_nl>;
|
|
|
|
|
template [[host_name("kernel_set_rows_q8_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q8_0, quantize_q8_0>;
|
|
|
|
|
template [[host_name("kernel_set_rows_q8_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q8_0, quantize_q8_0>;
|
|
|
|
|
template [[host_name("kernel_set_rows_q4_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_0, quantize_q4_0>;
|
|
|
|
|
template [[host_name("kernel_set_rows_q4_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_0, quantize_q4_0>;
|
|
|
|
|
template [[host_name("kernel_set_rows_q4_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q4_1, quantize_q4_1>;
|
|
|
|
|
template [[host_name("kernel_set_rows_q4_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q4_1, quantize_q4_1>;
|
|
|
|
|
template [[host_name("kernel_set_rows_q5_0_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_0, quantize_q5_0>;
|
|
|
|
|
template [[host_name("kernel_set_rows_q5_0_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_0, quantize_q5_0>;
|
|
|
|
|
template [[host_name("kernel_set_rows_q5_1_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_q5_1, quantize_q5_1>;
|
|
|
|
|
template [[host_name("kernel_set_rows_q5_1_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_q5_1, quantize_q5_1>;
|
|
|
|
|
template [[host_name("kernel_set_rows_iq4_nl_i64")]] kernel set_rows_q32_t kernel_set_rows_q32<int64_t, block_iq4_nl, quantize_iq4_nl>;
|
|
|
|
|
template [[host_name("kernel_set_rows_iq4_nl_i32")]] kernel set_rows_q32_t kernel_set_rows_q32<int32_t, block_iq4_nl, quantize_iq4_nl>;
|
|
|
|
|
|
|
|
|
|
//
|
|
|
|
|
// matrix-matrix multiplication
|
|
|
|
|
|