feat: support fa cute in sgl-kernel (#10205)

Co-authored-by: cicirori <32845984+cicirori@users.noreply.github.com>
This commit is contained in:
Yineng Zhang
2025-09-09 00:14:39 -07:00
committed by GitHub
parent d1d4074c4e
commit 94fb4e9e54
5 changed files with 1315 additions and 0 deletions

View File

@@ -9,6 +9,11 @@ try:
except:
raise ImportError("Can not import sgl_kernel. Please check your installation.")
try:
from ._fa4_interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
except ImportError:
flash_attn_varlen_func_v4 = None
@lru_cache(maxsize=1)
def is_fa3_supported(device=None) -> bool:
@@ -61,6 +66,7 @@ def flash_attn_with_kvcache(
sm_margin=0, # Can be tuned if some SMs are used for communication
return_softmax_lse=False,
sinks=None,
ver=3,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
@@ -147,6 +153,9 @@ def flash_attn_with_kvcache(
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
normalization factor).
"""
if ver == 4:
raise NotImplementedError("haven't implemented flash_attn_with_kvcache for fa4")
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
if softmax_scale is None:
@@ -237,7 +246,40 @@ def flash_attn_varlen_func(
sm_margin=0,
return_softmax_lse=False,
sinks=None,
ver=3,
):
if ver == 4:
assert (
flash_attn_varlen_func_v4 is not None
), "FA4 is not available, please check your installation."
# Using `(-1, -1)` as no sliding window causes correctness issues for FA4.
if window_size == (-1, -1):
window_size = (None, None)
return flash_attn_varlen_func_v4(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
# max_seqlen_q,
# max_seqlen_k,
seqused_q=seqused_q,
seqused_k=seqused_k,
softmax_scale=softmax_scale,
causal=causal,
# qv=qv,
# q_descale=q_descale,
# k_descale=k_descale,
# v_descale=v_descale,
window_size=window_size,
softcap=softcap,
# num_splits=num_splits,
pack_gqa=pack_gqa,
# sm_margin=sm_margin,
return_softmax_lse=return_softmax_lse,
learnable_sink=sinks,
)
if not is_fa3_supported():
raise NotImplementedError(
"flash_attn at sgl-kernel is only supported on sm90 and above"