[main][bugfix] Solved the problem of the d node getting stuck in the pd-separation scenario (#7534)

### What this PR does / why we need it?
A problem of the d node getting stuck in the pd-separation scenario is
solved.

We find it will crash at `torch.nn.functional.linear(x, weight, bias)`
after being stuck for a long time.
we found that the shapes of each dp
node were not aligned. this is the root cause.

- vLLM version: v0.18.0
- vLLM main:
4034c3d32e

Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
drslark
2026-03-23 18:53:07 +08:00
committed by GitHub
parent a253235a59
commit 41dadd4312

View File

@@ -439,7 +439,6 @@ class SpecDecodeBaseProposer(EagleProposer):
target_positions=model_positions, target_positions=model_positions,
inputs_embeds=None, inputs_embeds=None,
multi_steps_attn_metadata=multi_steps_attn_metadata, multi_steps_attn_metadata=multi_steps_attn_metadata,
is_dummy=True,
num_tokens=num_tokens, num_tokens=num_tokens,
) )
forward_context = get_forward_context() forward_context = get_forward_context()
@@ -702,7 +701,6 @@ class SpecDecodeBaseProposer(EagleProposer):
inputs_embeds, inputs_embeds,
multi_steps_attn_metadata, multi_steps_attn_metadata,
num_tokens, num_tokens,
is_dummy=False,
is_prefill=None, is_prefill=None,
) -> torch.Tensor: ) -> torch.Tensor:
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all
@@ -755,7 +753,7 @@ class SpecDecodeBaseProposer(EagleProposer):
self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: last_hidden_states.shape[0]], self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: last_hidden_states.shape[0]],
) )
if lmhead_tp_enable() and not is_dummy: if lmhead_tp_enable():
max_num_reqs_across_dp = ( max_num_reqs_across_dp = (
self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len
) )
@@ -766,7 +764,7 @@ class SpecDecodeBaseProposer(EagleProposer):
sample_hidden_states = last_hidden_states[token_indices_to_sample] sample_hidden_states = last_hidden_states[token_indices_to_sample]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy: if lmhead_tp_enable() and num_indices < logits.shape[0]:
logits = logits[:num_indices] logits = logits[:num_indices]
token_indices_to_sample = token_indices_to_sample[:num_indices] token_indices_to_sample = token_indices_to_sample[:num_indices]
@@ -879,7 +877,7 @@ class SpecDecodeBaseProposer(EagleProposer):
) )
num_indices = token_indices_to_sample.shape[0] num_indices = token_indices_to_sample.shape[0]
if lmhead_tp_enable() and not is_dummy: if lmhead_tp_enable():
max_num_reqs_across_dp = ( max_num_reqs_across_dp = (
self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len
) )
@@ -891,7 +889,7 @@ class SpecDecodeBaseProposer(EagleProposer):
sample_hidden_states = last_hidden_states[token_indices_to_sample] sample_hidden_states = last_hidden_states[token_indices_to_sample]
logits = self.model.compute_logits(sample_hidden_states) logits = self.model.compute_logits(sample_hidden_states)
if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy: if lmhead_tp_enable() and num_indices < logits.shape[0]:
logits = logits[:num_indices] logits = logits[:num_indices]
token_indices_to_sample = token_indices_to_sample[:num_indices] token_indices_to_sample = token_indices_to_sample[:num_indices]