ggml: allow casting between f32 and i32 (#15783)
* ggml: allow casting between f32 and i32 * fix cuda * add vulkan * fix CPU non-cont * add non-cont test case * add note * extend test number range * correct note * add cont version for vulkan
This commit is contained in:
@@ -583,6 +583,8 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_I32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_I32_F32,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
|
||||
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
|
||||
@@ -1616,6 +1618,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_I32, cpy_f32_i32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_I32_F32, cpy_i32_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
|
||||
@@ -1945,6 +1949,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_TYPE_Q5_0:
|
||||
case GGML_TYPE_Q5_1:
|
||||
case GGML_TYPE_IQ4_NL:
|
||||
case GGML_TYPE_I32:
|
||||
return true;
|
||||
default:
|
||||
return false;
|
||||
@@ -1977,6 +1982,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
default:
|
||||
return false;
|
||||
}
|
||||
case GGML_TYPE_I32:
|
||||
return op->type == GGML_TYPE_F32;
|
||||
default:
|
||||
return false;
|
||||
};
|
||||
@@ -5680,6 +5687,7 @@ static int ggml_metal_encode_node(
|
||||
|
||||
switch (dstt) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
|
||||
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_I32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
|
||||
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
|
||||
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
|
||||
@@ -5691,6 +5699,13 @@ static int ggml_metal_encode_node(
|
||||
default: GGML_ABORT("not implemented");
|
||||
};
|
||||
} break;
|
||||
case GGML_TYPE_I32:
|
||||
{
|
||||
switch (dstt) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_I32_F32].pipeline; break;
|
||||
default: GGML_ABORT("not implemented");
|
||||
};
|
||||
} break;
|
||||
case GGML_TYPE_F16:
|
||||
{
|
||||
switch (dstt) {
|
||||
|
||||
@@ -5338,6 +5338,8 @@ typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
|
||||
|
||||
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
|
||||
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
|
||||
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy<float, int32_t>;
|
||||
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy<int32_t, float>;
|
||||
#if defined(GGML_METAL_USE_BF16)
|
||||
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
|
||||
#endif
|
||||
|
||||
Reference in New Issue
Block a user