[cherry-pick pr-4254] bugfix for mtp>1 when lm_head_tp>1 (#4360)

### What this PR does / why we need it?
Previously, the dummy run executed compute_logits only once, regardless
of num_speculative_tokens. This caused execute_model to hang on
compute_logits when lm head tensor parallelism exceeded 1. The fix
ensures compute_logits executes correctly during dummy run, matching
num_speculative_tokens.

Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
zouyida2052
2025-12-01 11:11:15 +08:00
committed by GitHub
parent cd9f5c0611
commit 2b4f7a5016
3 changed files with 25 additions and 15 deletions

View File

@@ -2465,13 +2465,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
need_dummy_logits = (not self.in_profile_run
and lmhead_tp_enable())
if need_dummy_logits:
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
dummy_indices = torch.zeros(max_num_reqs_across_dp,
dtype=torch.int32)
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
dummy_indices = torch.zeros(max_num_reqs_across_dp,
dtype=torch.int32)
def dummy_compute_logits(hidden_states):
return self.model.compute_logits(
def dummy_compute_logits(hidden_states):
if not need_dummy_logits:
return None
return self.model.compute_logits(hidden_states[dummy_indices])
def dummy_drafter_compute_logits(hidden_states):
if not need_dummy_logits or self.drafter is None:
return
if hasattr(self.drafter, "model") and hasattr(
self.drafter.model, "compute_logits"):
return self.drafter.model.compute_logits(
hidden_states[dummy_indices])
with set_ascend_forward_context(
@@ -2493,8 +2501,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
with_prefill, is_torchair_compile, input_ids, positions,
attn_metadata, num_tokens, intermediate_tensors,
inputs_embeds)
if need_dummy_logits:
dummy_compute_logits(hidden_states)
dummy_compute_logits(hidden_states)
if self.drafter:
self.drafter.dummy_run(
@@ -2504,10 +2511,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
num_reqs=num_reqs,
num_tokens_across_dp=num_tokens_across_dp,
aclgraph_runtime_mode=aclgraph_runtime_mode,
batch_descriptor=batch_descriptor)
if need_dummy_logits:
self.drafter.model.compute_logits(
hidden_states[dummy_indices])
batch_descriptor=batch_descriptor,
dummy_compute_logits=dummy_drafter_compute_logits)
if self.in_profile_run and self.dynamic_eplb:
self.model.clear_all_moe_loads()
if not self.in_profile_run and self.dynamic_eplb: