[main][test] Refactor the mtp and eagle test case (#5326)

### What this PR does / why we need it?
1. Refactor the current test with mtp and eagle cases
2. Add new necessary cases with mtp and eagle

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ut

- vLLM version: release/v0.13.0
- vLLM main:
5fbfa8d9ef

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2025-12-31 09:22:58 +08:00
committed by GitHub
parent bdc721d35a
commit 46862ce1af
6 changed files with 362 additions and 349 deletions

View File

@@ -0,0 +1,206 @@
#
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
# Copyright 2025 The vLLM team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is a part of the vllm-ascend project.
# Adapted from vllm/tests/basic_correctness/test_basic_correctness.py
#
"""Compare the short outputs of HF and vLLM when using greedy sampling.
"""
from __future__ import annotations
import os
import pytest
from vllm import SamplingParams
from vllm.config import CompilationConfig
from tests.e2e.conftest import VllmRunner, cleanup_dist_env_and_memory
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
MODELS = ["wemaster/deepseek_mtp_main_random_bf16"]
MODELS_EAGLE = [
"vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B",
"RedHatAI/Qwen3-8B-speculator.eagle3"
]
MODELS_MAIN = ["LLM-Research/Meta-Llama-3.1-8B-Instruct", "Qwen/Qwen3-8B"]
VALID_COMBINATIONS = {("eagle", "vllm-ascend/EAGLE-LLaMA3.1-Instruct-8B",
"LLM-Research/Meta-Llama-3.1-8B-Instruct"),
("eagle3", "RedHatAI/Qwen3-8B-speculator.eagle3",
"Qwen/Qwen3-8B")}
@pytest.mark.parametrize("model_name", MODELS)
@pytest.mark.parametrize("num_speculative_tokens", [1, 2, 3])
@pytest.mark.parametrize("cudagraph_mode", ["PIECEWISE", "FULL_DECODE_ONLY"])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
def test_deepseek_mtp_correctness(model_name: str, num_speculative_tokens: int,
cudagraph_mode: str,
disable_padded_drafter_batch: bool):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using mtp speculative decoding.
'''
with VllmRunner(model_name,
tensor_parallel_size=1,
max_num_seqs=256,
gpu_memory_utilization=0.7,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method":
"mtp",
"num_speculative_tokens":
num_speculative_tokens,
"disable_padded_drafter_batch":
disable_padded_drafter_batch,
},
max_model_len=2000,
compilation_config=CompilationConfig(
cudagraph_mode=cudagraph_mode,
cudagraph_capture_sizes=[20],
)) as spec_llm:
sampling_config = SamplingParams(temperature=0,
max_tokens=256,
ignore_eos=False)
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
with VllmRunner(model_name,
tensor_parallel_size=1,
gpu_memory_utilization=0.7,
max_model_len=256,
compilation_config=CompilationConfig(
cudagraph_mode=cudagraph_mode,
cudagraph_capture_sizes=[20],
)) as ref_llm:
sampling_config = SamplingParams(temperature=0,
max_tokens=256,
ignore_eos=False)
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
ref_token_ids = ref_output[0][0]
spec_token_ids = spec_output[0][0]
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output[1][0]}")
print(f"spec_output: {spec_output[1][0]}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
cleanup_dist_env_and_memory()
del spec_llm
@pytest.mark.parametrize("model_name", MODELS_EAGLE)
@pytest.mark.parametrize("model_name_main", MODELS_MAIN)
@pytest.mark.parametrize("num_speculative_tokens", [1, 2])
@pytest.mark.parametrize("method", ["eagle", "eagle3"])
@pytest.mark.parametrize("disable_padded_drafter_batch", [True, False])
@pytest.mark.parametrize("async_scheduling", [True, False])
def test_llama_qwen3_eagle_correctness(model_name: str, model_name_main: str,
num_speculative_tokens: int,
method: str,
disable_padded_drafter_batch: bool,
async_scheduling: bool):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
if (method, model_name, model_name_main) not in VALID_COMBINATIONS or \
(async_scheduling and disable_padded_drafter_batch):
pytest.skip(
f"Invalid combination: method={method}, model_name={model_name}, model_name_main={model_name_main}, or case not support yet"
)
sampling_params = SamplingParams(
max_tokens=300,
temperature=0.0,
ignore_eos=False,
)
with VllmRunner(model_name_main,
tensor_parallel_size=1,
pipeline_parallel_size=1,
data_parallel_size=1,
disable_log_stats=False,
max_model_len=4096,
seed=1024,
async_scheduling=async_scheduling,
speculative_config={
"disable_padded_drafter_batch":
disable_padded_drafter_batch,
"method": method,
"model": model_name,
"num_speculative_tokens": num_speculative_tokens,
"max_model_len": 128,
"draft_vocab_size": 128256,
},
compilation_config=CompilationConfig(
cudagraph_mode="FULL_DECODE_ONLY",
cudagraph_capture_sizes=[12])) as llm:
spec_outputs = llm.generate(example_prompts, sampling_params)
cleanup_dist_env_and_memory()
del llm
with VllmRunner(model_name_main,
tensor_parallel_size=1,
pipeline_parallel_size=1,
data_parallel_size=1,
disable_log_stats=False,
max_model_len=4096,
seed=1024,
async_scheduling=async_scheduling,
compilation_config=CompilationConfig(
cudagraph_mode="FULL_DECODE_ONLY",
cudagraph_capture_sizes=[12])) as llm:
ref_outputs = llm.generate(example_prompts, sampling_params)
cleanup_dist_env_and_memory()
del llm
matches = 0
misses = 0
threshold = 0.66
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
ref_token_ids = ref_output[0][0]
spec_token_ids = spec_output[0][0]
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output[1][0]}")
print(f"spec_output: {spec_output[1][0]}")
# Heuristic: expect at least 66.6% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(threshold * len(ref_outputs))
cleanup_dist_env_and_memory()

View File

@@ -12,7 +12,7 @@ from vllm import LLM, SamplingParams
from vllm.config import CompilationConfig
from vllm.v1.metrics.reader import Counter, Vector
from tests.e2e.conftest import VllmRunner, cleanup_dist_env_and_memory
from tests.e2e.conftest import VllmRunner
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@@ -130,127 +130,6 @@ def test_ngram_correctness(
assert matches > int(0.66 * len(ref_outputs))
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
def test_eagle_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
use_eagle3: bool,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle speculative decoding.
'''
# NOTE: e2e of eagle has many problems before.
# We first check whether it is functioning properly.
# Should fix the e2e with VllmRunner in future.
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
tokenizer = AutoTokenizer.from_pretrained(model_name,
trust_remote_code=True)
prompts = [{
"role": "user",
"content": "Hello, my name is"
}, {
"role": "user",
"content": "The president of the United States is"
}, {
"role": "user",
"content": "The capital of France is"
}, {
"role": "user",
"content": "The future of AI is"
}]
prompts = [
tokenizer.apply_chat_template(
[prompt],
tokenize=False,
add_generation_prompt=True,
) for prompt in prompts
]
sampling_params = SamplingParams(
max_tokens=300,
temperature=0.8,
top_p=0.7,
top_k=4,
ignore_eos=False,
)
# Create an LLM.
llm = LLM(
model=model_name,
tensor_parallel_size=1,
pipeline_parallel_size=1,
data_parallel_size=1,
disable_log_stats=False,
max_model_len=4096,
seed=1024,
async_scheduling=True,
compilation_config={
"level": 3,
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_num_of_warmups": 1,
"cudagraph_capture_sizes": [12],
},
speculative_config={
"disable_padded_drafter_batch": False,
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 2,
"max_model_len": 128,
"draft_vocab_size": 128256,
},
)
llm.generate(prompts, sampling_params)
cleanup_dist_env_and_memory()
del llm
@pytest.mark.parametrize("use_eagle3", [False, True], ids=["eagle", "eagle3"])
def test_eaqgle_fullgraph_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
use_eagle3: bool,
):
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using eagle3 speculative decoding
in full-graph mode.
'''
spec_model_name = eagle3_model_name() if use_eagle3 else eagle_model_name()
with VllmRunner(model_name, max_model_len=1024) as ref_llm:
ref_outputs = ref_llm.model.chat(test_prompts, sampling_config)
with VllmRunner(model_name,
speculative_config={
"method": "eagle3" if use_eagle3 else "eagle",
"model": spec_model_name,
"num_speculative_tokens": 4,
},
compilation_config={
"level": 3,
"cudagraph_mode": "FULL_DECODE_ONLY",
"cudagraph_num_of_warmups": 1,
"cudagraph_capture_sizes": [5, 10, 15, 20],
},
max_model_len=1024) as runner:
spec_outputs = runner.model.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
# Heuristic: expect at least 70% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
def test_suffix_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,

View File

@@ -1,175 +0,0 @@
from __future__ import annotations
import os
import pytest
from vllm import SamplingParams
from vllm.config import CompilationConfig, CUDAGraphMode
from tests.e2e.conftest import VllmRunner
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False)
@pytest.fixture
def model_name():
return "wemaster/deepseek_mtp_main_random_bf16"
def mtp_correctness(sampling_config: SamplingParams,
model_name: str,
num_speculative_tokens: int,
graph_mode: CUDAGraphMode = CUDAGraphMode.PIECEWISE,
enforce_eager=False,
disable_padded_drafter_batch=True):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using mtp speculative decoding.
'''
with VllmRunner(model_name,
tensor_parallel_size=1,
gpu_memory_utilization=0.7,
max_model_len=256,
cudagraph_capture_sizes=[12],
enforce_eager=enforce_eager) as ref_llm:
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
graph_mode_str = "PIECEWISE"
if graph_mode == CUDAGraphMode.FULL:
graph_mode_str = "FULL_DECODE_ONLY"
with VllmRunner(model_name,
tensor_parallel_size=1,
max_num_seqs=256,
gpu_memory_utilization=0.7,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method":
"mtp",
"num_speculative_tokens":
num_speculative_tokens,
"disable_padded_drafter_batch":
disable_padded_drafter_batch,
},
enforce_eager=enforce_eager,
max_model_len=2000,
compilation_config=CompilationConfig(
cudagraph_mode=graph_mode_str,
cudagraph_capture_sizes=[12],
)) as spec_llm:
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
ref_token_ids = ref_output[0][0]
spec_token_ids = spec_output[0][0]
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output[1][0]}")
print(f"spec_output: {spec_output[1][0]}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm
def test_mtp1_correctness_eager(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 1, enforce_eager=True)
def test_mtp2_correctness_eager(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 2, enforce_eager=True)
def test_mtp1_correctness_piecewise_graph(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 1)
def test_mtp2_correctness_piecewise_graph(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 2)
def test_mtp1_correctness_full_graph(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 1, CUDAGraphMode.FULL)
def test_mtp2_correctness_full_graph(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config, model_name, 2, CUDAGraphMode.FULL)
def test_mtp1_correctness_eager_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
1,
enforce_eager=True,
disable_padded_drafter_batch=False)
def test_mtp2_correctness_eager_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
2,
enforce_eager=True,
disable_padded_drafter_batch=False)
@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed")
def test_mtp1_correctness_piecewise_graph_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
1,
disable_padded_drafter_batch=False)
@pytest.mark.skip("TODO(xyx): Revert me when mtp aclgraph is fixed")
def test_mtp2_correctness_piecewise_graph_with_pad(
sampling_config: SamplingParams,
model_name: str,
):
mtp_correctness(sampling_config,
model_name,
2,
disable_padded_drafter_batch=False)