diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e09c911ec..4dfb2660c 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -1071,6 +1071,16 @@ class ServerArgs: self.enable_mixed_chunk = False self.disable_radix_cache = True + if self.attention_backend == "fa4" or self.decode_attention_backend == "fa4": + raise ValueError( + "FA4 backend is only supported for prefill. Please use `--prefill-attention-backend fa4` instead." + ) + if self.prefill_attention_backend == "fa4": + logger.warning( + f"FA4 backend only supports page size 128, changing page_size from {self.page_size} to 128." + ) + self.page_size = 128 + def _handle_page_size(self): if self.page_size is None: self.page_size = 1 diff --git a/python/sglang/test/test_utils.py b/python/sglang/test/test_utils.py index a38df4962..8fd557688 100644 --- a/python/sglang/test/test_utils.py +++ b/python/sglang/test/test_utils.py @@ -129,6 +129,11 @@ def is_in_amd_ci(): return get_bool_env_var("SGLANG_IS_IN_CI_AMD") +def is_blackwell_system(): + """Return whether it is running on a Blackwell (B200) system.""" + return get_bool_env_var("IS_BLACKWELL") + + def _use_cached_default_models(model_repo: str): cache_dir = os.getenv("DEFAULT_MODEL_CACHE_DIR") if cache_dir and model_repo: @@ -151,6 +156,9 @@ DEFAULT_URL_FOR_TEST = f"http://127.0.0.1:{DEFAULT_PORT_FOR_SRT_TEST_RUNNER + 10 if is_in_amd_ci(): DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 3000 +if is_blackwell_system(): + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH = 3000 + def call_generate_lightllm(prompt, temperature, max_tokens, stop=None, url=None): assert url is not None diff --git a/sgl-kernel/CMakeLists.txt b/sgl-kernel/CMakeLists.txt index 7c4b61171..85e247452 100644 --- a/sgl-kernel/CMakeLists.txt +++ b/sgl-kernel/CMakeLists.txt @@ -91,7 +91,7 @@ FetchContent_Populate(repo-flashinfer) FetchContent_Declare( repo-flash-attention GIT_REPOSITORY https://github.com/sgl-project/sgl-attn - GIT_TAG f9af0c2a1d82ab1812e6987e9338363cc2bf0f8d + GIT_TAG ff87110aad048bb8c4e6effea4c563ddae88b0eb GIT_SHALLOW OFF ) FetchContent_Populate(repo-flash-attention) @@ -100,7 +100,7 @@ FetchContent_Populate(repo-flash-attention) FetchContent_Declare( repo-flash-attention-origin GIT_REPOSITORY https://github.com/Dao-AILab/flash-attention.git - GIT_TAG 203b9b3dba39d5d08dffb49c09aa622984dff07d + GIT_TAG 04adaf0e9028d4bec7073f69e4dfa3f6d3357189 GIT_SHALLOW OFF ) FetchContent_Populate(repo-flash-attention-origin) diff --git a/sgl-kernel/csrc/flash_extension.cc b/sgl-kernel/csrc/flash_extension.cc index f80db673f..df6024dfa 100644 --- a/sgl-kernel/csrc/flash_extension.cc +++ b/sgl-kernel/csrc/flash_extension.cc @@ -23,40 +23,43 @@ TORCH_LIBRARY_FRAGMENT(sgl_kernel, m) { * From flash-attention */ m.def( - "fwd(Tensor! q," - " Tensor k," - " Tensor v," - " Tensor? k_new," - " Tensor? v_new," - " Tensor? q_v," - " Tensor!? out," - " Tensor? cu_seqlens_q," - " Tensor? cu_seqlens_k," - " Tensor? cu_seqlens_k_new," - " Tensor? seqused_q," - " Tensor? seqused_k," + "fwd(Tensor q," // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + " Tensor k," // (b_k, s_k, h_k, d) or (total_k, h_k, d) or paged + " Tensor v," // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) or paged + " Tensor? k_new," // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) + " Tensor? v_new," // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) + " Tensor? q_v," // (b, s_q, h, dv) or (total_q_new, h, dv) + " Tensor? out," // (b, s_q, h, dv) or (total_q, h, dv) + " Tensor? cu_seqlens_q," // b+1 + " Tensor? cu_seqlens_k," // b+1 + " Tensor? cu_seqlens_k_new," // b+1 + " Tensor? seqused_q," // b + " Tensor? seqused_k," // b " int? max_seqlen_q," - " int? max_seqlen_k," - " Tensor? page_table," - " Tensor? kv_batch_idx," - " Tensor? leftpad_k," - " Tensor? rotary_cos," - " Tensor? rotary_sin," - " Tensor? seqlens_rotary," - " Tensor? q_descale," - " Tensor? k_descale," - " Tensor? v_descale," - " float softmax_scale," + " int? max_seqlen_k," // TODO: check if needed + " Tensor? page_table," // (b_k, max_num_pages_per_seq) + " Tensor? kv_batch_idx," // b + " Tensor? leftpad_k," // b + " Tensor? rotary_cos," // seqlen_ro x (rotary_dim / 2) + " Tensor? rotary_sin," // seqlen_ro x (rotary_dim / 2) + " Tensor? seqlens_rotary," // b + " Tensor? q_descale," // (b, h_k) + " Tensor? k_descale," // (b, h_k) + " Tensor? v_descale," // (b, h_k) + " float? softmax_scale," // now optional " bool is_causal," " int window_size_left," " int window_size_right," - " float softcap," + " int attention_chunk," // NEW + " float softcap," // promoted to double in C++; schema float is fine " bool is_rotary_interleaved," - " Tensor? scheduler_metadata," + " Tensor? scheduler_metadata," // (b + 1) " int num_splits," " bool? pack_gqa," " int sm_margin," - " Tensor? sinks) -> Tensor[]"); + " Tensor? sinks" + ") -> (Tensor, Tensor, Tensor, Tensor)"); // NEW return type: tuple of 4 tensors + m.impl("fwd", torch::kCUDA, make_pytorch_shim(&mha_fwd)); } diff --git a/sgl-kernel/include/sgl_flash_kernel_ops.h b/sgl-kernel/include/sgl_flash_kernel_ops.h index 383e207c3..b36af6b69 100644 --- a/sgl-kernel/include/sgl_flash_kernel_ops.h +++ b/sgl-kernel/include/sgl_flash_kernel_ops.h @@ -42,45 +42,44 @@ limitations under the License. /* * From flash-attention */ -std::vector mha_fwd( - at::Tensor& q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q - const at::Tensor& k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, - // h_k, d) if there is page_table. - const at::Tensor& v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, - // page_size, h_k, dv) if there is page_table. - std::optional& - k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new - std::optional& - v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new - std::optional& q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q - std::optional& out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q - std::optional& cu_seqlens_q_, // b+1 - std::optional& cu_seqlens_k_, // b+1 - std::optional& cu_seqlens_k_new_, // b+1 - std::optional& +std::tuple mha_fwd( + at::Tensor q, // (b, s_q, h, d) or (total_q, h, d) if there is cu_seqlens_q + at::Tensor k, // (b_k, s_k, h_k, d) or (total_k, h_k, d) if there is cu_seqlens_k or (num_pages, page_size, + // h_k, d) if there is page_table. + at::Tensor v, // (b_k, s_k, h_k, dv) or (total_k, h_k, dv) if there is cu_seqlens_k or (num_pages, + // page_size, h_k, dv) if there is page_table. + std::optional k_new_, // (b, s_k_new, h_k, d) or (total_k_new, h_k, d) if there is cu_seqlens_k_new + std::optional v_new_, // (b, s_k_new, h_k, dv) or (total_k_new, h_k, dv) if there is cu_seqlens_k_new + std::optional q_v_, // (b, s_q, h, dv) or (total_q_new, h, dv) if there is cu_seqlens_q + std::optional out_, // (b, s_q, h, dv) or (total_q, h, dv) if there is cu_seqlens_q + std::optional cu_seqlens_q_, // b+1 + std::optional cu_seqlens_k_, // b+1 + std::optional cu_seqlens_k_new_, // b+1 + std::optional seqused_q_, // b. If given, only this many elements of each batch element's queries and outputs are used. - std::optional& + std::optional seqused_k_, // b. If given, only this many elements of each batch element's keys are used. - std::optional max_seqlen_q_, + std::optional max_seqlen_q_, // TODO: check if we need max_seqlen_k - std::optional max_seqlen_k_, - std::optional& page_table_, // (b_k, max_num_pages_per_seq) - std::optional& kv_batch_idx_, // b. indices to index into the KV cache - std::optional& leftpad_k_, // b - std::optional& rotary_cos_, // seqlen_ro x (rotary_dim / 2) - std::optional& rotary_sin_, // seqlen_ro x (rotary_dim / 2) - std::optional& seqlens_rotary_, // b - std::optional& q_descale_, // (b, h_k), not (b, h) - std::optional& k_descale_, // (b, h_k) - std::optional& v_descale_, // (b, h_k) - float const softmax_scale, + std::optional max_seqlen_k_, + std::optional page_table_, // (b_k, max_num_pages_per_seq) + std::optional kv_batch_idx_, // b. indices to index into the KV cache + std::optional leftpad_k_, // b + std::optional rotary_cos_, // seqlen_ro x (rotary_dim / 2) + std::optional rotary_sin_, // seqlen_ro x (rotary_dim / 2) + std::optional seqlens_rotary_, // b + std::optional q_descale_, // (b, h_k), not (b, h) + std::optional k_descale_, // (b, h_k) + std::optional v_descale_, // (b, h_k) + std::optional softmax_scale_, bool is_causal, - int window_size_left, - int window_size_right, - float const softcap, - bool const is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 - std::optional& scheduler_metadata_, // (b + 1) - int num_splits, + int64_t window_size_left, + int64_t window_size_right, + int64_t attention_chunk, + double softcap, + bool is_rotary_interleaved, // if true, rotary combines indices 0 & 1, else indices 0 & rotary_dim / 2 + std::optional scheduler_metadata_, // (b + 1) + int64_t num_splits, std::optional pack_gqa_, - int const sm_margin, - std::optional& sinks_); + int64_t sm_margin, + std::optional& sinks_); // (h) diff --git a/sgl-kernel/python/sgl_kernel/_fa4_interface.py b/sgl-kernel/python/sgl_kernel/_fa4_interface.py index 684b4b25e..da1adeec4 100644 --- a/sgl-kernel/python/sgl_kernel/_fa4_interface.py +++ b/sgl-kernel/python/sgl_kernel/_fa4_interface.py @@ -1,14 +1,14 @@ -# Adapted from https://github.com/Dao-AILab/flash-attention/blob/203b9b3dba39d5d08dffb49c09aa622984dff07d/flash_attn/cute/interface.py +# Adapted from https://github.com/Dao-AILab/flash-attention/blob/54d8aa6751fc9d5f0357854079261913d5df1f9d/flash_attn/cute/interface.py # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. -# [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0. +# [2025-10-14] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.2.1. import copy import gc import logging import math -from typing import Optional, Tuple +from typing import Callable, Optional, Tuple logger = logging.getLogger(__name__) @@ -18,6 +18,7 @@ import cutlass import cutlass.cute as cute import torch from cutlass.cute.runtime import from_dlpack +from flash_attn.cute import utils from flash_attn.cute.flash_fwd import FlashAttentionForwardSm90 from flash_attn.cute.flash_fwd_sm100 import FlashAttentionForwardSm100 @@ -26,22 +27,6 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x -def _reason_recompile(compile_key, jit_func): - compile_cache = jit_func.compile_cache - compile_key_map = jit_func.compile_key_map - if not compile_cache: - return "not compiled yet" - for k, v in compile_cache.items(): - if k == compile_key: - continue - if len(k) != len(compile_key): - continue - for i in range(len(k)): - if k[i] != compile_key[i]: - return f"diff at '{compile_key_map[i]}': {k[i]} vs {compile_key[i]} " - return "unknown reason" - - torch2cute_dtype_map = { torch.float16: cutlass.Float16, torch.bfloat16: cutlass.BFloat16, @@ -72,7 +57,11 @@ def _flash_attn_fwd( num_threads: int = 384, pack_gqa: Optional[bool] = None, _compute_capability: Optional[int] = None, - return_softmax_lse: Optional[bool] = False, + score_mod: Callable | None = None, + return_lse: bool = False, + out: Optional[torch.Tensor] = None, + lse: Optional[torch.Tensor] = None, + buffers: Optional[list[torch.Tensor]] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: q, k, v = [maybe_contiguous(t) for t in (q, k, v)] num_head, head_dim = q.shape[-2:] @@ -169,23 +158,51 @@ def _flash_attn_fwd( q_batch_seqlen_shape = ( (batch_size, seqlen_q) if cu_seqlens_q is None else (total_q,) ) - out = torch.empty( - *q_batch_seqlen_shape, - num_head, - head_dim_v, - dtype=out_torch_dtype, - device=device, - ) lse_shape = ( (batch_size, num_head, seqlen_q) if cu_seqlens_q is None else (num_head, total_q) ) - lse = ( - torch.empty(lse_shape, dtype=torch.float32, device=device) - if return_softmax_lse - else None - ) + requires_grad = q.requires_grad or k.requires_grad or v.requires_grad + + if out is None: + out = torch.empty( + *q_batch_seqlen_shape, + num_head, + head_dim_v, + dtype=out_torch_dtype, + device=device, + ) + else: + expected_out_shape = (*q_batch_seqlen_shape, num_head, head_dim_v) + assert ( + out.shape == expected_out_shape + ), f"out tensor shape {out.shape} does not match expected shape {expected_out_shape}" + assert ( + out.dtype == out_torch_dtype + ), f"out tensor dtype {out.dtype} does not match expected dtype {out_torch_dtype}" + assert ( + out.device == device + ), f"out tensor device {out.device} does not match input device {device}" + assert out.is_cuda, "out tensor must be on CUDA device" + + if lse is None: + lse = ( + torch.empty(lse_shape, dtype=torch.float32, device=device) + if requires_grad or return_lse + else None + ) + elif lse is not None: + assert ( + lse.shape == lse_shape + ), f"lse tensor shape {lse.shape} does not match expected shape {lse_shape}" + assert ( + lse.dtype == torch.float32 + ), f"lse tensor dtype {lse.dtype} does not match expected dtype torch.float32" + assert ( + lse.device == device + ), f"lse tensor device {lse.device} does not match input device {device}" + assert lse.is_cuda, "lse tensor must be on CUDA device" dtype = torch2cute_dtype_map[q.dtype] q_tensor, k_tensor, v_tensor, o_tensor = [ @@ -242,6 +259,7 @@ def _flash_attn_fwd( current_stream = cuda.CUstream(torch.cuda.current_stream().cuda_stream) if compute_capability == 9: # TODO: tune block size according to hdim + # Perf heuristic from upstream: hdim=128, noncausal, non-local benefits from larger n_block if head_dim == head_dim_v == 128 and not causal and not local: n_block_size = 192 if compute_capability == 10: @@ -253,13 +271,34 @@ def _flash_attn_fwd( ): pack_gqa = False + if softcap is not None: + assert score_mod is None, "softcap and score_mod cannot be used together" + score_mod = utils.create_softcap_scoremod(softcap) + + if score_mod is not None: + is_varlen = ( + cu_seqlens_q is not None + or cu_seqlens_k is not None + or seqused_q is not None + or seqused_k is not None + ) + if is_varlen: + raise NotImplementedError( + "score_mod with buffers is not yet supported for varlen sequences. This will be fixed in a future PR." + ) + + cute_buffers = None + if buffers is not None: + cute_buffers = [from_dlpack(buf) for buf in buffers] + compile_key = ( dtype, head_dim, head_dim_v, qhead_per_kvhead, causal, - softcap is not None, + utils.hash_callable(score_mod) if score_mod is not None else None, + buffers is not None, lse is None, cu_seqlens_q is None, cu_seqlens_k is None, @@ -276,9 +315,6 @@ def _flash_attn_fwd( compute_capability, ) if compile_key not in _flash_attn_fwd.compile_cache: - logger.info( - f"Compiling FA4 kernel with reason: {_reason_recompile(compile_key, _flash_attn_fwd)}" - ) if compute_capability == 9: assert page_table is None, "paged KV not supported on SM 9.0" # fa_fwd = FlashAttentionForwardSm80( @@ -290,12 +326,14 @@ def _flash_attn_fwd( is_causal=causal, is_local=local, pack_gqa=pack_gqa, - m_block_size=m_block_size, - n_block_size=n_block_size, + tile_m=m_block_size, + tile_n=n_block_size, # num_stages=1, num_stages=2, num_threads=num_threads, Q_in_regs=False, + score_mod=score_mod, + has_buffers=buffers is not None, ) elif compute_capability == 10: assert page_size in [ @@ -313,12 +351,15 @@ def _flash_attn_fwd( and not local and cu_seqlens_q is None and seqused_q is None, + score_mod=score_mod, + has_buffers=buffers is not None, ) else: raise ValueError( f"Unsupported compute capability: {compute_capability}. Supported: 9.x, 10.x" ) # TODO: check @can_implement + # TODO caching for buffers; cute_buffers _flash_attn_fwd.compile_cache[compile_key] = cute.compile( fa_fwd, q_tensor, @@ -333,10 +374,10 @@ def _flash_attn_fwd( seqused_q_tensor, seqused_k_tensor, page_table_tensor, - softcap, window_size_left, window_size_right, learnable_sink_tensor, + cute_buffers, ) _flash_attn_fwd.compile_cache[compile_key]( q_tensor, @@ -351,46 +392,29 @@ def _flash_attn_fwd( seqused_q_tensor, seqused_k_tensor, page_table_tensor, - softcap, window_size_left, window_size_right, learnable_sink_tensor, + cute_buffers, ) return out, lse _flash_attn_fwd.compile_cache = {} -_flash_attn_fwd.compile_key_map = [ - "dtype", - "head_dim", - "head_dim_v", - "qhead_per_kvhead", - "causal", - "softcap is not None", - "lse is None", - "cu_seqlens_q is None", - "cu_seqlens_k is None", - "seqused_q is None", - "seqused_k is None", - "page_table is not None", - "window_size_left is not None", - "window_size_right is not None", - "learnable_sink is not None", - "m_block_size", - "n_block_size", - "num_threads", - "pack_gqa", - "compute_capability", -] def warmup_flash_attn(f): """ Decorator for flash_attn_varlen_func: - - On the first call, run several warmup passes with different flag combinations - - Warmups are executed sequentially to minimize peak GPU memory usage - - Does not modify user-provided tensors (clones data) - - Easy to extend with more compile-key dimensions + - On first call, run several warmup passes with different flag combinations: + * return_softmax_lse in {False, True} + * global noncausal (window_size=(None,None)) + * causal (window_size=(None,0)) + * local sliding window (window_size=(64,64)) + * optionally pack_gqa=True if qheads > kvheads and allowed + - No score_mod / softcap (not supported for varlen yet) + - Executes sequentially to minimize peak GPU mem + - Does not modify user tensors (clones) """ done = False @@ -399,30 +423,78 @@ def warmup_flash_attn(f): def maybe_clone(x): if isinstance(x, torch.Tensor): - return x.clone() + return x.detach().clone() # detach to avoid autograd edges return copy.deepcopy(x) return tuple(maybe_clone(a) for a in args), { k: maybe_clone(v) for k, v in kwargs.items() } + def _infer_heads(args, kwargs): + """Infer q and kv head counts from arguments.""" + # Expect signature: (q, k, v, cu_seqlens_q, cu_seqlens_k, ...) + q = args[0] if len(args) > 0 else kwargs.get("q") + k = args[1] if len(args) > 1 else kwargs.get("k") + try: + qh = int(q.shape[-2]) + kvh = int(k.shape[-2]) + return qh, kvh + except Exception: + return None, None + def _run_warmups(args, kwargs): """Run warmup calls sequentially and release memory after each.""" base_args, base_kwargs = _clone_args(args, kwargs) - # Warmup combinations for return_softmax_lse and causal - combos = [ - dict(return_softmax_lse=False, causal=False), - dict(return_softmax_lse=False, causal=True), - dict(return_softmax_lse=True, causal=False), - dict(return_softmax_lse=True, causal=True), + qh, kvh = _infer_heads(base_args, base_kwargs) + can_pack_gqa = ( + qh is not None and kvh is not None and qh % kvh == 0 and qh // kvh > 1 + ) + has_page_table = ( + "page_table" in base_kwargs and base_kwargs["page_table"] is not None + ) + + # Window presets covering global, causal, and local + window_presets = [ + (None, None), # global noncausal + (None, 0), # causal + (64, 64), # local sliding window ] + lse_flags = [False, True] + + # Base combo list + combos = [] + for ws in window_presets: + for return_lse_flag in lse_flags: + combos.append(dict(window_size=ws, return_softmax_lse=return_lse_flag)) + + # Optionally add a pack_gqa=True variant (FA4 may disable it internally for some varlen shapes/SMs) + if can_pack_gqa: + for ws in window_presets: + combos.append( + dict(window_size=ws, return_softmax_lse=False, pack_gqa=True) + ) + + # If page_table is present, warm one combo with it (page_table in compile key for SM100) + if has_page_table: + combos.append(dict(window_size=(None, None), return_softmax_lse=False)) + + # Run sequentially for combo in combos: wa, wk = _clone_args(base_args, base_kwargs) + # Keep user-provided softcap/score_mod OUT (varlen+score_mod unsupported) + wk.pop("score_mod", None) + if "softcap" in wk and wk["softcap"]: + wk["softcap"] = 0.0 + # Apply combo wk.update(combo) with torch.cuda.stream(torch.cuda.current_stream()): - f(*wa, **wk) + try: + f(*wa, **wk) + except Exception as e: + # Some combos can be invalid for specific head dims / arch. Ignore and continue. + logger.debug("Warmup combo skipped: %s", e) del wa, wk torch.cuda.empty_cache() gc.collect() @@ -430,7 +502,9 @@ def warmup_flash_attn(f): def wrapper(*args, **kwargs): nonlocal done if not done: - logger.info("Running flash_attn_varlen_func warmup passes...") + logger.info( + "Running FA4 warmup (global/causal/local, LSE on/off, optional GQA pack)..." + ) _run_warmups(args, kwargs) done = True return f(*args, **kwargs) @@ -472,7 +546,7 @@ def flash_attn_varlen_func( learnable_sink=learnable_sink, softcap=softcap, pack_gqa=pack_gqa, - return_softmax_lse=return_softmax_lse, + return_lse=return_softmax_lse, ) return (out, lse) if return_softmax_lse else out diff --git a/sgl-kernel/python/sgl_kernel/flash_attn.py b/sgl-kernel/python/sgl_kernel/flash_attn.py index f2f6d895f..c3ffbc540 100644 --- a/sgl-kernel/python/sgl_kernel/flash_attn.py +++ b/sgl-kernel/python/sgl_kernel/flash_attn.py @@ -45,7 +45,7 @@ def flash_attn_with_kvcache( qv=None, rotary_cos=None, rotary_sin=None, - cache_seqlens: Optional[Union[(int, torch.Tensor)]] = None, + cache_seqlens: Optional[Union[int, torch.Tensor]] = None, cache_batch_idx: Optional[torch.Tensor] = None, cache_leftpad: Optional[torch.Tensor] = None, page_table: Optional[torch.Tensor] = None, @@ -59,6 +59,7 @@ def flash_attn_with_kvcache( softmax_scale=None, causal=False, window_size=(-1, -1), # -1 means infinite context window + attention_chunk: Optional[int] = None, softcap=0.0, # 0.0 means deactivated rotary_interleaved=True, scheduler_metadata=None, @@ -137,6 +138,7 @@ def flash_attn_with_kvcache( Default to 1 / sqrt(headdim). causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). window_size: (left, right). If not (-1, -1), implements sliding window local attention. + attention_chunk: Optional[int]. If not None, splits the query into chunks of this size to save memory. softcap: float. Anything > 0 activates softcapping attention. rotary_interleaved: bool. Only applicable if rotary_cos and rotary_sin are passed in. If True, rotary embedding will combine dimensions 0 & 1, 2 & 3, etc. If False, @@ -216,6 +218,7 @@ def flash_attn_with_kvcache( ] rotary_cos, rotary_sin = [maybe_contiguous(x) for x in (rotary_cos, rotary_sin)] rotary_seqlens = maybe_contiguous(rotary_seqlens) + attention_chunk = 0 if attention_chunk is None else int(attention_chunk) out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( q, @@ -245,6 +248,7 @@ def flash_attn_with_kvcache( causal, window_size[0], window_size[1], + attention_chunk, softcap, rotary_interleaved, scheduler_metadata, @@ -263,10 +267,11 @@ def flash_attn_varlen_func( v, cu_seqlens_q, cu_seqlens_k, - max_seqlen_q, - max_seqlen_k, + max_seqlen_q=None, + max_seqlen_k=None, seqused_q=None, seqused_k=None, + page_table=None, softmax_scale=None, causal=False, qv=None, @@ -274,6 +279,7 @@ def flash_attn_varlen_func( k_descale=None, v_descale=None, window_size=(-1, -1), + attention_chunk=0, softcap=0.0, num_splits=1, pack_gqa=None, @@ -293,25 +299,18 @@ def flash_attn_varlen_func( q, k, v, - cu_seqlens_q, - cu_seqlens_k, - # max_seqlen_q, - # max_seqlen_k, + cu_seqlens_q=cu_seqlens_q, + cu_seqlens_k=cu_seqlens_k, seqused_q=seqused_q, seqused_k=seqused_k, + page_table=page_table, softmax_scale=softmax_scale, causal=causal, - # qv=qv, - # q_descale=q_descale, - # k_descale=k_descale, - # v_descale=v_descale, window_size=window_size, softcap=softcap, - # num_splits=num_splits, pack_gqa=pack_gqa, - # sm_margin=sm_margin, - return_softmax_lse=return_softmax_lse, learnable_sink=sinks, + return_softmax_lse=return_softmax_lse, ) if not is_fa3_supported(): @@ -319,10 +318,15 @@ def flash_attn_varlen_func( "flash_attn at sgl-kernel is only supported on sm90 and above" ) + # FA3 requires max_seqlen_q and max_seqlen_k + if max_seqlen_q is None or max_seqlen_k is None: + raise ValueError("max_seqlen_q and max_seqlen_k are required for FA3") + if softmax_scale is None: softmax_scale = (q.shape[-1] + (qv.shape[-1] if qv is not None else 0)) ** ( -0.5 ) + attention_chunk = 0 if attention_chunk is None else int(attention_chunk) out, softmax_lse, *rest = torch.ops.sgl_kernel.fwd.default( q, @@ -352,6 +356,7 @@ def flash_attn_varlen_func( causal, window_size[0], window_size[1], + attention_chunk, softcap, is_rotary_interleaved=False, scheduler_metadata=None, diff --git a/sgl-kernel/tests/test_flash_attention_4.py b/sgl-kernel/tests/test_flash_attention_4.py index 30e2134de..2296d71aa 100644 --- a/sgl-kernel/tests/test_flash_attention_4.py +++ b/sgl-kernel/tests/test_flash_attention_4.py @@ -1,4 +1,4 @@ -# Adapted from https://github.com/Dao-AILab/flash-attention/blob/b31ae1e4cd22cf5f820a2995b74b7cd3bd54355a/tests/cute/test_flash_attn.py +# Adapted from https://github.com/Dao-AILab/flash-attention/blob/8ecf128f683266735ba68e3c106ff67a2611886e/tests/cute/test_flash_attn.py # Copyright (c) 2025, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao. @@ -10,12 +10,25 @@ import pytest import torch import torch.nn.functional as F from einops import rearrange, repeat -from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache -from utils import is_hopper +try: + from flash_attn.layers.rotary import apply_rotary_emb +except ImportError: + apply_rotary_emb = None + +from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache +from sgl_kernel.testing.rotary_embedding import _apply_rotary_emb as apply_rotary_emb + +# from utils import is_hopper # Not used in this test + +# Force sgl_kernel.flash_attn wrappers to use FA4 (Cute-DSL) implementations. +# The wrappers accept a superset of args; for FA4, extra args are ignored. flash_attn_varlen_func = partial(flash_attn_varlen_func, ver=4) flash_attn_with_kvcache = partial(flash_attn_with_kvcache, ver=4) +# Skip this test on Hopper machine +skip_condition = torch.cuda.get_device_capability() < (10, 0) + def unpad_input(hidden_states, attention_mask, unused_mask=None): """ @@ -88,6 +101,11 @@ def generate_random_padding_mask( lengths = torch.randint( max_seqlen // 3, max_seqlen + 1, (batch_size, 1), device=device ) + else: + # This should never happen due to the assertion above, but for linter + lengths = torch.full( + (batch_size, 1), max_seqlen, device=device, dtype=torch.int32 + ) if zero_lengths: # Generate zero-lengths every 5 batches and the last batch. @@ -482,8 +500,7 @@ def attention_ref( @pytest.mark.skipif( - is_hopper(), - reason="skip on hopper", + skip_condition, reason="FA4 Requires compute capability of 10 or above." ) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @@ -497,8 +514,8 @@ def attention_ref( @pytest.mark.parametrize("deterministic", [False]) # @pytest.mark.parametrize("softcap", [0.0, 15.0]) @pytest.mark.parametrize("softcap", [0.0]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) @pytest.mark.parametrize("causal", [False, True]) # @pytest.mark.parametrize("causal", [False]) # @pytest.mark.parametrize("add_unused_qkv", [False, True]) @@ -522,11 +539,11 @@ def attention_ref( (64, 128), (128, 128), (256, 256), - (113, 203), - (128, 217), - (113, 211), - (108, 256), - (256, 512), + # (113, 203), + # (128, 217), + # (113, 211), + # (108, 256), + # (256, 512), (307, 256), (640, 128), (512, 256), @@ -658,25 +675,7 @@ def test_flash_attn_varlen_output( if causal or local: key_padding_mask = query_padding_mask - ( - q_unpad, - k_unpad, - v_unpad, - qv_unpad, - cu_seqlens_q, - cu_seqlens_k, - seqused_q, - seqused_k, - max_seqlen_q, - max_seqlen_k, - q, - k, - v, - qv, - output_pad_fn, - dq_pad_fn, - dk_pad_fn, - ) = generate_qkv( + result = generate_qkv( q, k, v, @@ -687,6 +686,25 @@ def test_flash_attn_varlen_output( query_unused_mask=query_unused_mask, key_unused_mask=key_unused_mask, ) + ( + q_unpad, # 0 + k_unpad, # 1 + v_unpad, # 2 + qv_unpad, # 3 + cu_seqlens_q, # 4 + cu_seqlens_k, # 5 + seqused_q, # 6 + seqused_k, # 7 + max_seqlen_q, # 8 + max_seqlen_k, # 9 + q, # 10 + k, # 11 + v, # 12 + qv, # 13 + output_pad_fn, # 14 + dq_pad_fn, # 15 + dk_pad_fn, # 16 + ) = result q_unpad, k_unpad, v_unpad = [ x.detach().to(dtype).requires_grad_() for x in (q_unpad, k_unpad, v_unpad) ] @@ -746,20 +764,16 @@ def test_flash_attn_varlen_output( v_unpad, cu_seqlens_q=cu_seqlens_q, cu_seqlens_k=cu_seqlens_k, - max_seqlen_q=None, - max_seqlen_k=None, - # seqused_q=seqused_q, - # seqused_k=seqused_k, + # max_seqlen_q and max_seqlen_k not needed for FA4 + seqused_q=seqused_q, + seqused_k=seqused_k, causal=causal, - # qv=qv_unpad, - # q_descale=q_descale, - # k_descale=k_descale, v_descale=v_descale, window_size=window_size, - # attention_chunk=attention_chunk, - sinks=learnable_sink, softcap=softcap, + sinks=learnable_sink, # FA4 uses learnable_sink, not sinks pack_gqa=pack_gqa, return_softmax_lse=True, + ver=4, # Use FA4 ) out = output_pad_fn(out_unpad) if query_unused_mask is not None: @@ -875,8 +889,7 @@ def test_flash_attn_varlen_output( @pytest.mark.skipif( - is_hopper(), - reason="skip on hopper", + skip_condition, reason="FA4 Requires compute capability of 10 or above." ) # @pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float8_e4m3fn]) @pytest.mark.parametrize("dtype", [torch.bfloat16]) @@ -887,8 +900,8 @@ def test_flash_attn_varlen_output( # @pytest.mark.parametrize("has_learnable_sink", [False]) # @pytest.mark.parametrize("new_kv", [False, True]) @pytest.mark.parametrize("new_kv", [False]) -@pytest.mark.parametrize("local", [False, True]) -# @pytest.mark.parametrize("local", [False]) +# @pytest.mark.parametrize("local", [False, True]) +@pytest.mark.parametrize("local", [False]) # @pytest.mark.parametrize("causal", [False, True]) @pytest.mark.parametrize("causal", [True]) # @pytest.mark.parametrize("seqlen_new_eq_seqlen_q", [True, False]) @@ -900,8 +913,8 @@ def test_flash_attn_varlen_output( # @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("page_size", [None] + ([1, 4, 128])) -@pytest.mark.parametrize("page_size", [None, 128]) -# @pytest.mark.parametrize("page_size", [128]) +# @pytest.mark.parametrize("page_size", [None, 128]) +@pytest.mark.parametrize("page_size", [128]) # @pytest.mark.parametrize("has_leftpad", [False, True]) @pytest.mark.parametrize("has_leftpad", [False]) # @pytest.mark.parametrize("has_batch_idx", [False, True]) @@ -1085,6 +1098,7 @@ def test_flash_attn_kvcache( .to(dtype_ref) ) page_table = None + num_blocks = None else: ( k_cache, @@ -1301,31 +1315,24 @@ def test_flash_attn_kvcache( else: k_cache_paged.copy_(k_cache_saved) v_cache_paged.copy_(v_cache_saved) - # out, lse, *rest = flash_attn_with_kvcache( - out, lse, *rest = flash_attn_with_kvcache( + # For FA4, use flash_attn_varlen_func directly instead of flash_attn_with_kvcache + # This matches the pattern from the original FA4 test + out, lse = flash_attn_varlen_func( q if not varlen_q else q_unpad, k_cache if page_size is None else k_cache_paged, v_cache if page_size is None else v_cache_paged, - # k if not new_kv or not varlen_q else k_unpad, - # v if not new_kv or not varlen_q else v_unpad, - # qv=qv if not varlen_q else qv_unpad, - # rotary_cos=cos, - # rotary_sin=sin, - cache_seqlens=cache_seqlens, - # cache_batch_idx=cache_batch_idx, - # cache_leftpad=cache_leftpad, - page_table=page_table, cu_seqlens_q=cu_seqlens_q, - # cu_seqlens_k_new=cu_seqlens_k_new, - # rotary_seqlens=rotary_seqlens, + cu_seqlens_k=None, # FA4 doesn't use cu_seqlens_k for KV cache + # max_seqlen_q and max_seqlen_k not needed for FA4 + seqused_k=cache_seqlens, # Use cache_seqlens as seqused_k + page_table=page_table, causal=causal, window_size=window_size, - sinks=learnable_sink, - # attention_chunk=attention_chunk, - # rotary_interleaved=rotary_interleaved, - # scheduler_metadata=scheduler_metadata, - # num_splits=num_splits, + sinks=learnable_sink, # FA4 uses learnable_sink, not sinks + softcap=0.0, + pack_gqa=None, return_softmax_lse=True, + ver=4, # Use FA4 ) if varlen_q: out = output_pad_fn(out) diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index e6b1419d7..e14faec2b 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -169,6 +169,7 @@ suites = { TestFile("test_disaggregation_pp.py", 140), ], "per-commit-4-gpu-b200": [ + # TestFile("test_flash_attention_4.py"), # TestFile("test_gpt_oss_4gpu.py", 600), # TestFile("test_deepseek_v3_fp4_4gpu.py", 3600), ], diff --git a/test/srt/test_flash_attention_4.py b/test/srt/test_flash_attention_4.py new file mode 100644 index 000000000..8bba922b6 --- /dev/null +++ b/test/srt/test_flash_attention_4.py @@ -0,0 +1,56 @@ +import unittest +from types import SimpleNamespace + +from sglang.srt.environ import envs +from sglang.srt.utils import get_device_sm, kill_process_tree +from sglang.test.few_shot_gsm8k import run_eval as run_eval_few_shot_gsm8k +from sglang.test.test_utils import ( + DEFAULT_MODEL_NAME_FOR_TEST, + DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + DEFAULT_URL_FOR_TEST, + popen_launch_server, +) + + +@unittest.skipIf(get_device_sm() < 100, "Test requires CUDA SM 100 or higher") +class TestFlashAttention4(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model = DEFAULT_MODEL_NAME_FOR_TEST + cls.base_url = DEFAULT_URL_FOR_TEST + other_args = [ + "--trust-remote-code", + "--mem-fraction-static", + "0.8", + "--prefill-attention-backend", + "fa4", + ] + cls.process = popen_launch_server( + cls.model, + cls.base_url, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + other_args=other_args, + ) + + @classmethod + def tearDownClass(cls): + kill_process_tree(cls.process.pid) + + def test_gsm8k(self): + args = SimpleNamespace( + num_shots=4, + data_path=None, + num_questions=100, + max_new_tokens=512, + parallel=128, + host="http://127.0.0.1", + port=int(self.base_url.split(":")[-1]), + ) + metrics = run_eval_few_shot_gsm8k(args) + print(metrics) + + self.assertGreater(metrics["accuracy"], 0.65) + + +if __name__ == "__main__": + unittest.main()