diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index 71c3def5f..f86e9a751 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -666,6 +666,13 @@ def _set_envs_and_config(server_args: ServerArgs): if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0": os.environ["TRTLLM_ENABLE_PDL"] = "1" + if os.environ.get("CUTE_DSL_LOG_LEVEL") is None: + # Default to warning level, to avoid too many logs + os.environ["CUTE_DSL_LOG_LEVEL"] = "30" + if os.environ.get("CUTE_DSL_LOG_TO_CONSOLE") is None: + # Need to set log to console, otherwise the log level won't take effect + os.environ["CUTE_DSL_LOG_TO_CONSOLE"] = "1" + # Can also be passed as argument os.environ["SGLANG_RUN_ID"] = ( f"sglang-run-{time.time()}-{random.randint(0, 100000000)}" diff --git a/python/sglang/srt/layers/attention/flashattention_backend.py b/python/sglang/srt/layers/attention/flashattention_backend.py index f7ca5e203..0ccde7c81 100644 --- a/python/sglang/srt/layers/attention/flashattention_backend.py +++ b/python/sglang/srt/layers/attention/flashattention_backend.py @@ -305,6 +305,7 @@ class FlashAttentionBackend(AttentionBackend): speculative_step_id=0, topk=0, speculative_num_steps=0, + fa_impl_ver=3, ): super().__init__() @@ -338,6 +339,8 @@ class FlashAttentionBackend(AttentionBackend): ) self.speculative_step_id = speculative_step_id + self.fa_impl_ver = fa_impl_ver + # Local attention settings self.attention_chunk_size = ( model_runner.attention_chunk_size @@ -712,6 +715,8 @@ class FlashAttentionBackend(AttentionBackend): # For fa3 interface version compatibility, we put new fields into conditional keyword args kwargs = {} + if self.fa_impl_ver != 3: + kwargs["ver"] = self.fa_impl_ver if sinks is not None: kwargs["sinks"] = sinks @@ -738,6 +743,7 @@ class FlashAttentionBackend(AttentionBackend): # Use Flash Attention for prefill if not self.use_mla: + assert self.fa_impl_ver in [3], "Only FA3 support here" # Do multi-head attention key_cache, value_cache = forward_batch.token_to_kv_pool.get_kv_buffer( layer.layer_id @@ -830,6 +836,7 @@ class FlashAttentionBackend(AttentionBackend): softmax_scale=layer.scaling, causal=False, return_softmax_lse=True, + **kwargs, ) else: # MHA for extend part of sequence without attending prefix kv cache @@ -844,6 +851,7 @@ class FlashAttentionBackend(AttentionBackend): softmax_scale=layer.scaling, causal=True, return_softmax_lse=forward_batch.mha_return_lse, + **kwargs, ) if forward_batch.mha_return_lse: output, lse, *rest = output @@ -851,6 +859,7 @@ class FlashAttentionBackend(AttentionBackend): return output, lse return output else: + assert self.fa_impl_ver in [3], "Only FA3 support here" # Do absorbed multi-latent attention kv_cache = forward_batch.token_to_kv_pool.get_key_buffer( layer.layer_id @@ -939,6 +948,7 @@ class FlashAttentionBackend(AttentionBackend): k_rope: Optional[torch.Tensor] = None, sinks: Optional[torch.Tensor] = None, ) -> torch.Tensor: + assert self.fa_impl_ver in [3], "Only FA3 support decoding" if k is not None: assert v is not None if save_kv_cache: @@ -985,6 +995,8 @@ class FlashAttentionBackend(AttentionBackend): # For fa3 interface version compatibility, we put new fields into conditional keyword args kwargs = {} + if self.fa_impl_ver != 3: + kwargs["ver"] = self.fa_impl_ver if sinks is not None: kwargs["sinks"] = sinks diff --git a/python/sglang/srt/layers/attention/hybrid_attn_backend.py b/python/sglang/srt/layers/attention/hybrid_attn_backend.py index 580a977ec..ec40100d1 100644 --- a/python/sglang/srt/layers/attention/hybrid_attn_backend.py +++ b/python/sglang/srt/layers/attention/hybrid_attn_backend.py @@ -21,6 +21,7 @@ class HybridAttnBackend(AttentionBackend): self.model_runner = model_runner self.prefill_backend = prefill_backend self.decode_backend = decode_backend + self.data_type = model_runner.kv_cache_dtype def _select_backend(self, forward_mode: ForwardMode) -> AttentionBackend: """ diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 75e493475..a21f392e8 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -516,6 +516,7 @@ class ModelRunner: "aiter", "flashinfer", "fa3", + "fa4", "triton", "flashmla", "cutlass_mla", @@ -1800,6 +1801,15 @@ class ModelRunner: ) return FlashAttentionBackend(self) + elif backend_str == "fa4": + assert ( + self.use_mla_backend + ), "FlashAttention v4 Support is at an early stage, only MLA model supported now" + from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionBackend, + ) + + return FlashAttentionBackend(self, fa_impl_ver=4) elif backend_str == "cutlass_mla": from sglang.srt.layers.attention.cutlass_mla_backend import ( CutlassMLABackend, diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index c46655f56..f905851b6 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -1124,6 +1124,9 @@ class DeepseekV2AttentionMLA(nn.Module): return AttnForwardMethod.MHA_CHUNKED_KV else: return _dispatch_mla_subtype() + elif attention_backend == "fa4": + # TODO(cicirori): use FA4 MHA for DeepSeekV3 for now + return AttnForwardMethod.MHA_CHUNKED_KV elif attention_backend == "trtllm_mla": original_mode = getattr(forward_batch, "_original_forward_mode", None) if ( diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 60d7f296a..32853b386 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -96,6 +96,7 @@ ATTENTION_BACKEND_CHOICES = [ # NVIDIA specific "cutlass_mla", "fa3", + "fa4", "flashinfer", "flashmla", "trtllm_mla", diff --git a/sgl-kernel/python/sgl_kernel/_fa4_interface.py b/sgl-kernel/python/sgl_kernel/_fa4_interface.py index 512b0aaef..684b4b25e 100644 --- a/sgl-kernel/python/sgl_kernel/_fa4_interface.py +++ b/sgl-kernel/python/sgl_kernel/_fa4_interface.py @@ -4,9 +4,15 @@ # [2025-07-04] Version in Cute-DSL, for Hopper and Blackwell. You'd need to install nvidia-cutlass-dsl==4.1.0. +import copy +import gc +import logging import math from typing import Optional, Tuple +logger = logging.getLogger(__name__) + + import cuda.bindings.driver as cuda import cutlass import cutlass.cute as cute @@ -20,6 +26,22 @@ def maybe_contiguous(x): return x.contiguous() if x is not None and x.stride(-1) != 1 else x +def _reason_recompile(compile_key, jit_func): + compile_cache = jit_func.compile_cache + compile_key_map = jit_func.compile_key_map + if not compile_cache: + return "not compiled yet" + for k, v in compile_cache.items(): + if k == compile_key: + continue + if len(k) != len(compile_key): + continue + for i in range(len(k)): + if k[i] != compile_key[i]: + return f"diff at '{compile_key_map[i]}': {k[i]} vs {compile_key[i]} " + return "unknown reason" + + torch2cute_dtype_map = { torch.float16: cutlass.Float16, torch.bfloat16: cutlass.BFloat16, @@ -254,6 +276,9 @@ def _flash_attn_fwd( compute_capability, ) if compile_key not in _flash_attn_fwd.compile_cache: + logger.info( + f"Compiling FA4 kernel with reason: {_reason_recompile(compile_key, _flash_attn_fwd)}" + ) if compute_capability == 9: assert page_table is None, "paged KV not supported on SM 9.0" # fa_fwd = FlashAttentionForwardSm80( @@ -335,8 +360,85 @@ def _flash_attn_fwd( _flash_attn_fwd.compile_cache = {} +_flash_attn_fwd.compile_key_map = [ + "dtype", + "head_dim", + "head_dim_v", + "qhead_per_kvhead", + "causal", + "softcap is not None", + "lse is None", + "cu_seqlens_q is None", + "cu_seqlens_k is None", + "seqused_q is None", + "seqused_k is None", + "page_table is not None", + "window_size_left is not None", + "window_size_right is not None", + "learnable_sink is not None", + "m_block_size", + "n_block_size", + "num_threads", + "pack_gqa", + "compute_capability", +] +def warmup_flash_attn(f): + """ + Decorator for flash_attn_varlen_func: + - On the first call, run several warmup passes with different flag combinations + - Warmups are executed sequentially to minimize peak GPU memory usage + - Does not modify user-provided tensors (clones data) + - Easy to extend with more compile-key dimensions + """ + done = False + + def _clone_args(args, kwargs): + """Clone tensor arguments to avoid sharing storage; deepcopy for others.""" + + def maybe_clone(x): + if isinstance(x, torch.Tensor): + return x.clone() + return copy.deepcopy(x) + + return tuple(maybe_clone(a) for a in args), { + k: maybe_clone(v) for k, v in kwargs.items() + } + + def _run_warmups(args, kwargs): + """Run warmup calls sequentially and release memory after each.""" + base_args, base_kwargs = _clone_args(args, kwargs) + + # Warmup combinations for return_softmax_lse and causal + combos = [ + dict(return_softmax_lse=False, causal=False), + dict(return_softmax_lse=False, causal=True), + dict(return_softmax_lse=True, causal=False), + dict(return_softmax_lse=True, causal=True), + ] + + for combo in combos: + wa, wk = _clone_args(base_args, base_kwargs) + wk.update(combo) + with torch.cuda.stream(torch.cuda.current_stream()): + f(*wa, **wk) + del wa, wk + torch.cuda.empty_cache() + gc.collect() + + def wrapper(*args, **kwargs): + nonlocal done + if not done: + logger.info("Running flash_attn_varlen_func warmup passes...") + _run_warmups(args, kwargs) + done = True + return f(*args, **kwargs) + + return wrapper + + +@warmup_flash_attn def flash_attn_varlen_func( q: torch.Tensor, k: torch.Tensor,