Remove hybrid_linear_attn attention backend and refactor attention registry (#10816)
Co-authored-by: Yi Zhang <1109276519@qq.com>
This commit is contained in:
@@ -1,3 +1,7 @@
|
|||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
ATTENTION_BACKENDS = {}
|
ATTENTION_BACKENDS = {}
|
||||||
|
|
||||||
|
|
||||||
@@ -158,35 +162,37 @@ def create_dual_chunk_flash_attn_backend(runner):
|
|||||||
return DualChunkFlashAttentionBackend(runner)
|
return DualChunkFlashAttentionBackend(runner)
|
||||||
|
|
||||||
|
|
||||||
@register_attention_backend("hybrid_linear_attn")
|
def attn_backend_wrapper(runner, full_attn_backend):
|
||||||
def create_hybrid_linear_attn_backend(runner):
|
"""
|
||||||
assert (
|
Wrapper for special models like hybrid GDN, so we don't
|
||||||
runner.is_hybrid_gdn
|
need to change the code of the original attention backend.
|
||||||
), "hybrid_linear_attn backend can only be used with hybrid GDN models."
|
"""
|
||||||
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
assert not (
|
||||||
HybridLinearAttnBackend,
|
runner.is_hybrid_gdn and runner.use_mla_backend
|
||||||
MambaAttnBackend,
|
), "hybrid_gdn can only be used with non-MLA models."
|
||||||
)
|
|
||||||
from sglang.srt.utils import is_blackwell, is_npu
|
|
||||||
|
|
||||||
if is_npu():
|
# wrap for hybrid GDN models
|
||||||
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
if runner.is_hybrid_gdn:
|
||||||
|
from sglang.srt.utils import is_blackwell, is_npu
|
||||||
|
|
||||||
full_attn_backend = AscendAttnBackend(runner)
|
if is_blackwell():
|
||||||
elif is_blackwell():
|
assert (
|
||||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
runner.server_args.attention_backend == "triton"
|
||||||
|
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
|
||||||
full_attn_backend = TritonAttnBackend(runner)
|
if is_npu():
|
||||||
else:
|
assert (
|
||||||
from sglang.srt.layers.attention.flashattention_backend import (
|
runner.server_args.attention_backend == "ascend"
|
||||||
FlashAttentionBackend,
|
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
|
||||||
|
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
|
||||||
|
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
||||||
|
HybridLinearAttnBackend,
|
||||||
|
MambaAttnBackend,
|
||||||
)
|
)
|
||||||
|
|
||||||
full_attn_backend = FlashAttentionBackend(runner)
|
linear_attn_backend = MambaAttnBackend(runner)
|
||||||
|
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
|
||||||
|
return HybridLinearAttnBackend(
|
||||||
|
full_attn_backend, linear_attn_backend, full_attn_layers
|
||||||
|
)
|
||||||
|
|
||||||
linear_attn_backend = MambaAttnBackend(runner)
|
return full_attn_backend
|
||||||
full_attn_layers = runner.model_config.hf_config.full_attention_layer_ids
|
|
||||||
|
|
||||||
return HybridLinearAttnBackend(
|
|
||||||
full_attn_backend, linear_attn_backend, full_attn_layers
|
|
||||||
)
|
|
||||||
|
|||||||
@@ -60,7 +60,10 @@ from sglang.srt.eplb.expert_location import (
|
|||||||
set_global_expert_location_metadata,
|
set_global_expert_location_metadata,
|
||||||
)
|
)
|
||||||
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
|
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
|
||||||
from sglang.srt.layers.attention.attention_registry import ATTENTION_BACKENDS
|
from sglang.srt.layers.attention.attention_registry import (
|
||||||
|
ATTENTION_BACKENDS,
|
||||||
|
attn_backend_wrapper,
|
||||||
|
)
|
||||||
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
||||||
from sglang.srt.layers.dp_attention import (
|
from sglang.srt.layers.dp_attention import (
|
||||||
get_attention_tp_group,
|
get_attention_tp_group,
|
||||||
@@ -347,7 +350,6 @@ class ModelRunner:
|
|||||||
if self.is_hybrid_gdn:
|
if self.is_hybrid_gdn:
|
||||||
logger.warning("Hybrid GDN model detected, disable radix cache")
|
logger.warning("Hybrid GDN model detected, disable radix cache")
|
||||||
self.server_args.disable_radix_cache = True
|
self.server_args.disable_radix_cache = True
|
||||||
self.server_args.attention_backend = "hybrid_linear_attn"
|
|
||||||
if self.server_args.max_mamba_cache_size is None:
|
if self.server_args.max_mamba_cache_size is None:
|
||||||
if self.server_args.max_running_requests is not None:
|
if self.server_args.max_running_requests is not None:
|
||||||
self.server_args.max_mamba_cache_size = (
|
self.server_args.max_mamba_cache_size = (
|
||||||
@@ -1648,10 +1650,9 @@ class ModelRunner:
|
|||||||
# Initialize token_to_kv_pool_allocator
|
# Initialize token_to_kv_pool_allocator
|
||||||
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
||||||
if self.token_to_kv_pool_allocator is None:
|
if self.token_to_kv_pool_allocator is None:
|
||||||
if _is_npu and self.server_args.attention_backend in [
|
if _is_npu and (
|
||||||
"ascend",
|
self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
|
||||||
"hybrid_linear_attn",
|
):
|
||||||
]:
|
|
||||||
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
||||||
self.max_total_num_tokens,
|
self.max_total_num_tokens,
|
||||||
page_size=self.page_size,
|
page_size=self.page_size,
|
||||||
@@ -1764,7 +1765,8 @@ class ModelRunner:
|
|||||||
def _get_attention_backend_from_str(self, backend_str: str):
|
def _get_attention_backend_from_str(self, backend_str: str):
|
||||||
if backend_str not in ATTENTION_BACKENDS:
|
if backend_str not in ATTENTION_BACKENDS:
|
||||||
raise ValueError(f"Invalid attention backend: {backend_str}")
|
raise ValueError(f"Invalid attention backend: {backend_str}")
|
||||||
return ATTENTION_BACKENDS[backend_str](self)
|
full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
|
||||||
|
return attn_backend_wrapper(self, full_attention_backend)
|
||||||
|
|
||||||
def init_double_sparsity_channel_config(self, selected_channel):
|
def init_double_sparsity_channel_config(self, selected_channel):
|
||||||
selected_channel = "." + selected_channel + "_proj"
|
selected_channel = "." + selected_channel + "_proj"
|
||||||
|
|||||||
@@ -100,7 +100,6 @@ ATTENTION_BACKEND_CHOICES = [
|
|||||||
"trtllm_mla",
|
"trtllm_mla",
|
||||||
"trtllm_mha",
|
"trtllm_mha",
|
||||||
"dual_chunk_flash_attn",
|
"dual_chunk_flash_attn",
|
||||||
"hybrid_linear_attn",
|
|
||||||
# AMD specific
|
# AMD specific
|
||||||
"aiter",
|
"aiter",
|
||||||
"wave",
|
"wave",
|
||||||
@@ -801,7 +800,7 @@ class ServerArgs:
|
|||||||
self.speculative_algorithm is None
|
self.speculative_algorithm is None
|
||||||
), "Speculative decoding is currently not supported with Flex Attention backend"
|
), "Speculative decoding is currently not supported with Flex Attention backend"
|
||||||
|
|
||||||
if is_npu() and self.attention_backend in ["ascend", "hybrid_linear_attn"]:
|
if is_npu() and self.attention_backend in ["ascend"]:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
|
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
|
||||||
)
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user