[cherry-pick pr-4254] bugfix for mtp>1 when lm_head_tp>1 (#4360)
### What this PR does / why we need it? Previously, the dummy run executed compute_logits only once, regardless of num_speculative_tokens. This caused execute_model to hang on compute_logits when lm head tensor parallelism exceeded 1. The fix ensures compute_logits executes correctly during dummy run, matching num_speculative_tokens. Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
@@ -116,7 +116,8 @@ class EagleProposer(Proposer):
|
||||
num_reqs: int = 0,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor=None):
|
||||
batch_descriptor=None,
|
||||
dummy_compute_logits=lambda hidden_states: None):
|
||||
moe_comm_type = self.runner._select_moe_comm_method(
|
||||
num_tokens, with_prefill)
|
||||
with set_ascend_forward_context(None,
|
||||
@@ -128,6 +129,7 @@ class EagleProposer(Proposer):
|
||||
positions=self.positions[:num_tokens],
|
||||
hidden_states=self.hidden_states[:num_tokens],
|
||||
)
|
||||
dummy_compute_logits(self.hidden_states)
|
||||
|
||||
def generate_token_ids(self,
|
||||
valid_sampled_token_ids: list[list[int]],
|
||||
|
||||
Reference in New Issue
Block a user