From 2a2ff9a8407f78c3af085fe3d4f76aab0ba58dae Mon Sep 17 00:00:00 2001 From: Yineng Zhang Date: Thu, 18 Sep 2025 16:27:59 -0700 Subject: [PATCH] refactor: use registry for _get_attention_backend_from_str (#10629) --- .../layers/attention/attention_registry.py | 192 ++++++++++++++++++ .../sglang/srt/model_executor/model_runner.py | 151 +------------- 2 files changed, 195 insertions(+), 148 deletions(-) create mode 100644 python/sglang/srt/layers/attention/attention_registry.py diff --git a/python/sglang/srt/layers/attention/attention_registry.py b/python/sglang/srt/layers/attention/attention_registry.py new file mode 100644 index 000000000..eb1a69d9e --- /dev/null +++ b/python/sglang/srt/layers/attention/attention_registry.py @@ -0,0 +1,192 @@ +ATTENTION_BACKENDS = {} + + +def register_attention_backend(name): + def decorator(fn): + ATTENTION_BACKENDS[name] = fn + return fn + + return decorator + + +@register_attention_backend("flashinfer") +def create_flashinfer_backend(runner): + import torch + + if not runner.use_mla_backend: + from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend + + # Init streams + if runner.server_args.speculative_algorithm == "EAGLE": + if ( + not hasattr(runner, "plan_stream_for_flashinfer") + or not runner.plan_stream_for_flashinfer + ): + runner.plan_stream_for_flashinfer = torch.cuda.Stream() + return FlashInferAttnBackend(runner) + else: + from sglang.srt.layers.attention.flashinfer_mla_backend import ( + FlashInferMLAAttnBackend, + ) + + return FlashInferMLAAttnBackend(runner) + + +@register_attention_backend("trtllm_mla") +def create_trtllm_mla_backend(runner): + if not runner.use_mla_backend: + raise ValueError("trtllm_mla backend can only be used with MLA models.") + from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend + + return TRTLLMMLABackend(runner) + + +@register_attention_backend("aiter") +def create_aiter_backend(runner): + from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend + + return AiterAttnBackend(runner) + + +@register_attention_backend("wave") +def create_wave_backend(runner): + from sglang.srt.layers.attention.wave_backend import WaveAttnBackend + + return WaveAttnBackend(runner) + + +@register_attention_backend("ascend") +def create_ascend_backend(runner): + from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend + + return AscendAttnBackend(runner) + + +@register_attention_backend("triton") +def create_triton_backend(runner): + assert not runner.model_config.is_encoder_decoder, ( + "Cross attention is not supported in the triton attention backend. " + "Please use `--attention-backend flashinfer`." + ) + if runner.server_args.enable_double_sparsity: + from sglang.srt.layers.attention.double_sparsity_backend import ( + DoubleSparseAttnBackend, + ) + + return DoubleSparseAttnBackend(runner) + else: + from sglang.srt.layers.attention.triton_backend import TritonAttnBackend + + return TritonAttnBackend(runner) + + +@register_attention_backend("torch_native") +def create_torch_native_backend(runner): + from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend + + return TorchNativeAttnBackend(runner) + + +@register_attention_backend("flex_attention") +def create_flex_attention_backend(runner): + from sglang.srt.layers.attention.torch_flex_backend import TorchFlexAttnBackend + + return TorchFlexAttnBackend(runner) + + +@register_attention_backend("flashmla") +def create_flashmla_backend(runner): + from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend + + return FlashMLABackend(runner) + + +@register_attention_backend("fa3") +def create_flashattention_v3_backend(runner): + import torch + + assert ( + torch.cuda.get_device_capability()[0] == 8 and not runner.use_mla_backend + ) or torch.cuda.get_device_capability()[0] == 9, ( + "FlashAttention v3 Backend requires SM>=80 and SM<=90. " + "Please use `--attention-backend flashinfer`." + ) + from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend + + return FlashAttentionBackend(runner) + + +@register_attention_backend("fa4") +def create_flashattention_v4_backend(runner): + assert ( + runner.use_mla_backend + ), "FlashAttention v4 Support is at an early stage, only MLA model supported now" + from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend + + return FlashAttentionBackend(runner, fa_impl_ver=4) + + +@register_attention_backend("cutlass_mla") +def create_cutlass_mla_backend(runner): + from sglang.srt.layers.attention.cutlass_mla_backend import CutlassMLABackend + + return CutlassMLABackend(runner) + + +@register_attention_backend("trtllm_mha") +def create_trtllm_mha_backend(runner): + if runner.use_mla_backend: + raise ValueError("trtllm_mha backend can only be used with non-MLA models.") + from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend + + return TRTLLMHAAttnBackend(runner) + + +@register_attention_backend("intel_amx") +def create_intel_amx_backend(runner): + from sglang.srt.layers.attention.intel_amx_backend import IntelAMXAttnBackend + + return IntelAMXAttnBackend(runner) + + +@register_attention_backend("dual_chunk_flash_attn") +def create_dual_chunk_flash_attn_backend(runner): + from sglang.srt.layers.attention.dual_chunk_flashattention_backend import ( + DualChunkFlashAttentionBackend, + ) + + return DualChunkFlashAttentionBackend(runner) + + +@register_attention_backend("hybrid_linear_attn") +def create_hybrid_linear_attn_backend(runner): + assert ( + runner.is_hybrid_gdn + ), "hybrid_linear_attn backend can only be used with hybrid GDN models." + from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( + HybridLinearAttnBackend, + MambaAttnBackend, + ) + from sglang.srt.utils import is_blackwell, is_npu + + if is_npu(): + from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend + + full_attn_backend = AscendAttnBackend(runner) + elif is_blackwell(): + from sglang.srt.layers.attention.triton_backend import TritonAttnBackend + + full_attn_backend = TritonAttnBackend(runner) + else: + from sglang.srt.layers.attention.flashattention_backend import ( + FlashAttentionBackend, + ) + + full_attn_backend = FlashAttentionBackend(runner) + + linear_attn_backend = MambaAttnBackend(runner) + full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids + + return HybridLinearAttnBackend( + full_attn_backend, linear_attn_backend, full_attn_layers + ) diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index f2df9b134..1355353bb 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -60,6 +60,7 @@ from sglang.srt.eplb.expert_location import ( set_global_expert_location_metadata, ) from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater +from sglang.srt.layers.attention.attention_registry import ATTENTION_BACKENDS from sglang.srt.layers.attention.tbo_backend import TboAttnBackend from sglang.srt.layers.dp_attention import ( get_attention_tp_group, @@ -1733,155 +1734,9 @@ class ModelRunner: return attn_backend def _get_attention_backend_from_str(self, backend_str: str): - if backend_str == "flashinfer": - if not self.use_mla_backend: - from sglang.srt.layers.attention.flashinfer_backend import ( - FlashInferAttnBackend, - ) - - # Init streams - if self.server_args.speculative_algorithm == "EAGLE": - if ( - not hasattr(self, "plan_stream_for_flashinfer") - or not self.plan_stream_for_flashinfer - ): - self.plan_stream_for_flashinfer = torch.cuda.Stream() - return FlashInferAttnBackend(self) - else: - from sglang.srt.layers.attention.flashinfer_mla_backend import ( - FlashInferMLAAttnBackend, - ) - - return FlashInferMLAAttnBackend(self) - elif backend_str == "aiter": - from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend - - return AiterAttnBackend(self) - elif self.server_args.attention_backend == "wave": - from sglang.srt.layers.attention.wave_backend import WaveAttnBackend - - return WaveAttnBackend(self) - elif backend_str == "ascend": - from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend - - return AscendAttnBackend(self) - elif backend_str == "triton": - assert not self.model_config.is_encoder_decoder, ( - "Cross attention is not supported in the triton attention backend. " - "Please use `--attention-backend flashinfer`." - ) - if self.server_args.enable_double_sparsity: - from sglang.srt.layers.attention.double_sparsity_backend import ( - DoubleSparseAttnBackend, - ) - - return DoubleSparseAttnBackend(self) - else: - from sglang.srt.layers.attention.triton_backend import TritonAttnBackend - - return TritonAttnBackend(self) - elif backend_str == "torch_native": - from sglang.srt.layers.attention.torch_native_backend import ( - TorchNativeAttnBackend, - ) - - return TorchNativeAttnBackend(self) - elif backend_str == "flex_attention": - from sglang.srt.layers.attention.torch_flex_backend import ( - TorchFlexAttnBackend, - ) - - return TorchFlexAttnBackend(self) - elif backend_str == "flashmla": - from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend - - return FlashMLABackend(self) - elif backend_str == "fa3": - assert ( - torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend - ) or torch.cuda.get_device_capability()[0] == 9, ( - "FlashAttention v3 Backend requires SM>=80 and SM<=90. " - "Please use `--attention-backend flashinfer`." - ) - from sglang.srt.layers.attention.flashattention_backend import ( - FlashAttentionBackend, - ) - - return FlashAttentionBackend(self) - elif backend_str == "fa4": - assert ( - self.use_mla_backend - ), "FlashAttention v4 Support is at an early stage, only MLA model supported now" - from sglang.srt.layers.attention.flashattention_backend import ( - FlashAttentionBackend, - ) - - return FlashAttentionBackend(self, fa_impl_ver=4) - elif backend_str == "cutlass_mla": - from sglang.srt.layers.attention.cutlass_mla_backend import ( - CutlassMLABackend, - ) - - return CutlassMLABackend(self) - elif backend_str == "trtllm_mla": - if not self.use_mla_backend: - raise ValueError("trtllm_mla backend can only be used with MLA models.") - from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend - - return TRTLLMMLABackend(self) - elif backend_str == "trtllm_mha": - if self.use_mla_backend: - raise ValueError( - "trtllm_mha backend can only be used with non-MLA models." - ) - from sglang.srt.layers.attention.trtllm_mha_backend import ( - TRTLLMHAAttnBackend, - ) - - return TRTLLMHAAttnBackend(self) - elif backend_str == "intel_amx": - from sglang.srt.layers.attention.intel_amx_backend import ( - IntelAMXAttnBackend, - ) - - return IntelAMXAttnBackend(self) - elif backend_str == "dual_chunk_flash_attn": - from sglang.srt.layers.attention.dual_chunk_flashattention_backend import ( - DualChunkFlashAttentionBackend, - ) - - return DualChunkFlashAttentionBackend(self) - elif backend_str == "hybrid_linear_attn": - assert ( - self.is_hybrid_gdn - ), "hybrid_linear_attn backend can only be used with hybrid GDN models." - from sglang.srt.layers.attention.hybrid_linear_attn_backend import ( - HybridLinearAttnBackend, - MambaAttnBackend, - ) - - if _is_npu: - from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend - - full_attn_backend = AscendAttnBackend(self) - elif is_blackwell(): - from sglang.srt.layers.attention.triton_backend import TritonAttnBackend - - full_attn_backend = TritonAttnBackend(self) - else: - from sglang.srt.layers.attention.flashattention_backend import ( - FlashAttentionBackend, - ) - - full_attn_backend = FlashAttentionBackend(self) - - linear_attn_backend = MambaAttnBackend(self) - full_attn_layers = self.model_config.hf_config.full_attention_layer_ids - return HybridLinearAttnBackend( - full_attn_backend, linear_attn_backend, full_attn_layers - ) - else: + if backend_str not in ATTENTION_BACKENDS: raise ValueError(f"Invalid attention backend: {backend_str}") + return ATTENTION_BACKENDS[backend_str](self) def init_double_sparsity_channel_config(self, selected_channel): selected_channel = "." + selected_channel + "_proj"