[Refactor][EAGLE] 3/N delete redundant methods in mtp_proposer (#5420)
### What this PR does / why we need it?
This PR aims to delete redundant methods in mtp_proposer. All the
deleted methods now can be found in eagle_proposer. We also remove some
methods in eagle_proposer since they are identical to those in
vllm-eagle.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM version: release/v0.13.0
- vLLM main:
81786c8774
---------
Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
@@ -15,6 +16,7 @@ from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -28,7 +30,10 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
update_attn_params)
|
||||
update_attn_dcp_pcp_params,
|
||||
update_attn_params,
|
||||
update_mla_attn_dcp_pcp_params,
|
||||
update_mla_attn_params)
|
||||
from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
||||
from vllm_ascend.ops.triton.spec_decode.utils import \
|
||||
prepare_inputs_padded_kernel
|
||||
@@ -37,10 +42,6 @@ from vllm_ascend.utils import shared_expert_dp_enabled
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
|
||||
|
||||
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
||||
|
||||
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
||||
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
||||
|
||||
@@ -93,27 +94,6 @@ class EagleProposer(VllmEagleProposer):
|
||||
self.use_sparse = hasattr(vllm_config.model_config.hf_text_config,
|
||||
"index_topk")
|
||||
|
||||
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
|
||||
"""
|
||||
NOTE(2025-12-18): This is an explicit copy from vLLM EagleProposer, only added
|
||||
to align with its logics.
|
||||
|
||||
Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
|
||||
hidden states and directly uses the last layer output just like eagle1.
|
||||
They might indicate this by setting "use_aux_hidden_state" to False
|
||||
inside the "eagle_config" dict of their hf_config.
|
||||
"""
|
||||
if self.method != "eagle3":
|
||||
return False
|
||||
# Assume that eagle3 heads use aux hidden states by default
|
||||
use_aux_hidden_state = True
|
||||
eagle_config = getattr(self.draft_model_config.hf_config,
|
||||
"eagle_config", None)
|
||||
if eagle_config is not None:
|
||||
use_aux_hidden_state = eagle_config.get("use_aux_hidden_state",
|
||||
True)
|
||||
return use_aux_hidden_state
|
||||
|
||||
def load_model(self, model: nn.Module) -> None:
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config,
|
||||
@@ -512,48 +492,6 @@ class EagleProposer(VllmEagleProposer):
|
||||
draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1)
|
||||
return draft_token_ids
|
||||
|
||||
def _get_attn_metadata(self, attn_metadata):
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
architecture = self.vllm_config.model_config.architecture
|
||||
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def prepare_next_token_ids_cpu(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It calculates the next token ids for each request based on the sampled
|
||||
token ids from the CPU. If a request has no sampled token ids (e.g.,
|
||||
during the initial decoding steps), it falls back to using the request
|
||||
state to get the next token id.
|
||||
"""
|
||||
req_ids = gpu_input_batch.req_ids
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(sampled_token_ids):
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
req_id = req_ids[i]
|
||||
req_state = requests[req_id]
|
||||
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[
|
||||
req_id]
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device)
|
||||
return next_token_ids
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
@@ -829,3 +767,92 @@ class EagleProposer(VllmEagleProposer):
|
||||
max_seq_len=0)
|
||||
|
||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||
|
||||
def _split_pcp_input(self, req_scheduled_tokens, input_ids,
|
||||
target_hidden_states):
|
||||
"""
|
||||
Split prefill input_ids and target_hidden_states in pcp group.
|
||||
1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad]
|
||||
2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5]
|
||||
3. split target_hidden_states (already include pcp padding):
|
||||
[h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5]
|
||||
4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split.
|
||||
"""
|
||||
if len(req_scheduled_tokens) == 0:
|
||||
# no prefill inputs to split, return empty result
|
||||
return (
|
||||
0,
|
||||
torch.zeros([0], device='npu'),
|
||||
torch.zeros([0, target_hidden_states.size(1)], device='npu'),
|
||||
0,
|
||||
torch.zeros([0]),
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
)
|
||||
|
||||
def _pcp_pad_and_split(num_tokens):
|
||||
num_pcp_padded_scheduled_tokens = cdiv(
|
||||
num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size
|
||||
pcp_pad = num_pcp_padded_scheduled_tokens - num_tokens
|
||||
chunk_size = num_pcp_padded_scheduled_tokens // (2 * self.pcp_size)
|
||||
|
||||
# split position_ids (and use split position_ids to split input_ids afterwards)
|
||||
req_position_cp: list[int] = []
|
||||
req_position_cp.extend(
|
||||
self.full_indices[self.pcp_rank *
|
||||
chunk_size:(self.pcp_rank + 1) * chunk_size])
|
||||
req_position_cp.extend(
|
||||
self.full_indices[num_pcp_padded_scheduled_tokens -
|
||||
(self.pcp_rank + 1) *
|
||||
chunk_size:num_pcp_padded_scheduled_tokens -
|
||||
self.pcp_rank * chunk_size])
|
||||
|
||||
return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad
|
||||
|
||||
num_pcp_scheduled_tokens = []
|
||||
ori_start_index = 0
|
||||
pad_start_index = 0
|
||||
pcp_split_input_ids_list = []
|
||||
pcp_split_hidden_states_list = []
|
||||
for ori_num_tokens in req_scheduled_tokens.values():
|
||||
req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = \
|
||||
_pcp_pad_and_split(ori_num_tokens)
|
||||
actual_num_tokens = len(req_position_pcp)
|
||||
num_pcp_scheduled_tokens.append(actual_num_tokens)
|
||||
pad_input_ids = F.pad(
|
||||
input_ids[ori_start_index:ori_start_index + ori_num_tokens],
|
||||
(0, num_pcp_pad))
|
||||
ori_start_index += ori_num_tokens
|
||||
pcp_chunk_indices = [
|
||||
pad_start_index + pos for pos in req_position_pcp
|
||||
]
|
||||
pcp_split_input_ids = pad_input_ids[req_position_pcp]
|
||||
pcp_split_hidden_states = target_hidden_states[pcp_chunk_indices]
|
||||
pcp_split_input_ids_list.append(pcp_split_input_ids)
|
||||
pcp_split_hidden_states_list.append(pcp_split_hidden_states)
|
||||
pad_start_index += num_pcp_padded_scheduled_tokens
|
||||
num_tokens = sum(num_pcp_scheduled_tokens)
|
||||
input_ids = torch.cat(pcp_split_input_ids_list)
|
||||
target_hidden_states = torch.cat(pcp_split_hidden_states_list, dim=0)
|
||||
max_query_len = max(num_pcp_scheduled_tokens)
|
||||
seq_lens = torch.tensor(num_pcp_scheduled_tokens, dtype=torch.int32)
|
||||
cu_num_tokens = torch.tensor(
|
||||
np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0))
|
||||
return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens
|
||||
|
||||
# update full-graph params for one spec token
|
||||
def _update_full_graph_params(self, forward_context, num_tokens):
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_mla_attn_dcp_pcp_params(self.update_stream,
|
||||
forward_context, num_tokens)
|
||||
else:
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
num_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
else:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_attn_dcp_pcp_params(self.update_stream, forward_context,
|
||||
num_tokens)
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
num_tokens, self.vllm_config)
|
||||
|
||||
Reference in New Issue
Block a user