[main][bugfix] Solved the problem of the d node getting stuck in the pd-separation scenario (#7534)
### What this PR does / why we need it?
A problem of the d node getting stuck in the pd-separation scenario is
solved.
We find it will crash at `torch.nn.functional.linear(x, weight, bias)`
after being stuck for a long time.
we found that the shapes of each dp
node were not aligned. this is the root cause.
- vLLM version: v0.18.0
- vLLM main:
4034c3d32e
Signed-off-by: drslark <slarksblood@qq.com>
This commit is contained in:
@@ -439,7 +439,6 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
target_positions=model_positions,
|
target_positions=model_positions,
|
||||||
inputs_embeds=None,
|
inputs_embeds=None,
|
||||||
multi_steps_attn_metadata=multi_steps_attn_metadata,
|
multi_steps_attn_metadata=multi_steps_attn_metadata,
|
||||||
is_dummy=True,
|
|
||||||
num_tokens=num_tokens,
|
num_tokens=num_tokens,
|
||||||
)
|
)
|
||||||
forward_context = get_forward_context()
|
forward_context = get_forward_context()
|
||||||
@@ -702,7 +701,6 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
inputs_embeds,
|
inputs_embeds,
|
||||||
multi_steps_attn_metadata,
|
multi_steps_attn_metadata,
|
||||||
num_tokens,
|
num_tokens,
|
||||||
is_dummy=False,
|
|
||||||
is_prefill=None,
|
is_prefill=None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all
|
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all
|
||||||
@@ -755,7 +753,7 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: last_hidden_states.shape[0]],
|
self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: last_hidden_states.shape[0]],
|
||||||
)
|
)
|
||||||
|
|
||||||
if lmhead_tp_enable() and not is_dummy:
|
if lmhead_tp_enable():
|
||||||
max_num_reqs_across_dp = (
|
max_num_reqs_across_dp = (
|
||||||
self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len
|
self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len
|
||||||
)
|
)
|
||||||
@@ -766,7 +764,7 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
sample_hidden_states = last_hidden_states[token_indices_to_sample]
|
sample_hidden_states = last_hidden_states[token_indices_to_sample]
|
||||||
logits = self.model.compute_logits(sample_hidden_states)
|
logits = self.model.compute_logits(sample_hidden_states)
|
||||||
|
|
||||||
if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy:
|
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
||||||
logits = logits[:num_indices]
|
logits = logits[:num_indices]
|
||||||
token_indices_to_sample = token_indices_to_sample[:num_indices]
|
token_indices_to_sample = token_indices_to_sample[:num_indices]
|
||||||
|
|
||||||
@@ -879,7 +877,7 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
)
|
)
|
||||||
|
|
||||||
num_indices = token_indices_to_sample.shape[0]
|
num_indices = token_indices_to_sample.shape[0]
|
||||||
if lmhead_tp_enable() and not is_dummy:
|
if lmhead_tp_enable():
|
||||||
max_num_reqs_across_dp = (
|
max_num_reqs_across_dp = (
|
||||||
self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len
|
self.vllm_config.scheduler_config.max_num_seqs * self.runner.uniform_decode_query_len
|
||||||
)
|
)
|
||||||
@@ -891,7 +889,7 @@ class SpecDecodeBaseProposer(EagleProposer):
|
|||||||
sample_hidden_states = last_hidden_states[token_indices_to_sample]
|
sample_hidden_states = last_hidden_states[token_indices_to_sample]
|
||||||
logits = self.model.compute_logits(sample_hidden_states)
|
logits = self.model.compute_logits(sample_hidden_states)
|
||||||
|
|
||||||
if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy:
|
if lmhead_tp_enable() and num_indices < logits.shape[0]:
|
||||||
logits = logits[:num_indices]
|
logits = logits[:num_indices]
|
||||||
token_indices_to_sample = token_indices_to_sample[:num_indices]
|
token_indices_to_sample = token_indices_to_sample[:num_indices]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user