[Bugfix] fix MTP support for lmhead_tensor_parallel_size (#3921)

### What this PR does / why we need it?
Fix the issue of MTP being enabled and setting
Imhead_tensor_parallel_size=16 causing the inference to hang.


Signed-off-by: wyh145 <1987244901@qq.com>
This commit is contained in:
Nagisa125
2025-10-31 14:34:28 +08:00
committed by GitHub
parent ee2e55e602
commit 9f7de45b75
2 changed files with 3 additions and 2 deletions

View File

@@ -51,7 +51,7 @@ class AscendVocabParallelEmbedding(VocabParallelEmbedding):
prefix: str = ""):
nn.Module.__init__(self)
if lmhead_tp_enable() and prefix.find("lm_head") != -1:
if lmhead_tp_enable() and prefix.find("head") != -1:
self.comm_group = get_lmhead_tp_group()
else:
self.comm_group = get_tp_group()

View File

@@ -2516,7 +2516,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor)
if need_dummy_logits:
dummy_compute_logits(hidden_states)
self.drafter.model.compute_logits(
hidden_states[dummy_indices])
if self.in_profile_run and self.dynamic_eplb:
self.model.clear_all_moe_loads()
if not self.in_profile_run and self.dynamic_eplb: