diff --git a/docs/backend/server_arguments.md b/docs/backend/server_arguments.md index 3c96a6816..983c7316c 100644 --- a/docs/backend/server_arguments.md +++ b/docs/backend/server_arguments.md @@ -138,7 +138,7 @@ Please consult the documentation below to learn more about the parameters you ma ## Kernel backend -* `attention_backend`: The backend for attention computation and KV cache management. +* `attention_backend`: This argument specifies the backend for attention computation and KV cache management, which can be `fa3`, `flashinfer`, `triton`, or `torch_native`. When deploying DeepSeek models, use this argument to specify the MLA backend. * `sampling_backend`: The backend for sampling. ## Constrained Decoding @@ -192,5 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma * `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you. * `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-, fp8wo, fp8dq-per_tensor, fp8dq-per_row. * `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8. -* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. When providing this argument, `attention_backend` argument is overridden. -* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on. +* `enable_flashinfer_mla`: Use the attention backend with FlashInfer MLA wrapper for DeepSeek models. **This argument will be deprecated in the next release. Please use `--attention_backend flashinfer` instead to enable FlashfIner MLA.** +* `flashinfer_mla_disable_ragged`: Disable the use of the ragged prefill wrapper for the FlashInfer MLA attention backend. Only use it when FlashInfer is being used as the MLA backend. diff --git a/docs/references/deepseek.md b/docs/references/deepseek.md index 9a079b9e7..4a1ed37d2 100644 --- a/docs/references/deepseek.md +++ b/docs/references/deepseek.md @@ -86,7 +86,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be - **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase. -- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off. +- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, [Flashinfer](https://docs.flashinfer.ai/api/mla.html) and Triton backends. It can be set with `--attention-backend` argument. - **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption. @@ -149,7 +149,7 @@ python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --spec ``` - The draft model are available at huggingface: [lmsys/DeepSeek-V3-0324-NextN](https://huggingface.co/lmsys/DeepSeek-V3-0324-NextN), [lmsys/DeepSeek-R1-NextN](https://huggingface.co/lmsys/DeepSeek-R1-NextN). It can also be exported from original DeepSeek-V3/R1 model with [export_deepseek_nextn.py](https://github.com/sgl-project/sglang/blob/main/scripts/export_deepseek_nextn.py) script. - The best configuratin for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes. -- Currently when using flashinfer mla wrapper (`--enable-flashinfer-mla`) and speculative decoding together, the `--speculative-eagle-topk` parameter should be set to `1`. +When using FlashInfer MLA wrapper (`--attention-backend flashinfer`) with speculative decoding, set the `--speculative-eagle-topk` parameter to `1`. The FlashAttention 3 backend also only supports `--speculative-eagle-topk 1`. - To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)): - Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value. - Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it. diff --git a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py index 65bcdf513..81afcb9da 100644 --- a/python/sglang/srt/layers/attention/flashinfer_mla_backend.py +++ b/python/sglang/srt/layers/attention/flashinfer_mla_backend.py @@ -71,8 +71,6 @@ class FlashInferMLAAttnBackend(AttentionBackend): self.device = model_runner.device self.skip_prefill = skip_prefill - global_config.enable_flashinfer_mla = True - # Allocate buffers global global_workspace_buffer if global_workspace_buffer is None: diff --git a/python/sglang/srt/managers/schedule_batch.py b/python/sglang/srt/managers/schedule_batch.py index 2c47a6ed2..26b6b91b7 100644 --- a/python/sglang/srt/managers/schedule_batch.py +++ b/python/sglang/srt/managers/schedule_batch.py @@ -76,7 +76,6 @@ global_server_args_dict = { "device": ServerArgs.device, "speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single, "speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc, - "enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla, "enable_flashmla": ServerArgs.enable_flashmla, "disable_radix_cache": ServerArgs.disable_radix_cache, "flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged, @@ -1437,7 +1436,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin): # Create seq_lens_cpu when needed if ( - global_server_args_dict["enable_flashinfer_mla"] + ( + global_server_args_dict["use_mla_backend"] + and global_server_args_dict["attention_backend"] == "flashinfer" + ) or global_server_args_dict["enable_flashmla"] or global_server_args_dict["attention_backend"] == "fa3" ): diff --git a/python/sglang/srt/model_executor/model_runner.py b/python/sglang/srt/model_executor/model_runner.py index 0f345daba..65a3d5fed 100644 --- a/python/sglang/srt/model_executor/model_runner.py +++ b/python/sglang/srt/model_executor/model_runner.py @@ -75,6 +75,7 @@ from sglang.srt.utils import ( get_available_gpu_memory, init_custom_process_group, is_cuda, + is_flashinfer_available, is_hip, monkey_patch_p2p_access_check, monkey_patch_vllm_gguf_config, @@ -123,6 +124,10 @@ class ModelRunner: self.page_size = server_args.page_size self.req_to_token_pool = req_to_token_pool self.token_to_kv_pool_allocator = token_to_kv_pool_allocator + self.use_mla_backend = ( + self.model_config.attention_arch == AttentionArch.MLA + and not server_args.disable_mla + ) # Model-specific adjustment self.model_specific_adjustment() @@ -151,7 +156,6 @@ class ModelRunner: "device": server_args.device, "speculative_accept_threshold_single": server_args.speculative_accept_threshold_single, "speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc, - "enable_flashinfer_mla": server_args.enable_flashinfer_mla, "enable_flashmla": server_args.enable_flashmla, "disable_radix_cache": server_args.disable_radix_cache, "flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged, @@ -159,6 +163,7 @@ class ModelRunner: "debug_tensor_dump_inject": server_args.debug_tensor_dump_inject, "n_share_experts_fusion": server_args.n_share_experts_fusion, "disable_shared_experts_fusion": server_args.disable_shared_experts_fusion, + "use_mla_backend": self.use_mla_backend, } ) @@ -219,27 +224,38 @@ class ModelRunner: def model_specific_adjustment(self): server_args = self.server_args - if ( - self.model_config.attention_arch == AttentionArch.MLA - and not server_args.disable_mla - ): + if server_args.enable_flashinfer_mla: + # TODO: remove this branch after enable_flashinfer_mla is deprecated + logger.info("MLA optimization is turned on. Use flashinfer backend.") + server_args.attention_backend = "flashinfer" + elif server_args.enable_flashmla: + # TODO: remove this branch after enable_flashmla is deprecated + logger.info("MLA optimization is turned on. Use flashmla decode.") + server_args.attention_backend = "flashmla" + elif server_args.attention_backend is None: + # By default, use flashinfer for non-mla attention and triton for mla attention + if not self.use_mla_backend: + server_args.attention_backend = ( + "flashinfer" if is_flashinfer_available() else "triton" + ) + else: + server_args.attention_backend = "triton" + logger.info( + f"Attention backend not set. Use {server_args.attention_backend} backend by default." + ) + elif self.use_mla_backend: # TODO: add MLA optimization on CPU if server_args.device != "cpu": - if server_args.enable_flashinfer_mla: + if server_args.attention_backend in ["flashinfer", "fa3", "triton"]: logger.info( - "MLA optimization is turned on. Use flashinfer mla backend." - ) - server_args.attention_backend = "flashinfer_mla" - elif server_args.enable_flashmla: - logger.info("MLA optimization is turned on. Use flashmla decode.") - server_args.attention_backend = "flashmla" - elif server_args.attention_backend == "fa3": - logger.info( - f"MLA optimization is turned on. Use flash attention 3 backend." + f"MLA optimization is turned on. Use {server_args.attention_backend} backend." ) else: - logger.info("MLA optimization is turned on. Use triton backend.") - server_args.attention_backend = "triton" + raise ValueError( + f"Invalid attention backend for MLA: {server_args.attention_backend}" + ) + else: + raise ValueError(f"MLA optimization not supported on CPU.") if server_args.enable_double_sparsity: logger.info( @@ -637,10 +653,7 @@ class ModelRunner: available_gpu_memory = get_available_gpu_memory( self.device, self.gpu_id, distributed=self.tp_size > 1 ) - if ( - self.model_config.attention_arch == AttentionArch.MLA - and not self.server_args.disable_mla - ): + if self.use_mla_backend: cell_size = ( (self.model_config.kv_lora_rank + self.model_config.qk_rope_head_dim) * self.model_config.num_hidden_layers @@ -751,10 +764,7 @@ class ModelRunner: # Draft worker shares req_to_token_pool with the target worker. assert self.is_draft_worker - if ( - self.model_config.attention_arch == AttentionArch.MLA - and not self.server_args.disable_mla - ): + if self.use_mla_backend: self.token_to_kv_pool = MLATokenToKVPool( self.max_total_num_tokens, page_size=self.page_size, @@ -825,14 +835,21 @@ class ModelRunner: def init_attention_backend(self): """Init attention kernel backend.""" if self.server_args.attention_backend == "flashinfer": - from sglang.srt.layers.attention.flashinfer_backend import ( - FlashInferAttnBackend, - ) + if not self.use_mla_backend: + from sglang.srt.layers.attention.flashinfer_backend import ( + FlashInferAttnBackend, + ) - # Init streams - if self.server_args.speculative_algorithm == "EAGLE": - self.plan_stream_for_flashinfer = torch.cuda.Stream() - self.attn_backend = FlashInferAttnBackend(self) + # Init streams + if self.server_args.speculative_algorithm == "EAGLE": + self.plan_stream_for_flashinfer = torch.cuda.Stream() + self.attn_backend = FlashInferAttnBackend(self) + else: + from sglang.srt.layers.attention.flashinfer_mla_backend import ( + FlashInferMLAAttnBackend, + ) + + self.attn_backend = FlashInferMLAAttnBackend(self) elif self.server_args.attention_backend == "triton": assert self.sliding_window_size is None, ( "Window attention is not supported in the triton attention backend. " @@ -858,12 +875,6 @@ class ModelRunner: ) self.attn_backend = TorchNativeAttnBackend(self) - elif self.server_args.attention_backend == "flashinfer_mla": - from sglang.srt.layers.attention.flashinfer_mla_backend import ( - FlashInferMLAAttnBackend, - ) - - self.attn_backend = FlashInferMLAAttnBackend(self) elif self.server_args.attention_backend == "flashmla": from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend diff --git a/python/sglang/srt/models/deepseek_v2.py b/python/sglang/srt/models/deepseek_v2.py index 91b3495f8..fa6de84b2 100644 --- a/python/sglang/srt/models/deepseek_v2.py +++ b/python/sglang/srt/models/deepseek_v2.py @@ -686,7 +686,6 @@ class DeepseekV2AttentionMLA(nn.Module): self.w_vc = None self.w_scale = None - self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"] self.flashinfer_mla_disable_ragged = global_server_args_dict[ "flashinfer_mla_disable_ragged" ] @@ -694,7 +693,7 @@ class DeepseekV2AttentionMLA(nn.Module): self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1" def no_absorb(self, forward_batch: ForwardBatch) -> bool: - if self.enable_flashinfer_mla: + if self.attention_backend == "flashinfer": # Flashinfer MLA: Do not absorb when enabling ragged prefill return ( not self.flashinfer_mla_disable_ragged diff --git a/python/sglang/srt/server_args.py b/python/sglang/srt/server_args.py index 1ed3d6880..71b7910ee 100644 --- a/python/sglang/srt/server_args.py +++ b/python/sglang/srt/server_args.py @@ -179,7 +179,7 @@ class ServerArgs: tool_call_parser: Optional[str] = None enable_hierarchical_cache: bool = False hicache_ratio: float = 2.0 - enable_flashinfer_mla: bool = False + enable_flashinfer_mla: bool = False # TODO: remove this argument enable_flashmla: bool = False flashinfer_mla_disable_ragged: bool = False warmups: Optional[str] = None @@ -267,15 +267,11 @@ class ServerArgs: else: self.cuda_graph_max_bs = 160 - # Choose kernel backends + # Set kernel backends for hpu device if self.device == "hpu": self.attention_backend = "torch_native" self.sampling_backend = "pytorch" - if self.attention_backend is None: - self.attention_backend = ( - "flashinfer" if is_flashinfer_available() else "triton" - ) if self.sampling_backend is None: self.sampling_backend = ( "flashinfer" if is_flashinfer_available() else "pytorch" @@ -842,7 +838,7 @@ class ServerArgs: parser.add_argument( "--enable-flashinfer-mla", action="store_true", - help="Enable FlashInfer MLA optimization", + help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!", ) parser.add_argument( "--enable-flashmla", diff --git a/python/sglang/srt/speculative/eagle_worker.py b/python/sglang/srt/speculative/eagle_worker.py index ffec4a0ad..514603424 100644 --- a/python/sglang/srt/speculative/eagle_worker.py +++ b/python/sglang/srt/speculative/eagle_worker.py @@ -11,7 +11,11 @@ from sglang.srt.distributed import GroupCoordinator, patch_tensor_parallel_group from sglang.srt.layers.dp_attention import disable_dp_size from sglang.srt.layers.logits_processor import LogitsProcessorOutput from sglang.srt.layers.sampler import get_token_ids_logprobs, get_top_logprobs -from sglang.srt.managers.schedule_batch import ScheduleBatch, get_last_loc +from sglang.srt.managers.schedule_batch import ( + ScheduleBatch, + get_last_loc, + global_server_args_dict, +) from sglang.srt.managers.tp_worker import TpModelWorker from sglang.srt.model_executor.forward_batch_info import ( CaptureHiddenMode, @@ -146,15 +150,26 @@ class EAGLEWorker(TpModelWorker): def init_attention_backend(self): # Create multi-step attn backends and cuda graph runners if self.server_args.attention_backend == "flashinfer": - from sglang.srt.layers.attention.flashinfer_backend import ( - FlashInferMultiStepDraftBackend, - ) + if not global_server_args_dict["use_mla_backend"]: + from sglang.srt.layers.attention.flashinfer_backend import ( + FlashInferMultiStepDraftBackend, + ) - self.draft_attn_backend = FlashInferMultiStepDraftBackend( - self.draft_model_runner, - self.topk, - self.speculative_num_steps, - ) + self.draft_attn_backend = FlashInferMultiStepDraftBackend( + self.draft_model_runner, + self.topk, + self.speculative_num_steps, + ) + else: + from sglang.srt.layers.attention.flashinfer_mla_backend import ( + FlashInferMLAMultiStepDraftBackend, + ) + + self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend( + self.draft_model_runner, + self.topk, + self.speculative_num_steps, + ) self.draft_extend_attn_backend = None self.padded_static_len = self.speculative_num_steps + 1 self.has_prefill_wrapper_verify = True @@ -171,19 +186,6 @@ class EAGLEWorker(TpModelWorker): self.draft_extend_attn_backend = None self.padded_static_len = self.speculative_num_steps + 1 self.has_prefill_wrapper_verify = False - elif self.server_args.attention_backend == "flashinfer_mla": - from sglang.srt.layers.attention.flashinfer_mla_backend import ( - FlashInferMLAMultiStepDraftBackend, - ) - - self.draft_attn_backend = FlashInferMLAMultiStepDraftBackend( - self.draft_model_runner, - self.topk, - self.speculative_num_steps, - ) - self.draft_extend_attn_backend = None - self.padded_static_len = self.speculative_num_steps + 1 - self.has_prefill_wrapper_verify = True elif self.server_args.attention_backend == "fa3": from sglang.srt.layers.attention.flashattention_backend import ( FlashAttentionMultiStepBackend, diff --git a/test/srt/test_mla_flashinfer.py b/test/srt/test_mla_flashinfer.py index 81de94d3a..4f0953e6a 100644 --- a/test/srt/test_mla_flashinfer.py +++ b/test/srt/test_mla_flashinfer.py @@ -26,7 +26,8 @@ class TestFlashinferMLA(CustomTestCase): "--enable-torch-compile", "--cuda-graph-max-bs", "2", - "--enable-flashinfer-mla", + "--attention-backend", + "flashinfer", ] ) cls.process = popen_launch_server( @@ -69,8 +70,8 @@ class TestFlashinferMLANoRagged(CustomTestCase): "--disable-cuda-graph", "--cuda-graph-max-bs", "4", - "--enable-flashinfer-mla", - "--flashinfer-mla-disable-ragged", + "--attention-backend", + "flashinfer", ] ) cls.process = popen_launch_server( @@ -125,7 +126,8 @@ class TestFlashinferMLAMTP(CustomTestCase): "1", "--speculative-num-draft-tokens", "4", - "--enable-flashinfer-mla", + "--attention-backend", + "flashinfer", ] ) cls.process = popen_launch_server(