[main][test] Refactor the mtp and eagle test case (#5326)

### What this PR does / why we need it?
1. Refactor the current test with mtp and eagle cases
2. Add new necessary cases with mtp and eagle

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ut

- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2025-12-31 09:22:58 +08:00
committed by GitHub
parent bdc721d35a
commit 46862ce1af
6 changed files with 362 additions and 349 deletions

View File

@@ -0,0 +1,206 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2025 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
"""Compare the short outputs of HF and vLLM when using greedy sampling.
"""
from __future__ import annotations
import os
import pytest
from vllm import SamplingParams
from vllm.config import CompilationConfig
from tests.e2e.conftest import VllmRunner, cleanup_dist_env_and_memory
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
MODELS = ["wemaster/deepseek_mtp_main_random_bf16"]
MODELS_EAGLE = [
"vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B",
"RedHatAI/Qwen3-8B-speculator.eagle3"
]
MODELS_MAIN = ["LLM-Research/Meta-Llama-3.1-8B-Instruct", "Qwen/Qwen3-8B"]
VALID_COMBINATIONS = {("eagle", "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B",
"LLM-Research/Meta-Llama-3.1-8B-Instruct"),
("eagle3", "RedHatAI/Qwen3-8B-speculator.eagle3",
"Qwen/Qwen3-8B")}
@pytest.mark.parametrize("model_name", MODELS)
@pytest.mark.parametrize("num_speculative_tokens", [1, 2, 3])
@pytest.mark.parametrize("cudagraph_mode", ["PIECEWISE", "FULL_DECODE_ONLY"])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
def test_deepseek_mtp_correctness(model_name: str, num_speculative_tokens: int,
cudagraph_mode: str,
disable_padded_drafter_batch: bool):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using mtp speculative decoding.
'''
with VllmRunner(model_name,
tensor_parallel_size=1,
max_num_seqs=256,
gpu_memory_utilization=0.7,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method":
"mtp",
"num_speculative_tokens":
num_speculative_tokens,
"disable_padded_drafter_batch":
disable_padded_drafter_batch,
},
max_model_len=2000,
compilation_config=CompilationConfig(
cudagraph_mode=cudagraph_mode,
cudagraph_capture_sizes=[20],
)) as spec_llm:
sampling_config = SamplingParams(temperature=0,
max_tokens=256,
ignore_eos=False)
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
with VllmRunner(model_name,
tensor_parallel_size=1,
gpu_memory_utilization=0.7,
max_model_len=256,
compilation_config=CompilationConfig(
cudagraph_mode=cudagraph_mode,
cudagraph_capture_sizes=[20],
)) as ref_llm:
sampling_config = SamplingParams(temperature=0,
max_tokens=256,
ignore_eos=False)
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
ref_token_ids = ref_output[0][0]
spec_token_ids = spec_output[0][0]
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output[1][0]}")
print(f"spec_output: {spec_output[1][0]}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
cleanup_dist_env_and_memory()
del spec_llm
@pytest.mark.parametrize("model_name", MODELS_EAGLE)
@pytest.mark.parametrize("model_name_main", MODELS_MAIN)
@pytest.mark.parametrize("num_speculative_tokens", [1, 2])
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
@pytest.mark.parametrize("async_scheduling", [True, False])
def test_llama_qwen3_eagle_correctness(model_name: str, model_name_main: str,
num_speculative_tokens: int,
method: str,
disable_padded_drafter_batch: bool,
async_scheduling: bool):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
if (method, model_name, model_name_main) not in VALID_COMBINATIONS or \
(async_scheduling and disable_padded_drafter_batch):
pytest.skip(
f"Invalid combination: method={method}, model_name={model_name}, model_name_main={model_name_main}, or case not support yet"
)
sampling_params = SamplingParams(
max_tokens=300,
temperature=0.0,
ignore_eos=False,
)
with VllmRunner(model_name_main,
tensor_parallel_size=1,
pipeline_parallel_size=1,
data_parallel_size=1,
disable_log_stats=False,
max_model_len=4096,
seed=1024,
async_scheduling=async_scheduling,
speculative_config={
"disable_padded_drafter_batch":
disable_padded_drafter_batch,
"method": method,
"model": model_name,
"num_speculative_tokens": num_speculative_tokens,
"max_model_len": 128,
"draft_vocab_size": 128256,
},
compilation_config=CompilationConfig(
cudagraph_mode="FULL_DECODE_ONLY",
cudagraph_capture_sizes=[12])) as llm:
spec_outputs = llm.generate(example_prompts, sampling_params)
cleanup_dist_env_and_memory()
del llm
with VllmRunner(model_name_main,
tensor_parallel_size=1,
pipeline_parallel_size=1,
data_parallel_size=1,
disable_log_stats=False,
max_model_len=4096,
seed=1024,
async_scheduling=async_scheduling,
compilation_config=CompilationConfig(
cudagraph_mode="FULL_DECODE_ONLY",
cudagraph_capture_sizes=[12])) as llm:
ref_outputs = llm.generate(example_prompts, sampling_params)
cleanup_dist_env_and_memory()
del llm
matches = 0
misses = 0
threshold = 0.66
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
ref_token_ids = ref_output[0][0]
spec_token_ids = spec_output[0][0]
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output[1][0]}")
print(f"spec_output: {spec_output[1][0]}")
# Heuristic: expect at least 66.6% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(threshold * len(ref_outputs))
cleanup_dist_env_and_memory()

View File

@@ -0,0 +1,373 @@
# SPDX-License-Identifier: Apache-2.0
from __future__ import annotations
import math
import os
import random
from typing import Any
import pytest
from transformers import AutoTokenizer
from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.v1.metrics.reader import Counter, Vector
from tests.e2e.conftest import VllmRunner
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
MODELS = {
"eagle": {
"main": "LLM-Research/Meta-Llama-3.1-8B-Instruct",
"spec": "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B",
},
"eagle3": {
"main": "Qwen/Qwen3-8B",
"spec": "RedHatAI/Qwen3-8B-speculator.eagle3",
},
}
# NOTE: golden may change (eagle_proposer only runs in eager mode currently),
# thus please update it if ci fails but you have better acceptance
BASELINES = {
"eagle": [0.74, 0.44, 0.29],
"eagle3": [0.68, 0.40, 0.18],
}
@pytest.fixture
def test_prompts():
prompt_types = ["repeat", "sentence"]
num_prompts = 100
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
give no other output than the word at least ten times in a row,
in lowercase with spaces between each word and without quotes.
"""
elif kind == "sentence":
prompt = f"""
please give a ten-word sentence that
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}])
return prompts
@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0, max_tokens=10, ignore_eos=False)
@pytest.fixture
def model_name():
return "LLM-Research/Meta-Llama-3.1-8B-Instruct"
def eagle_model_name():
return "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B"
def eagle3_model_name():
return "vllm-ascend/EAGLE3-LLaMA3.1-Instruct-8B"
def test_ngram_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
with VllmRunner(
model_name,
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
) as ref_llm:
ref_outputs = ref_llm.model.chat(test_prompts, sampling_config)
with VllmRunner(
model_name,
speculative_config={
"method": "ngram",
"prompt_lookup_max": 5,
"prompt_lookup_min": 3,
"num_speculative_tokens": 3,
},
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
def test_suffix_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using ngram speculative decoding.
'''
with VllmRunner(model_name,
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8]) as ref_llm:
ref_outputs = ref_llm.model.chat(test_prompts, sampling_config)
with VllmRunner(model_name,
speculative_config={
"method": "suffix",
"num_speculative_tokens": 8,
},
cudagraph_capture_sizes=[1, 2, 4, 8],
max_model_len=1024) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
def test_suffix_acceptance(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
'''
Check that suffix decoding caching takes effect and improves acceptance
lengths and acceptance rates over multiple runs of the same prompts.
'''
num_draft = []
num_accept = []
with VllmRunner(model_name,
speculative_config={
"method": "suffix",
"suffix_decoding_max_spec_factor": 2.0,
"suffix_decoding_max_cached_requests": 1000,
"num_speculative_tokens": 10,
},
max_model_len=1024,
cudagraph_capture_sizes=[1, 2, 4, 8],
disable_log_stats=False) as runner:
for i in range(10):
runner.model.chat(test_prompts[i], sampling_config)
metrics = runner.model.get_metrics()
for metric in metrics:
print(metric)
if metric.name == "vllm:spec_decode_num_draft_tokens":
num_draft.append(metric.value)
if metric.name == "vllm:spec_decode_num_accepted_tokens":
num_accept.append(metric.value)
# Calculate the acceptance rates for the first and last runs.
first_accept_tokens = num_accept[0]
first_draft_tokens = num_draft[0]
first_accept_rate = first_accept_tokens / first_draft_tokens
# Take the diff since the stats are cumulative.
last_accept_tokens = num_accept[-1] - num_accept[-2]
last_draft_tokens = num_draft[-1] - num_draft[-2]
last_accept_rate = last_accept_tokens / last_draft_tokens
# Expect the acceptance length to improve.
assert first_accept_tokens < last_accept_tokens
# Expect the acceptance rate to improve.
assert first_accept_rate < last_accept_rate
# Heuristic: expect at least 80% acceptance rate at the end.
assert last_accept_rate > 0.60
@pytest.mark.parametrize("use_eagle3", [True], ids=["eagle3"])
def test_eagle_logprobs(
model_name: str,
use_eagle3: bool,
):
prompt = {"role": "user", "content": "Hello world " * 10}
sampling_params = SamplingParams(temperature=0,
logprobs=1,
max_tokens=10,
ignore_eos=False)
ref_llm = LLM(model=model_name, max_model_len=2048)
ref_outputs = ref_llm.chat([prompt], sampling_params)
ref_logprobs = []
for output in ref_outputs[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
ref_logprobs.append(logprobs[token_id])
del ref_llm
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
with VllmRunner(
model_name,
max_num_seqs=1,
max_num_batched_tokens=2048,
gpu_memory_utilization=0.6,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 2,
"max_model_len": 128,
},
max_model_len=128,
cudagraph_capture_sizes=[1, 2, 4, 8],
) as runner:
spec_outputs = runner.model.chat([prompt], sampling_params)
# Collect logprobs outputs from spec decode LLM.
spec_logprobs = []
for output in spec_outputs[0].outputs:
for logprobs in output.logprobs:
for token_id in logprobs:
spec_logprobs.append(logprobs[token_id])
for ref_logprob, spec_logprob in zip(ref_logprobs, spec_logprobs):
assert math.isclose(ref_logprob.logprob,
spec_logprob.logprob,
rel_tol=5e-2,
abs_tol=1e-1)
assert ref_logprob.rank == spec_logprob.rank
assert ref_logprob.decoded_token == spec_logprob.decoded_token
@pytest.mark.parametrize("method", MODELS.keys())
@pytest.mark.parametrize("num_speculative_tokens", [3])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
@pytest.mark.parametrize("async_scheduling", [True, False])
def test_llama_qwen_eagle_acceptance(
method: str,
num_speculative_tokens: int,
disable_padded_drafter_batch: bool,
async_scheduling: bool,
):
if disable_padded_drafter_batch and async_scheduling:
pytest.skip(
"skip disable_padded_drafter_batch=True and async_scheduling=True",
)
main_model_name = MODELS[method]["main"]
spec_model_name = MODELS[method]["spec"]
tokenizer = AutoTokenizer.from_pretrained(
main_model_name,
trust_remote_code=True,
)
sampling_params = SamplingParams(
temperature=0,
ignore_eos=False,
max_tokens=256,
)
prompts = [
{
"role": "user",
"content": "Hello, my name is",
},
{
"role": "user",
"content": "The president of the United States is",
},
{
"role": "user",
"content": "The capital of France is",
},
{
"role": "user",
"content": "The future of AI is",
},
]
prompts = [
tokenizer.apply_chat_template(
[prompt],
tokenize=False,
add_generation_prompt=True,
) for prompt in prompts
]
speculative_config = {
"method": method,
"num_speculative_tokens": num_speculative_tokens,
"disable_padded_drafter_batch": disable_padded_drafter_batch,
"model": spec_model_name,
}
compilation_config = CompilationConfig(cudagraph_capture_sizes=[12])
with VllmRunner(
main_model_name,
max_model_len=2048,
disable_log_stats=False,
tensor_parallel_size=1,
max_num_seqs=256,
distributed_executor_backend="mp",
gpu_memory_utilization=0.7,
speculative_config=speculative_config,
compilation_config=compilation_config,
async_scheduling=async_scheduling,
) as llm:
_ = llm.generate(prompts, sampling_params)
metrics = llm.model.get_metrics()
num_drafts = 0
num_accepted_tokens_per_pos = [0] * num_speculative_tokens
for metric in metrics:
if metric.name == "vllm:spec_decode_num_drafts":
assert isinstance(metric, Counter)
num_drafts += metric.value
elif metric.name == "vllm:spec_decode_num_accepted_tokens_per_pos":
assert isinstance(metric, Vector)
for pos in range(len(metric.values)):
num_accepted_tokens_per_pos[pos] += metric.values[pos]
acceptance_per_pos = [
num_accepted_tokens / num_drafts
for num_accepted_tokens in num_accepted_tokens_per_pos
]
golden = BASELINES[method]
match = all(abs(a - b) < 0.06 for a, b in zip(acceptance_per_pos, golden))
if not match:
print(f"acceptance_per_pos: {acceptance_per_pos}")
print(f"golden: {golden}")
assert match