From 0753ef831eebd8d261ad136718aaa4bb62a65e33 Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Mon, 22 Sep 2025 16:41:42 -0700 Subject: [PATCH] [Auto Sync] Update flashattention_backend.py (20250922) (#10762) Co-authored-by: github-actions[bot] Co-authored-by: Gordon Gustafson --- .../srt/layers/attention/flashattention_backend.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index be7fed8de..67cad8d23 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -692,8 +692,13 @@ class FlashAttentionBackend(AttentionBackend): k_descale, v_descale = None, None # only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention # has corresponding quantization method so that layer.k_scale is not None, - # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case. - if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256: + # 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case, + # 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys. + if ( + self.kv_cache_dtype_str != "auto" + and layer.head_dim <= 256 + and self.fa_impl_ver != 4 + ): if layer.k_scale is not None: descale_shape = (forward_batch.batch_size, layer.tp_k_head_num) k_descale = layer.k_scale.expand(descale_shape)