feat: support flashinfer mla attention for deepseek v3 (#3550)
This commit is contained in:
@@ -67,6 +67,7 @@ from sglang.srt.utils import (
|
||||
monkey_patch_p2p_access_check,
|
||||
monkey_patch_vllm_gguf_config,
|
||||
set_cpu_offload_max_bytes,
|
||||
set_cuda_arch,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -110,8 +111,14 @@ class ModelRunner:
|
||||
):
|
||||
# TODO: add MLA optimization on CPU
|
||||
if self.server_args.device != "cpu":
|
||||
logger.info("MLA optimization is turned on. Use triton backend.")
|
||||
self.server_args.attention_backend = "triton"
|
||||
if server_args.enable_flashinfer_mla:
|
||||
logger.info(
|
||||
"FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
|
||||
)
|
||||
self.server_args.attention_backend = "flashinfer"
|
||||
else:
|
||||
logger.info("MLA optimization is turned on. Use triton backend.")
|
||||
self.server_args.attention_backend = "triton"
|
||||
|
||||
if self.server_args.enable_double_sparsity:
|
||||
logger.info(
|
||||
@@ -169,6 +176,7 @@ class ModelRunner:
|
||||
"enable_dp_attention": server_args.enable_dp_attention,
|
||||
"enable_ep_moe": server_args.enable_ep_moe,
|
||||
"device": server_args.device,
|
||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -292,6 +300,8 @@ class ModelRunner:
|
||||
if torch.cuda.get_device_capability()[1] < 5:
|
||||
raise RuntimeError("SGLang only supports sm75 and above.")
|
||||
|
||||
set_cuda_arch()
|
||||
|
||||
# Prepare the model config
|
||||
self.load_config = LoadConfig(
|
||||
load_format=self.server_args.load_format,
|
||||
|
||||
Reference in New Issue
Block a user