[1/2] Support FA4 for MHA Prefill in sgl-kernel (#10940)
This commit is contained in:
@@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "sgl-kernel"
|
name = "sgl-kernel"
|
||||||
version = "0.3.12"
|
version = "0.3.13"
|
||||||
description = "Kernel Library for SGLang"
|
description = "Kernel Library for SGLang"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "sgl-kernel"
|
name = "sgl-kernel"
|
||||||
version = "0.3.12"
|
version = "0.3.13"
|
||||||
description = "Kernel Library for SGLang"
|
description = "Kernel Library for SGLang"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"
|
|||||||
|
|
||||||
[project]
|
[project]
|
||||||
name = "sgl-kernel"
|
name = "sgl-kernel"
|
||||||
version = "0.3.12"
|
version = "0.3.13"
|
||||||
description = "Kernel Library for SGLang"
|
description = "Kernel Library for SGLang"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@@ -153,7 +153,43 @@ def flash_attn_with_kvcache(
|
|||||||
normalization factor).
|
normalization factor).
|
||||||
"""
|
"""
|
||||||
if ver == 4:
|
if ver == 4:
|
||||||
raise NotImplementedError("haven't implemented flash_attn_with_kvcache for fa4")
|
assert (
|
||||||
|
flash_attn_varlen_func_v4 is not None
|
||||||
|
), "FA4 is not available, please check your installation."
|
||||||
|
# Using `(-1, -1)` as no sliding window causes correctness issues for FA4.
|
||||||
|
assert (
|
||||||
|
k is None and v is None
|
||||||
|
), "FA4 does not support updating KV cache in-place."
|
||||||
|
assert (
|
||||||
|
rotary_cos is None
|
||||||
|
and rotary_sin is None
|
||||||
|
and rotary_interleaved is None
|
||||||
|
and rotary_seqlens is None
|
||||||
|
), "FA4 does not support rotary embedding."
|
||||||
|
assert (
|
||||||
|
cache_batch_idx is None and cache_leftpad is None
|
||||||
|
), "FA4 does not support non-consecutive batch indices or left padding."
|
||||||
|
assert (
|
||||||
|
q_descale is None and k_descale is None and v_descale is None
|
||||||
|
), "FA4 does not support descale."
|
||||||
|
|
||||||
|
if window_size == (-1, -1):
|
||||||
|
window_size = (None, None)
|
||||||
|
return flash_attn_varlen_func_v4(
|
||||||
|
q=q,
|
||||||
|
k=k_cache,
|
||||||
|
v=v_cache,
|
||||||
|
cu_seqlens_q=cu_seqlens_q,
|
||||||
|
seqused_k=cache_seqlens,
|
||||||
|
softmax_scale=softmax_scale,
|
||||||
|
causal=causal,
|
||||||
|
window_size=window_size,
|
||||||
|
softcap=softcap,
|
||||||
|
pack_gqa=pack_gqa,
|
||||||
|
return_softmax_lse=return_softmax_lse,
|
||||||
|
learnable_sink=sinks,
|
||||||
|
page_table=page_table,
|
||||||
|
)
|
||||||
|
|
||||||
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension"
|
||||||
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension"
|
||||||
|
|||||||
@@ -1 +1 @@
|
|||||||
__version__ = "0.3.12"
|
__version__ = "0.3.13"
|
||||||
|
|||||||
Reference in New Issue
Block a user