From 01d3515dcfa28468c966cebee341104533e6a4c1 Mon Sep 17 00:00:00 2001 From: lilinsiman Date: Fri, 6 Mar 2026 20:49:49 +0800 Subject: [PATCH] [eagle][cp][bugfix] Fix the bug in eagle and cp enabled (#6981) ### What this PR does / why we need it? When eagle and cp are enabled at the same time, there is an error in pcp_allgather due to hidden_states. This PR fixes this issue. - vLLM version: v0.16.0 - vLLM main: https://github.com/vllm-project/vllm/commit/15d76f74e2fdb12a95ea00f0ca283acf6219a2b7 --------- Signed-off-by: lilinsiman --- vllm_ascend/spec_decode/eagle_proposer.py | 14 ++++++++++++-- vllm_ascend/worker/model_runner_v1.py | 2 +- 2 files changed, 13 insertions(+), 3 deletions(-) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 0a5e5d74..3086ae30 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -718,7 +718,17 @@ class AscendEagleProposer(EagleProposer): hidden_states = torch.index_select( hidden_states, 0, self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]] ) - last_hidden_states = hidden_states # TODO: check it + if self.method == "mtp": + last_hidden_states = hidden_states + else: + # eagle and eagle3 need allgather last_hidden_states. + last_hidden_states = last_hidden_states[:num_tokens] + last_hidden_states = get_pcp_group().all_gather(last_hidden_states, 0) + last_hidden_states = torch.index_select( + last_hidden_states, + 0, + self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: last_hidden_states.shape[0]], + ) num_indices = last_token_indices.shape[0] if lmhead_tp_enable() and not is_dummy: @@ -957,7 +967,7 @@ class AscendEagleProposer(EagleProposer): if self.pcp_size * self.dcp_size > 1: num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens( - ori_seq_len + draft_step, + ori_seq_len + draft_step + 1, self.pcp_size, self.dcp_size, self.runner.parallel_config.cp_kv_cache_interleave_size, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index 7fa58a44..64f4c22d 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -1048,7 +1048,7 @@ class NPUModelRunner(GPUModelRunner): target_positions = positions target_hidden_states = hidden_states if self.use_aux_hidden_state_outputs: - target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1) + target_hidden_states = torch.cat([h for h in aux_hidden_states], dim=-1) else: target_token_ids = self.input_ids.gpu[token_indices] target_positions = self._get_positions(token_indices)