Deprecate enable-flashinfer-mla and enable-flashmla (#5480)

This commit is contained in:
Baizhou Zhang
2025-04-17 01:43:33 -07:00
committed by GitHub
parent 4fb05583ef
commit 6fb29ffd9e
6 changed files with 18 additions and 31 deletions

View File

@@ -192,6 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
* `enable_flashinfer_mla`: Use the attention backend with FlashInfer MLA wrapper for DeepSeek models. **This argument will be deprecated in the next release. Please use `--attention_backend flashinfer` instead to enable FlashfIner MLA.**
* `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend.
* `disable_chunked_prefix_cache`: Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend.

View File

@@ -86,7 +86,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, [Flashinfer](https://docs.flashinfer.ai/api/mla.html) and Triton backends. It can be set with `--attention-backend` argument.
- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), and [Triton](https://github.com/triton-lang/triton) backends. It can be set with `--attention-backend` argument.
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.

View File

@@ -76,7 +76,6 @@ global_server_args_dict = {
"device": ServerArgs.device,
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
"enable_flashmla": ServerArgs.enable_flashmla,
"disable_radix_cache": ServerArgs.disable_radix_cache,
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
"chunked_prefill_size": ServerArgs.chunked_prefill_size,
@@ -1480,7 +1479,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
global_server_args_dict["use_mla_backend"]
and global_server_args_dict["attention_backend"] == "flashinfer"
)
or global_server_args_dict["enable_flashmla"]
or global_server_args_dict["attention_backend"] == "flashmla"
or global_server_args_dict["attention_backend"] == "fa3"
):
seq_lens_cpu = self.seq_lens.cpu()

View File

@@ -156,7 +156,6 @@ class ModelRunner:
"device": server_args.device,
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
"enable_flashmla": server_args.enable_flashmla,
"disable_radix_cache": server_args.disable_radix_cache,
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
@@ -225,15 +224,7 @@ class ModelRunner:
def model_specific_adjustment(self):
server_args = self.server_args
if server_args.enable_flashinfer_mla:
# TODO: remove this branch after enable_flashinfer_mla is deprecated
logger.info("MLA optimization is turned on. Use flashinfer backend.")
server_args.attention_backend = "flashinfer"
elif server_args.enable_flashmla:
# TODO: remove this branch after enable_flashmla is deprecated
logger.info("MLA optimization is turned on. Use flashmla decode.")
server_args.attention_backend = "flashmla"
elif server_args.attention_backend is None:
if server_args.attention_backend is None:
# By default, use flashinfer for non-mla attention and triton for mla attention
if not self.use_mla_backend:
if (
@@ -259,7 +250,12 @@ class ModelRunner:
elif self.use_mla_backend:
# TODO: add MLA optimization on CPU
if server_args.device != "cpu":
if server_args.attention_backend in ["flashinfer", "fa3", "triton"]:
if server_args.attention_backend in [
"flashinfer",
"fa3",
"triton",
"flashmla",
]:
logger.info(
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
)

View File

@@ -179,8 +179,6 @@ class ServerArgs:
tool_call_parser: Optional[str] = None
enable_hierarchical_cache: bool = False
hicache_ratio: float = 2.0
enable_flashinfer_mla: bool = False # TODO: remove this argument
enable_flashmla: bool = False
flashinfer_mla_disable_ragged: bool = False
warmups: Optional[str] = None
n_share_experts_fusion: int = 0
@@ -254,7 +252,7 @@ class ServerArgs:
assert self.chunked_prefill_size % self.page_size == 0
if self.enable_flashmla is True:
if self.attention_backend == "flashmla":
logger.warning(
"FlashMLA only supports a page_size of 64, change page_size to 64."
)
@@ -823,7 +821,7 @@ class ServerArgs:
parser.add_argument(
"--attention-backend",
type=str,
choices=["flashinfer", "triton", "torch_native", "fa3"],
choices=["flashinfer", "triton", "torch_native", "fa3", "flashmla"],
default=ServerArgs.attention_backend,
help="Choose the kernels for attention layers.",
)
@@ -843,13 +841,13 @@ class ServerArgs:
)
parser.add_argument(
"--enable-flashinfer-mla",
action="store_true",
help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
action=DeprecatedAction,
help="--enable-flashinfer-mla is deprecated. Please use '--attention-backend flashinfer' instead.",
)
parser.add_argument(
"--enable-flashmla",
action="store_true",
help="Enable FlashMLA decode optimization",
action=DeprecatedAction,
help="--enable-flashmla is deprecated. Please use '--attention-backend flashmla' instead.",
)
parser.add_argument(
"--flashinfer-mla-disable-ragged",

View File

@@ -176,16 +176,11 @@ def main(args, server_args):
]
)
if server_args.enable_flashinfer_mla:
if server_args.attention_backend:
other_args.extend(
[
"--enable-flashinfer-mla",
]
)
if server_args.enable_flashmla:
other_args.extend(
[
"--enable-flashmla",
"--attention-backend",
server_args.attention_backend,
]
)