[Refactor][EAGLE] 3/N delete redundant methods in mtp_proposer (#5420)

### What this PR does / why we need it?
This PR aims to delete redundant methods in mtp_proposer. All the
deleted methods now can be found in eagle_proposer. We also remove some
methods in eagle_proposer since they are identical to those in
vllm-eagle.

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: release/v0.13.0
- vLLM main:
81786c8774

---------

Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
Zetong Li
2026-01-06 16:47:39 +08:00
committed by GitHub
parent b94d589769
commit fe3f2c7702
3 changed files with 97 additions and 483 deletions

View File

@@ -4,6 +4,7 @@ from typing import Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
@@ -15,6 +16,7 @@ from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.triton_utils import HAS_TRITON, triton
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
@@ -28,7 +30,10 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
update_attn_params)
update_attn_dcp_pcp_params,
update_attn_params,
update_mla_attn_dcp_pcp_params,
update_mla_attn_params)
from vllm_ascend.ops.rotary_embedding import update_cos_sin
from vllm_ascend.ops.triton.spec_decode.utils import \
prepare_inputs_padded_kernel
@@ -37,10 +42,6 @@ from vllm_ascend.utils import shared_expert_dp_enabled
PADDING_SLOT_ID = -1
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
# Currently we will fix block size to a small one since `num_reqs` can't be too large
_PREPARE_INPUTS_BLOCK_SIZE = 4
@@ -93,27 +94,6 @@ class EagleProposer(VllmEagleProposer):
self.use_sparse = hasattr(vllm_config.model_config.hf_text_config,
"index_topk")
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
"""
NOTE(2025-12-18): This is an explicit copy from vLLM EagleProposer, only added
to align with its logics.
Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
hidden states and directly uses the last layer output just like eagle1.
They might indicate this by setting "use_aux_hidden_state" to False
inside the "eagle_config" dict of their hf_config.
"""
if self.method != "eagle3":
return False
# Assume that eagle3 heads use aux hidden states by default
use_aux_hidden_state = True
eagle_config = getattr(self.draft_model_config.hf_config,
"eagle_config", None)
if eagle_config is not None:
use_aux_hidden_state = eagle_config.get("use_aux_hidden_state",
True)
return use_aux_hidden_state
def load_model(self, model: nn.Module) -> None:
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config,
@@ -512,48 +492,6 @@ class EagleProposer(VllmEagleProposer):
draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1)
return draft_token_ids
def _get_attn_metadata(self, attn_metadata):
if attn_metadata is not None and isinstance(attn_metadata, dict):
architecture = self.vllm_config.model_config.architecture
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
attn_metadata = attn_metadata[layer_name]
return attn_metadata
def prepare_next_token_ids_cpu(
self,
sampled_token_ids: list[list[int]],
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
num_scheduled_tokens: dict[str, int],
) -> torch.Tensor:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids for each request based on the sampled
token ids from the CPU. If a request has no sampled token ids (e.g.,
during the initial decoding steps), it falls back to using the request
state to get the next token id.
"""
req_ids = gpu_input_batch.req_ids
next_token_ids: list[int] = []
for i, token_ids in enumerate(sampled_token_ids):
if token_ids:
# Common case.
next_token_id = token_ids[-1]
else:
# Partial prefill (rare case).
# Get the next token id from the request state.
req_id = req_ids[i]
req_state = requests[req_id]
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[
req_id]
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.input_ids.device)
return next_token_ids
def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
@@ -829,3 +767,92 @@ class EagleProposer(VllmEagleProposer):
max_seq_len=0)
return spec_common_attn_metadata, token_indices, token_indices_to_sample
def _split_pcp_input(self, req_scheduled_tokens, input_ids,
target_hidden_states):
"""
Split prefill input_ids and target_hidden_states in pcp group.
1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad]
2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5]
3. split target_hidden_states (already include pcp padding):
[h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5]
4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split.
"""
if len(req_scheduled_tokens) == 0:
# no prefill inputs to split, return empty result
return (
0,
torch.zeros([0], device='npu'),
torch.zeros([0, target_hidden_states.size(1)], device='npu'),
0,
torch.zeros([0]),
torch.tensor([0], dtype=torch.int32),
)
def _pcp_pad_and_split(num_tokens):
num_pcp_padded_scheduled_tokens = cdiv(
num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size
pcp_pad = num_pcp_padded_scheduled_tokens - num_tokens
chunk_size = num_pcp_padded_scheduled_tokens // (2 * self.pcp_size)
# split position_ids (and use split position_ids to split input_ids afterwards)
req_position_cp: list[int] = []
req_position_cp.extend(
self.full_indices[self.pcp_rank *
chunk_size:(self.pcp_rank + 1) * chunk_size])
req_position_cp.extend(
self.full_indices[num_pcp_padded_scheduled_tokens -
(self.pcp_rank + 1) *
chunk_size:num_pcp_padded_scheduled_tokens -
self.pcp_rank * chunk_size])
return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad
num_pcp_scheduled_tokens = []
ori_start_index = 0
pad_start_index = 0
pcp_split_input_ids_list = []
pcp_split_hidden_states_list = []
for ori_num_tokens in req_scheduled_tokens.values():
req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = \
_pcp_pad_and_split(ori_num_tokens)
actual_num_tokens = len(req_position_pcp)
num_pcp_scheduled_tokens.append(actual_num_tokens)
pad_input_ids = F.pad(
input_ids[ori_start_index:ori_start_index + ori_num_tokens],
(0, num_pcp_pad))
ori_start_index += ori_num_tokens
pcp_chunk_indices = [
pad_start_index + pos for pos in req_position_pcp
]
pcp_split_input_ids = pad_input_ids[req_position_pcp]
pcp_split_hidden_states = target_hidden_states[pcp_chunk_indices]
pcp_split_input_ids_list.append(pcp_split_input_ids)
pcp_split_hidden_states_list.append(pcp_split_hidden_states)
pad_start_index += num_pcp_padded_scheduled_tokens
num_tokens = sum(num_pcp_scheduled_tokens)
input_ids = torch.cat(pcp_split_input_ids_list)
target_hidden_states = torch.cat(pcp_split_hidden_states_list, dim=0)
max_query_len = max(num_pcp_scheduled_tokens)
seq_lens = torch.tensor(num_pcp_scheduled_tokens, dtype=torch.int32)
cu_num_tokens = torch.tensor(
np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0))
return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens
# update full-graph params for one spec token
def _update_full_graph_params(self, forward_context, num_tokens):
if self.vllm_config.model_config.use_mla:
if self.pcp_size * self.dcp_size > 1:
update_mla_attn_dcp_pcp_params(self.update_stream,
forward_context, num_tokens)
else:
update_mla_attn_params(self.update_stream, forward_context,
num_tokens,
self.vllm_config.speculative_config)
else:
if self.pcp_size * self.dcp_size > 1:
update_attn_dcp_pcp_params(self.update_stream, forward_context,
num_tokens)
else:
update_attn_params(self.update_stream, forward_context,
num_tokens, self.vllm_config)