diff --git a/sgl-kernel/csrc/flash_extension.cc b/sgl-kernel/csrc/flash_extension.cc index c4fbe0092..f80db673f 100644 --- a/sgl-kernel/csrc/flash_extension.cc +++ b/sgl-kernel/csrc/flash_extension.cc @@ -55,7 +55,8 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { " Tensor? scheduler_metadata," " int num_splits," " bool? pack_gqa," - " int sm_margin) -> Tensor[]"); + " int sm_margin," + " Tensor? sinks) -> Tensor[]"); m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); } diff --git a/sgl-kernel/include/sgl_flash_kernel_ops.h b/sgl-kernel/include/sgl_flash_kernel_ops.h index c406fa9f3..383e207c3 100644 --- a/sgl-kernel/include/sgl_flash_kernel_ops.h +++ b/sgl-kernel/include/sgl_flash_kernel_ops.h @@ -82,4 +82,5 @@ std::vector mha_fwd( std::optional& scheduler_metadata_, // (b + 1) int num_splits, std::optional pack_gqa_, - int const sm_margin); + int const sm_margin, + std::optional& sinks_); diff --git a/sgl-kernel/python/sgl_kernel/flash_attn.py b/sgl-kernel/python/sgl_kernel/flash_attn.py index fbf0b0d3f..36951325e 100644 --- a/sgl-kernel/python/sgl_kernel/flash_attn.py +++ b/sgl-kernel/python/sgl_kernel/flash_attn.py @@ -58,6 +58,7 @@ def flash_attn_with_kvcache( pack_gqa=None, # Can be tuned for speed sm_margin=0, # Can be tuned if some SMs are used for communication return_softmax_lse=False, + sinks=None, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -205,6 +206,7 @@ def flash_attn_with_kvcache( num_splits, pack_gqa, sm_margin, + sinks, ) # return (out, softmax_lse) if return_softmax_lse else out return (out, softmax_lse, *rest) if return_softmax_lse else out @@ -232,6 +234,7 @@ def flash_attn_varlen_func( pack_gqa=None, sm_margin=0, return_softmax_lse=False, + sinks=None, ): if not is_fa3_supported(): raise NotImplementedError( @@ -277,6 +280,7 @@ def flash_attn_varlen_func( num_splits=num_splits, pack_gqa=pack_gqa, sm_margin=sm_margin, + sinks=sinks, ) return (out, softmax_lse, *rest) if return_softmax_lse else out diff --git a/sgl-kernel/tests/test_flash_attention.py b/sgl-kernel/tests/test_flash_attention.py index 0c7e854b9..0900e5940 100644 --- a/sgl-kernel/tests/test_flash_attention.py +++ b/sgl-kernel/tests/test_flash_attention.py @@ -1,7 +1,7 @@ # Adapted from https://github.com/Dao-AILab/flash-attention/blob/main/hopper/test_flash_attn.py import itertools import math -import os +from typing import Optional import pytest import torch @@ -45,12 +45,12 @@ DISABLE_BACKWARD = True # or torch.cuda.get_device_capability("cuda")[0] < 9 # ) -DISABLE_SPLIT = True +DISABLE_SPLIT = False DISABLE_PAGEDKV = True -DISABLE_APPENDKV = True -DISABLE_LOCAL = True +DISABLE_APPENDKV = False +DISABLE_LOCAL = False DISABLE_SOFTCAP = True -DISABLE_PACKGQA = True +DISABLE_PACKGQA = False DISABLE_FP16 = True DISABLE_FP8 = True @@ -199,6 +199,7 @@ def attention_ref( v_descale=None, window_size=(-1, -1), # -1 means infinite window size sink_token_length=0, + sinks: Optional[torch.Tensor] = None, softcap=0.0, upcast=True, reorder_ops=False, @@ -271,7 +272,18 @@ def attention_ref( scores.masked_fill_(local_mask, float("-inf")) if attn_bias is not None: scores = scores + attn_bias - attention = torch.softmax(scores, dim=-1).to(v.dtype) + if sinks is None: + attention = torch.softmax(scores, dim=-1).to(v.dtype) + else: + scores_fp32 = scores.to(torch.float32) + logits_max = torch.amax(scores_fp32, dim=-1, keepdim=True) + sinks = rearrange(sinks, "h -> h 1 1") + logits_or_sinks_max = torch.maximum(sinks, logits_max) + unnormalized_scores = torch.exp(scores_fp32 - logits_or_sinks_max) + normalizer = unnormalized_scores.sum(dim=-1, keepdim=True) + torch.exp( + sinks - logits_or_sinks_max + ) + attention = (unnormalized_scores / normalizer).to(v.dtype) # We want to mask here so that the attention matrix doesn't have any NaNs # Otherwise we'll get NaN in dV if query_padding_mask is not None: @@ -459,8 +471,10 @@ def generate_qkv( ) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_sink", [False, True]) +# @pytest.mark.parametrize("has_sink", [False]) @pytest.mark.parametrize("new_kv", [False] + ([True] if not DISABLE_APPENDKV else [])) # @pytest.mark.parametrize("new_kv", [True]) # @pytest.mark.parametrize( @@ -540,6 +554,7 @@ def test_flash_attn_kvcache( new_kv, mha_type, dtype, + has_sink, ): from sgl_kernel.flash_attn import flash_attn_with_kvcache @@ -565,6 +580,12 @@ def test_flash_attn_kvcache( assert nheads % nheads_k == 0 dtype_ref = torch.bfloat16 if dtype == torch.float8_e4m3fn else dtype dv_vals = [128, d] if d > 128 and d <= 192 else ([256, 512, d] if d <= 64 else [d]) + + if has_sink: + sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + sinks = None + if dtype == torch.float8_e4m3fn or not is_hopper(): # for fp8 and ampere arch, we not support v head dim != qk head dim dv_vals = [d] @@ -820,6 +841,7 @@ def test_flash_attn_kvcache( qv=qv, window_size=window_size, key_leftpad=cache_leftpad, + sinks=sinks, ) out_pt, _ = attention_ref( q_ro, @@ -834,6 +856,7 @@ def test_flash_attn_kvcache( reorder_ops=True, key_leftpad=cache_leftpad, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + sinks=sinks, ) q = q.to(dtype) q_unpad = q_unpad.to(dtype) if varlen_q else None @@ -888,6 +911,7 @@ def test_flash_attn_kvcache( scheduler_metadata=scheduler_metadata, num_splits=num_splits, return_softmax_lse=True, + sinks=sinks, ) if varlen_q: out = output_pad_fn(out) @@ -1019,8 +1043,10 @@ def _generate_block_kvcache( ) # @pytest.mark.parametrize("dtype", [torch.bfloat16]) # @pytest.mark.parametrize("dtype", [torch.float8_e4m3fn]) -# @pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) -@pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("mha_type", ["mha", "mqa", "gqa"]) +# @pytest.mark.parametrize("mha_type", ["mha"]) +@pytest.mark.parametrize("has_sink", [False, True]) +# @pytest.mark.parametrize("has_sink", [False]) # @pytest.mark.parametrize("has_qv", [False, True]) @pytest.mark.parametrize("has_qv", [False]) # @pytest.mark.parametrize("deterministic", [False, True]) @@ -1078,6 +1104,7 @@ def test_flash_attn_varlen_output( has_qv, mha_type, dtype, + has_sink, ): from sgl_kernel.flash_attn import flash_attn_varlen_func @@ -1131,6 +1158,12 @@ def test_flash_attn_varlen_output( qv_ref = None # Put window_size after QKV randn so that window_size changes from test to test window_size = (-1, -1) if not local else torch.randint(0, seqlen_k, (2,)) + + if has_sink: + sinks = torch.randn(nheads, dtype=torch.bfloat16, device=device) + else: + sinks = None + if dtype == torch.float8_e4m3fn: q_descale, k_descale, v_descale = [ torch.rand(batch_size, nheads_kv, device=device, dtype=torch.float32) @@ -1209,6 +1242,7 @@ def test_flash_attn_varlen_output( v_descale=v_descale, window_size=window_size, softcap=softcap, + sinks=sinks, ) out_pt, attn_pt = attention_ref( q_ref, @@ -1226,6 +1260,7 @@ def test_flash_attn_varlen_output( upcast=False, reorder_ops=True, intermediate_dtype=dtype if dtype == torch.float8_e4m3fn else None, + sinks=sinks, ) print(f"Pytorch max diff: {(out_pt - out_ref).abs().max().item()}") @@ -1258,6 +1293,7 @@ def test_flash_attn_varlen_output( window_size=window_size, softcap=softcap, return_softmax_lse=True, + sinks=sinks, ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: