vulkan: Support FA with K/V in F32 (#16543)

This commit is contained in:
Jeff Bolz
2025-10-14 08:53:37 -05:00
committed by GitHub
parent 7ea15bb64c
commit 4258e0cfe7
4 changed files with 49 additions and 8 deletions

View File

@@ -2649,11 +2649,13 @@ static void ggml_vk_load_shaders(vk_device& device) {
} \
}
CREATE_FA(GGML_TYPE_F32, f32, FA_SCALAR, )
CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, )
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, )
#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT)
if (device->coopmat1_fa_support) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1)
CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1)
@@ -2661,6 +2663,7 @@ static void ggml_vk_load_shaders(vk_device& device) {
#endif
#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT)
if (device->coopmat2) {
CREATE_FA(GGML_TYPE_F32, f32, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2)
CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2)
@@ -7457,8 +7460,16 @@ static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx
}
const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type));
const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type));
uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type));
// For F32, the shader treats it as a block of size 4 (for vec4 loads)
if (k->type == GGML_TYPE_F32) {
k_stride /= 4;
}
if (v->type == GGML_TYPE_F32) {
v_stride /= 4;
}
uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows);
bool aligned = (KV % alignment) == 0 &&
@@ -12660,6 +12671,7 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
}
switch (op->src[1]->type) {
case GGML_TYPE_F16:
case GGML_TYPE_F32:
case GGML_TYPE_Q4_0:
case GGML_TYPE_Q8_0:
// supported in scalar and coopmat2 paths