[Refactor][EAGLE] 2/N: load model and generate token (#5437)

### What this PR does / why we need it?
1. Refactor eagle and mtp function: load_model and generate_token_ids
2. Remove redundant code in mtp and eagle file
3. Refactor the UT of file

2/N of Refactor and merge mtp and eagle
Relational RFC: https://github.com/vllm-project/vllm-ascend/issues/5467

### Does this PR introduce _any_ user-facing change?
no

### How was this patch tested?
ut and tests

- vLLM version: release/v0.13.0
- vLLM main:
81786c8774

---------

Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
lilinsiman
2026-01-05 14:07:54 +08:00
committed by GitHub
parent 50e7934415
commit 52863c4165
8 changed files with 229 additions and 609 deletions

View File

@@ -4,7 +4,6 @@ from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
@@ -13,6 +12,7 @@ from vllm.logger import logger
from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils.platform_utils import is_pin_memory_available
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
@@ -109,25 +109,54 @@ class EagleProposer(VllmEagleProposer):
def load_model(self, model: nn.Module) -> None:
target_attn_layer_names = set(
get_layers_from_vllm_config(self.vllm_config, Attention).keys())
get_layers_from_vllm_config(self.vllm_config,
AttentionLayerBase).keys())
target_indexer_layer_names = set(
get_layers_from_vllm_config(self.vllm_config,
DeepseekV32IndexerCache).keys())
self.model = get_model(vllm_config=self.vllm_config,
model_config=self.vllm_config.
speculative_config.draft_model_config)
draft_attn_layer_names = (get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase).keys() -
target_attn_layer_names)
self.attn_layer_name = next(iter(draft_attn_layer_names))
indexer_layers = get_layers_from_vllm_config(
self.vllm_config, DeepseekV32IndexerCache).keys()
draft_attn_layer = get_layers_from_vllm_config(
self.vllm_config, AttentionLayerBase).keys()
draft_attn_layer_names = draft_attn_layer - target_attn_layer_names
draft_indexer_layer_names = indexer_layers - target_indexer_layer_names
draft_attn_layer_names = draft_attn_layer_names - draft_indexer_layer_names
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = list(draft_attn_layer_names)
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
logger.info(
"The EAGLE head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = model.model.embed_tokens
if self.method == "mtp":
if self.vllm_config.model_config.is_deepseek_mla and \
torch.equal(self.model.model.embed_tokens.weight,
model.model.embed_tokens.weight):
# If pp>1, the weights of mtp and the main model's embedding are not on the same device.
# check if mtp model use main model's embedding and LMhead
logger.info(
"The MTP head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = model.model.embed_tokens
else:
logger.info(
" The MTP head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
)
else:
logger.info(
"The EAGLE head shares the same vocab embedding" \
" with the target model."
)
self.model.model.embed_tokens = model.model.embed_tokens
else:
logger.info(
"Since PP > 1, the EAGLE head loaded its own vocab embedding" \
"Since PP > 1 or other reasons the model head loaded its own vocab embedding" \
" weights instead of sharing them with the target model."
)
@@ -141,6 +170,13 @@ class EagleProposer(VllmEagleProposer):
else:
self.model.lm_head = model.lm_head
if self.method == "mtp" and \
self.vllm_config.model_config.is_deepseek_mla:
for _, layer_module in self.model.model.layers.items():
if torch.equal(layer_module.shared_head.head.weight,
model.lm_head.weight):
layer_module.shared_head.head = model.lm_head
if self.vllm_config.compilation_config.cudagraph_mode.has_full_cudagraphs(
) and self.use_cuda_graph:
self.update_stream = torch.npu.Stream()
@@ -205,7 +241,7 @@ class EagleProposer(VllmEagleProposer):
attn_metadata_eagle = builder.build_for_graph_capture(
common_attn_metadata, AscendAttentionState.ChunkedPrefill)
attn_metadata = {}
for layer_name in [self.attn_layer_name]:
for layer_name in self.attn_layer_name:
attn_metadata[layer_name] = attn_metadata_eagle
for i in range(self.num_speculative_tokens):
if i > 0 and in_graph_capturing and aclgraph_runtime_mode == CUDAGraphMode.FULL:
@@ -235,135 +271,6 @@ class EagleProposer(VllmEagleProposer):
self.vllm_config,
)
def generate_token_ids(self,
sampled_token_ids: torch.Tensor | list[list[int]],
sampling_metadata: SamplingMetadata = None,
scheduler_output: SchedulerOutput = None,
spec_decode_metadata: SpecDecodeMetadata = None,
positions: torch.Tensor = None,
num_scheduled_tokens: int = 0,
hidden_states: torch.Tensor = None,
aux_hidden_states: torch.Tensor = None):
common_attn_metadata = self.runner.spec_decode_common_attn_metadata
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
# When padded-batch is disabled, the sampled_token_ids should be
# the cpu-side list[list[int]] of valid sampled tokens for each
# request, with invalid requests having empty lists.
assert isinstance(sampled_token_ids, list), \
"sampled_token_ids should be a python list when" \
"padded-batch is disabled."
next_token_ids = self.prepare_next_token_ids_cpu(
sampled_token_ids, self.runner.requests,
self.runner.input_batch, scheduler_output.num_scheduled_tokens)
else:
# When using padded-batch, the sampled_token_ids should be
# the gpu tensor of sampled tokens for each request, of shape
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
# value -1.
assert isinstance(sampled_token_ids, torch.Tensor), \
"sampled_token_ids should be a torch.Tensor when" \
"padded-batch is enabled."
next_token_ids, valid_sampled_tokens_count = \
self.prepare_next_token_ids_padded(
common_attn_metadata,
sampled_token_ids,
self.runner.requests,
self.runner.input_batch,
self.runner.discard_request_indices.gpu,
self.runner.num_discarded_requests
)
self._copy_valid_sampled_token_count(next_token_ids,
valid_sampled_tokens_count)
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
if self.pcp_size > 1:
long_seq_metadata = self.runner.long_seq_metadata
input_ids_pcp_full = self.runner.pcp_manager.input_ids_pcp_full.gpu
query_start_loc_pcp_full = self.runner.pcp_manager.query_start_loc_pcp_full.gpu
query_start_loc_pcp_full_cpu = self.runner.pcp_manager.query_start_loc_pcp_full.cpu
num_reqs = self.runner.input_batch.num_reqs
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
query_start_loc_pcp_full_cpu[:num_reqs]
num_prefill_reqs = (ori_query_lens
> self.decode_threshold).sum().item()
num_decode_reqs = num_reqs - num_prefill_reqs
else:
long_seq_metadata = None
num_prefill_reqs = 0
num_decode_reqs = 0
if spec_decode_metadata is None:
# update pcp related params
if self.pcp_size > 1:
token_indices_to_sample = \
query_start_loc_pcp_full_cpu[1:num_reqs + 1] - 1
target_token_ids = input_ids_pcp_full[:num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
target_hidden_states = hidden_states
else:
token_indices_to_sample = None
# input_ids can be None for multimodal models.
target_token_ids = self.runner.input_ids.gpu[:
num_scheduled_tokens]
target_positions = positions[:num_scheduled_tokens]
if self.method == "eagle3":
target_hidden_states = torch.cat(
[h[:num_scheduled_tokens] for h in aux_hidden_states],
dim=-1)
else:
target_hidden_states = hidden_states[:num_scheduled_tokens]
else:
if self.pcp_size > 1:
common_attn_metadata.query_start_loc_cpu = \
query_start_loc_pcp_full_cpu[:num_reqs + 1]
common_attn_metadata.query_start_loc = \
query_start_loc_pcp_full[:num_reqs + 1]
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
token_indices_to_sample = None
common_attn_metadata, token_indices =\
self.prepare_inputs(
common_attn_metadata,
sampled_token_ids,
spec_decode_metadata.num_draft_tokens)
else:
common_attn_metadata, token_indices, \
token_indices_to_sample =\
self.prepare_inputs_padded(
common_attn_metadata,
spec_decode_metadata,
valid_sampled_tokens_count)
if self.pcp_size > 1:
target_token_ids = input_ids_pcp_full[token_indices]
target_positions = positions
target_hidden_states = hidden_states
else:
target_token_ids = self.runner.input_ids.gpu[token_indices]
target_positions = positions[token_indices]
if self.method == "eagle3":
target_hidden_states = torch.cat(
[h[token_indices] for h in aux_hidden_states], dim=-1)
else:
target_hidden_states = hidden_states[token_indices]
draft_token_ids = self._propose(
target_token_ids=target_token_ids,
target_positions=target_positions,
target_hidden_states=target_hidden_states,
next_token_ids=next_token_ids,
last_token_indices=token_indices_to_sample,
common_attn_metadata=common_attn_metadata,
sampling_metadata=sampling_metadata,
req_scheduled_tokens=req_scheduled_tokens,
long_seq_metadata=long_seq_metadata,
num_prefill_reqs=num_prefill_reqs,
num_decode_reqs=num_decode_reqs,
scheduler_output=scheduler_output,
num_scheduled_tokens=num_scheduled_tokens,
)
return draft_token_ids
def _propose(
self,
# [num_tokens]
@@ -430,9 +337,11 @@ class EagleProposer(VllmEagleProposer):
self.runner.get_model())
# update global cos, sin
update_cos_sin(self.positions[:num_input_tokens])
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_name:
per_layer_attn_metadata[layer_name] = attn_metadata
with set_ascend_forward_context(
{self.attn_layer_name: attn_metadata},
per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens,
num_actual_tokens=num_tokens,
@@ -558,7 +467,7 @@ class EagleProposer(VllmEagleProposer):
# Run the model.
with set_ascend_forward_context(
{self.attn_layer_name: attn_metadata},
per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size,
num_actual_tokens=batch_size,
@@ -696,28 +605,6 @@ class EagleProposer(VllmEagleProposer):
return next_token_ids, valid_sampled_tokens_count
def _copy_valid_sampled_token_count(
self, next_token_ids: torch.Tensor,
valid_sampled_tokens_count: torch.Tensor) -> None:
if self.runner.valid_sampled_token_count_event is not None:
default_stream = torch.npu.current_stream()
# initialize a new stream to overlap the copy operation with
# prepare_input of draft model.
with torch.npu.stream(
self.runner.valid_sampled_token_count_copy_stream):
self.runner.valid_sampled_token_count_copy_stream.wait_stream(
default_stream) # type: ignore
self.runner.valid_sampled_token_count_cpu[:
valid_sampled_tokens_count
.shape[0]].copy_(
valid_sampled_tokens_count,
non_blocking=True
)
self.runner.valid_sampled_token_count_event.record()
self.runner.input_batch.prev_sampled_token_ids = next_token_ids.unsqueeze(
1)
def prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,