feat: support fa cute in sgl-kernel (#10205)
Co-authored-by: cicirori <32845984+cicirori@users.noreply.github.com>
This commit is contained in:
@@ -9,6 +9,11 @@ try:
|
||||
except:
|
||||
raise ImportError("Can not import sgl_kernel. Please check your installation.")
|
||||
|
||||
try:
|
||||
from ._fa4_interface import flash_attn_varlen_func as flash_attn_varlen_func_v4
|
||||
except ImportError:
|
||||
flash_attn_varlen_func_v4 = None
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def is_fa3_supported(device=None) -> bool:
|
||||
@@ -61,6 +66,7 @@ def flash_attn_with_kvcache(
|
||||
sm_margin=0, # Can be tuned if some SMs are used for communication
|
||||
return_softmax_lse=False,
|
||||
sinks=None,
|
||||
ver=3,
|
||||
):
|
||||
"""
|
||||
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
|
||||
@@ -147,6 +153,9 @@ def flash_attn_with_kvcache(
|
||||
logsumexp of each row of the matrix QK^T * scaling (e.g., log of the softmax
|
||||
normalization factor).
|
||||
"""
|
||||
if ver == 4:
|
||||
raise NotImplementedError("haven't implemented flash_attn_with_kvcache for fa4")
|
||||
|
||||
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
||||
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
||||
if softmax_scale is None:
|
||||
@@ -237,7 +246,40 @@ def flash_attn_varlen_func(
|
||||
sm_margin=0,
|
||||
return_softmax_lse=False,
|
||||
sinks=None,
|
||||
ver=3,
|
||||
):
|
||||
if ver == 4:
|
||||
assert (
|
||||
flash_attn_varlen_func_v4 is not None
|
||||
), "FA4 is not available, please check your installation."
|
||||
# Using `(-1, -1)` as no sliding window causes correctness issues for FA4.
|
||||
if window_size == (-1, -1):
|
||||
window_size = (None, None)
|
||||
return flash_attn_varlen_func_v4(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens_q,
|
||||
cu_seqlens_k,
|
||||
# max_seqlen_q,
|
||||
# max_seqlen_k,
|
||||
seqused_q=seqused_q,
|
||||
seqused_k=seqused_k,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
# qv=qv,
|
||||
# q_descale=q_descale,
|
||||
# k_descale=k_descale,
|
||||
# v_descale=v_descale,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
# num_splits=num_splits,
|
||||
pack_gqa=pack_gqa,
|
||||
# sm_margin=sm_margin,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
learnable_sink=sinks,
|
||||
)
|
||||
|
||||
if not is_fa3_supported():
|
||||
raise NotImplementedError(
|
||||
"flash_attn at sgl-kernel is only supported on sm90 and above"
|
||||
|
||||
Reference in New Issue
Block a user