From c13d90b766748221dfd4b29c3c05eb384c858b46 Mon Sep 17 00:00:00 2001 From: lilinsiman Date: Fri, 27 Feb 2026 16:06:56 +0800 Subject: [PATCH] [Refactor][EAGLE] 7/N Merged PCP and disable_padded interface (#6811) ### What this PR does / why we need it? [Refactor][EAGLE] 7/N Merged PCP and disable_padded interface into eagle_proposer.py This pull request significantly refactors the speculative decoding mechanism by merging Parallel Context Processing (PCP) and Multi-Token Prediction (MTP) functionalities directly into the eagle_proposer.py. The changes aim to enhance the efficiency and correctness of distributed speculative decoding, particularly by enabling the Eagle feature to work seamlessly with the disable_padded interface. This involves detailed adjustments to attention metadata, input/output processing, and state management to ensure proper operation in parallel environments. 1. The PCP and MTP features are migrated to the eagle_proposer.py 2. The Eagle and PCP features are integrated 3. Enable the eagle feature to use the disable_padded interface ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Tests and UT - vLLM version: v0.15.0 - vLLM main: https://github.com/vllm-project/vllm/commit/83b47f67b1dfad505606070ae4d9f83e50ad4ebd --------- Signed-off-by: lilinsiman --- tests/ut/attention/test_attention_cp.py | 1 + vllm_ascend/attention/attention_v1.py | 1 + .../context_parallel/attention_cp.py | 12 +- vllm_ascend/spec_decode/eagle_proposer.py | 251 +++++++++++++++--- vllm_ascend/spec_decode/mtp_proposer.py | 12 +- vllm_ascend/worker/model_runner_v1.py | 28 +- 6 files changed, 245 insertions(+), 60 deletions(-) diff --git a/tests/ut/attention/test_attention_cp.py b/tests/ut/attention/test_attention_cp.py index 3cdbeb6b..877f593b 100644 --- a/tests/ut/attention/test_attention_cp.py +++ b/tests/ut/attention/test_attention_cp.py @@ -118,6 +118,7 @@ class TestAscendAttentionCPImpl(TestBase): attn_metadata = MagicMock() attn_metadata.decode_meta = MagicMock() + attn_metadata.num_decodes_flatten = 5 attn_metadata.decode_meta.batch_seq_mask = torch.tensor( [1, 0], dtype=torch.bool) output = self.impl._forward_decode_pcp_dcp(query, attn_metadata) diff --git a/vllm_ascend/attention/attention_v1.py b/vllm_ascend/attention/attention_v1.py index 988f40d8..8b1b5a15 100644 --- a/vllm_ascend/attention/attention_v1.py +++ b/vllm_ascend/attention/attention_v1.py @@ -159,6 +159,7 @@ class AscendMetadata: num_decode_tokens: int = 0 num_prefills: int = 0 num_decodes: int = 0 + num_decodes_flatten: int = 0 # The sequence length per sequence. Sequence length means the computed # tokens + new tokens (is None if it is a decoding). diff --git a/vllm_ascend/attention/context_parallel/attention_cp.py b/vllm_ascend/attention/context_parallel/attention_cp.py index c2f919fe..121743f9 100644 --- a/vllm_ascend/attention/context_parallel/attention_cp.py +++ b/vllm_ascend/attention/context_parallel/attention_cp.py @@ -117,6 +117,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): block_table = common_attn_metadata.block_table_tensor query_lens = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1] + self.num_decodes_flatten = query_lens[:num_decodes].sum().item() seq_lens = common_attn_metadata.seq_lens_cpu[:num_reqs] long_seq_metadata = common_attn_metadata.prefill_context_parallel_metadata @@ -146,7 +147,7 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): pcp_size = get_pcp_group().world_size if self.chunked_prefill_enabled and max_context_len_cpu > 0: local_context_lens_allranks = ( - torch.tensor(num_computed_tokens_of_pcp_dcp)[num_decodes:num_reqs] + torch.tensor(num_computed_tokens_of_pcp_dcp)[self.num_decodes_flatten :] .to(self.device) .to(dtype=torch.int32) ) @@ -214,23 +215,24 @@ class AscendAttentionCPMetadataBuilder(AscendAttentionMetadataBuilder): prefill_metadata = AscendMetadataForPrefill( pcp_metadata=pcp_metadata, chunked_context=chunked_context_metadata, - block_tables=block_table[num_decodes:], + block_tables=block_table[self.num_decodes_flatten :, ...], actual_seq_lengths_q=torch.cumsum(query_lens, dim=0), ) if num_decodes > 0: num_computed_tokens_array = np.array(num_computed_tokens_of_pcp_dcp) - num_computed_tokens_array = num_computed_tokens_array[:num_decodes] + num_computed_tokens_array = num_computed_tokens_array[: self.num_decodes_flatten] # TODO: numpy array mode of the shared memory is used to improve performance decode_metadata = AscendMetadataForDecode( num_computed_tokens_of_pcp_dcp=num_computed_tokens_array, - block_tables=block_table[:num_decodes], + block_tables=block_table[: self.num_decodes_flatten], ) attn_metadata = AscendMetadata( num_actual_tokens=num_actual_tokens, num_decode_tokens=num_decode_tokens, num_actual_tokens_pcp_padded=num_actual_tokens_pcp_padded, + num_decodes_flatten=self.num_decodes_flatten, block_tables=block_table, query_start_loc=query_start_loc, seq_lens=seq_lens, @@ -550,7 +552,7 @@ class AscendAttentionCPImpl(AscendAttentionBackendImpl): "actual_seq_lengths_kv": attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp[ :, self.pcp_rank, self.dcp_rank ], - "actual_seq_lengths": attn_metadata.actual_seq_lengths_q[: attn_metadata.num_decodes], + "actual_seq_lengths": torch.arange(attn_metadata.num_decodes_flatten) + 1, } graph_params = get_graph_params() forward_context: ForwardContext = get_forward_context() diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 69afa3b6..cc0dcd90 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -10,6 +10,7 @@ import torch.nn as nn import torch.nn.functional as F from vllm.config import CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config from vllm.distributed.parallel_state import ( + get_pcp_group, get_pp_group, get_tp_group, get_world_group, @@ -326,6 +327,12 @@ class EagleProposer(VllmEagleProposer): decode_token_per_req=self.runner.decode_token_per_req, max_seq_len=0, ) + if self.pcp_size * self.dcp_size > 1: + # update long_seq related params and flatten block_table + common_attn_metadata.prefill_context_parallel_metadata = self.runner.pcp_manager.long_seq_metadata + common_attn_metadata.block_table_tensor = self.runner.input_batch.block_table[0].get_device_tensor()[ + : num_reqs * self.decode_threshold + ] builder = self.runner.attn_groups[0][0].get_metadata_builder() # update the tensor's address for each step. @@ -343,7 +350,9 @@ class EagleProposer(VllmEagleProposer): model_positions = self._get_positions(num_tokens) - batch_size = num_tokens // (self.num_speculative_tokens + 1) if not is_profile else self.runner.max_num_reqs + batch_size = num_tokens // (self.num_speculative_tokens + 1) # if not is_profile else self.runner.max_num_reqs + if is_profile: + batch_size = min(batch_size, self.runner.max_num_reqs) with set_ascend_forward_context( multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None, @@ -371,6 +380,7 @@ class EagleProposer(VllmEagleProposer): inputs_embeds=None, multi_steps_attn_metadata=multi_steps_attn_metadata, is_dummy=True, + num_tokens=num_tokens, ) forward_context = get_forward_context() if forward_context.cudagraph_runtime_mode == CUDAGraphMode.FULL and not forward_context.capturing: @@ -414,6 +424,62 @@ class EagleProposer(VllmEagleProposer): # Replace the last token with the next token. # E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4] self.input_ids[last_token_indices] = next_token_ids + + assert self.runner is not None + # update pcp related params + if self.pcp_size * self.dcp_size > 1: + assert long_seq_metadata is not None + common_attn_metadata.prefill_context_parallel_metadata = long_seq_metadata + ori_last_token_indices = last_token_indices.clone() + query_lens_d = self.runner.query_lens[:num_decode_reqs] + if self.pcp_size > 1: + # 1. preprocess decode/prefill input_ids & target_hidden_states + # decode input_ids: keep unchanged + # decode target_hidden_states: remove padding + # prefill input_ids: add padding and pcp split + # prefill target_hidden_states: pcp split + num_tokens_d = query_lens_d.sum().item() + num_tokens_d_padded = num_tokens_d * self.pcp_size + input_ids_d = self.input_ids[:num_tokens_d] + input_ids_p = self.input_ids[num_tokens_d:num_tokens] + target_hidden_states_d_padded = target_hidden_states[:num_tokens_d_padded] + if num_tokens_d: + # remove padding (from pcp all-gather) in decode part + mask_start_loc = torch.cat( + [torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d * self.pcp_size, dim=0)[:-1]] + ) + mask_len = query_lens_d + mask = [] + for req_id in range(num_decode_reqs): + mask += list(range(mask_start_loc[req_id], mask_start_loc[req_id] + mask_len[req_id])) + target_hidden_states_d = target_hidden_states_d_padded[mask] + else: + target_hidden_states_d = target_hidden_states_d_padded + target_hidden_states_p = target_hidden_states[num_tokens_d_padded:] + req_scheduled_tokens_p = {} + for i, req_id in enumerate(self.runner.input_batch.req_ids): + if i >= num_decode_reqs: + req_scheduled_tokens_p[req_id] = req_scheduled_tokens[req_id] + (num_tokens_p, input_ids_p, target_hidden_states_p, max_query_len_p, seq_lens_p, cu_num_tokens_p) = ( + self._split_pcp_input(req_scheduled_tokens_p, input_ids_p, target_hidden_states_p) + ) + num_tokens = num_tokens_d + num_tokens_p + target_positions = target_positions[:num_tokens] + self.input_ids[:num_tokens].copy_(torch.cat([input_ids_d, input_ids_p], dim=0)) + target_hidden_states = torch.cat([target_hidden_states_d, target_hidden_states_p], dim=0) + # 2. update sample_indices according to main model + if num_decode_reqs: + last_token_indices[:num_decode_reqs] = self.runner.logits_indices[last_token_indices[:num_decode_reqs]] + if num_prefill_reqs: + last_token_indices[-num_prefill_reqs:] = self.runner.logits_indices[-num_prefill_reqs:] + # 3. update attn_metadata params that may be influenced by pcp + common_attn_metadata.num_actual_tokens = num_tokens + common_attn_metadata.max_query_len = max(self.decode_threshold, max_query_len_p) + common_attn_metadata.seq_lens[-num_prefill_reqs:] = seq_lens_p + common_attn_metadata.seq_lens_cpu[-num_prefill_reqs:] = seq_lens_p + query_start_loc_p = cu_num_tokens_p[1:] + common_attn_metadata.query_start_loc[num_decode_reqs].item() + common_attn_metadata.query_start_loc[-num_prefill_reqs:] = query_start_loc_p + common_attn_metadata.query_start_loc_cpu[-num_prefill_reqs:] = query_start_loc_p if self.use_cuda_graph and num_tokens <= self.runner.cudagraph_batch_sizes[-1]: num_input_tokens = self.runner.cudagraph_dispatcher._bs_to_padded_graph_size[num_tokens] if not ( @@ -468,11 +534,7 @@ class EagleProposer(VllmEagleProposer): # only tensor which will be used in current FIA. # Strictly speaking, `query_start_loc`, `seq_lens` should also have # their memory allocated separately for each step just like `slot_mapping`. - slot_mapping_lens = ( - num_input_tokens - if num_input_tokens < common_attn_metadata.slot_mapping.shape[0] - else common_attn_metadata.slot_mapping.shape[0] - ) + slot_mapping_lens = common_attn_metadata.slot_mapping.shape[0] self.slot_mapping_group[0][:slot_mapping_lens].copy_(common_attn_metadata.slot_mapping[:slot_mapping_lens]) self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1) common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens] @@ -491,21 +553,87 @@ class EagleProposer(VllmEagleProposer): per_layer_attn_metadata[layer_name] = attn_metadata multi_steps_attn_metadata = [per_layer_attn_metadata] - # Copy the old attn_metadata and update - for draft_step in range(1, self.num_speculative_tokens): - common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm( - draft_step, - attn_metadata, - common_attn_metadata, - batch_size, - num_input_tokens, - used_update_positions, - aclgraph_runtime_mode, - ) - per_layer_attn_metadata = dict() - for layer_name in self.attn_layer_names: - per_layer_attn_metadata[layer_name] = attn_metadata - multi_steps_attn_metadata.append(per_layer_attn_metadata) + attn_metadata_i = per_layer_attn_metadata[self.attn_layer_names[0]] + if self.pcp_size * self.dcp_size > 1: + if self.num_speculative_tokens > 1 and not attn_metadata_i.num_prefills: + # For pcp/dcp, tokens are split across different cp ranks, + # so we can not simply update slot_mapping by += 1. + # Instead, we pre-allocate mtp slot_mapping in model_runner + # (_generate_pcp_mtp_input), and use updated slot_indices + # to get corresponding slot_mapping in each step. + num_reject_tokens = ( + torch.tensor(self.runner.pcp_manager.cu_num_tokens_pcp_full, dtype=torch.int32).to(self.device) + - ori_last_token_indices + - 1 + ) + num_accept_tokens = query_lens_d.to(self.device) - num_reject_tokens + ori_seq_len = attn_metadata_i.seq_lens[:batch_size].clone() + mtp_slot_mapping = self.runner.pcp_manager.mtp_slot_pad + + # slot_mapping index base offset: + # scheduled tokens + pre-allocated mtp tokens + accepted tokens + slot_idx_base = ( + torch.cat( + [ + torch.tensor([0], dtype=torch.int32, device=self.device), + (torch.cumsum(query_lens_d, dim=0)[:-1] * self.pcp_size).to(self.device), + ] + ) + + torch.arange(num_decode_reqs, device=self.device) + * (self.num_speculative_tokens - 1) + * self.pcp_size + + (num_accept_tokens - 1) * self.pcp_size + ) + slot_indices_list = [] + for req_id in range(num_decode_reqs): + slot_indices_list.append( + torch.arange(slot_idx_base[req_id], slot_idx_base[req_id] + self.pcp_size, device=self.device) + ) + slot_indices = torch.cat(slot_indices_list, dim=0) + + # fold block_table (restore it to original size before flattened) + block_indices = torch.cat( + [torch.tensor([0], dtype=torch.int32), torch.cumsum(query_lens_d, dim=0)[:-1]] + ) + common_attn_metadata.block_table_tensor[:batch_size] = common_attn_metadata.block_table_tensor[ + block_indices + ] + common_attn_metadata.block_table_tensor = common_attn_metadata.block_table_tensor[:batch_size] + + # Copy the old attn_metadata and update + for draft_step in range(1, self.num_speculative_tokens): + common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm( + draft_step, + attn_metadata, + common_attn_metadata, + batch_size, + num_input_tokens, + used_update_positions, + aclgraph_runtime_mode, + ori_seq_len, + slot_indices, + mtp_slot_mapping, + ) + per_layer_attn_metadata = dict() + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + multi_steps_attn_metadata.append(per_layer_attn_metadata) + else: + # Copy the old attn_metadata and update + for draft_step in range(1, self.num_speculative_tokens): + common_attn_metadata, attn_metadata = self.attn_update_stack_num_spec_norm( + draft_step, + attn_metadata, + common_attn_metadata, + batch_size, + num_input_tokens, + used_update_positions, + aclgraph_runtime_mode, + ) + per_layer_attn_metadata = dict() + for layer_name in self.attn_layer_names: + per_layer_attn_metadata[layer_name] = attn_metadata + multi_steps_attn_metadata.append(per_layer_attn_metadata) last_token_indices_len = last_token_indices.shape[0] self.last_token_indices[:last_token_indices_len].copy_(last_token_indices) @@ -533,6 +661,8 @@ class EagleProposer(VllmEagleProposer): target_positions=target_positions, inputs_embeds=inputs_embeds, multi_steps_attn_metadata=multi_steps_attn_metadata, + num_tokens=num_tokens, + is_prefill=attn_metadata_i.num_prefills, ) forward_context = get_forward_context() @@ -548,7 +678,9 @@ class EagleProposer(VllmEagleProposer): target_positions, inputs_embeds, multi_steps_attn_metadata, + num_tokens, is_dummy=False, + is_prefill=None, ) -> torch.Tensor: # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all # speculative tokens' proposings. `model_input_ids`, `model_positions` and @@ -575,6 +707,15 @@ class EagleProposer(VllmEagleProposer): last_hidden_states, model_positions, hidden_states ) + if self.pcp_size > 1: + # remove graph padding before all_gather + hidden_states = hidden_states[:num_tokens] + hidden_states = get_pcp_group().all_gather(hidden_states, 0) + hidden_states = torch.index_select( + hidden_states, 0, self.runner.pcp_manager.pcp_allgather_restore_idx.gpu[: hidden_states.shape[0]] + ) + last_hidden_states = hidden_states # TODO: check it + num_indices = last_token_indices.shape[0] if lmhead_tp_enable() and not is_dummy: max_num_reqs_across_dp = ( @@ -596,6 +737,13 @@ class EagleProposer(VllmEagleProposer): # [batch_size, 1] return draft_token_ids.view(-1, 1) + if self.pcp_size * self.dcp_size > 1 and is_prefill: + draft_token_ids = logits.argmax(dim=-1) + draft_token_ids_list = [] + for _ in range(self.num_speculative_tokens): + draft_token_ids_list.append(draft_token_ids) + return torch.stack(draft_token_ids_list, dim=1) + # Generate the remaining draft tokens. draft_token_ids_tensor = torch.zeros( (self.num_speculative_tokens, *draft_token_ids.shape), dtype=draft_token_ids.dtype, device=self.device @@ -722,6 +870,9 @@ class EagleProposer(VllmEagleProposer): input_batch_size, used_update_positions, aclgraph_runtime_mode, + ori_seq_len=None, + slot_indices=None, + mtp_slot_mapping=None, ): assert draft_step > 0 common_attn_metadata = self.shallow_copy_metadata(old_common_metadata) @@ -797,28 +948,42 @@ class EagleProposer(VllmEagleProposer): attn_metadata_builder = self._get_attention_metadata_builder() else: attn_metadata_builder = self.attn_metadata_builder - block_size = attn_metadata_builder.kv_cache_spec.block_size - # Compute the slot mapping. - if self.uses_mrope: - block_numbers = clamped_positions[0] // block_size + if self.pcp_size * self.dcp_size > 1: + num_computed_tokens_of_pcp_dcp = self.runner.pcp_manager._get_cp_local_seq_lens( + ori_seq_len + draft_step, + self.pcp_size, + self.dcp_size, + self.runner.parallel_config.cp_kv_cache_interleave_size, + ) + cp_seq_len = num_computed_tokens_of_pcp_dcp[:, self.pcp_rank, self.dcp_rank] + # update slot_mapping + slot_indices += self.pcp_size + slot_mapping = mtp_slot_mapping[slot_indices] + common_attn_metadata.slot_mapping[: batch_size * self.pcp_size] = slot_mapping else: - block_numbers = clamped_positions // block_size - block_ids = old_common_metadata.block_table_tensor.gather(dim=1, index=block_numbers.view(-1, 1)) - block_ids = block_ids.view(-1) - if self.uses_mrope: - slot_mapping = block_ids * block_size + clamped_positions[0] % block_size - else: - slot_mapping = block_ids * block_size + clamped_positions % block_size + block_size = attn_metadata_builder.kv_cache_spec.block_size - # Mask out the slot mappings that exceed the max model length. - # Otherwise, the KV cache will be inadvertently updated with the - # padding tokens. - slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) - self.slot_mapping_group[draft_step][: slot_mapping.shape[0]].copy_(slot_mapping.to(torch.int32)) - self.slot_mapping_group[draft_step][slot_mapping.shape[0] :].fill_(PADDING_SLOT_ID) - # Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx] - common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][: slot_mapping.shape[0]] + # Compute the slot mapping. + if self.uses_mrope: + block_numbers = clamped_positions[0] // block_size + else: + block_numbers = clamped_positions // block_size + block_ids = old_common_metadata.block_table_tensor.gather(dim=1, index=block_numbers.view(-1, 1)) + block_ids = block_ids.view(-1) + if self.uses_mrope: + slot_mapping = block_ids * block_size + clamped_positions[0] % block_size + else: + slot_mapping = block_ids * block_size + clamped_positions % block_size + + # Mask out the slot mappings that exceed the max model length. + # Otherwise, the KV cache will be inadvertently updated with the + # padding tokens. + slot_mapping.masked_fill_(exceeds_max_model_len, PADDING_SLOT_ID) + self.slot_mapping_group[draft_step][: slot_mapping.shape[0]].copy_(slot_mapping.to(torch.int32)) + self.slot_mapping_group[draft_step][slot_mapping.shape[0] :].fill_(PADDING_SLOT_ID) + # Set the address of the attn_metadata.slot_mapping to the self.slot_mapping_group[idx] + common_attn_metadata.slot_mapping = self.slot_mapping_group[draft_step][: slot_mapping.shape[0]] # Rebuild attention metadata attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore @@ -826,6 +991,12 @@ class EagleProposer(VllmEagleProposer): draft_index=draft_step, ) + if self.pcp_size * self.dcp_size > 1: + if self.vllm_config.model_config.use_mla: + attn_metadata.decode.cp_seq_len = cp_seq_len + else: + attn_metadata.decode_meta.num_computed_tokens_of_pcp_dcp = num_computed_tokens_of_pcp_dcp + return common_attn_metadata, attn_metadata def prepare_next_token_ids_padded( diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 61b15e93..5da99cae 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -39,11 +39,7 @@ class MtpProposer(EagleProposer): # Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer. # Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph. # TODO: this conditional check should be removed after bug fixing. - if ( - self.pcp_size * self.dcp_size == 1 - and not self.speculative_config.disable_padded_drafter_batch - and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() - ): + if not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(): super().dummy_run( num_tokens, with_prefill, @@ -175,11 +171,7 @@ class MtpProposer(EagleProposer): # Currently, both GLM and DS encounter issues when enabling the fullgraph mode and running on EagleProposer. # Therefore, we temporarily bypass this problem by adding a conditional check for fullgraph. # TODO: this conditional check should be removed after bug fixing. - if ( - self.pcp_size * self.dcp_size == 1 - and not self.speculative_config.disable_padded_drafter_batch - and not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs() - ): + if not self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(): draft_token_ids = super()._propose( target_token_ids, target_positions, diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index fee69da7..51d2ecef 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -561,7 +561,6 @@ class NPUModelRunner(GPUModelRunner): dtype=np.int32, ) attn_state = self._build_attn_state(num_reqs, num_scheduled_tokens, num_valid_tokens) - self.attn_state = attn_state # type: ignore # Determine if it's a splitfuse batch with_prefill = attn_state not in [AscendAttentionState.DecodeOnly, AscendAttentionState.SpecDecoding] @@ -800,7 +799,7 @@ class NPUModelRunner(GPUModelRunner): attn_state = AscendAttentionState.SpecDecoding # Speculative decoding. elif np.all(num_valid_tokens == 1): - if self.speculative_config and self.speculative_config.method == "mtp": + if self.speculative_config: attn_state = AscendAttentionState.SpecDecoding else: attn_state = AscendAttentionState.ChunkedPrefill @@ -809,6 +808,14 @@ class NPUModelRunner(GPUModelRunner): attn_state = AscendAttentionState.ChunkedPrefill else: attn_state = AscendAttentionState.PrefillCacheHit + + # For the overlay of the PCP feature and the eagle3, attn_state needs to be recovered + # TODO: Resolved the conflict between the sunset of attn_state and the PCP that requires this interface. + if attn_state == AscendAttentionState.SpecDecoding and self.speculative_config.method != "mtp": + self.attn_state = AscendAttentionState.ChunkedPrefill # type: ignore + else: + self.attn_state = attn_state # type: ignore + return attn_state def _calc_spec_decode_metadata( @@ -977,6 +984,10 @@ class NPUModelRunner(GPUModelRunner): target_token_ids = input_ids_pcp_full[:num_scheduled_tokens] target_positions = self._get_positions(num_scheduled_tokens) target_hidden_states = hidden_states + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat( + [h[:num_scheduled_tokens] for h in aux_hidden_states], dim=-1 + ) else: token_indices_to_sample = None # input_ids can be None for multimodal models. @@ -1014,6 +1025,8 @@ class NPUModelRunner(GPUModelRunner): target_token_ids = input_ids_pcp_full[token_indices] target_positions = positions target_hidden_states = hidden_states + if self.use_aux_hidden_state_outputs: + target_hidden_states = torch.cat([h[token_indices] for h in aux_hidden_states], dim=-1) else: target_token_ids = self.input_ids.gpu[token_indices] target_positions = self._get_positions(token_indices) @@ -1260,13 +1273,18 @@ class NPUModelRunner(GPUModelRunner): num_tokens_padded, input_ids, positions, intermediate_tensors, inputs_embeds, **model_kwargs ) with record_function_or_nullcontext("post process"): + aux_hidden_states = None + if self.use_aux_hidden_state_outputs: + hidden_states, aux_hidden_states = hidden_states if self.pcp_size > 1: # NOTE we must `slice` hidden_states because pcp_allgather_restore_idx # ignores the padding from CUDA Graph. hidden_states = self.pcp_manager.get_restore_hidden_states(hidden_states) - aux_hidden_states = None - if self.use_aux_hidden_state_outputs: - hidden_states, aux_hidden_states = hidden_states + if aux_hidden_states is not None: + aux_hidden_states = [ + self.pcp_manager.get_restore_hidden_states(aux_hidden_states_pcp) + for aux_hidden_states_pcp in aux_hidden_states + ] if not self.broadcast_pp_output: # Common case.