[Refactor][EAGLE] 1/N delete __init__ in mtp_proposer (#5176)

### What this PR does / why we need it?
This PR aims to refactor eagle-related modules in vllm-ascend.

This is the starting PR of eagle refactoring. Provided with vllm-eagle,
ascend-eagle and ascend-mtp, we first let ascend-mtp inherit from
ascend-eagle and let ascend-eagle inherit from vllm-eagle. As a
initialization, we just delete `__init__` in mtp_proposer and simplify
the corresponding logic in eagle_proposer.

Based on "vllm-eagle <----- ascend-eagle <----- ascend-mtp", our target
is to gradually delete ascend-mtp and enable ascend-eagle to converge to
vllm-eagle. So the main workspace is eagle_proposer. In this way, we
hope that contributors can concurrently refactor eagle.

Incoming changes:
1. delete common methods in vllm-eagle & ascend-eagle & ascend-mtp
2. delete `load_model` in mtp_proposer
3. delete `dummy_run` and `propose` in mtp_proposer
4. ......

RFC: #5467

### Does this PR introduce _any_ user-facing change?
N/A

### How was this patch tested?
by ci

- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c

---------

Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
Zetong Li
2025-12-29 16:25:52 +08:00
committed by GitHub
parent 28b7614322
commit 92353c0643
4 changed files with 119 additions and 176 deletions

View File

@@ -5,8 +5,8 @@ import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.config import (CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config, set_current_vllm_config)
from vllm.config import (CUDAGraphMode, get_layers_from_vllm_config,
set_current_vllm_config)
from vllm.distributed import get_pcp_group
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import get_forward_context
@@ -20,12 +20,10 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils.math_utils import cdiv
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.utils.torch_utils import set_default_torch_dtype
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
@@ -35,9 +33,8 @@ from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
update_mla_attn_dcp_pcp_params,
update_mla_attn_params)
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
shared_expert_dp_enabled)
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
logger = init_logger(__name__)
@@ -64,102 +61,11 @@ def _load_model(architecture):
return model
class MtpProposer(Proposer):
class MtpProposer(EagleProposer):
# TODO: Find out why ModelRunner does not this explicit typing?
model: Union[nn.Module, ACLGraphWrapper]
def __init__(
self,
vllm_config: VllmConfig,
device,
runner,
):
self.name = SpecDcodeType.MTP
self.vllm_config = vllm_config
self.speculative_config = vllm_config.speculative_config
assert self.speculative_config is not None
self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method
self.runner = runner
self.device = device
self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
self.decode_threshold = 1 + self.num_speculative_tokens
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
self.token_arange_np = np.arange(self.max_num_tokens)
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.enable_shared_expert_dp = shared_expert_dp_enabled()
self.pcp_size = self.runner.pcp_size
self.dcp_size = self.runner.dcp_size
self.pcp_rank = self.runner.pcp_rank
self.dcp_rank = self.runner.dcp_rank
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
self.draft_indexer_metadata_builder: Optional[
AttentionMetadataBuilder] = None
self.attn_layer_names: list[str] = []
self.indexer_layer_names: list[str] = []
self.use_aclgraph = self.runner._use_aclgraph()
# persistent buffers for aclgraph graph
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=device)
self.uses_mrope = self.vllm_config.model_config.uses_mrope
if self.uses_mrope:
# M-RoPE need (3, max_num_tokens)
self.mrope_positions = torch.zeros((3, self.max_num_tokens),
dtype=torch.int64,
device=device)
else:
# RoPE need (max_num_tokens,)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=device)
self.hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
self.full_indices = range(
self.runner.max_num_tokens * self.pcp_size * self.dcp_size +
self.pcp_size * self.dcp_size * self.runner.max_num_reqs)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_batch_size = vllm_config.scheduler_config.max_num_seqs
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
self.arange = torch.arange(max_num_slots_for_arange,
device=device,
dtype=torch.int32)
self.arange_cpu = torch.arange(max_num_slots_for_arange,
device="cpu",
dtype=torch.int32)
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
self.backup_next_token_ids = CpuGpuBuffer(
max_batch_size,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True,
)
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
"index_topk")
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
def load_model(self, model) -> None:
loader = get_model_loader(self.vllm_config.load_config)