[eagle][cp][bugfix] Fix the bug in eagle and cp enabled (#6981)

### What this PR does / why we need it?
When eagle and cp are enabled at the same time, there is an error in
pcp_allgather due to hidden_states. This PR fixes this issue.

- vLLM version: v0.16.0
- vLLM main:
15d76f74e2
---------
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2026-03-06 20:49:49 +08:00
committed by GitHub
parent 1c0ecf806a
commit 01d3515dcf
2 changed files with 13 additions and 3 deletions

View File

@@ -718,7 +718,17 @@ class AscendEagleProposer(EagleProposer):
hidden_states = torch.index_select( hidden_states = torch.index_select(
hidden_states, 0, self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]] hidden_states, 0, self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]]
) )
last_hidden_states = hidden_states # TODO: check it if self.method == "mtp":
last_hidden_states = hidden_states
else:
# eagle and eagle3 need allgather last_hidden_states.
last_hidden_states = last_hidden_states[:num_tokens]
last_hidden_states = get_pcp_group().all_gather(last_hidden_states, 0)
last_hidden_states = torch.index_select(
last_hidden_states,
0,
self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: last_hidden_states.shape[0]],
)
num_indices = last_token_indices.shape[0] num_indices = last_token_indices.shape[0]
if lmhead_tp_enable() and not is_dummy: if lmhead_tp_enable() and not is_dummy:
@@ -957,7 +967,7 @@ class AscendEagleProposer(EagleProposer):
if self.pcp_size * self.dcp_size > 1: if self.pcp_size * self.dcp_size > 1:
num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens( num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens(
ori_seq_len + draft_step, ori_seq_len + draft_step + 1,
self.pcp_size, self.pcp_size,
self.dcp_size, self.dcp_size,
self.runner.parallel_config.cp_kv_cache_interleave_size, self.runner.parallel_config.cp_kv_cache_interleave_size,

View File

@@ -1048,7 +1048,7 @@ class NPUModelRunner(GPUModelRunner):
target_positions = positions target_positions = positions
target_hidden_states = hidden_states target_hidden_states = hidden_states
if self.use_aux_hidden_state_outputs: if self.use_aux_hidden_state_outputs:
target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1) target_hidden_states = torch.cat([h for h in aux_hidden_states], dim=-1)
else: else:
target_token_ids = self.input_ids.gpu[token_indices] target_token_ids = self.input_ids.gpu[token_indices]
target_positions = self._get_positions(token_indices) target_positions = self._get_positions(token_indices)