feat: support flashinfer mla attention for deepseek v3 (#3550)

This commit is contained in:
Yineng Zhang
2025-02-14 08:50:14 +08:00
committed by GitHub
parent 368de3661e
commit 70f894b810
12 changed files with 299 additions and 135 deletions

View File

@@ -67,6 +67,7 @@ from sglang.srt.utils import (
monkey_patch_p2p_access_check,
monkey_patch_vllm_gguf_config,
set_cpu_offload_max_bytes,
set_cuda_arch,
)
logger = logging.getLogger(__name__)
@@ -110,8 +111,14 @@ class ModelRunner:
):
# TODO: add MLA optimization on CPU
if self.server_args.device != "cpu":
logger.info("MLA optimization is turned on. Use triton backend.")
self.server_args.attention_backend = "triton"
if server_args.enable_flashinfer_mla:
logger.info(
"FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
)
self.server_args.attention_backend = "flashinfer"
else:
logger.info("MLA optimization is turned on. Use triton backend.")
self.server_args.attention_backend = "triton"
if self.server_args.enable_double_sparsity:
logger.info(
@@ -169,6 +176,7 @@ class ModelRunner:
"enable_dp_attention": server_args.enable_dp_attention,
"enable_ep_moe": server_args.enable_ep_moe,
"device": server_args.device,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
}
)
@@ -292,6 +300,8 @@ class ModelRunner:
if torch.cuda.get_device_capability()[1] < 5:
raise RuntimeError("SGLang only supports sm75 and above.")
set_cuda_arch()
# Prepare the model config
self.load_config = LoadConfig(
load_format=self.server_args.load_format,