[Auto Sync] Update flashattention_backend.py (20250922) (#10762)

Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
Co-authored-by: Gordon Gustafson <ggustafson@together.ai>
This commit is contained in:
Yineng Zhang
2025-09-22 16:41:42 -07:00
committed by GitHub
parent 662393f27d
commit 0753ef831e

View File

@@ -692,8 +692,13 @@ class FlashAttentionBackend(AttentionBackend):
k_descale, v_descale = None, None
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
# has corresponding quantization method so that layer.k_scale is not None,
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
if (
self.kv_cache_dtype_str != "auto"
and layer.head_dim <= 256
and self.fa_impl_ver != 4
):
if layer.k_scale is not None:
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
k_descale = layer.k_scale.expand(descale_shape)