Revert "Replace enable_flashinfer_mla argument with attention_backend" (#5048)
This commit is contained in:
@@ -71,6 +71,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
self.device = model_runner.device
|
||||
self.skip_prefill = skip_prefill
|
||||
|
||||
global_config.enable_flashinfer_mla = True
|
||||
|
||||
# Allocate buffers
|
||||
global global_workspace_buffer
|
||||
if global_workspace_buffer is None:
|
||||
|
||||
@@ -76,6 +76,7 @@ global_server_args_dict = {
|
||||
"device": ServerArgs.device,
|
||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
||||
"enable_flashmla": ServerArgs.enable_flashmla,
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
@@ -1434,7 +1435,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
# Create seq_lens_cpu when needed
|
||||
if (
|
||||
global_server_args_dict["attention_backend"] == "flashinfer_mla"
|
||||
global_server_args_dict["enable_flashinfer_mla"]
|
||||
or global_server_args_dict["enable_flashmla"]
|
||||
or global_server_args_dict["attention_backend"] == "fa3"
|
||||
):
|
||||
|
||||
@@ -151,6 +151,7 @@ class ModelRunner:
|
||||
"device": server_args.device,
|
||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||
"enable_flashmla": server_args.enable_flashmla,
|
||||
"disable_radix_cache": server_args.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
@@ -222,14 +223,10 @@ class ModelRunner:
|
||||
):
|
||||
# TODO: add MLA optimization on CPU
|
||||
if server_args.device != "cpu":
|
||||
if (
|
||||
server_args.attention_backend == "flashinfer"
|
||||
or server_args.enable_flashinfer_mla
|
||||
):
|
||||
if server_args.enable_flashinfer_mla:
|
||||
logger.info(
|
||||
"MLA optimization is turned on. Use flashinfer backend."
|
||||
"MLA optimization is turned on. Use flashinfer mla backend."
|
||||
)
|
||||
# Here we use a special flashinfer_mla tag to differentiate it from normal flashinfer backend
|
||||
server_args.attention_backend = "flashinfer_mla"
|
||||
elif server_args.enable_flashmla:
|
||||
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
||||
|
||||
@@ -684,6 +684,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.w_vc = None
|
||||
self.w_scale = None
|
||||
|
||||
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
|
||||
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
||||
"flashinfer_mla_disable_ragged"
|
||||
]
|
||||
@@ -691,7 +692,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
||||
|
||||
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
||||
if self.attention_backend == "flashinfer_mla":
|
||||
if self.enable_flashinfer_mla:
|
||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||
return (
|
||||
not self.flashinfer_mla_disable_ragged
|
||||
|
||||
@@ -179,7 +179,7 @@ class ServerArgs:
|
||||
tool_call_parser: Optional[str] = None
|
||||
enable_hierarchical_cache: bool = False
|
||||
hicache_ratio: float = 2.0
|
||||
enable_flashinfer_mla: bool = False # TODO: remove this argument
|
||||
enable_flashinfer_mla: bool = False
|
||||
enable_flashmla: bool = False
|
||||
flashinfer_mla_disable_ragged: bool = False
|
||||
warmups: Optional[str] = None
|
||||
@@ -836,7 +836,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-mla",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
|
||||
help="Enable FlashInfer MLA optimization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashmla",
|
||||
|
||||
Reference in New Issue
Block a user