[Auto Sync] Update flashattention_backend.py (20250922) (#10762)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Gordon Gustafson <ggustafson@together.ai>
This commit is contained in:
@@ -692,8 +692,13 @@ class FlashAttentionBackend(AttentionBackend):
|
||||
k_descale, v_descale = None, None
|
||||
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
||||
# has corresponding quantization method so that layer.k_scale is not None,
|
||||
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
|
||||
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
|
||||
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
|
||||
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
|
||||
if (
|
||||
self.kv_cache_dtype_str != "auto"
|
||||
and layer.head_dim <= 256
|
||||
and self.fa_impl_ver != 4
|
||||
):
|
||||
if layer.k_scale is not None:
|
||||
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||
k_descale = layer.k_scale.expand(descale_shape)
|
||||
|
||||
Reference in New Issue
Block a user