ggml: allow casting between f32 and i32 (#15783)

* ggml: allow casting between f32 and i32

* fix cuda

* add vulkan

* fix CPU non-cont

* add non-cont test case

* add note

* extend test number range

* correct note

* add cont version for vulkan
This commit is contained in:
Xuan-Son Nguyen
2025-09-08 17:33:01 +07:00
committed by GitHub
parent 5ef22d281d
commit 9fcb29f22f
12 changed files with 247 additions and 3 deletions

View File

@@ -583,6 +583,8 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_CPY_F16_F32,
GGML_METAL_KERNEL_TYPE_CPY_BF16_F32,
GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16,
GGML_METAL_KERNEL_TYPE_CPY_F32_I32,
GGML_METAL_KERNEL_TYPE_CPY_I32_F32,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0,
GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1,
@@ -1616,6 +1618,8 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F16_F16, cpy_f16_f16, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_F32, cpy_bf16_f32, use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_BF16_BF16, cpy_bf16_bf16, use_bfloat);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_I32, cpy_f32_i32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_I32_F32, cpy_i32_f32, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0, cpy_f32_q8_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_0, cpy_f32_q4_0, true);
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_CPY_F32_Q4_1, cpy_f32_q4_1, true);
@@ -1945,6 +1949,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
case GGML_TYPE_Q5_0:
case GGML_TYPE_Q5_1:
case GGML_TYPE_IQ4_NL:
case GGML_TYPE_I32:
return true;
default:
return false;
@@ -1977,6 +1982,8 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
default:
return false;
}
case GGML_TYPE_I32:
return op->type == GGML_TYPE_F32;
default:
return false;
};
@@ -5680,6 +5687,7 @@ static int ggml_metal_encode_node(
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline; break;
case GGML_TYPE_I32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_I32].pipeline; break;
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F16].pipeline; break;
case GGML_TYPE_BF16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_BF16].pipeline; break;
case GGML_TYPE_Q8_0: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_Q8_0].pipeline; break;
@@ -5691,6 +5699,13 @@ static int ggml_metal_encode_node(
default: GGML_ABORT("not implemented");
};
} break;
case GGML_TYPE_I32:
{
switch (dstt) {
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_I32_F32].pipeline; break;
default: GGML_ABORT("not implemented");
};
} break;
case GGML_TYPE_F16:
{
switch (dstt) {

View File

@@ -5338,6 +5338,8 @@ typedef decltype(kernel_cpy<float, float>) kernel_cpy_t;
template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy<float, float>;
template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy<float, half>;
template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy<float, int32_t>;
template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy<int32_t, float>;
#if defined(GGML_METAL_USE_BF16)
template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy<float, bfloat>;
#endif