Signed-off-by: Faraz Khoubsirat <58580514+farazkh80@users.noreply.github.com>
This commit is contained in:
@@ -9,8 +9,12 @@
|
||||
| **Triton** | ❌ | ✅ | ✅ | ✅ | ❌ |
|
||||
| **Torch Native** | ❌ | ❌ | ❌ | ❌ | ❌ |
|
||||
| **FlashMLA** | ✅ | ✅ | ✅ | ❌ | ❌ |
|
||||
| **TRTLLM MLA** | ✅ | ❌ | ✅ | ✅ | ❌ |
|
||||
| **Ascend** | ✅ | ❌ | ❌ | ❌ | ❌ |
|
||||
|
||||
**Notes:**
|
||||
- TRTLLM MLA only implements decode operations. For prefill operations (including multimodal inputs), it falls back to FlashInfer MLA backend.
|
||||
|
||||
Note: Every kernel backend is compatible with a page size > 1 by specifying an argument such as `--page-size 16`.
|
||||
This is because a page size of 16 can be converted to a page size of 1 in the kernel backend.
|
||||
The "❌" and "✅" symbols in the table above under "Page Size > 1" indicate whether the kernel actually operates with a page size greater than 1, rather than treating a page size of 16 as a page size of 1.
|
||||
@@ -48,6 +52,11 @@ python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attenti
|
||||
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend flashmla --kv-cache-dtype fp8_e4m3 --trust-remote-code
|
||||
```
|
||||
|
||||
- TRTLLM MLA (Optimized for Blackwell Architecture, e.g., B200)
|
||||
```bash
|
||||
python3 -m sglang.launch_server --tp 8 --model deepseek-ai/DeepSeek-R1 --attention-backend trtllm_mla --trust-remote-code
|
||||
```
|
||||
|
||||
- Ascend
|
||||
```bash
|
||||
python3 -m sglang.launch_server --model meta-llama/Meta-Llama-3.1-8B-Instruct --attention-backend ascend
|
||||
|
||||
@@ -90,7 +90,7 @@ Please refer to [the example](https://github.com/sgl-project/sglang/tree/main/be
|
||||
|
||||
- **Weight Absorption**: By applying the associative law of matrix multiplication to reorder computation steps, this method balances computation and memory access and improves efficiency in the decoding phase.
|
||||
|
||||
- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), [FlashMLA](https://github.com/deepseek-ai/FlashMLA), [CutlassMLA](https://github.com/sgl-project/sglang/pull/5390), and [Triton](https://github.com/triton-lang/triton) backends. The default FA3 provides good performance across wide workloads.
|
||||
- **MLA Attention Backends**: Currently SGLang supports different optimized MLA attention backends, including [FlashAttention3](https://github.com/Dao-AILab/flash-attention), [Flashinfer](https://docs.flashinfer.ai/api/mla.html), [FlashMLA](https://github.com/deepseek-ai/FlashMLA), [CutlassMLA](https://github.com/sgl-project/sglang/pull/5390), **TRTLLM MLA** (optimized for Blackwell architecture), and [Triton](https://github.com/triton-lang/triton) backends. The default FA3 provides good performance across wide workloads.
|
||||
|
||||
- **FP8 Quantization**: W8A8 FP8 and KV Cache FP8 quantization enables efficient FP8 inference. Additionally, we have implemented Batched Matrix Multiplication (BMM) operator to facilitate FP8 inference in MLA with weight absorption.
|
||||
|
||||
@@ -104,7 +104,7 @@ Overall, with these optimizations, we have achieved up to **7x** acceleration in
|
||||
<img src="https://lmsys.org/images/blog/sglang_v0_3/deepseek_mla.svg" alt="Multi-head Latent Attention for DeepSeek Series Models">
|
||||
</p>
|
||||
|
||||
**Usage**: MLA optimization is enabled by default.
|
||||
**Usage**: MLA optimization is enabled by default. For MLA models on Blackwell architecture (e.g., B200), the default backend is FlashInfer. To use the optimized TRTLLM MLA backend for decode operations, explicitly specify `--attention-backend trtllm_mla`. Note that TRTLLM MLA only optimizes decode operations - prefill operations (including multimodal inputs) will fall back to FlashInfer MLA.
|
||||
|
||||
**Reference**: Check [Blog](https://lmsys.org/blog/2024-09-04-sglang-v0-3/#deepseek-multi-head-latent-attention-mla-throughput-optimizations) and [Slides](https://github.com/sgl-project/sgl-learning-materials/blob/main/slides/lmsys_1st_meetup_deepseek_mla.pdf) for more details.
|
||||
|
||||
@@ -161,7 +161,7 @@ Add arguments `--speculative-algorithm`, `--speculative-num-steps`, `--speculati
|
||||
python3 -m sglang.launch_server --model-path deepseek-ai/DeepSeek-V3-0324 --speculative-algorithm EAGLE --speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2 --trust-remote-code --tp 8
|
||||
```
|
||||
- The best configuration for `--speculative-num-steps`, `--speculative-eagle-topk` and `--speculative-num-draft-tokens` can be searched with [bench_speculative.py](https://github.com/sgl-project/sglang/blob/main/scripts/playground/bench_speculative.py) script for given batch size. The minimum configuration is `--speculative-num-steps 1 --speculative-eagle-topk 1 --speculative-num-draft-tokens 2`, which can achieve speedup for larger batch sizes.
|
||||
- FlashAttention3 FlashMLA and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA backend is still under development.
|
||||
- FlashAttention3, FlashMLA, and Triton backend fully supports MTP usage. For FlashInfer backend (`--attention-backend flashinfer`) with speculative decoding,`--speculative-eagle-topk` parameter should be set to `1`. MTP support for the CutlassMLA and TRTLLM MLA backends are still under development.
|
||||
- To enable DeepSeek MTP for large batch sizes (>32), there are some parameters should be changed (Reference [this discussion](https://github.com/sgl-project/sglang/issues/4543#issuecomment-2737413756)):
|
||||
- Adjust `--max-running-requests` to a larger number. The default value is `32` for MTP. For larger batch sizes, you should increase this value beyond the default value.
|
||||
- Set `--cuda-graph-bs`. It's a list of batch sizes for cuda graph capture. The default captured batch sizes for speculative decoding is set [here](https://github.com/sgl-project/sglang/blob/49420741746c8f3e80e0eb17e7d012bfaf25793a/python/sglang/srt/model_executor/cuda_graph_runner.py#L126). You can include more batch sizes into it.
|
||||
|
||||
Reference in New Issue
Block a user