From 2bc61dd19460908dcf0d94b44deed9f2b07e2460 Mon Sep 17 00:00:00 2001 From: li-kesen Date: Tue, 30 Sep 2025 10:16:16 +0800 Subject: [PATCH] Remove hybrid_linear_attn attention backend and refactor attention registry (#10816) Co-authored-by: Yi Zhang <1109276519@qq.com> --- .../layers/attention/attention_registry.py | 60 ++++++++++--------- .../sglang/srt/model_executor/model_runner.py | 16 ++--- python/sglang/srt/server_args.py | 3 +- 3 files changed, 43 insertions(+), 36 deletions(-) diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py index eb1a69d9e..658ad1f0f 100644 --- a/python/sglang/srt/layers/attention/attention_registry.py +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -1,3 +1,7 @@ +import logging + +logger = logging.getLogger(__name__) + ATTENTION_BACKENDS = {} @@ -158,35 +162,37 @@ def create_dual_chunk_flash_attn_backend(runner): return DualChunkFlashAttentionBackend(runner) -@register_attention_backend("hybrid_linear_attn") -def create_hybrid_linear_attn_backend(runner): - assert ( - runner.is_hybrid_gdn - ), "hybrid_linear_attn backend can only be used with hybrid GDN models." - from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( - HybridLinearAttnBackend, - MambaAttnBackend, - ) - from sglang.srt.utils import is_blackwell, is_npu +def attn_backend_wrapper(runner, full_attn_backend): + """ + Wrapper for special models like hybrid GDN, so we don't + need to change the code of the original attention backend. + """ + assert not ( + runner.is_hybrid_gdn and runner.use_mla_backend + ), "hybrid_gdn can only be used with non-MLA models." - if is_npu(): - from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend + # wrap for hybrid GDN models + if runner.is_hybrid_gdn: + from sglang.srt.utils import is_blackwell, is_npu - full_attn_backend = AscendAttnBackend(runner) - elif is_blackwell(): - from sglang.srt.layers.attention.triton_backend import TritonAttnBackend - - full_attn_backend = TritonAttnBackend(runner) - else: - from sglang.srt.layers.attention.flashattention_backend import ( - FlashAttentionBackend, + if is_blackwell(): + assert ( + runner.server_args.attention_backend == "triton" + ), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend." + if is_npu(): + assert ( + runner.server_args.attention_backend == "ascend" + ), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend." + logger.info(f"Using hybrid linear attention backend for hybrid GDN models.") + from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( + HybridLinearAttnBackend, + MambaAttnBackend, ) - full_attn_backend = FlashAttentionBackend(runner) + linear_attn_backend = MambaAttnBackend(runner) + full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids + return HybridLinearAttnBackend( + full_attn_backend, linear_attn_backend, full_attn_layers + ) - linear_attn_backend = MambaAttnBackend(runner) - full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids - - return HybridLinearAttnBackend( - full_attn_backend, linear_attn_backend, full_attn_layers - ) + return full_attn_backend diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 4f1a2efad..8df5dffb6 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -60,7 +60,10 @@ from sglang.srt.eplb.expert_location import ( set_global_expert_location_metadata, ) from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater -from sglang.srt.layers.attention.attention_registry import ATTENTION_BACKENDS +from sglang.srt.layers.attention.attention_registry import ( + ATTENTION_BACKENDS, + attn_backend_wrapper, +) from sglang.srt.layers.attention.tbo_backend import TboAttnBackend from sglang.srt.layers.dp_attention import ( get_attention_tp_group, @@ -347,7 +350,6 @@ class ModelRunner: if self.is_hybrid_gdn: logger.warning("Hybrid GDN model detected, disable radix cache") self.server_args.disable_radix_cache = True - self.server_args.attention_backend = "hybrid_linear_attn" if self.server_args.max_mamba_cache_size is None: if self.server_args.max_running_requests is not None: self.server_args.max_mamba_cache_size = ( @@ -1648,10 +1650,9 @@ class ModelRunner: # Initialize token_to_kv_pool_allocator need_sort = self.server_args.disaggregation_mode in ("decode", "prefill") if self.token_to_kv_pool_allocator is None: - if _is_npu and self.server_args.attention_backend in [ - "ascend", - "hybrid_linear_attn", - ]: + if _is_npu and ( + self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn + ): self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator( self.max_total_num_tokens, page_size=self.page_size, @@ -1764,7 +1765,8 @@ class ModelRunner: def _get_attention_backend_from_str(self, backend_str: str): if backend_str not in ATTENTION_BACKENDS: raise ValueError(f"Invalid attention backend: {backend_str}") - return ATTENTION_BACKENDS[backend_str](self) + full_attention_backend = ATTENTION_BACKENDS[backend_str](self) + return attn_backend_wrapper(self, full_attention_backend) def init_double_sparsity_channel_config(self, selected_channel): selected_channel = "." + selected_channel + "_proj" diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 6f571f39b..a32547120 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -100,7 +100,6 @@ ATTENTION_BACKEND_CHOICES = [ "trtllm_mla", "trtllm_mha", "dual_chunk_flash_attn", - "hybrid_linear_attn", # AMD specific "aiter", "wave", @@ -801,7 +800,7 @@ class ServerArgs: self.speculative_algorithm is None ), "Speculative decoding is currently not supported with Flex Attention backend" - if is_npu() and self.attention_backend in ["ascend", "hybrid_linear_attn"]: + if is_npu() and self.attention_backend in ["ascend"]: logger.warning( "At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128." )