[Refactor][EAGLE] 1/N delete __init__ in mtp_proposer (#5176)
### What this PR does / why we need it?
This PR aims to refactor eagle-related modules in vllm-ascend.
This is the starting PR of eagle refactoring. Provided with vllm-eagle,
ascend-eagle and ascend-mtp, we first let ascend-mtp inherit from
ascend-eagle and let ascend-eagle inherit from vllm-eagle. As a
initialization, we just delete `__init__` in mtp_proposer and simplify
the corresponding logic in eagle_proposer.
Based on "vllm-eagle <----- ascend-eagle <----- ascend-mtp", our target
is to gradually delete ascend-mtp and enable ascend-eagle to converge to
vllm-eagle. So the main workspace is eagle_proposer. In this way, we
hope that contributors can concurrently refactor eagle.
Incoming changes:
1. delete common methods in vllm-eagle & ascend-eagle & ascend-mtp
2. delete `load_model` in mtp_proposer
3. delete `dummy_run` and `propose` in mtp_proposer
4. ......
RFC: #5467
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -18,8 +18,8 @@ from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer as VllmEagleProposer
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
@@ -29,7 +29,7 @@ from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
update_attn_params)
|
||||
from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm_ascend.utils import shared_expert_dp_enabled
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
@@ -38,29 +38,15 @@ _DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
|
||||
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
||||
|
||||
|
||||
class EagleProposer(Proposer):
|
||||
class EagleProposer(VllmEagleProposer):
|
||||
|
||||
def __init__(self,
|
||||
vllm_config: VllmConfig,
|
||||
device: torch.device,
|
||||
runner=None):
|
||||
self.name = SpecDcodeType.EAGLE if vllm_config.speculative_config.method == "eagle" else SpecDcodeType.EAGLE3
|
||||
self.vllm_config = vllm_config
|
||||
self.device = device
|
||||
self.runner = runner
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
self.draft_model_config = self.speculative_config.draft_model_config
|
||||
self.method = self.speculative_config.method
|
||||
self.num_speculative_tokens = vllm_config.speculative_config.num_speculative_tokens
|
||||
super().__init__(vllm_config, device, runner)
|
||||
|
||||
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
|
||||
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
# We need to get the hidden size from the draft model config because
|
||||
# the draft model's hidden size can be different from the target model's
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
self.hidden_size = vllm_config.speculative_config.draft_model_config.get_hidden_size(
|
||||
)
|
||||
|
||||
# there is synchronization between mtp steps when enabling aclgraph,
|
||||
# disable aclgraph when use async scheduling to avoid the
|
||||
# synchronization overhead.
|
||||
@@ -77,45 +63,28 @@ class EagleProposer(Proposer):
|
||||
sorted(
|
||||
self.vllm_config.compilation_config.cudagraph_capture_sizes))
|
||||
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
# Currently we do not use pcp. This is used to adapt the pcp branch.
|
||||
self.pcp_size = 0
|
||||
self.backup_next_token_ids = CpuGpuBuffer(
|
||||
max_batch_size,
|
||||
dtype=torch.int32,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
device=device,
|
||||
with_numpy=True,
|
||||
)
|
||||
self.pcp_size = self.runner.pcp_size
|
||||
self.decode_threshold = 1 + self.num_speculative_tokens
|
||||
|
||||
# persistent buffers for cuda graph
|
||||
self.input_ids = torch.zeros(
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.positions = torch.zeros(
|
||||
self.vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
self.hidden_states = torch.zeros(
|
||||
(self.vllm_config.scheduler_config.max_num_batched_tokens,
|
||||
self.hidden_size),
|
||||
dtype=self.vllm_config.model_config.dtype,
|
||||
device=device)
|
||||
self.max_num_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
self.token_arange_np = np.arange(self.max_num_tokens)
|
||||
max_num_slots_for_arange = max(self.max_num_tokens, max_batch_size + 1)
|
||||
self.arange = torch.arange(max_num_slots_for_arange,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
self.arange_cpu = torch.arange(max_num_slots_for_arange,
|
||||
self.arange_cpu = torch.arange(self.arange.shape[0],
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
self.attn_mask_builder = AttentionMaskBuilder(self.device)
|
||||
self.eagle3_use_aux_hidden_state: bool = (
|
||||
self._get_eagle3_use_aux_hidden_state_from_config())
|
||||
|
||||
self.enable_shared_expert_dp = shared_expert_dp_enabled()
|
||||
|
||||
self.dcp_size = self.runner.dcp_size
|
||||
self.pcp_rank = self.runner.pcp_rank
|
||||
self.dcp_rank = self.runner.dcp_rank
|
||||
|
||||
self.use_aclgraph = self.runner._use_aclgraph()
|
||||
|
||||
self.full_indices = range(
|
||||
self.runner.max_num_tokens * self.pcp_size * self.dcp_size +
|
||||
self.pcp_size * self.dcp_size * self.runner.max_num_reqs)
|
||||
|
||||
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
|
||||
"index_topk")
|
||||
|
||||
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
|
||||
"""
|
||||
@@ -165,7 +134,7 @@ class EagleProposer(Proposer):
|
||||
# share lm_head with the target model if needed
|
||||
# some model definition do not define lm_head explicitly
|
||||
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
|
||||
if self.name == SpecDcodeType.EAGLE and hasattr(model, "lm_head"):
|
||||
if self.method == "eagle" and hasattr(model, "lm_head"):
|
||||
logger.info("Loading EAGLE LM head weights from the target model.")
|
||||
if supports_multimodal(model):
|
||||
self.model.lm_head = model.get_language_model().lm_head
|
||||
@@ -337,7 +306,7 @@ class EagleProposer(Proposer):
|
||||
target_token_ids = self.runner.input_ids.gpu[:
|
||||
num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
if self.name == SpecDcodeType.EAGLE3:
|
||||
if self.method == "eagle3":
|
||||
target_hidden_states = torch.cat(
|
||||
[h[:num_scheduled_tokens] for h in aux_hidden_states],
|
||||
dim=-1)
|
||||
@@ -371,7 +340,7 @@ class EagleProposer(Proposer):
|
||||
else:
|
||||
target_token_ids = self.runner.input_ids.gpu[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
if self.name == SpecDcodeType.EAGLE3:
|
||||
if self.method == "eagle3":
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states], dim=-1)
|
||||
else:
|
||||
@@ -424,7 +393,7 @@ class EagleProposer(Proposer):
|
||||
if last_token_indices is None:
|
||||
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||
|
||||
if self.name == SpecDcodeType.EAGLE3:
|
||||
if self.method == "eagle3":
|
||||
assert isinstance(self.get_model(), Eagle3LlamaForCausalLM)
|
||||
target_hidden_states = self.model.combine_hidden_states(
|
||||
target_hidden_states)
|
||||
|
||||
Reference in New Issue
Block a user