feat: add mtp ut and fix some bugs (#2453)

### What this PR does / why we need it?
Fix mtp mode ut

### Does this PR introduce _any_ user-facing change?
Nothing

### How was this patch tested?
This can be tested in the same way as a unit test.


- vLLM version: v0.10.0
- vLLM main:
53415653ff

Signed-off-by: 赵江江 <zhaojiangjiang1@h-partners.com>
Co-authored-by: 赵江江 <zhaojiangjiang1@h-partners.com>
This commit is contained in:
ZhaoJiangJiang
2025-08-22 17:09:08 +08:00
committed by GitHub
parent dd04a96ee3
commit 3629bc4431
10 changed files with 129 additions and 75 deletions

View File

@@ -1,43 +1,13 @@
from __future__ import annotations
import random
from typing import Any
import os
import pytest
from vllm import LLM, SamplingParams
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
@pytest.fixture
def test_prompts():
prompt_types = ["repeat", "sentence"]
num_prompts = 10
prompts = []
random.seed(0)
random_prompt_type_choices = random.choices(prompt_types, k=num_prompts)
# Generate a mixed batch of prompts, some of which can be easily
# predicted by n-gram matching and some which likely cannot.
for kind in random_prompt_type_choices:
word_choices = ["test", "temp", "hello", "where"]
word = random.choice(word_choices)
if kind == "repeat":
prompt = f"""
please repeat the word '{word}' 10 times.
give no other output than the word at least ten times in a row,
in lowercase with spaces between each word and without quotes.
"""
elif kind == "sentence":
prompt = f"""
please give a ten-word sentence that
uses the word {word} at least once.
give no other output than that simple sentence without quotes.
"""
else:
raise ValueError(f"Unknown prompt type: {kind}")
prompts.append([{"role": "user", "content": prompt}])
return prompts
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@pytest.fixture
@@ -50,39 +20,56 @@ def model_name():
return "wemaster/deepseek_mtp_main_random_bf16"
@pytest.mark.skipif(
True, reason="TODO: Enable me after test_mtp_correctness is fixed")
def test_mtp_correctness(
test_prompts: list[list[dict[str, Any]]],
sampling_config: SamplingParams,
model_name: str,
):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using mtp speculative decoding.
'''
ref_llm = LLM(model=model_name, max_model_len=256, enforce_eager=True)
ref_outputs = ref_llm.chat(test_prompts, sampling_config)
del ref_llm
with VllmRunner(model_name,
tensor_parallel_size=1,
gpu_memory_utilization=0.7,
max_model_len=256,
enforce_eager=True) as ref_llm:
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
with VllmRunner(
model_name,
tensor_parallel_size=1,
max_num_seqs=256,
gpu_memory_utilization=0.7,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
enforce_eager=True,
max_model_len=2000,
additional_config={"ascend_scheduler_config": {
"enabled": False
}}) as spec_llm:
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
spec_llm = LLM(model=model_name,
trust_remote_code=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
max_model_len=256,
enforce_eager=True)
spec_outputs = spec_llm.chat(test_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
if ref_output.outputs[0].text == spec_output.outputs[0].text:
ref_token_ids = ref_output[0][0]
spec_token_ids = spec_output[0][0]
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output.outputs[0].text}")
print(f"spec_output: {spec_output.outputs[0].text}")
print(f"ref_output: {ref_output[1][0]}")
print(f"spec_output: {spec_output[1][0]}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.