From 6fb29ffd9e7bf1f7d783c238cce814601ab4e105 Mon Sep 17 00:00:00 2001 From: Baizhou Zhang Date: Thu, 17 Apr 2025 01:43:33 -0700 Subject: [PATCH] Deprecate enable-flashinfer-mla and enable-flashmla (#5480) --- docs/backend/server_arguments.md | 1 - docs/references/deepseek.md | 2 +- python/sglang/srt/managers/schedule_batch.py | 3 +-- .../sglang/srt/model_executor/model_runner.py | 18 +++++++----------- python/sglang/srt/server_args.py | 14 ++++++-------- scripts/playground/bench_speculative.py | 11 +++-------- 6 files changed, 18 insertions(+), 31 deletions(-) diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index bdd708cd0..560b4deb2 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -192,6 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma * `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. * `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. -* `enable_flashinfer_mla`: Use the attention backend with FlashInfer MLA wrapper for DeepSeek models. **This argument will be deprecated in the next release. Please use `--attention_backend flashinfer` instead to enable FlashfIner MLA.** * `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend. * `disable_chunked_prefix_cache`: Disable the use of chunked prefix cache for DeepSeek models. Only use it when FA3 is attention backend. diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 91d6d78fa..77a4a38fb 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -86,7 +86,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. -- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, [Flashinfer](https://docs.flashinfer.ai/api/mla.html) and Triton backends. It can be set with `--attention-backend` argument. +- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), and [Triton](https://github.com/triton-lang/triton) backends. It can be set with `--attention-backend` argument. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 7d3afa824..e50f74dfa 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -76,7 +76,6 @@ global_server_args_dict = { "device": ServerArgs.device, "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, - "enable_flashmla": ServerArgs.enable_flashmla, "disable_radix_cache": ServerArgs.disable_radix_cache, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, "chunked_prefill_size": ServerArgs.chunked_prefill_size, @@ -1480,7 +1479,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): global_server_args_dict["use_mla_backend"] and global_server_args_dict["attention_backend"] == "flashinfer" ) - or global_server_args_dict["enable_flashmla"] + or global_server_args_dict["attention_backend"] == "flashmla" or global_server_args_dict["attention_backend"] == "fa3" ): seq_lens_cpu = self.seq_lens.cpu() diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index df443c599..7fa9c05c7 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -156,7 +156,6 @@ class ModelRunner: "device": server_args.device, "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, - "enable_flashmla": server_args.enable_flashmla, "disable_radix_cache": server_args.disable_radix_cache, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, "debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder, @@ -225,15 +224,7 @@ class ModelRunner: def model_specific_adjustment(self): server_args = self.server_args - if server_args.enable_flashinfer_mla: - # TODO: remove this branch after enable_flashinfer_mla is deprecated - logger.info("MLA optimization is turned on. Use flashinfer backend.") - server_args.attention_backend = "flashinfer" - elif server_args.enable_flashmla: - # TODO: remove this branch after enable_flashmla is deprecated - logger.info("MLA optimization is turned on. Use flashmla decode.") - server_args.attention_backend = "flashmla" - elif server_args.attention_backend is None: + if server_args.attention_backend is None: # By default, use flashinfer for non-mla attention and triton for mla attention if not self.use_mla_backend: if ( @@ -259,7 +250,12 @@ class ModelRunner: elif self.use_mla_backend: # TODO: add MLA optimization on CPU if server_args.device != "cpu": - if server_args.attention_backend in ["flashinfer", "fa3", "triton"]: + if server_args.attention_backend in [ + "flashinfer", + "fa3", + "triton", + "flashmla", + ]: logger.info( f"MLA optimization is turned on. Use {server_args.attention_backend} backend." ) diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index e172e66bb..436c0f306 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -179,8 +179,6 @@ class ServerArgs: tool_call_parser: Optional[str] = None enable_hierarchical_cache: bool = False hicache_ratio: float = 2.0 - enable_flashinfer_mla: bool = False # TODO: remove this argument - enable_flashmla: bool = False flashinfer_mla_disable_ragged: bool = False warmups: Optional[str] = None n_share_experts_fusion: int = 0 @@ -254,7 +252,7 @@ class ServerArgs: assert self.chunked_prefill_size % self.page_size == 0 - if self.enable_flashmla is True: + if self.attention_backend == "flashmla": logger.warning( "FlashMLA only supports a page_size of 64, change page_size to 64." ) @@ -823,7 +821,7 @@ class ServerArgs: parser.add_argument( "--attention-backend", type=str, - choices=["flashinfer", "triton", "torch_native", "fa3"], + choices=["flashinfer", "triton", "torch_native", "fa3", "flashmla"], default=ServerArgs.attention_backend, help="Choose the kernels for attention layers.", ) @@ -843,13 +841,13 @@ class ServerArgs: ) parser.add_argument( "--enable-flashinfer-mla", - action="store_true", - help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!", + action=DeprecatedAction, + help="--enable-flashinfer-mla is deprecated. Please use '--attention-backend flashinfer' instead.", ) parser.add_argument( "--enable-flashmla", - action="store_true", - help="Enable FlashMLA decode optimization", + action=DeprecatedAction, + help="--enable-flashmla is deprecated. Please use '--attention-backend flashmla' instead.", ) parser.add_argument( "--flashinfer-mla-disable-ragged", diff --git a/scripts/playground/bench_speculative.py b/scripts/playground/bench_speculative.py index 72faf706a..ff2e97262 100644 --- a/scripts/playground/bench_speculative.py +++ b/scripts/playground/bench_speculative.py @@ -176,16 +176,11 @@ def main(args, server_args): ] ) - if server_args.enable_flashinfer_mla: + if server_args.attention_backend: other_args.extend( [ - "--enable-flashinfer-mla", - ] - ) - if server_args.enable_flashmla: - other_args.extend( - [ - "--enable-flashmla", + "--attention-backend", + server_args.attention_backend, ] )