Remove hybrid_linear_attn attention backend and refactor attention registry (#10816)

Co-authored-by: Yi Zhang <1109276519@qq.com>
This commit is contained in:
li-kesen
2025-09-30 10:16:16 +08:00
committed by GitHub
parent 6535fda127
commit 2bc61dd194
3 changed files with 43 additions and 36 deletions

View File

@@ -60,7 +60,10 @@ from sglang.srt.eplb.expert_location import (
set_global_expert_location_metadata,
)
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
from sglang.srt.layers.attention.attention_registry import ATTENTION_BACKENDS
from sglang.srt.layers.attention.attention_registry import (
ATTENTION_BACKENDS,
attn_backend_wrapper,
)
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
from sglang.srt.layers.dp_attention import (
get_attention_tp_group,
@@ -347,7 +350,6 @@ class ModelRunner:
if self.is_hybrid_gdn:
logger.warning("Hybrid GDN model detected, disable radix cache")
self.server_args.disable_radix_cache = True
self.server_args.attention_backend = "hybrid_linear_attn"
if self.server_args.max_mamba_cache_size is None:
if self.server_args.max_running_requests is not None:
self.server_args.max_mamba_cache_size = (
@@ -1648,10 +1650,9 @@ class ModelRunner:
# Initialize token_to_kv_pool_allocator
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
if self.token_to_kv_pool_allocator is None:
if _is_npu and self.server_args.attention_backend in [
"ascend",
"hybrid_linear_attn",
]:
if _is_npu and (
self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
):
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
self.max_total_num_tokens,
page_size=self.page_size,
@@ -1764,7 +1765,8 @@ class ModelRunner:
def _get_attention_backend_from_str(self, backend_str: str):
if backend_str not in ATTENTION_BACKENDS:
raise ValueError(f"Invalid attention backend: {backend_str}")
return ATTENTION_BACKENDS[backend_str](self)
full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
return attn_backend_wrapper(self, full_attention_backend)
def init_double_sparsity_channel_config(self, selected_channel):
selected_channel = "." + selected_channel + "_proj"