[Bugfix] Resolve MTP > 1 issue when lm head tp > 1 (#4254)
### What this PR does / why we need it?
Previously, the dummy run executed compute_logits only once, regardless
of num_speculative_tokens. This caused execute_model to hang on
compute_logits when lm head tensor parallelism exceeded 1. The fix
ensures compute_logits executes correctly during dummy run, matching
num_speculative_tokens.
I set the `non_blocking` argument to False when moving
`exceeds_max_model_len` to the CPU. From what I understand, using
`non_blocking=True` and immediately accessing the tensor on the CPU can
cause accuracy problems. However, this issue doesn't happen when
transferring data to a device. ref:
https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/18
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -3003,14 +3003,21 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
|
||||
need_dummy_logits = (not self.in_profile_run
|
||||
and lmhead_tp_enable())
|
||||
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
|
||||
dummy_indices = torch.zeros(max_num_reqs_across_dp,
|
||||
dtype=torch.int32)
|
||||
|
||||
if need_dummy_logits:
|
||||
max_num_reqs_across_dp = num_tokens if not with_prefill else max_num_reqs
|
||||
dummy_indices = torch.zeros(max_num_reqs_across_dp,
|
||||
dtype=torch.int32)
|
||||
def dummy_compute_logits(hidden_states):
|
||||
if not need_dummy_logits:
|
||||
return None
|
||||
return self.model.compute_logits(hidden_states[dummy_indices])
|
||||
|
||||
def dummy_compute_logits(hidden_states):
|
||||
return self.model.compute_logits(
|
||||
def dummy_drafter_compute_logits(hidden_states):
|
||||
if not need_dummy_logits or self.drafter is None:
|
||||
return
|
||||
if hasattr(self.drafter, "model") and hasattr(
|
||||
self.drafter.model, "compute_logits"):
|
||||
return self.drafter.model.compute_logits(
|
||||
hidden_states[dummy_indices])
|
||||
|
||||
with set_ascend_forward_context(
|
||||
@@ -3032,8 +3039,7 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
with_prefill, is_torchair_compile, input_ids, positions,
|
||||
attn_metadata, num_tokens, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
if need_dummy_logits:
|
||||
dummy_compute_logits(hidden_states)
|
||||
dummy_compute_logits(hidden_states)
|
||||
|
||||
if self.drafter:
|
||||
self.drafter.dummy_run(
|
||||
@@ -3042,10 +3048,8 @@ class NPUModelRunner(LoRAModelRunnerMixin):
|
||||
num_reqs=num_reqs,
|
||||
num_tokens_across_dp=num_tokens_across_dp,
|
||||
aclgraph_runtime_mode=aclgraph_runtime_mode,
|
||||
batch_descriptor=batch_descriptor)
|
||||
if need_dummy_logits:
|
||||
self.drafter.model.compute_logits(
|
||||
hidden_states[dummy_indices])
|
||||
batch_descriptor=batch_descriptor,
|
||||
dummy_compute_logits=dummy_drafter_compute_logits)
|
||||
if self.in_profile_run and self.dynamic_eplb:
|
||||
self.model.clear_all_moe_loads()
|
||||
if not self.in_profile_run and self.dynamic_eplb:
|
||||
|
||||
Reference in New Issue
Block a user