Revert "Replace enable_flashinfer_mla argument with attention_backend" (#5048)
This commit is contained in:
@@ -138,7 +138,7 @@ Please consult the documentation below to learn more about the parameters you ma
|
||||
|
||||
## Kernel backend
|
||||
|
||||
* `attention_backend`: The backend for attention computation and KV cache management, and can be one of `fa3`, `flashinfer`, `triton` or `torch_native`. When deploying deepseek models, this argument is for specifying the MLA backend it uses.
|
||||
* `attention_backend`: The backend for attention computation and KV cache management.
|
||||
* `sampling_backend`: The backend for sampling.
|
||||
|
||||
## Constrained Decoding
|
||||
@@ -192,5 +192,5 @@ Please consult the documentation below to learn more about the parameters you ma
|
||||
* `cuda_graph_bs`: The batch sizes to capture by `CudaGraphRunner`. By default this is done for you.
|
||||
* `torchao_config`: Experimental feature that optimizes the model with [torchao](https://github.com/pytorch/ao). Possible choices are: int8dq, int8wo, int4wo-<group_size>, fp8wo, fp8dq-per_tensor, fp8dq-per_row.
|
||||
* `triton_attention_num_kv_splits`: Use to adjust the number of KV splits in triton kernels. Default is 8.
|
||||
* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. **This argument will be deprecated soon! Please use `--attention_backend flashinfer` instead for switching on flashfiner mla!**
|
||||
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when flashinfer is used as mla backend turned on.
|
||||
* `enable_flashinfer_mla`: Use the attention backend with flashinfer MLA wrapper for deepseek models. When providing this argument, `attention_backend` argument is overridden.
|
||||
* `flashinfer_mla_disable_ragged`: Disable usage of ragged prefill wrapper for flashinfer mla attention backend. Should be used when `enable_flashinfer_mla` is turned on.
|
||||
|
||||
@@ -86,7 +86,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
|
||||
|
||||
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
|
||||
|
||||
- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including FlashAttention3, [Flashinfer](https://docs.flashinfer.ai/api/mla.html) and Triton backends. It can be set with `--attention-backend` argument.
|
||||
- **Flashinfer MLA Wrapper**: By providing `--enable-flashinfer-mla` argument, the server will use MLA kernels customized by Flashinfer. More details can be referred to [this document](https://docs.flashinfer.ai/api/mla.html). Under long input scenarios, flashinfer mla can improve performance significantly. Optimized triton kernels will be used when flashinfer mla is turned off.
|
||||
|
||||
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
|
||||
|
||||
@@ -149,7 +149,7 @@ python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --spec
|
||||
```
|
||||
- The draft model are available at huggingface: [lmsys/DeepSeek-V3-0324-NextN](https://huggingface.co/lmsys/DeepSeek-V3-0324-NextN), [lmsys/DeepSeek-R1-NextN](https://huggingface.co/lmsys/DeepSeek-R1-NextN). It can also be exported from original DeepSeek-V3/R1 model with [export_deepseek_nextn.py](https://github.com/sgl-project/sglang/blob/main/scripts/export_deepseek_nextn.py) script.
|
||||
- The best configuratin for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes.
|
||||
- Currently when using flashinfer mla wrapper (`--attention-backend flashinfer`) and speculative decoding together, the `--speculative-eagle-topk` parameter should be set to `1`. The MTP feature on FlashAttention 3 backend is still under beta.
|
||||
- Currently when using flashinfer mla wrapper (`--enable-flashinfer-mla`) and speculative decoding together, the `--speculative-eagle-topk` parameter should be set to `1`.
|
||||
- To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)):
|
||||
- Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value.
|
||||
- Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it.
|
||||
|
||||
@@ -71,6 +71,8 @@ class FlashInferMLAAttnBackend(AttentionBackend):
|
||||
self.device = model_runner.device
|
||||
self.skip_prefill = skip_prefill
|
||||
|
||||
global_config.enable_flashinfer_mla = True
|
||||
|
||||
# Allocate buffers
|
||||
global global_workspace_buffer
|
||||
if global_workspace_buffer is None:
|
||||
|
||||
@@ -76,6 +76,7 @@ global_server_args_dict = {
|
||||
"device": ServerArgs.device,
|
||||
"speculative_accept_threshold_single": ServerArgs.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": ServerArgs.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": ServerArgs.enable_flashinfer_mla,
|
||||
"enable_flashmla": ServerArgs.enable_flashmla,
|
||||
"disable_radix_cache": ServerArgs.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": ServerArgs.flashinfer_mla_disable_ragged,
|
||||
@@ -1434,7 +1435,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
|
||||
|
||||
# Create seq_lens_cpu when needed
|
||||
if (
|
||||
global_server_args_dict["attention_backend"] == "flashinfer_mla"
|
||||
global_server_args_dict["enable_flashinfer_mla"]
|
||||
or global_server_args_dict["enable_flashmla"]
|
||||
or global_server_args_dict["attention_backend"] == "fa3"
|
||||
):
|
||||
|
||||
@@ -151,6 +151,7 @@ class ModelRunner:
|
||||
"device": server_args.device,
|
||||
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
||||
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
||||
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
||||
"enable_flashmla": server_args.enable_flashmla,
|
||||
"disable_radix_cache": server_args.disable_radix_cache,
|
||||
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
||||
@@ -222,14 +223,10 @@ class ModelRunner:
|
||||
):
|
||||
# TODO: add MLA optimization on CPU
|
||||
if server_args.device != "cpu":
|
||||
if (
|
||||
server_args.attention_backend == "flashinfer"
|
||||
or server_args.enable_flashinfer_mla
|
||||
):
|
||||
if server_args.enable_flashinfer_mla:
|
||||
logger.info(
|
||||
"MLA optimization is turned on. Use flashinfer backend."
|
||||
"MLA optimization is turned on. Use flashinfer mla backend."
|
||||
)
|
||||
# Here we use a special flashinfer_mla tag to differentiate it from normal flashinfer backend
|
||||
server_args.attention_backend = "flashinfer_mla"
|
||||
elif server_args.enable_flashmla:
|
||||
logger.info("MLA optimization is turned on. Use flashmla decode.")
|
||||
|
||||
@@ -684,6 +684,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.w_vc = None
|
||||
self.w_scale = None
|
||||
|
||||
self.enable_flashinfer_mla = global_server_args_dict["enable_flashinfer_mla"]
|
||||
self.flashinfer_mla_disable_ragged = global_server_args_dict[
|
||||
"flashinfer_mla_disable_ragged"
|
||||
]
|
||||
@@ -691,7 +692,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
self.rocm_fused_decode_mla = os.getenv("SGLANG_ROCM_FUSED_DECODE_MLA") == "1"
|
||||
|
||||
def no_absorb(self, forward_batch: ForwardBatch) -> bool:
|
||||
if self.attention_backend == "flashinfer_mla":
|
||||
if self.enable_flashinfer_mla:
|
||||
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
||||
return (
|
||||
not self.flashinfer_mla_disable_ragged
|
||||
|
||||
@@ -179,7 +179,7 @@ class ServerArgs:
|
||||
tool_call_parser: Optional[str] = None
|
||||
enable_hierarchical_cache: bool = False
|
||||
hicache_ratio: float = 2.0
|
||||
enable_flashinfer_mla: bool = False # TODO: remove this argument
|
||||
enable_flashinfer_mla: bool = False
|
||||
enable_flashmla: bool = False
|
||||
flashinfer_mla_disable_ragged: bool = False
|
||||
warmups: Optional[str] = None
|
||||
@@ -836,7 +836,7 @@ class ServerArgs:
|
||||
parser.add_argument(
|
||||
"--enable-flashinfer-mla",
|
||||
action="store_true",
|
||||
help="Enable FlashInfer MLA optimization. This argument will be deprecated soon! Please use '--attention-backend flashinfer' instead for switching on flashfiner mla!",
|
||||
help="Enable FlashInfer MLA optimization",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--enable-flashmla",
|
||||
|
||||
@@ -26,8 +26,7 @@ class TestFlashinferMLA(CustomTestCase):
|
||||
"--enable-torch-compile",
|
||||
"--cuda-graph-max-bs",
|
||||
"2",
|
||||
"--attention-backend",
|
||||
"flashinfer",
|
||||
"--enable-flashinfer-mla",
|
||||
]
|
||||
)
|
||||
cls.process = popen_launch_server(
|
||||
@@ -70,8 +69,8 @@ class TestFlashinferMLANoRagged(CustomTestCase):
|
||||
"--disable-cuda-graph",
|
||||
"--cuda-graph-max-bs",
|
||||
"4",
|
||||
"--attention-backend",
|
||||
"flashinfer",
|
||||
"--enable-flashinfer-mla",
|
||||
"--flashinfer-mla-disable-ragged",
|
||||
]
|
||||
)
|
||||
cls.process = popen_launch_server(
|
||||
@@ -126,8 +125,7 @@ class TestFlashinferMLAMTP(CustomTestCase):
|
||||
"1",
|
||||
"--speculative-num-draft-tokens",
|
||||
"4",
|
||||
"--attention-backend",
|
||||
"flashinfer",
|
||||
"--enable-flashinfer-mla",
|
||||
]
|
||||
)
|
||||
cls.process = popen_launch_server(
|
||||
|
||||
Reference in New Issue
Block a user