[V1][BUGFIX][0.10.1] FIX mtp on main branch (#2632)
### What this PR does / why we need it?
Fix MTP torchair bug caused by torchair refactor and moe refactor
Depends on PRs:
fused moe fix: https://github.com/vllm-project/vllm-ascend/pull/2627
torchair multi DP fix:
https://github.com/vllm-project/vllm-ascend/pull/2626
### Does this PR introduce _any_ user-facing change?
when dp is enabled, to run mtp online server, need to disable server log
due to the current metrics does not support multi dp
`--disable-log-stats`
### How was this patch tested?
- vLLM version: v0.10.1.1
- vLLM main:
7c8271cd1e
Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
@@ -0,0 +1,90 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from vllm import SamplingParams
|
||||
|
||||
from tests.e2e.conftest import VllmRunner
|
||||
from vllm_ascend.ascend_config import clear_ascend_config
|
||||
|
||||
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sampling_config():
|
||||
return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_name():
|
||||
return "wemaster/deepseek_mtp_main_random_bf16"
|
||||
|
||||
|
||||
def test_mtp_torchair_correctness(
|
||||
sampling_config: SamplingParams,
|
||||
model_name: str,
|
||||
):
|
||||
example_prompts = [
|
||||
"Hello, my name is",
|
||||
"The president of the United States is",
|
||||
"The capital of France is",
|
||||
"The future of AI is",
|
||||
]
|
||||
'''
|
||||
Compare the outputs of a original LLM and a speculative LLM
|
||||
should be the same when using mtp speculative decoding.
|
||||
'''
|
||||
clear_ascend_config()
|
||||
with VllmRunner(model_name,
|
||||
tensor_parallel_size=1,
|
||||
gpu_memory_utilization=0.7,
|
||||
max_model_len=256,
|
||||
enforce_eager=False,
|
||||
additional_config={
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"use_cached_graph": False,
|
||||
"graph_batch_sizes": [1, 2, 4],
|
||||
},
|
||||
}) as ref_llm:
|
||||
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
|
||||
clear_ascend_config()
|
||||
with VllmRunner(model_name,
|
||||
tensor_parallel_size=1,
|
||||
max_num_seqs=256,
|
||||
gpu_memory_utilization=0.7,
|
||||
distributed_executor_backend="mp",
|
||||
enable_expert_parallel=True,
|
||||
speculative_config={
|
||||
"method": "deepseek_mtp",
|
||||
"num_speculative_tokens": 1,
|
||||
},
|
||||
enforce_eager=False,
|
||||
max_model_len=2000,
|
||||
additional_config={
|
||||
"torchair_graph_config": {
|
||||
"enabled": True,
|
||||
"use_cached_graph": False,
|
||||
"graph_batch_sizes": [1, 2, 4],
|
||||
}
|
||||
}) as spec_llm:
|
||||
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
|
||||
|
||||
matches = 0
|
||||
misses = 0
|
||||
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
|
||||
ref_token_ids = ref_output[0][0]
|
||||
spec_token_ids = spec_output[0][0]
|
||||
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
|
||||
matches += 1
|
||||
else:
|
||||
misses += 1
|
||||
print(f"ref_output: {ref_output[1][0]}")
|
||||
print(f"spec_output: {spec_output[1][0]}")
|
||||
|
||||
# Heuristic: expect at least 66% of the prompts to match exactly
|
||||
# Upon failure, inspect the outputs to check for inaccuracy.
|
||||
assert matches > int(0.66 * len(ref_outputs))
|
||||
del spec_llm
|
||||
clear_ascend_config()
|
||||
Reference in New Issue
Block a user