Refactor flashinfer logic for deepseek v3 and fix accuracy bug (#3785)
This commit is contained in:
@@ -34,6 +34,7 @@ from sglang.srt.distributed import (
|
||||
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
||||
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
||||
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
||||
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
||||
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
||||
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
||||
from sglang.srt.layers.dp_attention import (
|
||||
@@ -113,9 +114,9 @@ class ModelRunner:
|
||||
if self.server_args.device != "cpu":
|
||||
if server_args.enable_flashinfer_mla:
|
||||
logger.info(
|
||||
"FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
|
||||
"MLA optimization is turned on. Use flashinfer mla backend."
|
||||
)
|
||||
self.server_args.attention_backend = "flashinfer"
|
||||
self.server_args.attention_backend = "flashinfer_mla"
|
||||
else:
|
||||
logger.info("MLA optimization is turned on. Use triton backend.")
|
||||
self.server_args.attention_backend = "triton"
|
||||
@@ -703,6 +704,8 @@ class ModelRunner:
|
||||
self.attn_backend = TritonAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "torch_native":
|
||||
self.attn_backend = TorchNativeAttnBackend(self)
|
||||
elif self.server_args.attention_backend == "flashinfer_mla":
|
||||
self.attn_backend = FlashInferMLAAttnBackend(self)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid attention backend: {self.server_args.attention_backend}"
|
||||
|
||||
Reference in New Issue
Block a user