diff --git a/docs/advanced_features/server_arguments.md b/docs/advanced_features/server_arguments.md index 5960cfb2d..af526ee88 100644 --- a/docs/advanced_features/server_arguments.md +++ b/docs/advanced_features/server_arguments.md @@ -228,6 +228,8 @@ Please consult the documentation below and [server_args.py](https://github.com/s | `--sampling-backend` | Choose the kernels for sampling layers. | None | | `--grammar-backend` | Choose the backend for grammar-guided decoding. | None | | `--mm-attention-backend` | Set multimodal attention backend. | None | +| `--nsa-prefill-backend` | Prefill attention implementation for nsa backend. | `flashmla_sparse` | +| `--nsa-decode-backend` | Decode attention implementation for nsa backend. | `flashmla_kv` | ## Speculative decoding diff --git a/python/sglang/srt/layers/attention/nsa_backend.py b/python/sglang/srt/layers/attention/nsa_backend.py index 66b32b2c6..7da15cc47 100644 --- a/python/sglang/srt/layers/attention/nsa_backend.py +++ b/python/sglang/srt/layers/attention/nsa_backend.py @@ -140,9 +140,7 @@ def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor: ) -_NSA_IMPL_T: TypeAlias = Literal[ - "flashmla_prefill", "flashmla_decode", "fa3", "tilelang" -] +_NSA_IMPL_T: TypeAlias = Literal["flashmla_sparse", "flashmla_kv", "fa3", "tilelang"] NSA_PREFILL_IMPL: _NSA_IMPL_T NSA_DECODE_IMPL: _NSA_IMPL_T @@ -181,8 +179,8 @@ class NativeSparseAttnBackend(AttentionBackend): self.req_to_token = model_runner.req_to_token_pool.req_to_token global NSA_PREFILL_IMPL, NSA_DECODE_IMPL - NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill - NSA_DECODE_IMPL = model_runner.server_args.nsa_decode + NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill_backend + NSA_DECODE_IMPL = model_runner.server_args.nsa_decode_backend self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32) @@ -336,7 +334,7 @@ class NativeSparseAttnBackend(AttentionBackend): cache_seqlens=nsa_cache_seqlens_int32, seq_len_q=1, ) - if NSA_DECODE_IMPL == "flashmla_decode" + if NSA_DECODE_IMPL == "flashmla_kv" else None ), nsa_cache_seqlens_int32=nsa_cache_seqlens_int32, @@ -383,7 +381,7 @@ class NativeSparseAttnBackend(AttentionBackend): ), seq_len_q=1, ) - if NSA_DECODE_IMPL == "flashmla_decode" + if NSA_DECODE_IMPL == "flashmla_kv" else None ), } @@ -421,7 +419,7 @@ class NativeSparseAttnBackend(AttentionBackend): seqlens_expanded = cache_seqlens_int32 nsa_extend_seq_lens_list = [1] * num_tokens - if NSA_DECODE_IMPL == "flashmla_decode": + if NSA_DECODE_IMPL == "flashmla_kv": flashmla_metadata = self.decode_cuda_graph_metadata[ "flashmla_metadata" ].slice(slice(0, num_tokens + 1)) @@ -478,7 +476,7 @@ class NativeSparseAttnBackend(AttentionBackend): ) nsa_extend_seq_lens_list = [1] * bs * self.speculative_num_draft_tokens - if NSA_DECODE_IMPL == "flashmla_decode": + if NSA_DECODE_IMPL == "flashmla_kv": flashmla_metadata = self.decode_cuda_graph_metadata[ "flashmla_metadata" ].slice(slice(0, bs * self.speculative_num_draft_tokens + 1)) @@ -534,7 +532,7 @@ class NativeSparseAttnBackend(AttentionBackend): ) nsa_extend_seq_lens_list = [1] * bs - if NSA_DECODE_IMPL == "flashmla_decode": + if NSA_DECODE_IMPL == "flashmla_kv": flashmla_metadata = self.decode_cuda_graph_metadata[ "flashmla_metadata" ].slice(slice(0, bs * self.speculative_num_draft_tokens + 1)) @@ -712,7 +710,7 @@ class NativeSparseAttnBackend(AttentionBackend): else: assert metadata.real_page_table is metadata.page_table_1 - if NSA_DECODE_IMPL == "flashmla_decode": + if NSA_DECODE_IMPL == "flashmla_kv": flashmla_metadata = metadata.flashmla_metadata.slice( slice(0, seqlens_expanded_size + 1) ) @@ -803,20 +801,20 @@ class NativeSparseAttnBackend(AttentionBackend): sm_scale=layer.scaling, v_head_dim=layer.v_head_dim, ) - elif NSA_PREFILL_IMPL == "flashmla_prefill": + elif NSA_PREFILL_IMPL == "flashmla_sparse": if q_rope is not None: q_all = torch.cat([q_nope, q_rope], dim=-1) - return self._forward_flashmla_prefill( + return self._forward_flashmla_sparse( q_all=q_all, kv_cache=kv_cache, page_table_1=page_table_1, sm_scale=layer.scaling, v_head_dim=layer.v_head_dim, ) - elif NSA_PREFILL_IMPL == "flashmla_decode": + elif NSA_PREFILL_IMPL == "flashmla_kv": if q_rope is not None: q_all = torch.cat([q_nope, q_rope], dim=-1) - return self._forward_flashmla_decode( + return self._forward_flashmla_kv( q_all=q_all, kv_cache=kv_cache, sm_scale=layer.scaling, @@ -897,20 +895,20 @@ class NativeSparseAttnBackend(AttentionBackend): page_size=1, ) - if NSA_DECODE_IMPL == "flashmla_prefill": + if NSA_DECODE_IMPL == "flashmla_sparse": if q_rope is not None: q_all = torch.cat([q_nope, q_rope], dim=-1) - return self._forward_flashmla_prefill( + return self._forward_flashmla_sparse( q_all=q_all, kv_cache=kv_cache, page_table_1=page_table_1, sm_scale=layer.scaling, v_head_dim=layer.v_head_dim, ) - elif NSA_DECODE_IMPL == "flashmla_decode": + elif NSA_DECODE_IMPL == "flashmla_kv": if q_rope is not None: q_all = torch.cat([q_nope, q_rope], dim=-1) - return self._forward_flashmla_decode( + return self._forward_flashmla_kv( q_all=q_all, kv_cache=kv_cache, sm_scale=layer.scaling, @@ -998,7 +996,7 @@ class NativeSparseAttnBackend(AttentionBackend): ) return o # type: ignore - def _forward_flashmla_prefill( + def _forward_flashmla_sparse( self, q_all: torch.Tensor, kv_cache: torch.Tensor, @@ -1017,7 +1015,7 @@ class NativeSparseAttnBackend(AttentionBackend): ) return o - def _forward_flashmla_decode( + def _forward_flashmla_kv( self, q_all: torch.Tensor, kv_cache: torch.Tensor, diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index ff5e58dc2..a2076a203 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -128,7 +128,7 @@ DETERMINISTIC_ATTENTION_BACKEND_CHOICES = ["flashinfer", "fa3", "triton"] DEFAULT_LORA_EVICTION_POLICY = "lru" -NSA_CHOICES = ["flashmla_prefill", "flashmla_decode", "fa3", "tilelang", "aiter"] +NSA_CHOICES = ["flashmla_sparse", "flashmla_kv", "fa3", "tilelang", "aiter"] RADIX_EVICTION_POLICY_CHOICES = ["lru", "lfu"] @@ -324,8 +324,8 @@ class ServerArgs: sampling_backend: Optional[str] = None grammar_backend: Optional[str] = None mm_attention_backend: Optional[str] = None - nsa_prefill: str = "flashmla_prefill" - nsa_decode: str = "fa3" + nsa_prefill_backend: str = "flashmla_sparse" + nsa_decode_backend: str = "fa3" # Speculative decoding enable_beta_spec: bool = False @@ -1024,10 +1024,10 @@ class ServerArgs: logger.warning("Setting KV cache dtype to fp8.") if self.kv_cache_dtype == "fp8_e4m3": - self.nsa_prefill = "flashmla_decode" - self.nsa_decode = "flashmla_decode" + self.nsa_prefill_backend = "flashmla_kv" + self.nsa_decode_backend = "flashmla_kv" logger.warning( - "Setting NSA backend to flashmla_decode for FP8 KV Cache." + "Setting NSA backend to flashmla_kv for FP8 KV Cache." ) # Logging env vars for NSA @@ -2356,14 +2356,14 @@ class ServerArgs: help="Set multimodal attention backend.", ) parser.add_argument( - "--nsa-prefill", - default=ServerArgs.nsa_prefill, + "--nsa-prefill-backend", + default=ServerArgs.nsa_prefill_backend, type=str, choices=NSA_CHOICES, ) parser.add_argument( - "--nsa-decode", - default=ServerArgs.nsa_decode, + "--nsa-decode-backend", + default=ServerArgs.nsa_decode_backend, type=str, choices=NSA_CHOICES, )