[feat][spec decode]Unified draft parallel (#6766)

### What this PR does / why we need it?
Implement a unified parallelized speculative decoding in VLLM
Ascend,which can simultaneously support parallel speculative inference
schemes such as Pard, P-Eagle, etc. refer to
https://github.com/vllm-project/vllm-ascend/pull/6565 and
https://github.com/vllm-project/vllm-ascend/pull/4078

### How was this patch tested?

run with parallel drafting script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811 \
--speculative-config '{"model": "/model/PARD-Llama-3.2-1B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'

base script:
export target=/model/Llama-3.1-8B-Instruct
export draft=/model/PARD-Llama-3.2-1B
export CUDA_VISIBLE_DEVICES=6
export ASCEND_RT_VISIBLE_DEVICES=6
vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811

benchmark script:
MAX_CONCURRENCY=1
NUM_PROMPTS=80
vllm bench serve --port 8811 \
    --temperature 0 \
    --model /model/Llama-3.1-8B-Instruct \
    --backend openai-chat \
    --endpoint /v1/chat/completions \
    --dataset-name hf \
    --dataset-path philschmid/mt-bench \
    --num-prompts ${NUM_PROMPTS} \
    --max-concurrency ${MAX_CONCURRENCY} \
    --seed 1234

test results :
base(without spec decode): TTFT 79.46ms TPOT 26.99ms
output_tokens_throughput 36.75 tok/s
this pr(with parallel drafting): TTFT 72.24ms TPOT 13.45ms
output_tokens_throughput 72.98 tok/s
per-position acceptance(from position 0 to 7):
79.48%、56.93%、40%、27.90%、19.79%、14.25%、10.57%、7.61%.

----------------------------------------------------------------------
run on qwen3 model script :
export target=/model/Qwen3-1.7B
export draft=/model/PARD-Qwen3-0.6B
export CUDA_VISIBLE_DEVICES=1
export ASCEND_RT_VISIBLE_DEVICES=1

vllm serve $target \
  --tensor-parallel-size 1 \
  --max-model-len 4096 \
  --no-enable-prefix-caching \
  --port 8811 \
--speculative-config '{"model": "/model/PARD-Qwen3-0.6B", "method":
"draft_model", "num_speculative_tokens": 8, "parallel_drafting": true}'

cc  @NickJudyHvv
- vLLM version: v0.15.0
- vLLM main:
9562912cea

---------

Signed-off-by: 01267596 <xiongkai123@cmbchina.com>
Signed-off-by: kx <1670186653@qq.com>
Signed-off-by: HF-001 <1670186653@qq.com>
Co-authored-by: 01267596 <xiongkai123@cmbchina.com>
This commit is contained in:
kx
2026-03-13 14:07:35 +08:00
committed by GitHub
parent 6ee7ffb98a
commit df1ee8070d
18 changed files with 1943 additions and 311 deletions

View File

@@ -4,7 +4,7 @@ from __future__ import annotations
import math
import os
import random
from typing import Any, Union
from typing import Any
import pytest
from transformers import AutoTokenizer
@@ -17,23 +17,32 @@ from tests.e2e.conftest import VllmRunner
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
MODELS = {
#"eagle": {
# "eagle": {
# "main": "LLM-Research/Meta-Llama-3.1-8B-Instruct",
# "spec": "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B",
#},
# },
"eagle3": {
"main": "Qwen/Qwen3-8B",
"spec": "RedHatAI/Qwen3-8B-speculator.eagle3",
},
}
DRAFT_PARALLEL_MODELS = {
"draft_parallel": {
"main": "LLM-Research/Meta-Llama-3.1-8B-Instruct",
"spec": "amd/PARD-Llama-3.2-1B",
},
}
# NOTE: golden may change (eagle_proposer only runs in eager mode currently),
# thus please update it if ci fails but you have better acceptance
BASELINES = {
"eagle": [0.74, 0.44, 0.29],
"eagle3": [0.68, 0.40, 0.18],
"draft_parallel": [0.83, 0.50, 0.33, 0.17, 0.17, 0.17, 0.17, 0.00],
}
@pytest.fixture
def test_prompts():
prompt_types = ["repeat", "sentence"]
@@ -89,6 +98,7 @@ def eagle3_model_name():
def vl_model_name():
return "Qwen/Qwen3-VL-8B-Instruct"
def vl_eagle3_model_name():
return "MNN/Qwen3-VL-8B-Instruct-Eagle3"
@@ -98,28 +108,28 @@ def test_ngram_correctness(
sampling_config: SamplingParams,
model_name: str,
):
'''
"""
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
"""
with VllmRunner(
model_name,
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
model_name,
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
) as ref_llm:
ref_outputs = ref_llm.model.chat(test_prompts, sampling_config)
with VllmRunner(
model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0
@@ -142,27 +152,27 @@ def test_qwen3_vl_eagle_correctness(
sampling_config: SamplingParams,
vl_model_name: str,
):
'''
"""
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
'''
"""
with VllmRunner(
vl_model_name,
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
vl_model_name,
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
) as ref_llm:
ref_outputs = ref_llm.model.chat(test_prompts, sampling_config)
spec_model_name = vl_eagle3_model_name()
with VllmRunner(
vl_model_name,
speculative_config={
"method": "eagle3",
"model": spec_model_name,
"num_speculative_tokens": 2,
},
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
vl_model_name,
speculative_config={
"method": "eagle3",
"model": spec_model_name,
"num_speculative_tokens": 2,
},
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0
@@ -179,27 +189,28 @@ def test_qwen3_vl_eagle_correctness(
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
def test_suffix_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
'''
"""
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
with VllmRunner(model_name,
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8]) as ref_llm:
"""
with VllmRunner(model_name, max_model_len=1024, cudagraph_capture_sizes=[1, 2, 4, 8]) as ref_llm:
ref_outputs = ref_llm.model.chat(test_prompts, sampling_config)
with VllmRunner(model_name,
speculative_config={
"method": "suffix",
"num_speculative_tokens": 8,
},
cudagraph_capture_sizes=[1, 2, 4, 8],
max_model_len=1024) as runner:
with VllmRunner(
model_name,
speculative_config={
"method": "suffix",
"num_speculative_tokens": 8,
},
cudagraph_capture_sizes=[1, 2, 4, 8],
max_model_len=1024,
) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0
misses = 0
@@ -221,22 +232,24 @@ def test_suffix_acceptance(
sampling_config: SamplingParams,
model_name: str,
):
'''
"""
Check that suffix decoding caching takes effect and improves acceptance
lengths and acceptance rates over multiple runs of the same prompts.
'''
"""
num_draft = []
num_accept = []
with VllmRunner(model_name,
speculative_config={
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
"suffix_decoding_max_cached_requests": 1000,
"num_speculative_tokens": 10,
},
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
disable_log_stats=False) as runner:
with VllmRunner(
model_name,
speculative_config={
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
"suffix_decoding_max_cached_requests": 1000,
"num_speculative_tokens": 10,
},
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
disable_log_stats=False,
) as runner:
for i in range(10):
runner.model.chat(test_prompts[i], sampling_config)
metrics = runner.model.get_metrics()
@@ -271,13 +284,10 @@ def test_suffix_acceptance(
def test_eagle_logprobs(
model_name: str,
use_eagle3: bool,
draft_tensor_parallel_size: Union[None, int],
draft_tensor_parallel_size: None | int,
):
prompt = {"role": "user", "content": "Hello world " * 10}
sampling_params = SamplingParams(temperature=0,
logprobs=1,
max_tokens=10,
ignore_eos=False)
sampling_params = SamplingParams(temperature=0, logprobs=1, max_tokens=10, ignore_eos=False)
ref_llm = LLM(model=model_name, max_model_len=2048)
ref_outputs = ref_llm.chat([prompt], sampling_params)
@@ -290,19 +300,19 @@ def test_eagle_logprobs(
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
with VllmRunner(
model_name,
max_num_seqs=1,
max_num_batched_tokens=2048,
gpu_memory_utilization=0.6,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 2,
"draft_tensor_parallel_size": draft_tensor_parallel_size,
"max_model_len": 128,
},
max_model_len=128,
cudagraph_capture_sizes=[1, 2, 4, 8],
model_name,
max_num_seqs=1,
max_num_batched_tokens=2048,
gpu_memory_utilization=0.6,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 2,
"draft_tensor_parallel_size": draft_tensor_parallel_size,
"max_model_len": 128,
},
max_model_len=128,
cudagraph_capture_sizes=[1, 2, 4, 8],
) as runner:
spec_outputs = runner.model.chat([prompt], sampling_params)
@@ -314,10 +324,7 @@ def test_eagle_logprobs(
spec_logprobs.append(logprobs[token_id])
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose(ref_logprob.logprob,
spec_logprob.logprob,
rel_tol=5e-2,
abs_tol=1e-1)
assert math.isclose(ref_logprob.logprob, spec_logprob.logprob, rel_tol=5e-2, abs_tol=1e-1)
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token
@@ -330,7 +337,7 @@ def test_eagle_logprobs(
def test_llama_qwen_eagle_acceptance(
method: str,
num_speculative_tokens: int,
draft_tensor_parallel_size: Union[None, int],
draft_tensor_parallel_size: None | int,
disable_padded_drafter_batch: bool,
async_scheduling: bool,
):
@@ -375,7 +382,8 @@ def test_llama_qwen_eagle_acceptance(
[prompt],
tokenize=False,
add_generation_prompt=True,
) for prompt in prompts
)
for prompt in prompts
]
speculative_config = {
@@ -389,16 +397,16 @@ def test_llama_qwen_eagle_acceptance(
compilation_config = CompilationConfig(cudagraph_capture_sizes=[12])
with VllmRunner(
main_model_name,
max_model_len=2048,
disable_log_stats=False,
tensor_parallel_size=1,
max_num_seqs=256,
distributed_executor_backend="mp",
gpu_memory_utilization=0.7,
speculative_config=speculative_config,
compilation_config=compilation_config,
async_scheduling=async_scheduling,
main_model_name,
max_model_len=2048,
disable_log_stats=False,
tensor_parallel_size=1,
max_num_seqs=256,
distributed_executor_backend="mp",
gpu_memory_utilization=0.7,
speculative_config=speculative_config,
compilation_config=compilation_config,
async_scheduling=async_scheduling,
) as llm:
outputs = llm.model.generate(prompts, sampling_params)
metrics = llm.model.get_metrics()
@@ -419,10 +427,7 @@ def test_llama_qwen_eagle_acceptance(
for pos in range(len(metric.values)):
num_accepted_tokens_per_pos[pos] += metric.values[pos]
acceptance_per_pos = [
num_accepted_tokens / num_drafts
for num_accepted_tokens in num_accepted_tokens_per_pos
]
acceptance_per_pos = [num_accepted_tokens / num_drafts for num_accepted_tokens in num_accepted_tokens_per_pos]
if method == "eagle":
golden = [0.7313432835820896, 0.373134328358209, 0.19402985074626866]
else:
@@ -434,3 +439,98 @@ def test_llama_qwen_eagle_acceptance(
print(f"golden: {golden}")
assert match
@pytest.mark.parametrize("method", DRAFT_PARALLEL_MODELS.keys())
@pytest.mark.parametrize("num_speculative_tokens", [8])
@pytest.mark.parametrize("draft_tensor_parallel_size", [None, 1])
def test_parallel_drafting_acceptance(
method: str,
num_speculative_tokens: int,
draft_tensor_parallel_size: None | int,
):
"""
Test acceptance rate for parallel drafting speculative decoding
using a smaller draft model with parallel_drafting enabled.
"""
main_model_name = DRAFT_PARALLEL_MODELS[method]["main"]
spec_model_name = DRAFT_PARALLEL_MODELS[method]["spec"]
tokenizer = AutoTokenizer.from_pretrained(
main_model_name,
trust_remote_code=True,
)
sampling_params = SamplingParams(
temperature=0,
ignore_eos=False,
max_tokens=256,
)
prompts = [
{
"role": "user",
"content": "Hello, your name is",
},
]
prompts = [
tokenizer.apply_chat_template(
[prompt],
tokenize=False,
add_generation_prompt=True,
)
for prompt in prompts
]
speculative_config = {
"method": "draft_model",
"model": spec_model_name,
"num_speculative_tokens": num_speculative_tokens,
"draft_tensor_parallel_size": draft_tensor_parallel_size,
"parallel_drafting": True,
}
compilation_config = CompilationConfig(cudagraph_capture_sizes=[12])
with VllmRunner(
main_model_name,
max_model_len=4096,
disable_log_stats=False,
tensor_parallel_size=1,
max_num_seqs=256,
distributed_executor_backend="mp",
gpu_memory_utilization=0.8,
speculative_config=speculative_config,
compilation_config=compilation_config,
enable_prefix_caching=False,
) as llm:
outputs = llm.model.generate(prompts, sampling_params)
metrics = llm.model.get_metrics()
for output in outputs:
prompt = output.prompt
generated_text = output.outputs[0].text
output_tokens = output.outputs[0].token_ids
print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
print(f"Output tokens: {output_tokens}")
num_drafts = 0
num_accepted_tokens_per_pos = [0] * num_speculative_tokens
for metric in metrics:
if metric.name == "vllm:spec_decode_num_drafts":
assert isinstance(metric, Counter)
num_drafts += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
assert isinstance(metric, Vector)
for pos in range(len(metric.values)):
num_accepted_tokens_per_pos[pos] += metric.values[pos]
acceptance_per_pos = [num_accepted_tokens / num_drafts for num_accepted_tokens in num_accepted_tokens_per_pos]
golden = BASELINES[method]
match = all(abs(a - b) < 0.1 for a, b in zip(acceptance_per_pos, golden))
if not match:
print(f"acceptance_per_pos: {acceptance_per_pos}")
print(f"golden: {golden}")
assert match