[Auto Sync] Update flashattention_backend.py (20250922) (#10762)
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com> Co-authored-by: Gordon Gustafson <ggustafson@together.ai>
This commit is contained in:
@@ -692,8 +692,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|||||||
k_descale, v_descale = None, None
|
k_descale, v_descale = None, None
|
||||||
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
# only use kv scaling if: 1) fp8 kv is explicitly enabled, 2) RadixAttention
|
||||||
# has corresponding quantization method so that layer.k_scale is not None,
|
# has corresponding quantization method so that layer.k_scale is not None,
|
||||||
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case.
|
# 3) layer.head_dim <= 256 since fa3 kernel require fp16 and bf16 data type in this case,
|
||||||
if self.kv_cache_dtype_str != "auto" and layer.head_dim <= 256:
|
# 4) fa_impl_ver != 4 since fa4 does not currently support fp8 queries and keys.
|
||||||
|
if (
|
||||||
|
self.kv_cache_dtype_str != "auto"
|
||||||
|
and layer.head_dim <= 256
|
||||||
|
and self.fa_impl_ver != 4
|
||||||
|
):
|
||||||
if layer.k_scale is not None:
|
if layer.k_scale is not None:
|
||||||
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
descale_shape = (forward_batch.batch_size, layer.tp_k_head_num)
|
||||||
k_descale = layer.k_scale.expand(descale_shape)
|
k_descale = layer.k_scale.expand(descale_shape)
|
||||||
|
|||||||
Reference in New Issue
Block a user