Refactor flashinfer logic for deepseek v3 and fix accuracy bug (#3785)

This commit is contained in:
Baizhou Zhang
2025-02-24 04:07:25 -08:00
committed by GitHub
parent 27a46317b6
commit b110084654
4 changed files with 565 additions and 19 deletions

View File

@@ -34,6 +34,7 @@ from sglang.srt.distributed import (
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
from sglang.srt.layers.dp_attention import (
@@ -113,9 +114,9 @@ class ModelRunner:
if self.server_args.device != "cpu":
if server_args.enable_flashinfer_mla:
logger.info(
"FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
"MLA optimization is turned on. Use flashinfer mla backend."
)
self.server_args.attention_backend = "flashinfer"
self.server_args.attention_backend = "flashinfer_mla"
else:
logger.info("MLA optimization is turned on. Use triton backend.")
self.server_args.attention_backend = "triton"
@@ -703,6 +704,8 @@ class ModelRunner:
self.attn_backend = TritonAttnBackend(self)
elif self.server_args.attention_backend == "torch_native":
self.attn_backend = TorchNativeAttnBackend(self)
elif self.server_args.attention_backend == "flashinfer_mla":
self.attn_backend = FlashInferMLAAttnBackend(self)
else:
raise ValueError(
f"Invalid attention backend: {self.server_args.attention_backend}"