Support FlashMLA backend (#4472)

Co-authored-by: yinfan98 <1106310035@qq.com>
This commit is contained in:
lukec
2025-03-17 00:07:06 +08:00
committed by GitHub
parent 1b859295f4
commit a53fe428f9
6 changed files with 209 additions and 1 deletions

View File

@@ -149,6 +149,7 @@ class ModelRunner:
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"enable_flashmla": server_args.enable_flashmla,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
@@ -223,6 +224,9 @@ class ModelRunner:
"MLA optimization is turned on. Use flashinfer mla backend."
)
server_args.attention_backend = "flashinfer_mla"
elif server_args.enable_flashmla:
logger.info("MLA optimization is turned on. Use flashmla decode.")
server_args.attention_backend = "flashmla"
else:
logger.info("MLA optimization is turned on. Use triton backend.")
server_args.attention_backend = "triton"
@@ -840,6 +844,10 @@ class ModelRunner:
)
self.attn_backend = FlashInferMLAAttnBackend(self)
elif self.server_args.attention_backend == "flashmla":
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
self.attn_backend = FlashMLABackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"