vulkan: in flash attention, bounds check against nem1 (don't rely on GGML_KQ_MASK_PAD) (#16316)

This commit is contained in:
Jeff Bolz
2025-10-03 03:33:08 -05:00
committed by GitHub
parent 136bda78c5
commit e308efda8e
4 changed files with 27 additions and 12 deletions

View File

@@ -153,12 +153,13 @@ void main() {
}
if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) {
bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0;
[[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) {
uint32_t c = (idx + tid) % Bc;
uint32_t r = (idx + tid) / Bc;
if (idx + tid < Bc * Br) {
if (!KV_bounds_check || j * Bc + c < KV) {
if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) {
masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]);
} else {
masksh[c][r] = float(0);