[Bugfix] Resolve MTP > 1 issue when lm head tp > 1 (#4254)
### What this PR does / why we need it?
Previously, the dummy run executed compute_logits only once, regardless
of num_speculative_tokens. This caused execute_model to hang on
compute_logits when lm head tensor parallelism exceeded 1. The fix
ensures compute_logits executes correctly during dummy run, matching
num_speculative_tokens.
I set the `non_blocking` argument to False when moving
`exceeds_max_model_len` to the CPU. From what I understand, using
`non_blocking=True` and immediately accessing the tensor on the CPU can
cause accuracy problems. However, this issue doesn't happen when
transferring data to a device. ref:
https://discuss.pytorch.org/t/should-we-set-non-blocking-to-true/38234/18
- vLLM version: v0.11.0
- vLLM main:
2918c1b49c
---------
Signed-off-by: Jade Zheng <zheng.shoujian@outlook.com>
This commit is contained in:
@@ -213,7 +213,8 @@ class MtpProposer(Proposer):
|
||||
num_reqs: int = 0,
|
||||
num_tokens_across_dp=None,
|
||||
aclgraph_runtime_mode: CUDAGraphMode = CUDAGraphMode.NONE,
|
||||
batch_descriptor=None) -> None:
|
||||
batch_descriptor=None,
|
||||
dummy_compute_logits=lambda hidden_states: None) -> None:
|
||||
|
||||
(
|
||||
num_tokens,
|
||||
@@ -296,6 +297,7 @@ class MtpProposer(Proposer):
|
||||
self.update_stream, forward_context,
|
||||
positions.shape[0],
|
||||
self.vllm_config.speculative_config)
|
||||
dummy_compute_logits(previous_hidden_states)
|
||||
if with_prefill:
|
||||
break
|
||||
|
||||
@@ -756,6 +758,7 @@ class MtpProposer(Proposer):
|
||||
logits = self.model.compute_logits(sample_hidden_states)
|
||||
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
||||
logits = logits[:num_indices]
|
||||
last_token_indices = last_token_indices[:num_indices]
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
|
||||
if self.num_speculative_tokens == 1:
|
||||
@@ -821,7 +824,7 @@ class MtpProposer(Proposer):
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
exceeds_max_model_len_cpu = exceeds_max_model_len.to(
|
||||
attn_metadata_i.seq_lens.device, non_blocking=True)
|
||||
attn_metadata_i.seq_lens.device, non_blocking=False)
|
||||
attn_metadata_i.seq_lens[:batch_size].masked_fill_(
|
||||
exceeds_max_model_len_cpu, 1)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
|
||||
Reference in New Issue
Block a user