[Revision] Replace enable_flashinfer_mla argument with attention_backend (#5052)

This commit is contained in:
Baizhou Zhang
2025-04-05 01:23:02 -07:00
committed by GitHub
parent ca8d02abd5
commit efbae697b3
9 changed files with 92 additions and 82 deletions

View File

@@ -75,6 +75,7 @@ from sglang.srt.utils import (
get_available_gpu_memory,
init_custom_process_group,
is_cuda,
is_flashinfer_available,
is_hip,
monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config,
@@ -123,6 +124,10 @@ class ModelRunner:
self.page_size = server_args.page_size
self.req_to_token_pool = req_to_token_pool
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
self.use_mla_backend = (
self.model_config.attention_arch == AttentionArch.MLA
and not server_args.disable_mla
)
# Model-specific adjustment
self.model_specific_adjustment()
@@ -151,7 +156,6 @@ class ModelRunner:
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"enable_flashmla": server_args.enable_flashmla,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
@@ -159,6 +163,7 @@ class ModelRunner:
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
"n_share_experts_fusion": server_args.n_share_experts_fusion,
"disable_shared_experts_fusion": server_args.disable_shared_experts_fusion,
"use_mla_backend": self.use_mla_backend,
}
)
@@ -219,27 +224,38 @@ class ModelRunner:
def model_specific_adjustment(self):
server_args = self.server_args
if (
self.model_config.attention_arch == AttentionArch.MLA
and not server_args.disable_mla
):
if server_args.enable_flashinfer_mla:
# TODO: remove this branch after enable_flashinfer_mla is deprecated
logger.info("MLA optimization is turned on. Use flashinfer backend.")
server_args.attention_backend = "flashinfer"
elif server_args.enable_flashmla:
# TODO: remove this branch after enable_flashmla is deprecated
logger.info("MLA optimization is turned on. Use flashmla decode.")
server_args.attention_backend = "flashmla"
elif server_args.attention_backend is None:
# By default, use flashinfer for non-mla attention and triton for mla attention
if not self.use_mla_backend:
server_args.attention_backend = (
"flashinfer" if is_flashinfer_available() else "triton"
)
else:
server_args.attention_backend = "triton"
logger.info(
f"Attention backend not set. Use {server_args.attention_backend} backend by default."
)
elif self.use_mla_backend:
# TODO: add MLA optimization on CPU
if server_args.device != "cpu":
if server_args.enable_flashinfer_mla:
if server_args.attention_backend in ["flashinfer", "fa3", "triton"]:
logger.info(
"MLA optimization is turned on. Use flashinfer mla backend."
)
server_args.attention_backend = "flashinfer_mla"
elif server_args.enable_flashmla:
logger.info("MLA optimization is turned on. Use flashmla decode.")
server_args.attention_backend = "flashmla"
elif server_args.attention_backend == "fa3":
logger.info(
f"MLA optimization is turned on. Use flash attention 3 backend."
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
)
else:
logger.info("MLA optimization is turned on. Use triton backend.")
server_args.attention_backend = "triton"
raise ValueError(
f"Invalid attention backend for MLA: {server_args.attention_backend}"
)
else:
raise ValueError(f"MLA optimization not supported on CPU.")
if server_args.enable_double_sparsity:
logger.info(
@@ -637,10 +653,7 @@ class ModelRunner:
available_gpu_memory = get_available_gpu_memory(
self.device, self.gpu_id, distributed=self.tp_size > 1
)
if (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
):
if self.use_mla_backend:
cell_size = (
(self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim)
* self.model_config.num_hidden_layers
@@ -751,10 +764,7 @@ class ModelRunner:
# Draft worker shares req_to_token_pool with the target worker.
assert self.is_draft_worker
if (
self.model_config.attention_arch == AttentionArch.MLA
and not self.server_args.disable_mla
):
if self.use_mla_backend:
self.token_to_kv_pool = MLATokenToKVPool(
self.max_total_num_tokens,
page_size=self.page_size,
@@ -825,14 +835,21 @@ class ModelRunner:
def init_attention_backend(self):
"""Init attention kernel backend."""
if self.server_args.attention_backend == "flashinfer":
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
)
if not self.use_mla_backend:
from sglang.srt.layers.attention.flashinfer_backend import (
FlashInferAttnBackend,
)
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
self.attn_backend = FlashInferAttnBackend(self)
# Init streams
if self.server_args.speculative_algorithm == "EAGLE":
self.plan_stream_for_flashinfer = torch.cuda.Stream()
self.attn_backend = FlashInferAttnBackend(self)
else:
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
)
self.attn_backend = FlashInferMLAAttnBackend(self)
elif self.server_args.attention_backend == "triton":
assert self.sliding_window_size is None, (
"Window attention is not supported in the triton attention backend. "
@@ -858,12 +875,6 @@ class ModelRunner:
)
self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla":
from sglang.srt.layers.attention.flashinfer_mla_backend import (
FlashInferMLAAttnBackend,
)
self.attn_backend = FlashInferMLAAttnBackend(self)
elif self.server_args.attention_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend