[Refactor][EAGLE] 6/N route mtp to eagle except pcp/dcp+mtp (#6349)

### What this PR does / why we need it?

Overview: This pull request refactors speculative decoding for Eagle and
MTP proposers on Ascend hardware. It fixes a bug related to
draft_attn_metadatas being lost, migrates the lmhead feature, and adds
routing logic in MtpProposer.

Details:
1. Migrated the lmhead feature from mtp to eagle and normalized it in
eagle_proposer.
2. Fixed the bug where draft_attn_metadatas was lost after enabling
eagle mode in the merge graph.
3. Added the routing for pcp and disable padded drafter batch; in mtp
mode, if pcp and disable padded drafter batch are not enabled, the
normalized file eagle_proposer will be used.

RFC: https://github.com/vllm-project/vllm-ascend/issues/5467

### Does this PR introduce _any_ user-facing change?
No

### How was this patch tested?
ut and test

- vLLM version: v0.14.1
- vLLM main:
dc917cceb8

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2026-02-02 19:15:31 +08:00
committed by GitHub
parent c08364f761
commit 7932255c06
4 changed files with 90 additions and 24 deletions

View File

@@ -43,6 +43,7 @@ def set_ascend_forward_context(
model_instance: torch.nn.Module = None,
is_draft_model=False,
skip_compiled: bool = False,
draft_attn_metadatas=None,
):
"""A context manager that stores the current forward context,
can be attention metadata, etc.
@@ -61,6 +62,7 @@ def set_ascend_forward_context(
with set_forward_context(**forward_context_kwargs):
forward_context = get_forward_context()
forward_context.draft_attn_metadatas = draft_attn_metadatas
from vllm_ascend.ops.fused_moe.moe_comm_method import get_moe_comm_method

View File

@@ -41,7 +41,7 @@ from vllm_ascend.ops.rotary_embedding import update_cos_sin
from vllm_ascend.ops.triton.spec_decode.utils import \
prepare_inputs_padded_kernel
from vllm_ascend.ops.triton.triton_utils import get_vectorcore_num
from vllm_ascend.utils import enable_sp, shared_expert_dp_enabled
from vllm_ascend.utils import enable_sp, shared_expert_dp_enabled, lmhead_tp_enable
# Currently we will fix block size to a small one since `num_reqs` can't be too large
_PREPARE_INPUTS_BLOCK_SIZE = 4
@@ -323,6 +323,13 @@ class EagleProposer(VllmEagleProposer):
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None,
is_profile=False):
(
num_tokens,
num_tokens_across_dp,
_,
) = self.runner._sync_metadata_across_dp(num_tokens,
is_draft_model=True)
# update global cos, sin
update_cos_sin(self._get_positions(num_tokens))
@@ -380,12 +387,7 @@ class EagleProposer(VllmEagleProposer):
model_previous_hidden_states = self.hidden_states[:num_tokens]
batch_size = num_tokens // (self.num_speculative_tokens + 1)
(
num_tokens,
num_tokens_across_dp,
_,
) = self.runner._sync_metadata_across_dp(num_tokens,
is_draft_model=True)
with set_ascend_forward_context(
multi_steps_attn_metadata[0] if multi_steps_attn_metadata else None,
self.vllm_config,
@@ -395,7 +397,8 @@ class EagleProposer(VllmEagleProposer):
in_profile_run=is_profile,
batch_descriptor=batch_descriptor,
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
is_draft_model=True,
draft_attn_metadatas=multi_steps_attn_metadata):
self._runnable(
num_input_tokens=num_tokens,
@@ -405,6 +408,7 @@ class EagleProposer(VllmEagleProposer):
target_positions=model_positions,
inputs_embeds=None,
multi_steps_attn_metadata=multi_steps_attn_metadata,
is_dummy=True,
)
forward_context = get_forward_context()
if (forward_context.cudagraph_runtime_mode
@@ -461,6 +465,13 @@ class EagleProposer(VllmEagleProposer):
else:
num_input_tokens = num_tokens
(
num_input_tokens,
num_tokens_across_dp,
_,
) = self.runner._sync_metadata_across_dp(num_input_tokens,
is_draft_model=True)
has_lora = len(self.runner.input_batch.lora_id_to_lora_request) > 0
if self.use_cuda_graph:
aclgraph_runtime_mode, batch_descriptor = \
@@ -498,7 +509,7 @@ class EagleProposer(VllmEagleProposer):
common_attn_metadata.slot_mapping[:slot_mapping_lens])
self.slot_mapping_group[0][slot_mapping_lens:].fill_(-1)
common_attn_metadata.slot_mapping = self.slot_mapping_group[0][:slot_mapping_lens]
common_attn_metadata.num_input_tokens = num_input_tokens
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
builder = self.runner.attn_groups[0][0].get_metadata_builder()
attn_metadata = builder.build(0, common_attn_metadata,
@@ -537,12 +548,6 @@ class EagleProposer(VllmEagleProposer):
self.last_token_indices[:last_token_indices_len].copy_(
last_token_indices)
(
num_input_tokens,
num_tokens_across_dp,
_,
) = self.runner._sync_metadata_across_dp(num_input_tokens,
is_draft_model=True)
with set_ascend_forward_context(
multi_steps_attn_metadata[0],
self.vllm_config,
@@ -551,7 +556,8 @@ class EagleProposer(VllmEagleProposer):
num_actual_tokens=num_tokens,
batch_descriptor=batch_descriptor,
aclgraph_runtime_mode=aclgraph_runtime_mode,
is_draft_model=True):
is_draft_model=True,
draft_attn_metadatas=multi_steps_attn_metadata):
draft_token_ids = self._runnable(
num_input_tokens=num_input_tokens,
@@ -575,6 +581,7 @@ class EagleProposer(VllmEagleProposer):
target_positions,
inputs_embeds,
multi_steps_attn_metadata,
is_dummy=False,
) -> torch.Tensor:
# The lifecycle of `input_ids`, `positions`, `hidden_states` runs through all speculative tokens' proposings.
# `model_input_ids`, `model_positions` and `model_hidden_states` are used to represent the inputs of speculative model.
@@ -585,6 +592,17 @@ class EagleProposer(VllmEagleProposer):
model_hidden_states, model_positions = self.maybe_pad_and_reduce(
model_hidden_states, model_positions)
# Expend the remaining moe layers for suiting vllm.
forward_context = get_forward_context()
if forward_context and hasattr(forward_context, 'remaining_moe_layers'):
if self.num_speculative_tokens > 1:
moe_layers_needed = len(forward_context.remaining_moe_layers) * self.num_speculative_tokens
if len(forward_context.remaining_moe_layers) < moe_layers_needed:
original_layers = list(forward_context.remaining_moe_layers)
repeat_count = (moe_layers_needed + len(original_layers) - 1) // len(original_layers)
expanded_layers = original_layers * repeat_count
forward_context.remaining_moe_layers = expanded_layers
ret_hidden_states = self.model(
input_ids=model_input_ids,
positions=model_positions,
@@ -600,8 +618,21 @@ class EagleProposer(VllmEagleProposer):
last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad(
last_hidden_states, model_positions, hidden_states)
num_indices = last_token_indices.shape[0]
if lmhead_tp_enable() and not is_dummy:
max_num_reqs_across_dp = (
self.vllm_config.scheduler_config.max_num_seqs *
self.runner.uniform_decode_query_len)
last_token_indices = nn.functional.pad(
last_token_indices, (0, max_num_reqs_across_dp - num_indices))
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy:
logits = logits[:num_indices]
last_token_indices = last_token_indices[:num_indices]
draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated.
@@ -699,10 +730,25 @@ class EagleProposer(VllmEagleProposer):
last_hidden_states, model_positions, hidden_states = self.maybe_all_gather_and_unpad(
last_hidden_states, model_positions, hidden_states)
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size])
num_indices = last_token_indices.shape[0]
if lmhead_tp_enable() and not is_dummy:
max_num_reqs_across_dp = (
self.vllm_config.scheduler_config.max_num_seqs *
self.runner.uniform_decode_query_len)
last_token_indices = nn.functional.pad(
last_token_indices,
(0, max_num_reqs_across_dp - num_indices),
)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states)
if lmhead_tp_enable() and num_indices < logits.shape[0] and not is_dummy:
logits = logits[:num_indices]
last_token_indices = last_token_indices[:num_indices]
# TODO(wenlong): get more than one token for tree attention
hidden_states = hidden_states[:batch_size]
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_tensor[draft_step + 1] = draft_token_ids
@@ -810,7 +856,7 @@ class EagleProposer(VllmEagleProposer):
block_numbers = clamped_positions[0] // block_size
else:
block_numbers = (clamped_positions // block_size)
block_ids = old_attn_metadata.block_tables.gather(
block_ids = old_common_metadata.block_table_tensor.gather(
dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
if self.uses_mrope:

View File

@@ -37,7 +37,16 @@ class MtpProposer(EagleProposer):
batch_descriptor=None,
dummy_compute_logits=lambda hidden_states: None,
is_profile=False) -> None:
if (
self.pcp_size * self.dcp_size == 1
and not self.speculative_config.disable_padded_drafter_batch
):
super().dummy_run(
num_tokens, with_prefill, in_graph_capturing, num_reqs,
num_tokens_across_dp, aclgraph_runtime_mode, batch_descriptor,
dummy_compute_logits, is_profile
)
return
(
num_tokens,
num_tokens_across_dp,
@@ -151,6 +160,19 @@ class MtpProposer(EagleProposer):
scheduler_output: SchedulerOutput = None,
num_scheduled_tokens: int = 0,
) -> torch.Tensor:
if (
self.pcp_size * self.dcp_size == 1
and not self.speculative_config.disable_padded_drafter_batch
):
draft_token_ids = super()._propose(
target_token_ids, target_positions, target_hidden_states,
next_token_ids, last_token_indices, common_attn_metadata,
sampling_metadata, mm_embed_inputs, req_scheduled_tokens,
long_seq_metadata, num_prefill_reqs, num_decode_reqs,
scheduler_output, num_scheduled_tokens
)
return draft_token_ids
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]

View File

@@ -113,13 +113,10 @@ from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.spec_decode.medusa_proposer import MedusaProposer
from vllm_ascend.spec_decode.mtp_proposer import MtpProposer
from vllm_ascend.utils import (
AscendDeviceType,
enable_sp,
get_ascend_device_type,
is_drafter_moe_model,
is_moe_model,
lmhead_tp_enable,
maybe_trans_nz,
set_weight_prefetch_method,
)
from vllm_ascend.worker.npu_input_batch import NPUInputBatch
@@ -140,7 +137,6 @@ if TYPE_CHECKING:
else:
xgr = LazyLoader("xgr", globals(), "xgrammar")
import torch_npu
# if true, allow tensor initialization and casting with internal format (e.g., NZ)
torch.npu.config.allow_internal_format = True