[FEAT] Refactor spec decode to support efficient padded speculation (#3528)

### What this PR does / why we need it?
1. Refactor the file `mtp_proposer.py`, splits torchair related codes
into `mtp_torchair_proposer.py`
2. According to https://github.com/vllm-project/vllm/pull/24539,
implements padded speculative decoding as described in
https://github.com/vllm-project/vllm/issues/21984.
### Does this PR introduce _any_ user-facing change?
User can use `disable_padded_drafter_batch` to disable/enable padded
speculation, default is `False`.
offline example:
```
speculative_config={"method": "deepseek_mtp", "num_speculative_tokens": 1, "disable_padded_drafter_batch": False}
```

### How was this patch tested?

- [x] egaer with pad/unpad:
- [x] aclgraph with pad/unpad
- [x] torchair with pad/unpad

performance test of deepseek-r1 with tp16、dp1
aclgraph with pad ITL: 168ms
aclgraph with unpad ITL: 169ms
original: 178ms


- vLLM version: v0.11.0rc3
- vLLM main:
83f478bb19

---------

Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
xuyexiong
2025-10-30 16:53:05 +08:00
committed by GitHub
parent 10772d94e3
commit eff3e5fc6f
7 changed files with 1203 additions and 440 deletions

View File

@@ -1,11 +1,15 @@
from __future__ import annotations
import os
import pytest
from vllm import SamplingParams
from vllm.config import CompilationConfig, CUDAGraphMode
from tests.e2e.conftest import VllmRunner
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@pytest.fixture
def sampling_config():
@@ -17,12 +21,12 @@ def model_name():
return "wemaster/deepseek_mtp_main_random_bf16"
def mtp_correctness(
sampling_config: SamplingParams,
model_name: str,
num_speculative_tokens: int,
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
):
def mtp_correctness(sampling_config: SamplingParams,
model_name: str,
num_speculative_tokens: int,
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
enforce_eager=False,
disable_padded_drafter_batch=True):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
@@ -37,7 +41,7 @@ def mtp_correctness(
tensor_parallel_size=1,
gpu_memory_utilization=0.7,
max_model_len=256,
enforce_eager=False) as ref_llm:
enforce_eager=enforce_eager) as ref_llm:
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
graph_mode_str = "PIECEWISE"
@@ -54,8 +58,9 @@ def mtp_correctness(
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": num_speculative_tokens,
"disable_padded_drafter_batch": disable_padded_drafter_batch,
},
enforce_eager=False,
enforce_eager=enforce_eager,
max_model_len=2000,
compilation_config=CompilationConfig(
cudagraph_mode=graph_mode_str),
@@ -82,6 +87,20 @@ def mtp_correctness(
del spec_llm
def test_mtp1_correctness_eager(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 1, enforce_eager=True)
def test_mtp2_correctness_eager(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 2, enforce_eager=True)
@pytest.mark.skip("TODO(cmq): Revert me when mtp aclgraph is fixed")
def test_mtp1_correctness_piecewise_graph(
sampling_config: SamplingParams,
@@ -110,3 +129,47 @@ def test_mtp2_correctness_full_graph(
model_name: str,
):
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
def test_mtp1_correctness_eager_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
1,
enforce_eager=True,
disable_padded_drafter_batch=False)
def test_mtp2_correctness_eager_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
2,
enforce_eager=True,
disable_padded_drafter_batch=False)
@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed")
def test_mtp1_correctness_piecewise_graph_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
1,
disable_padded_drafter_batch=False)
@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed")
def test_mtp2_correctness_piecewise_graph_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
2,
disable_padded_drafter_batch=False)