HIP: use v_dot2_f32_f16 instruction for FA (#15884)

This commit is contained in:
Johannes Gäßler
2025-09-09 14:04:43 +02:00
committed by GitHub
parent ed54e32558
commit 17bc5a815f
2 changed files with 26 additions and 6 deletions

View File

@@ -304,12 +304,7 @@ static __global__ void flash_attn_tile(
for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) {
#pragma unroll
for (int j_KQ_0 = 0; j_KQ_0 < ncols; j_KQ_0 += nwarps) {
#ifdef FAST_FP16_AVAILABLE
const float2 tmp = __half22float2(K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps]);
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += tmp.x + tmp.y;
#else
sum[i_KQ_0/warp_size][j_KQ_0/nwarps] += K_k[i_KQ_0/warp_size] * Q_k[j_KQ_0/nwarps];
#endif // FAST_FP16_AVAILABLE
ggml_cuda_mad(sum[i_KQ_0/warp_size][j_KQ_0/nwarps], K_k[i_KQ_0/warp_size], Q_k[j_KQ_0/nwarps]);
}
}
}