[V1][BUGFIX][0.10.1] FIX mtp on main branch (#2632)

### What this PR does / why we need it?
Fix MTP torchair bug caused by torchair refactor and moe refactor

Depends on PRs:
fused moe fix: https://github.com/vllm-project/vllm-ascend/pull/2627 
torchair multi DP fix:
https://github.com/vllm-project/vllm-ascend/pull/2626

### Does this PR introduce _any_ user-facing change?
when dp is enabled, to run mtp online server, need to disable server log
due to the current metrics does not support multi dp
`--disable-log-stats`
### How was this patch tested?


- vLLM version: v0.10.1.1
- vLLM main:
7c8271cd1e

Signed-off-by: xuyexiong <xuyexiong@huawei.com>
This commit is contained in:
xuyexiong
2025-09-02 11:12:41 +08:00
committed by GitHub
parent fef18b60bc
commit 214b32a346
4 changed files with 125 additions and 4 deletions

View File

@@ -45,6 +45,7 @@ from vllm_ascend.distributed.communication_op import \
from vllm_ascend.distributed.parallel_state import get_mc2_group
from vllm_ascend.ops.expert_load_balancer import ExpertLoadBalancer
from vllm_ascend.ops.sequence_parallel import MetadataForPadding
from vllm_ascend.quantization.quant_config import AscendFusedMoEMethod
from vllm_ascend.torchair.utils import npu_stream_switch, npu_wait_tensor
from vllm_ascend.utils import (AscendSocVersion, dispose_tensor,
get_all_reduce_merge_state,
@@ -1055,7 +1056,13 @@ class TorchairAscendFusedMoE(FusedMoE):
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(
self.moe)
else:
self.quant_method = quant_config.get_quant_method(self, prefix)
if quant_config.is_layer_skipped_ascend(
prefix, quant_config.packed_modules_mapping):
self.quant_method = TorchairAscendUnquantizedFusedMoEMethod(
self.moe)
else:
self.quant_method = AscendFusedMoEMethod(
quant_config, prefix, quant_config.packed_modules_mapping)
assert self.quant_method is not None

View File

@@ -18,6 +18,8 @@ from vllm_ascend.ascend_config import get_ascend_config
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.models.deepseek_mtp import CustomDeepSeekMTP
from vllm_ascend.torchair.models.torchair_deepseek_mtp import \
TorchairDeepSeekMTP
from vllm_ascend.torchair.utils import TorchairCommonAttentionMetadata
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
@@ -266,8 +268,12 @@ class MtpProposer:
with set_default_torch_dtype(
draft_model_config.dtype), set_current_vllm_config(
self.vllm_config):
self.model = CustomDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
if self.torchair_graph_enabled:
self.model = TorchairDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
else:
self.model = CustomDeepSeekMTP(
vllm_config=self.vllm_config).to(target_device)
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -