From 8f278fc101a1c6f3aa3ccb47d7c698a5555fb646 Mon Sep 17 00:00:00 2001 From: lilinsiman Date: Tue, 17 Mar 2026 16:14:45 +0800 Subject: [PATCH] [eagle3][pcp] fix bug for eagle3 and cp enable (#7309) ### What this PR does / why we need it? This PR fixes the bug for eagle3 and cp enable introduced by the parallel speculative inference PR. ### Does this PR introduce _any_ user-facing change? no ### How was this patch tested? tests and ut - vLLM version: v0.17.0 - vLLM main: https://github.com/vllm-project/vllm/commit/4034c3d32e30d01639459edd3ab486f56993876d --------- Signed-off-by: lilinsiman --- .../4-cards/long_sequence/test_mtp.py | 25 +++ vllm_ascend/spec_decode/eagle_proposer.py | 143 ++++++++++-------- 2 files changed, 105 insertions(+), 63 deletions(-) diff --git a/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py b/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py index cc8eb9b5..6e56bf1c 100644 --- a/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py +++ b/tests/e2e/multicard/4-cards/long_sequence/test_mtp.py @@ -29,6 +29,10 @@ prompts = [ "The president of United States is", "AI future is" ] model = "wemaster/deepseek_mtp_main_random_bf16" +model_eagle3 = { + "main": "Qwen/Qwen3-8B", + "spec": "RedHatAI/Qwen3-8B-speculator.eagle3", +} @wait_until_npu_memory_free() def test_pcp_dcp_mtp1_eager(): @@ -141,3 +145,24 @@ def test_dcp_mtp3_full_graph(): async_scheduling=False, ) as runner: runner.generate_greedy(prompts, 32) + + +@wait_until_npu_memory_free() +def test_pcp_eagle3_eager(): + with VllmRunner( + model_eagle3["main"], + max_model_len=1024, + tensor_parallel_size=2, + enforce_eager=True, + prefill_context_parallel_size=2, + decode_context_parallel_size=1, + max_num_batched_tokens=1024, + block_size=128, + speculative_config={ + "num_speculative_tokens": 3, + "method": "eagle3", + "model": model_eagle3["spec"] + }, + async_scheduling=False, + ) as runner: + runner.generate_greedy(prompts, 32) diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 00c26e30..62e7fe80 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -475,7 +475,7 @@ class SpecDecodeBaseProposer(EagleProposer): target_hidden_states = self.model.combine_hidden_states(target_hidden_states) assert target_hidden_states.shape[-1] == self.hidden_size - num_tokens, token_indices_to_sample, common_attn_metadata = self.set_inputs_first_pass( + num_tokens, token_indices_to_sample, common_attn_metadata, long_seq_args = self.set_inputs_first_pass( target_token_ids=target_token_ids, next_token_ids=next_token_ids, target_positions=target_positions, @@ -483,65 +483,15 @@ class SpecDecodeBaseProposer(EagleProposer): token_indices_to_sample=token_indices_to_sample, cad=common_attn_metadata, num_rejected_tokens_gpu=num_rejected_tokens_gpu, + req_scheduled_tokens=req_scheduled_tokens, + long_seq_metadata=long_seq_metadata, + num_prefill_reqs=num_prefill_reqs, + num_decode_reqs=num_decode_reqs, ) - - assert self.runner is not None - # update pcp related params if self.pcp_size * self.dcp_size > 1: - assert long_seq_metadata is not None - common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata - ori_token_indices_to_sample = token_indices_to_sample.clone() - query_lens_d = self.runner.query_lens[:num_decode_reqs] - if self.pcp_size > 1: - # 1. preprocess decode/prefill input_ids & target_hidden_states - # decode input_ids: keep unchanged - # decode target_hidden_states: remove padding - # prefill input_ids: add padding and pcp split - # prefill target_hidden_states: pcp split - num_tokens_d = query_lens_d.sum().item() - num_tokens_d_padded = num_tokens_d * self.pcp_size - input_ids_d = self.input_ids[:num_tokens_d] - input_ids_p = self.input_ids[num_tokens_d:num_tokens] - target_hidden_states_d_padded = target_hidden_states[:num_tokens_d_padded] - if num_tokens_d: - # remove padding (from pcp all-gather) in decode part - mask_start_loc = torch.cat( - [torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]] - ) - mask_len = query_lens_d - mask = [] - for req_id in range(num_decode_reqs): - mask += list(range(mask_start_loc[req_id], mask_start_loc[req_id] + mask_len[req_id])) - target_hidden_states_d = target_hidden_states_d_padded[mask] - else: - target_hidden_states_d = target_hidden_states_d_padded - target_hidden_states_p = target_hidden_states[num_tokens_d_padded:] - req_scheduled_tokens_p = {} - for i, req_id in enumerate(self.runner.input_batch.req_ids): - if i >= num_decode_reqs: - req_scheduled_tokens_p[req_id] = req_scheduled_tokens[req_id] - (num_tokens_p, input_ids_p, target_hidden_states_p, max_query_len_p, seq_lens_p, cu_num_tokens_p) = ( - self._split_pcp_input(req_scheduled_tokens_p, input_ids_p, target_hidden_states_p) - ) - num_tokens = num_tokens_d + num_tokens_p - target_positions = target_positions[:num_tokens] - self.input_ids[:num_tokens].copy_(torch.cat([input_ids_d, input_ids_p], dim=0)) - target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0) - # 2. update sample_indices according to main model - if num_decode_reqs: - token_indices_to_sample[:num_decode_reqs] = self.runner.logits_indices[ - token_indices_to_sample[:num_decode_reqs] - ] - if num_prefill_reqs: - token_indices_to_sample[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:] - # 3. update attn_metadata params that may be influenced by pcp - common_attn_metadata.num_actual_tokens = num_tokens - common_attn_metadata.max_query_len = max(self.decode_threshold, max_query_len_p) - common_attn_metadata.seq_lens[-num_prefill_reqs:] = seq_lens_p - common_attn_metadata.seq_lens_cpu[-num_prefill_reqs:] = seq_lens_p - query_start_loc_p = cu_num_tokens_p[1:] + common_attn_metadata.query_start_loc[num_decode_reqs].item() - common_attn_metadata.query_start_loc[-num_prefill_reqs:] = query_start_loc_p - common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = query_start_loc_p + assert long_seq_args is not None + query_lens_d, ori_token_indices_to_sample = long_seq_args + assert self.runner is not None if self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[-1]: num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_tokens] if not ( @@ -986,7 +936,11 @@ class SpecDecodeBaseProposer(EagleProposer): token_indices_to_sample: torch.Tensor | None, cad: CommonAttentionMetadata, num_rejected_tokens_gpu: torch.Tensor | None, - ) -> tuple[int, torch.Tensor, CommonAttentionMetadata]: + req_scheduled_tokens=None, + long_seq_metadata=None, + num_prefill_reqs=0, + num_decode_reqs=0, + ) -> tuple[int, torch.Tensor, CommonAttentionMetadata, tuple[Any, Any] | None]: if not self.needs_extra_input_slots: # Default EAGLE pathway: no reshaping of input tensors needed. # Simply rotate the input ids and leave the positions unchanged, @@ -1002,6 +956,68 @@ class SpecDecodeBaseProposer(EagleProposer): # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[token_indices_to_sample] = next_token_ids + assert self.runner is not None + # update pcp related params + ori_token_indices_to_sample = None + query_lens_d = None + if self.pcp_size * self.dcp_size > 1: + assert long_seq_metadata is not None + cad.prefill_context_parallel_metadata = long_seq_metadata + ori_token_indices_to_sample = token_indices_to_sample.clone() + query_lens_d = self.runner.query_lens[:num_decode_reqs] + if self.pcp_size > 1: + # 1. preprocess decode/prefill input_ids & target_hidden_states + # decode input_ids: keep unchanged + # decode target_hidden_states: remove padding + # prefill input_ids: add padding and pcp split + # prefill target_hidden_states: pcp split + assert query_lens_d is not None + num_tokens_d = query_lens_d.sum().item() + num_tokens_d_padded = num_tokens_d * self.pcp_size + input_ids_d = self.input_ids[:num_tokens_d] + input_ids_p = self.input_ids[num_tokens_d:num_tokens] + target_hidden_states_d_padded = target_hidden_states[:num_tokens_d_padded] + if num_tokens_d: + # remove padding (from pcp all-gather) in decode part + mask_start_loc = torch.cat( + [torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]] + ) + mask_len = query_lens_d + mask = [] + for req_id in range(num_decode_reqs): + assert None not in (mask_start_loc, mask_len) + mask += list(range(mask_start_loc[req_id], mask_start_loc[req_id] + mask_len[req_id])) + target_hidden_states_d = target_hidden_states_d_padded[mask] + else: + target_hidden_states_d = target_hidden_states_d_padded + target_hidden_states_p = target_hidden_states[num_tokens_d_padded:] + req_scheduled_tokens_p = {} + for i, req_id in enumerate(self.runner.input_batch.req_ids): + if i >= num_decode_reqs: + req_scheduled_tokens_p[req_id] = req_scheduled_tokens[req_id] + (num_tokens_p, input_ids_p, target_hidden_states_p, max_query_len_p, seq_lens_p, cu_num_tokens_p) = ( + self._split_pcp_input(req_scheduled_tokens_p, input_ids_p, target_hidden_states_p) + ) + num_tokens = num_tokens_d + num_tokens_p + target_positions = target_positions[:num_tokens] + self.input_ids[:num_tokens].copy_(torch.cat([input_ids_d, input_ids_p], dim=0)) + target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0) + # 2. update sample_indices according to main model + if num_decode_reqs: + token_indices_to_sample[:num_decode_reqs] = self.runner.logits_indices[ + token_indices_to_sample[:num_decode_reqs] + ] + if num_prefill_reqs: + token_indices_to_sample[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:] + # 3. update attn_metadata params that may be influenced by pcp + cad.num_actual_tokens = num_tokens + cad.max_query_len = max(self.decode_threshold, max_query_len_p) + cad.seq_lens[-num_prefill_reqs:] = seq_lens_p + cad.seq_lens_cpu[-num_prefill_reqs:] = seq_lens_p + query_start_loc_p = cu_num_tokens_p[1:] + cad.query_start_loc[num_decode_reqs].item() + cad.query_start_loc[-num_prefill_reqs:] = query_start_loc_p + cad.query_start_loc_cpu[-num_prefill_reqs:] = query_start_loc_p + # copy inputs to buffer for cudagraph if self.uses_xdrope_dim > 0 and self.draft_uses_xdrope_dim == 0: target_positions = target_positions[0] @@ -1009,7 +1025,7 @@ class SpecDecodeBaseProposer(EagleProposer): self._set_positions(num_tokens, target_positions) self.hidden_states[:num_tokens] = target_hidden_states - return num_tokens, token_indices_to_sample, cad + return num_tokens, token_indices_to_sample, cad, (query_lens_d, ori_token_indices_to_sample) else: assert self.is_rejected_token_mask is not None assert self.is_masked_token_mask is not None @@ -1057,7 +1073,7 @@ class SpecDecodeBaseProposer(EagleProposer): # Use torch.where to avoid DtoH sync from boolean indexing mask = self.is_masked_token_mask[:total_num_output_tokens] torch.where( - mask.unsqueeze(1), + mask.unsqueeze(1), # type: ignore self.parallel_drafting_hidden_state_tensor, self.hidden_states[:total_num_output_tokens], out=self.hidden_states[:total_num_output_tokens], @@ -1093,7 +1109,7 @@ class SpecDecodeBaseProposer(EagleProposer): new_slot_mapping=new_slot_mapping, ) - return total_num_output_tokens, token_indices_to_sample, new_cad + return total_num_output_tokens, token_indices_to_sample, new_cad, None def model_returns_tuple(self) -> bool: return self.method not in ("mtp", "draft_model") @@ -1198,7 +1214,8 @@ class SpecDecodeBaseProposer(EagleProposer): # update slot_mapping slot_indices += self.pcp_size slot_mapping = mtp_slot_mapping[slot_indices] - common_attn_metadata.slot_mapping[: batch_size * self.pcp_size] = slot_mapping + self.slot_mapping_group[draft_step][: batch_size * self.pcp_size] = slot_mapping + common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step] else: # NOTE: In vllm, `block_size = attn_metadata_builder.kv_cache_spec.block_size`. # However, in vllm-ascend, the above value can be multiple of `kernel_block_size`,