[Refactor][EAGLE] 1/N delete __init__ in mtp_proposer (#5176)
### What this PR does / why we need it?
This PR aims to refactor eagle-related modules in vllm-ascend.
This is the starting PR of eagle refactoring. Provided with vllm-eagle,
ascend-eagle and ascend-mtp, we first let ascend-mtp inherit from
ascend-eagle and let ascend-eagle inherit from vllm-eagle. As a
initialization, we just delete `__init__` in mtp_proposer and simplify
the corresponding logic in eagle_proposer.
Based on "vllm-eagle <----- ascend-eagle <----- ascend-mtp", our target
is to gradually delete ascend-mtp and enable ascend-eagle to converge to
vllm-eagle. So the main workspace is eagle_proposer. In this way, we
hope that contributors can concurrently refactor eagle.
Incoming changes:
1. delete common methods in vllm-eagle & ascend-eagle & ascend-mtp
2. delete `load_model` in mtp_proposer
3. delete `dummy_run` and `propose` in mtp_proposer
4. ......
RFC: #5467
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM version: v0.12.0
- vLLM main:
ad32e3e19c
---------
Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -5,8 +5,8 @@ import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import (CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config, set_current_vllm_config)
|
||||
from vllm.config import (CUDAGraphMode, get_layers_from_vllm_config,
|
||||
set_current_vllm_config)
|
||||
from vllm.distributed import get_pcp_group
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
@@ -20,12 +20,10 @@ from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.utils.torch_utils import set_default_torch_dtype
|
||||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||||
CommonAttentionMetadata)
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.utils import CpuGpuBuffer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
@@ -35,9 +33,8 @@ from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
update_mla_attn_dcp_pcp_params,
|
||||
update_mla_attn_params)
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm_ascend.utils import (ProfileExecuteDuration, lmhead_tp_enable,
|
||||
shared_expert_dp_enabled)
|
||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -64,102 +61,11 @@ def _load_model(architecture):
|
||||
return model
|
||||
|
||||
|
||||
class MtpProposer(Proposer):
|
||||
class MtpProposer(EagleProposer):
|
||||
|
||||
# TODO: Find out why ModelRunner does not this explicit typing?
|
||||
model: Union[nn.Module, ACLGraphWrapper]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vllm_config: VllmConfig,
|
||||
device,
|
||||
runner,
|
||||
):
|
||||
self.name = SpecDcodeType.MTP
|
||||
self.vllm_config = vllm_config
|
||||
self.speculative_config = vllm_config.speculative_config
|
||||
assert self.speculative_config is not None
|
||||
self.draft_model_config = self.speculative_config.draft_model_config
|
||||
self.method = self.speculative_config.method
|
||||
|
||||
self.runner = runner
|
||||
self.device = device
|
||||
self.dtype = vllm_config.model_config.dtype
|
||||
self.max_model_len = vllm_config.model_config.max_model_len
|
||||
self.block_size = vllm_config.cache_config.block_size
|
||||
self.num_speculative_tokens = self.speculative_config.num_speculative_tokens
|
||||
self.decode_threshold = 1 + self.num_speculative_tokens
|
||||
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
|
||||
self.token_arange_np = np.arange(self.max_num_tokens)
|
||||
# We need to get the hidden size from the draft model config because
|
||||
# the draft model's hidden size can be different from the target model's
|
||||
# hidden size (e.g., Llama 3.3 70B).
|
||||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||||
self.enable_shared_expert_dp = shared_expert_dp_enabled()
|
||||
|
||||
self.pcp_size = self.runner.pcp_size
|
||||
self.dcp_size = self.runner.dcp_size
|
||||
self.pcp_rank = self.runner.pcp_rank
|
||||
self.dcp_rank = self.runner.dcp_rank
|
||||
|
||||
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
||||
self.draft_indexer_metadata_builder: Optional[
|
||||
AttentionMetadataBuilder] = None
|
||||
self.attn_layer_names: list[str] = []
|
||||
self.indexer_layer_names: list[str] = []
|
||||
|
||||
self.use_aclgraph = self.runner._use_aclgraph()
|
||||
|
||||
# persistent buffers for aclgraph graph
|
||||
self.input_ids = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=device)
|
||||
self.uses_mrope = self.vllm_config.model_config.uses_mrope
|
||||
if self.uses_mrope:
|
||||
# M-RoPE need (3, max_num_tokens)
|
||||
self.mrope_positions = torch.zeros((3, self.max_num_tokens),
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
else:
|
||||
# RoPE need (max_num_tokens,)
|
||||
self.positions = torch.zeros(self.max_num_tokens,
|
||||
dtype=torch.int64,
|
||||
device=device)
|
||||
self.hidden_states = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=device)
|
||||
self.full_indices = range(
|
||||
self.runner.max_num_tokens * self.pcp_size * self.dcp_size +
|
||||
self.pcp_size * self.dcp_size * self.runner.max_num_reqs)
|
||||
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||||
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
|
||||
self.arange = torch.arange(max_num_slots_for_arange,
|
||||
device=device,
|
||||
dtype=torch.int32)
|
||||
self.arange_cpu = torch.arange(max_num_slots_for_arange,
|
||||
device="cpu",
|
||||
dtype=torch.int32)
|
||||
|
||||
self.inputs_embeds = torch.zeros(
|
||||
(self.max_num_tokens, self.hidden_size),
|
||||
dtype=self.dtype,
|
||||
device=device)
|
||||
|
||||
self.backup_next_token_ids = CpuGpuBuffer(
|
||||
max_batch_size,
|
||||
dtype=torch.int32,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
device=device,
|
||||
with_numpy=True,
|
||||
)
|
||||
self.use_sparse = hasattr(vllm_config.model_config.hf_config,
|
||||
"index_topk")
|
||||
self.use_async_scheduling = self.vllm_config.scheduler_config.async_scheduling
|
||||
|
||||
def load_model(self, model) -> None:
|
||||
loader = get_model_loader(self.vllm_config.load_config)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user