Revert "Replace enable_flashinfer_mla argument with attention_backend" (#5048)

This commit is contained in:
Lianmin Zheng
2025-04-03 13:30:56 -07:00
committed by GitHub
parent b8b6008f47
commit 74885a848b
8 changed files with 20 additions and 21 deletions

View File

@@ -71,6 +71,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
self.device = model_runner.device
self.skip_prefill = skip_prefill
global_config.enable_flashinfer_mla = True
# Allocate buffers
global global_workspace_buffer
if global_workspace_buffer is None:

View File

@@ -76,6 +76,7 @@ global_server_args_dict = {
"device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
"enable_flashmla": ServerArgs.enable_flashmla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
@@ -1434,7 +1435,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
# Create seq_lens_cpu when needed
if (
global_server_args_dict["attention_backend"] == "flashinfer_mla"
global_server_args_dict["enable_flashinfer_mla"]
or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "fa3"
):

View File

@@ -151,6 +151,7 @@ class ModelRunner:
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
"enable_flashmla": server_args.enable_flashmla,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
@@ -222,14 +223,10 @@ class ModelRunner:
):
# TODO: add MLA optimization on CPU
if server_args.device != "cpu":
if (
server_args.attention_backend == "flashinfer"
or server_args.enable_flashinfer_mla
):
if server_args.enable_flashinfer_mla:
logger.info(
"MLA optimization is turned on. Use flashinfer backend."
"MLA optimization is turned on. Use flashinfer mla backend."
)
# Here we use a special flashinfer_mla tag to differentiate it from normal flashinfer backend
server_args.attention_backend = "flashinfer_mla"
elif server_args.enable_flashmla:
logger.info("MLA optimization is turned on. Use flashmla decode.")

View File

@@ -684,6 +684,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.w_vc = None
self.w_scale = None
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
self.flashinfer_mla_disable_ragged = global_server_args_dict[
"flashinfer_mla_disable_ragged"
]
@@ -691,7 +692,7 @@ class DeepseekV2AttentionMLA(nn.Module):
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
if self.attention_backend == "flashinfer_mla":
if self.enable_flashinfer_mla:
# Flashinfer MLA: Do not absorb when enabling ragged prefill
return (
not self.flashinfer_mla_disable_ragged

View File

@@ -179,7 +179,7 @@ class ServerArgs:
tool_call_parser: Optional[str] = None
enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0
enable_flashinfer_mla: bool = False # TODO: remove this argument
enable_flashinfer_mla: bool = False
enable_flashmla: bool = False
flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None
@@ -836,7 +836,7 @@ class ServerArgs:
parser.add_argument(
"--enable-flashinfer-mla",
action="store_true",
help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
help="Enable FlashInfer MLA optimization",
)
parser.add_argument(
"--enable-flashmla",