[Feat] Support FlashMLA backend with MTP and FP8 KV cache (#6109)
Co-authored-by: Yingyi <yingyihuang2000@outlook.com> Co-authored-by: neiltian <neiltian@tencent.com> Co-authored-by: lukec <118525388+sleepcoo@users.noreply.github.com> Co-authored-by: kexueyu <kexueyu@tencent.com> Co-authored-by: vincentmeng <vincentmeng@tencent.com> Co-authored-by: pengmeng <pengmeng@tencent.com>
This commit is contained in:
@@ -8,6 +8,7 @@
|
||||
| **FA3** | ✅ | ✅ | ✅ | ✅ | ✅ |
|
||||
| **Triton** | ❌ | ✅ | ✅ | ❌ | ❌ |
|
||||
| **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
|
||||
|
||||
## User guide
|
||||
@@ -30,10 +31,15 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-V3 --trust-r
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend triton
|
||||
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-V3 --attention-backend triton --trust-remote-code
|
||||
|
||||
```
|
||||
|
||||
- Torch Native
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend torch_native
|
||||
```
|
||||
|
||||
- FlashMLA
|
||||
```bash
|
||||
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --trust-remote-code
|
||||
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code
|
||||
```
|
||||
|
||||
@@ -158,7 +158,7 @@ Add arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculati
|
||||
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 --trust-remote-code --tp 8
|
||||
```
|
||||
- The best configuration for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes.
|
||||
- FlashAttention3 and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the FlashMLA backend and CutlassMLA backend is still under development.
|
||||
- FlashAttention3 FlashMLA and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA backend is still under development.
|
||||
- To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)):
|
||||
- Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value.
|
||||
- Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it.
|
||||
|
||||
Reference in New Issue
Block a user