[1/2] Support FA4 for MHA Prefill in sgl-kernel (#10940)
This commit is contained in:
@@ -153,7 +153,43 @@ def flash_attn_with_kvcache(
|
||||
normalization factor).
|
||||
"""
|
||||
if ver == 4:
|
||||
raise NotImplementedError("haven't implemented flash_attn_with_kvcache for fa4")
|
||||
assert (
|
||||
flash_attn_varlen_func_v4 is not None
|
||||
), "FA4 is not available, please check your installation."
|
||||
# Using `(-1, -1)` as no sliding window causes correctness issues for FA4.
|
||||
assert (
|
||||
k is None and v is None
|
||||
), "FA4 does not support updating KV cache in-place."
|
||||
assert (
|
||||
rotary_cos is None
|
||||
and rotary_sin is None
|
||||
and rotary_interleaved is None
|
||||
and rotary_seqlens is None
|
||||
), "FA4 does not support rotary embedding."
|
||||
assert (
|
||||
cache_batch_idx is None and cache_leftpad is None
|
||||
), "FA4 does not support non-consecutive batch indices or left padding."
|
||||
assert (
|
||||
q_descale is None and k_descale is None and v_descale is None
|
||||
), "FA4 does not support descale."
|
||||
|
||||
if window_size == (-1, -1):
|
||||
window_size = (None, None)
|
||||
return flash_attn_varlen_func_v4(
|
||||
q=q,
|
||||
k=k_cache,
|
||||
v=v_cache,
|
||||
cu_seqlens_q=cu_seqlens_q,
|
||||
seqused_k=cache_seqlens,
|
||||
softmax_scale=softmax_scale,
|
||||
causal=causal,
|
||||
window_size=window_size,
|
||||
softcap=softcap,
|
||||
pack_gqa=pack_gqa,
|
||||
return_softmax_lse=return_softmax_lse,
|
||||
learnable_sink=sinks,
|
||||
page_table=page_table,
|
||||
)
|
||||
|
||||
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
||||
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
||||
|
||||
@@ -1 +1 @@
|
||||
__version__ = "0.3.12"
|
||||
__version__ = "0.3.13"
|
||||
|
||||
Reference in New Issue
Block a user