diff --git a/sgl-kernel/pyproject.toml b/sgl-kernel/pyproject.toml index bcbb7498f..469ae5a90 100644 --- a/sgl-kernel/pyproject.toml +++ b/sgl-kernel/pyproject.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "sgl-kernel" -version = "0.3.12" +version = "0.3.13" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.10" diff --git a/sgl-kernel/pyproject_cpu.toml b/sgl-kernel/pyproject_cpu.toml index 2e44837a3..ec5085361 100644 --- a/sgl-kernel/pyproject_cpu.toml +++ b/sgl-kernel/pyproject_cpu.toml @@ -8,7 +8,7 @@ build-backend = "scikit_build_core.build" [project] name = "sgl-kernel" -version = "0.3.12" +version = "0.3.13" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.10" diff --git a/sgl-kernel/pyproject_rocm.toml b/sgl-kernel/pyproject_rocm.toml index c58ec3f69..a43b849b0 100644 --- a/sgl-kernel/pyproject_rocm.toml +++ b/sgl-kernel/pyproject_rocm.toml @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta" [project] name = "sgl-kernel" -version = "0.3.12" +version = "0.3.13" description = "Kernel Library for SGLang" readme = "README.md" requires-python = ">=3.10" diff --git a/sgl-kernel/python/sgl_kernel/flash_attn.py b/sgl-kernel/python/sgl_kernel/flash_attn.py index f6b87c311..d2f401b5f 100644 --- a/sgl-kernel/python/sgl_kernel/flash_attn.py +++ b/sgl-kernel/python/sgl_kernel/flash_attn.py @@ -153,7 +153,43 @@ def flash_attn_with_kvcache( normalization factor). """ if ver == 4: - raise NotImplementedError("haven't implemented flash_attn_with_kvcache for fa4") + assert ( + flash_attn_varlen_func_v4 is not None + ), "FA4 is not available, please check your installation." + # Using `(-1, -1)` as no sliding window causes correctness issues for FA4. + assert ( + k is None and v is None + ), "FA4 does not support updating KV cache in-place." + assert ( + rotary_cos is None + and rotary_sin is None + and rotary_interleaved is None + and rotary_seqlens is None + ), "FA4 does not support rotary embedding." + assert ( + cache_batch_idx is None and cache_leftpad is None + ), "FA4 does not support non-consecutive batch indices or left padding." + assert ( + q_descale is None and k_descale is None and v_descale is None + ), "FA4 does not support descale." + + if window_size == (-1, -1): + window_size = (None, None) + return flash_attn_varlen_func_v4( + q=q, + k=k_cache, + v=v_cache, + cu_seqlens_q=cu_seqlens_q, + seqused_k=cache_seqlens, + softmax_scale=softmax_scale, + causal=causal, + window_size=window_size, + softcap=softcap, + pack_gqa=pack_gqa, + return_softmax_lse=return_softmax_lse, + learnable_sink=sinks, + page_table=page_table, + ) assert k_cache.stride(-1) == 1, "k_cache must have contiguous last dimension" assert v_cache.stride(-1) == 1, "v_cache must have contiguous last dimension" diff --git a/sgl-kernel/python/sgl_kernel/version.py b/sgl-kernel/python/sgl_kernel/version.py index df0ed3321..8a3be2e00 100644 --- a/sgl-kernel/python/sgl_kernel/version.py +++ b/sgl-kernel/python/sgl_kernel/version.py @@ -1 +1 @@ -__version__ = "0.3.12" +__version__ = "0.3.13"