[Refactor][EAGLE] 1/N delete __init__ in mtp_proposer (#5176)

### What this PR does / why we need it?
This PR aims to refactor eagle-related modules in vllm-ascend.

This is the starting PR of eagle refactoring. Provided with vllm-eagle,
ascend-eagle and ascend-mtp, we first let ascend-mtp inherit from
ascend-eagle and let ascend-eagle inherit from vllm-eagle. As a
initialization, we just delete `__init__` in mtp_proposer and simplify
the corresponding logic in eagle_proposer.

Based on "vllm-eagle <----- ascend-eagle <----- ascend-mtp", our target
is to gradually delete ascend-mtp and enable ascend-eagle to converge to
vllm-eagle. So the main workspace is eagle_proposer. In this way, we
hope that contributors can concurrently refactor eagle.

Incoming changes:
1. delete common methods in vllm-eagle & ascend-eagle & ascend-mtp
2. delete `load_model` in mtp_proposer
3. delete `dummy_run` and `propose` in mtp_proposer
4. ......

RFC: #5467

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
Zetong Li
2025-12-29 16:25:52 +08:00
committed by GitHub
parent 28b7614322
commit 92353c0643
4 changed files with 119 additions and 176 deletions

View File

@@ -18,8 +18,8 @@ from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import EagleProposer as VllmEagleProposer
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
@@ -29,7 +29,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
update_attn_params)
from vllm_ascend.ops.rotary_embedding import update_cos_sin
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.utils import shared_expert_dp_enabled
PADDING_SLOT_ID = -1
@@ -38,29 +38,15 @@ _DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
class EagleProposer(Proposer):
class EagleProposer(VllmEagleProposer):
def __init__(self,
vllm_config: VllmConfig,
device: torch.device,
runner=None):
self.name = SpecDcodeType.EAGLE if vllm_config.speculative_config.method == "eagle" else SpecDcodeType.EAGLE3
self.vllm_config = vllm_config
self.device = device
self.runner = runner
self.speculative_config = vllm_config.speculative_config
self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method
self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
super().__init__(vllm_config, device, runner)
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
self.block_size = vllm_config.cache_config.block_size
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size(
)
# there is synchronization between mtp steps when enabling aclgraph,
# disable aclgraph when use async scheduling to avoid the
# synchronization overhead.
@@ -77,45 +63,28 @@ class EagleProposer(Proposer):
sorted(
self.vllm_config.compilation_config.cudagraph_capture_sizes))
max_batch_size = vllm_config.scheduler_config.max_num_seqs
# Currently we do not use pcp. This is used to adapt the pcp branch.
self.pcp_size = 0
self.backup_next_token_ids = CpuGpuBuffer(
max_batch_size,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True,
)
self.pcp_size = self.runner.pcp_size
self.decode_threshold = 1 + self.num_speculative_tokens
# persistent buffers for cuda graph
self.input_ids = torch.zeros(
self.vllm_config.scheduler_config.max_num_batched_tokens,
dtype=torch.int32,
device=device)
self.positions = torch.zeros(
self.vllm_config.scheduler_config.max_num_batched_tokens,
dtype=torch.int64,
device=device)
self.hidden_states = torch.zeros(
(self.vllm_config.scheduler_config.max_num_batched_tokens,
self.hidden_size),
dtype=self.vllm_config.model_config.dtype,
device=device)
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
self.token_arange_np = np.arange(self.max_num_tokens)
max_num_slots_for_arange = max(self.max_num_tokens, max_batch_size + 1)
self.arange = torch.arange(max_num_slots_for_arange,
device=device,
dtype=torch.int32)
self.arange_cpu = torch.arange(max_num_slots_for_arange,
self.arange_cpu = torch.arange(self.arange.shape[0],
device="cpu",
dtype=torch.int32)
self.attn_mask_builder = AttentionMaskBuilder(self.device)
self.eagle3_use_aux_hidden_state: bool = (
self._get_eagle3_use_aux_hidden_state_from_config())
self.enable_shared_expert_dp = shared_expert_dp_enabled()
self.dcp_size = self.runner.dcp_size
self.pcp_rank = self.runner.pcp_rank
self.dcp_rank = self.runner.dcp_rank
self.use_aclgraph = self.runner._use_aclgraph()
self.full_indices = range(
self.runner.max_num_tokens * self.pcp_size * self.dcp_size +
self.pcp_size * self.dcp_size * self.runner.max_num_reqs)
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
"index_topk")
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
"""
@@ -165,7 +134,7 @@ class EagleProposer(Proposer):
# share lm_head with the target model if needed
# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.name == SpecDcodeType.EAGLE and hasattr(model, "lm_head"):
if self.method == "eagle" and hasattr(model, "lm_head"):
logger.info("Loading EAGLE LM head weights from the target model.")
if supports_multimodal(model):
self.model.lm_head = model.get_language_model().lm_head
@@ -337,7 +306,7 @@ class EagleProposer(Proposer):
target_token_ids = self.runner.input_ids.gpu[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.name == SpecDcodeType.EAGLE3:
if self.method == "eagle3":
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
@@ -371,7 +340,7 @@ class EagleProposer(Proposer):
else:
target_token_ids = self.runner.input_ids.gpu[token_indices]
target_positions = positions[token_indices]
if self.name == SpecDcodeType.EAGLE3:
if self.method == "eagle3":
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
@@ -424,7 +393,7 @@ class EagleProposer(Proposer):
if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
if self.name == SpecDcodeType.EAGLE3:
if self.method == "eagle3":
assert isinstance(self.get_model(), Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)