[cherry-pick pr-4254] bugfix for mtp>1 when lm_head_tp>1 (#4360)

### What this PR does / why we need it?
Previously, the dummy run executed compute_logits only once, regardless
of num_speculative_tokens. This caused execute_model to hang on
compute_logits when lm head tensor parallelism exceeded 1. The fix
ensures compute_logits executes correctly during dummy run, matching
num_speculative_tokens.

Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
zouyida2052
2025-12-01 11:11:15 +08:00
committed by GitHub
parent cd9f5c0611
commit 2b4f7a5016
3 changed files with 25 additions and 15 deletions

View File

@@ -116,7 +116,8 @@ class EagleProposer(Proposer):
num_reqs: int = 0,
num_tokens_across_dp: Optional[torch.Tensor] = None,
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
batch_descriptor=None):
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None):
moe_comm_type = self.runner._select_moe_comm_method(
num_tokens, with_prefill)
with set_ascend_forward_context(None,
@@ -128,6 +129,7 @@ class EagleProposer(Proposer):
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)
dummy_compute_logits(self.hidden_states)
def generate_token_ids(self,
valid_sampled_token_ids: list[list[int]],