ggml : fix FA mask dim 2 and 3 (#14505)

* ggml : fix FA mask dim 2 and 3

ggml-ci

* backends : unsupport batched FA in CUDA and Vulkan

ggml-ci

* vulkan : disable FA for mask->ne[2] != 1
This commit is contained in:
Georgi Gerganov
2025-07-03 10:46:57 +03:00
committed by GitHub
parent d4cdd9c1c3
commit 9067487c44
9 changed files with 26 additions and 15 deletions

View File

@@ -230,8 +230,10 @@ typedef struct {
uint64_t nb22;
uint64_t nb23;
int32_t ne32;
int32_t ne33;
uint64_t nb31;
uint64_t nb32;
uint64_t nb33;
int32_t ne1;
int32_t ne2;
float scale;

View File

@@ -5018,8 +5018,10 @@ static bool ggml_metal_encode_node(
/*.nb22 =*/ nb22,
/*.nb23 =*/ nb23,
/*.ne32 =*/ ne32,
/*.ne33 =*/ ne33,
/*.nb31 =*/ nb31,
/*.nb32 =*/ nb32,
/*.nb33 =*/ nb33,
/*.ne1 =*/ ne1,
/*.ne2 =*/ ne2,
/*.scale =*/ scale,

View File

@@ -3857,7 +3857,7 @@ kernel void kernel_flash_attn_ext(
// load the mask in shared memory
#pragma unroll(Q)
for (short j = 0; j < Q; ++j) {
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq3%args.ne32)*args.nb32);
device const half * pm = (device const half *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
const float m = pm[ic + tiisg];
@@ -4343,7 +4343,7 @@ kernel void kernel_flash_attn_ext_vec(
const bool has_mask = mask != q;
// pointer to the mask
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq3%args.ne32)*args.nb32);
device const half * pm = (device const half *) (mask + iq1*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33);
float slope = 1.0f;