[V1][BUGFIX][0.10.1] FIX mtp on main branch (#2632)

### What this PR does / why we need it?
Fix MTP torchair bug caused by torchair refactor and moe refactor

Depends on PRs:
fused moe fix: https://github.com/vllm-project/vllm-ascend/pull/2627 
torchair multi DP fix:
https://github.com/vllm-project/vllm-ascend/pull/2626

### Does this PR introduce _any_ user-facing change?
when dp is enabled, to run mtp online server, need to disable server log
due to the current metrics does not support multi dp
`--disable-log-stats`
### How was this patch tested?


- vLLM version: v0.10.1.1
- vLLM main:
7c8271cd1e

Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
xuyexiong
2025-09-02 11:12:41 +08:00
committed by GitHub
parent fef18b60bc
commit 214b32a346
4 changed files with 125 additions and 4 deletions

View File

@@ -0,0 +1,90 @@
from __future__ import annotations
import os
import pytest
from vllm import SamplingParams
from tests.e2e.conftest import VllmRunner
from vllm_ascend.ascend_config import clear_ascend_config
os.environ["VLLM_WORKER_MULTIPROC_METHOD"] = "spawn"
@pytest.fixture
def sampling_config():
return SamplingParams(temperature=0, max_tokens=256, ignore_eos=False)
@pytest.fixture
def model_name():
return "wemaster/deepseek_mtp_main_random_bf16"
def test_mtp_torchair_correctness(
sampling_config: SamplingParams,
model_name: str,
):
example_prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
'''
Compare the outputs of a original LLM and a speculative LLM
should be the same when using mtp speculative decoding.
'''
clear_ascend_config()
with VllmRunner(model_name,
tensor_parallel_size=1,
gpu_memory_utilization=0.7,
max_model_len=256,
enforce_eager=False,
additional_config={
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": False,
"graph_batch_sizes": [1, 2, 4],
},
}) as ref_llm:
ref_outputs = ref_llm.generate(example_prompts, sampling_config)
clear_ascend_config()
with VllmRunner(model_name,
tensor_parallel_size=1,
max_num_seqs=256,
gpu_memory_utilization=0.7,
distributed_executor_backend="mp",
enable_expert_parallel=True,
speculative_config={
"method": "deepseek_mtp",
"num_speculative_tokens": 1,
},
enforce_eager=False,
max_model_len=2000,
additional_config={
"torchair_graph_config": {
"enabled": True,
"use_cached_graph": False,
"graph_batch_sizes": [1, 2, 4],
}
}) as spec_llm:
spec_outputs = spec_llm.generate(example_prompts, sampling_config)
matches = 0
misses = 0
for ref_output, spec_output in zip(ref_outputs, spec_outputs):
ref_token_ids = ref_output[0][0]
spec_token_ids = spec_output[0][0]
if ref_token_ids == spec_token_ids[:len(ref_token_ids)]:
matches += 1
else:
misses += 1
print(f"ref_output: {ref_output[1][0]}")
print(f"spec_output: {spec_output[1][0]}")
# Heuristic: expect at least 66% of the prompts to match exactly
# Upon failure, inspect the outputs to check for inaccuracy.
assert matches > int(0.66 * len(ref_outputs))
del spec_llm
clear_ascend_config()

View File

@@ -23,6 +23,8 @@ from pytest_mock import MockerFixture
from vllm.model_executor.layers.fused_moe import FusedMoEMethodBase
from vllm_ascend.ascend_forward_context import _get_fused_moe_state
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
from vllm_ascend.quantization.quantizer import W8A8Quantizer
from vllm_ascend.torchair.ops.torchair_fused_moe import (
TorchairAscendFusedMoE, TorchairAscendUnquantizedFusedMoEMethod)
from vllm_ascend.utils import AscendSocVersion, adapt_patch # noqa E402
@@ -233,12 +235,28 @@ class TestTorchairAscendFusedMoe:
mock_quant_config = MagicMock()
mock_quant_method = MockFusedMoEMethod()
mock_quant_config.get_quant_method.return_value = mock_quant_method
mock_quant_config.is_layer_skipped_ascend.return_value = False
with patch(
'vllm_ascend.quantization.quantizer.AscendQuantizer.get_quantizer',
return_value=W8A8Quantizer):
moe = TorchairAscendFusedMoE(**default_moe_config,
quant_config=mock_quant_config)
assert moe.quant_method is not None
assert isinstance(moe.quant_method, AscendFusedMoEMethod)
def test_init_with_mixed_quant(self, mock_dist_env, default_moe_config):
mock_quant_config = MagicMock()
mock_quant_method = MockFusedMoEMethod()
mock_quant_config.get_quant_method.return_value = mock_quant_method
mock_quant_config.is_layer_skipped_ascend.return_value = True
moe = TorchairAscendFusedMoE(**default_moe_config,
quant_config=mock_quant_config)
assert moe.quant_method is not None
assert moe.quant_method == mock_quant_method
assert isinstance(moe.quant_method,
TorchairAscendUnquantizedFusedMoEMethod)
@pytest.mark.parametrize(
"others_param",

View File

@@ -45,6 +45,7 @@ from vllm_ascend.distributed.communication_op import \
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state,
@@ -1055,7 +1056,13 @@ class TorchairAscendFusedMoE(FusedMoE):
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(
self.moe)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
if quant_config.is_layer_skipped_ascend(
prefix, quant_config.packed_modules_mapping):
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(
self.moe)
else:
self.quant_method = AscendFusedMoEMethod(
quant_config, prefix, quant_config.packed_modules_mapping)
assert self.quant_method is not None

View File

@@ -18,6 +18,8 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
TorchairDeepSeekMTP
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
@@ -266,8 +268,12 @@ class MtpProposer:
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
self.model = CustomDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
if self.torchair_graph_enabled:
self.model = TorchairDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
else:
self.model = CustomDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -