[cherry-pick pr-4254] bugfix for mtp>1 when lm_head_tp>1 (#4360)
### What this PR does / why we need it? Previously, the dummy run executed compute_logits only once, regardless of num_speculative_tokens. This caused execute_model to hang on compute_logits when lm head tensor parallelism exceeded 1. The fix ensures compute_logits executes correctly during dummy run, matching num_speculative_tokens. Signed-off-by: zouyida2052 <zouyida2002@gmail.com>
This commit is contained in:
@@ -116,7 +116,8 @@ class EagleProposer(Proposer):
|
||||
num_reqs: int = 0,
|
||||
num_tokens_across_dp: Optional[torch.Tensor] = None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor=None):
|
||||
batch_descriptor=None,
|
||||
dummy_compute_logits=lambda hidden_states: None):
|
||||
moe_comm_type = self.runner._select_moe_comm_method(
|
||||
num_tokens, with_prefill)
|
||||
with set_ascend_forward_context(None,
|
||||
@@ -128,6 +129,7 @@ class EagleProposer(Proposer):
|
||||
positions=self.positions[:num_tokens],
|
||||
hidden_states=self.hidden_states[:num_tokens],
|
||||
)
|
||||
dummy_compute_logits(self.hidden_states)
|
||||
|
||||
def generate_token_ids(self,
|
||||
valid_sampled_token_ids: list[list[int]],
|
||||
|
||||
@@ -114,7 +114,8 @@ class MtpProposer(Proposer):
|
||||
num_reqs: int = 0,
|
||||
num_tokens_across_dp=None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor=None) -> None:
|
||||
batch_descriptor=None,
|
||||
dummy_compute_logits=lambda hidden_states: None) -> None:
|
||||
if not self.torchair_graph_enabled:
|
||||
# TODO: adapt enable_dbo later
|
||||
(num_tokens, num_tokens_across_dp, with_prefill,
|
||||
@@ -188,6 +189,7 @@ class MtpProposer(Proposer):
|
||||
self.model(input_ids=input_ids,
|
||||
positions=positions,
|
||||
hidden_states=previous_hidden_states)
|
||||
dummy_compute_logits(previous_hidden_states)
|
||||
if with_prefill:
|
||||
break
|
||||
|
||||
@@ -490,6 +492,7 @@ class MtpProposer(Proposer):
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
||||
logits = logits[:num_indices]
|
||||
last_token_indices = last_token_indices[:num_indices]
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
if self.num_speculative_tokens == 1:
|
||||
@@ -554,7 +557,7 @@ class MtpProposer(Proposer):
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
|
||||
attn_metadata_i.seq_lens.device, non_blocking=True)
|
||||
attn_metadata_i.seq_lens.device, non_blocking=False)
|
||||
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
|
||||
exceeds_max_model_len_cpu, 1)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
|
||||
@@ -2465,13 +2465,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
need_dummy_logits = (not self.in_profile_run
|
||||
and lmhead_tp_enable())
|
||||
|
||||
if need_dummy_logits:
|
||||
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
|
||||
dummy_indices = torch.zeros(max_num_reqs_across_dp,
|
||||
dtype=torch.int32)
|
||||
|
||||
def dummy_compute_logits(hidden_states):
|
||||
return self.model.compute_logits(
|
||||
if not need_dummy_logits:
|
||||
return None
|
||||
return self.model.compute_logits(hidden_states[dummy_indices])
|
||||
|
||||
def dummy_drafter_compute_logits(hidden_states):
|
||||
if not need_dummy_logits or self.drafter is None:
|
||||
return
|
||||
if hasattr(self.drafter, "model") and hasattr(
|
||||
self.drafter.model, "compute_logits"):
|
||||
return self.drafter.model.compute_logits(
|
||||
hidden_states[dummy_indices])
|
||||
|
||||
with set_ascend_forward_context(
|
||||
@@ -2493,7 +2501,6 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
attn_metadata, num_tokens, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
if need_dummy_logits:
|
||||
dummy_compute_logits(hidden_states)
|
||||
|
||||
if self.drafter:
|
||||
@@ -2504,10 +2511,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_reqs=num_reqs,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor)
|
||||
if need_dummy_logits:
|
||||
self.drafter.model.compute_logits(
|
||||
hidden_states[dummy_indices])
|
||||
batch_descriptor=batch_descriptor,
|
||||
dummy_compute_logits=dummy_drafter_compute_logits)
|
||||
if self.in_profile_run and self.dynamic_eplb:
|
||||
self.model.clear_all_moe_loads()
|
||||
if not self.in_profile_run and self.dynamic_eplb:
|
||||
|
||||
Reference in New Issue
Block a user