From 7932255c06413eb92d6dd4ba1bed42489e371add Mon Sep 17 00:00:00 2001 From: lilinsiman Date: Mon, 2 Feb 2026 19:15:31 +0800 Subject: [PATCH] [Refactor][EAGLE] 6/N route mtp to eagle except pcp/dcp+mtp (#6349) ### What this PR does / why we need it? Overview: This pull request refactors speculative decoding for Eagle and MTP proposers on Ascend hardware. It fixes a bug related to draft_attn_metadatas being lost, migrates the lmhead feature, and adds routing logic in MtpProposer. Details: 1. Migrated the lmhead feature from mtp to eagle and normalized it in eagle_proposer. 2. Fixed the bug where draft_attn_metadatas was lost after enabling eagle mode in the merge graph. 3. Added the routing for pcp and disable padded drafter batch; in mtp mode, if pcp and disable padded drafter batch are not enabled, the normalized file eagle_proposer will be used. RFC: https://github.com/vllm-project/vllm-ascend/issues/5467 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? ut and test - vLLM version: v0.14.1 - vLLM main: https://github.com/vllm-project/vllm/commit/dc917cceb877dfd13f98c538c4c96158047d98bd --------- Signed-off-by: lilinsiman --- vllm_ascend/ascend_forward_context.py | 2 + vllm_ascend/spec_decode/eagle_proposer.py | 84 ++++++++++++++++++----- vllm_ascend/spec_decode/mtp_proposer.py | 24 ++++++- vllm_ascend/worker/model_runner_v1.py | 4 -- 4 files changed, 90 insertions(+), 24 deletions(-) diff --git a/vllm_ascend/ascend_forward_context.py b/vllm_ascend/ascend_forward_context.py index 04ffa7b3..26604f0b 100644 --- a/vllm_ascend/ascend_forward_context.py +++ b/vllm_ascend/ascend_forward_context.py @@ -43,6 +43,7 @@ def set_ascend_forward_context( model_instance: torch.nn.Module = None, is_draft_model=False, skip_compiled: bool = False, + draft_attn_metadatas=None, ): """A context manager that stores the current forward context, can be attention metadata, etc. @@ -61,6 +62,7 @@ def set_ascend_forward_context( with set_forward_context(**forward_context_kwargs): forward_context = get_forward_context() + forward_context.draft_attn_metadatas = draft_attn_metadatas from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 761ca013..8864155b 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -41,7 +41,7 @@ from vllm_ascend.ops.rotary_embedding import update_cos_sin from vllm_ascend.ops.triton.spec_decode.utils import \ prepare_inputs_padded_kernel from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num -from vllm_ascend.utils import enable_sp, shared_expert_dp_enabled +from vllm_ascend.utils import enable_sp, shared_expert_dp_enabled, lmhead_tp_enable # Currently we will fix block size to a small one since `num_reqs` can't be too large _PREPARE_INPUTS_BLOCK_SIZE = 4 @@ -323,6 +323,13 @@ class EagleProposer(VllmEagleProposer): batch_descriptor=None, dummy_compute_logits=lambda hidden_states: None, is_profile=False): + ( + num_tokens, + num_tokens_across_dp, + _, + ) = self.runner._sync_metadata_across_dp(num_tokens, + is_draft_model=True) + # update global cos, sin update_cos_sin(self._get_positions(num_tokens)) @@ -380,12 +387,7 @@ class EagleProposer(VllmEagleProposer): model_previous_hidden_states = self.hidden_states[:num_tokens] batch_size = num_tokens // (self.num_speculative_tokens + 1) - ( - num_tokens, - num_tokens_across_dp, - _, - ) = self.runner._sync_metadata_across_dp(num_tokens, - is_draft_model=True) + with set_ascend_forward_context( multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None, self.vllm_config, @@ -395,7 +397,8 @@ class EagleProposer(VllmEagleProposer): in_profile_run=is_profile, batch_descriptor=batch_descriptor, aclgraph_runtime_mode=aclgraph_runtime_mode, - is_draft_model=True): + is_draft_model=True, + draft_attn_metadatas=multi_steps_attn_metadata): self._runnable( num_input_tokens=num_tokens, @@ -405,6 +408,7 @@ class EagleProposer(VllmEagleProposer): target_positions=model_positions, inputs_embeds=None, multi_steps_attn_metadata=multi_steps_attn_metadata, + is_dummy=True, ) forward_context = get_forward_context() if (forward_context.cudagraph_runtime_mode @@ -461,6 +465,13 @@ class EagleProposer(VllmEagleProposer): else: num_input_tokens = num_tokens + ( + num_input_tokens, + num_tokens_across_dp, + _, + ) = self.runner._sync_metadata_across_dp(num_input_tokens, + is_draft_model=True) + has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0 if self.use_cuda_graph: aclgraph_runtime_mode, batch_descriptor = \ @@ -498,7 +509,7 @@ class EagleProposer(VllmEagleProposer): common_attn_metadata.slot_mapping[:slot_mapping_lens]) self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1) common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens] - + common_attn_metadata.num_input_tokens = num_input_tokens # FIXME(woosuk): The below two ops cause synchronization. Optimize. builder = self.runner.attn_groups[0][0].get_metadata_builder() attn_metadata = builder.build(0, common_attn_metadata, @@ -537,12 +548,6 @@ class EagleProposer(VllmEagleProposer): self.last_token_indices[:last_token_indices_len].copy_( last_token_indices) - ( - num_input_tokens, - num_tokens_across_dp, - _, - ) = self.runner._sync_metadata_across_dp(num_input_tokens, - is_draft_model=True) with set_ascend_forward_context( multi_steps_attn_metadata[0], self.vllm_config, @@ -551,7 +556,8 @@ class EagleProposer(VllmEagleProposer): num_actual_tokens=num_tokens, batch_descriptor=batch_descriptor, aclgraph_runtime_mode=aclgraph_runtime_mode, - is_draft_model=True): + is_draft_model=True, + draft_attn_metadatas=multi_steps_attn_metadata): draft_token_ids = self._runnable( num_input_tokens=num_input_tokens, @@ -575,6 +581,7 @@ class EagleProposer(VllmEagleProposer): target_positions, inputs_embeds, multi_steps_attn_metadata, + is_dummy=False, ) -> torch.Tensor: # The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings. # `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model. @@ -585,6 +592,17 @@ class EagleProposer(VllmEagleProposer): model_hidden_states, model_positions = self.maybe_pad_and_reduce( model_hidden_states, model_positions) + # Expend the remaining moe layers for suiting vllm. + forward_context = get_forward_context() + if forward_context and hasattr(forward_context, 'remaining_moe_layers'): + if self.num_speculative_tokens > 1: + moe_layers_needed = len(forward_context.remaining_moe_layers) * self.num_speculative_tokens + if len(forward_context.remaining_moe_layers) < moe_layers_needed: + original_layers = list(forward_context.remaining_moe_layers) + repeat_count = (moe_layers_needed + len(original_layers) - 1) // len(original_layers) + expanded_layers = original_layers * repeat_count + forward_context.remaining_moe_layers = expanded_layers + ret_hidden_states = self.model( input_ids=model_input_ids, positions=model_positions, @@ -600,8 +618,21 @@ class EagleProposer(VllmEagleProposer): last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( last_hidden_states, model_positions, hidden_states) + num_indices = last_token_indices.shape[0] + if lmhead_tp_enable() and not is_dummy: + max_num_reqs_across_dp = ( + self.vllm_config.scheduler_config.max_num_seqs * + self.runner.uniform_decode_query_len) + last_token_indices = nn.functional.pad( + last_token_indices, (0, max_num_reqs_across_dp - num_indices)) + sample_hidden_states = last_hidden_states[last_token_indices] logits = self.model.compute_logits(sample_hidden_states) + + if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy: + logits = logits[:num_indices] + last_token_indices = last_token_indices[:num_indices] + draft_token_ids = logits.argmax(dim=-1) # Early exit if there is only one draft token to be generated. @@ -699,10 +730,25 @@ class EagleProposer(VllmEagleProposer): last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad( last_hidden_states, model_positions, hidden_states) - hidden_states = hidden_states[:batch_size] - logits = self.model.compute_logits(last_hidden_states[:batch_size]) + num_indices = last_token_indices.shape[0] + if lmhead_tp_enable() and not is_dummy: + max_num_reqs_across_dp = ( + self.vllm_config.scheduler_config.max_num_seqs * + self.runner.uniform_decode_query_len) + last_token_indices = nn.functional.pad( + last_token_indices, + (0, max_num_reqs_across_dp - num_indices), + ) + + sample_hidden_states = last_hidden_states[last_token_indices] + logits = self.model.compute_logits(sample_hidden_states) + + if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy: + logits = logits[:num_indices] + last_token_indices = last_token_indices[:num_indices] # TODO(wenlong): get more than one token for tree attention + hidden_states = hidden_states[:batch_size] draft_token_ids = logits.argmax(dim=-1) draft_token_ids_tensor[draft_step + 1] = draft_token_ids @@ -810,7 +856,7 @@ class EagleProposer(VllmEagleProposer): block_numbers = clamped_positions[0] // block_size else: block_numbers = (clamped_positions // block_size) - block_ids = old_attn_metadata.block_tables.gather( + block_ids = old_common_metadata.block_table_tensor.gather( dim=1, index=block_numbers.view(-1, 1)) block_ids = block_ids.view(-1) if self.uses_mrope: diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 7e279322..5a4326ab 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -37,7 +37,16 @@ class MtpProposer(EagleProposer): batch_descriptor=None, dummy_compute_logits=lambda hidden_states: None, is_profile=False) -> None: - + if ( + self.pcp_size * self.dcp_size == 1 + and not self.speculative_config.disable_padded_drafter_batch + ): + super().dummy_run( + num_tokens, with_prefill, in_graph_capturing, num_reqs, + num_tokens_across_dp, aclgraph_runtime_mode, batch_descriptor, + dummy_compute_logits, is_profile + ) + return ( num_tokens, num_tokens_across_dp, @@ -151,6 +160,19 @@ class MtpProposer(EagleProposer): scheduler_output: SchedulerOutput = None, num_scheduled_tokens: int = 0, ) -> torch.Tensor: + if ( + self.pcp_size * self.dcp_size == 1 + and not self.speculative_config.disable_padded_drafter_batch + ): + draft_token_ids = super()._propose( + target_token_ids, target_positions, target_hidden_states, + next_token_ids, last_token_indices, common_attn_metadata, + sampling_metadata, mm_embed_inputs, req_scheduled_tokens, + long_seq_metadata, num_prefill_reqs, num_decode_reqs, + scheduler_output, num_scheduled_tokens + ) + return draft_token_ids + num_tokens = target_token_ids.shape[0] batch_size = next_token_ids.shape[0] diff --git a/vllm_ascend/worker/model_runner_v1.py b/vllm_ascend/worker/model_runner_v1.py index b4aacea3..fe62d3cf 100644 --- a/vllm_ascend/worker/model_runner_v1.py +++ b/vllm_ascend/worker/model_runner_v1.py @@ -113,13 +113,10 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer from vllm_ascend.spec_decode.mtp_proposer import MtpProposer from vllm_ascend.utils import ( - AscendDeviceType, enable_sp, - get_ascend_device_type, is_drafter_moe_model, is_moe_model, lmhead_tp_enable, - maybe_trans_nz, set_weight_prefetch_method, ) from vllm_ascend.worker.npu_input_batch import NPUInputBatch @@ -140,7 +137,6 @@ if TYPE_CHECKING: else: xgr = LazyLoader("xgr", globals(), "xgrammar") -import torch_npu # if true, allow tensor initialization and casting with internal format (e.g., NZ) torch.npu.config.allow_internal_format = True