[Refactor][EAGLE] 2/N: load model and generate token (#5437)
### What this PR does / why we need it?
1. Refactor eagle and mtp function: load_model and generate_token_ids
2. Remove redundant code in mtp and eagle file
3. Refactor the UT of file
2/N of Refactor and merge mtp and eagle
Relational RFC: https://github.com/vllm-project/vllm-ascend/issues/5467
### Does this PR introduce _any_ user-facing change?
no
### How was this patch tested?
ut and tests
- vLLM version: release/v0.13.0
- vLLM main:
81786c8774
---------
Signed-off-by: lilinsiman <lilinsiman@gmail.com>
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
#
|
||||
# Copyright (c) 2025 Huawei Technologies Co., Ltd. All Rights Reserved.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2025 The vLLM team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -54,6 +54,7 @@ from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.mem_utils import DeviceMemoryProfiler
|
||||
from vllm.v1.attention.backends.gdn_attn import GDNAttentionMetadataBuilder
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.kv_cache_interface import (AttentionSpec,
|
||||
EncoderOnlyAttentionSpec,
|
||||
FullAttentionSpec, KVCacheConfig,
|
||||
@@ -113,7 +114,6 @@ from vllm_ascend.worker.pcp_utils import PCPManager
|
||||
from vllm_ascend.ascend_forward_context import ( # isort: skip
|
||||
MoECommType, get_mc2_tokens_capacity, select_moe_comm_method,
|
||||
set_ascend_forward_context, set_mc2_mask, set_mc2_tokens_capacity)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import xgrammar as xgr # type: ignore[import-untyped]
|
||||
from vllm.v1.core.sched.output import GrammarOutput, SchedulerOutput
|
||||
@@ -1257,6 +1257,7 @@ class NPUModelRunner(GPUModelRunner):
|
||||
logits_indices=logits_indices,
|
||||
)
|
||||
|
||||
# TODO: Once the PCP features are complete, it will fully inherit the classes from the VLLM community.
|
||||
def propose_draft_token_ids(
|
||||
self,
|
||||
valid_sampled_token_ids: torch.Tensor | list[list[int]],
|
||||
@@ -1273,10 +1274,147 @@ class NPUModelRunner(GPUModelRunner):
|
||||
# Speculative decoding is not enabled.
|
||||
draft_token_ids = None
|
||||
else:
|
||||
draft_token_ids = self.drafter.generate_token_ids(
|
||||
valid_sampled_token_ids, sampling_metadata, scheduler_output,
|
||||
spec_decode_metadata, positions, num_scheduled_tokens,
|
||||
hidden_states, aux_hidden_states)
|
||||
if self.speculative_config.method in ("suffix", "ngram"):
|
||||
draft_token_ids = self.drafter.generate_token_ids(
|
||||
valid_sampled_token_ids, sampling_metadata,
|
||||
scheduler_output, spec_decode_metadata, positions,
|
||||
num_scheduled_tokens, hidden_states, aux_hidden_states)
|
||||
|
||||
elif self.speculative_config.use_eagle():
|
||||
common_attn_metadata = self.spec_decode_common_attn_metadata
|
||||
sampled_token_ids = valid_sampled_token_ids
|
||||
|
||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||
# When padded-batch is disabled, the sampled_token_ids should be
|
||||
# the cpu-side list[list[int]] of valid sampled tokens for each
|
||||
# request, with invalid requests having empty lists.
|
||||
assert isinstance(sampled_token_ids, list), \
|
||||
"sampled_token_ids should be a python list when" \
|
||||
"padded-batch is disabled."
|
||||
assert self.drafter is not None
|
||||
next_token_ids = self.drafter.prepare_next_token_ids_cpu(
|
||||
sampled_token_ids, self.requests, self.input_batch,
|
||||
scheduler_output.num_scheduled_tokens)
|
||||
else:
|
||||
# When using padded-batch, the sampled_token_ids should be
|
||||
# the gpu tensor of sampled tokens for each request, of shape
|
||||
# (num_reqs, num_spec_tokens + 1) with rejected tokens having
|
||||
# value -1.
|
||||
assert isinstance(sampled_token_ids, torch.Tensor), \
|
||||
"sampled_token_ids should be a torch.Tensor when" \
|
||||
"padded-batch is enabled."
|
||||
assert self.drafter is not None
|
||||
next_token_ids, valid_sampled_tokens_count = \
|
||||
self.drafter.prepare_next_token_ids_padded(
|
||||
common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
self.requests,
|
||||
self.input_batch,
|
||||
self.discard_request_indices.gpu,
|
||||
self.num_discarded_requests
|
||||
)
|
||||
self._copy_valid_sampled_token_count(
|
||||
next_token_ids, valid_sampled_tokens_count)
|
||||
|
||||
req_scheduled_tokens = scheduler_output.num_scheduled_tokens
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
long_seq_metadata = self.long_seq_metadata # type: ignore
|
||||
input_ids_pcp_full = self.pcp_manager.input_ids_pcp_full.gpu
|
||||
query_start_loc_pcp_full = self.pcp_manager.query_start_loc_pcp_full.gpu
|
||||
query_start_loc_pcp_full_cpu = self.pcp_manager.query_start_loc_pcp_full.cpu
|
||||
num_reqs = self.input_batch.num_reqs
|
||||
ori_query_lens = query_start_loc_pcp_full_cpu[1:num_reqs+1] - \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs]
|
||||
num_prefill_reqs = (ori_query_lens
|
||||
> self.decode_threshold).sum().item()
|
||||
num_decode_reqs = num_reqs - num_prefill_reqs
|
||||
else:
|
||||
long_seq_metadata = None # type: ignore
|
||||
num_prefill_reqs = 0
|
||||
num_decode_reqs = 0
|
||||
if spec_decode_metadata is None:
|
||||
# update pcp related params
|
||||
if self.pcp_size > 1:
|
||||
token_indices_to_sample = \
|
||||
query_start_loc_pcp_full[1:num_reqs + 1] - 1
|
||||
target_token_ids = input_ids_pcp_full[:
|
||||
num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
target_hidden_states = hidden_states
|
||||
else:
|
||||
token_indices_to_sample = None
|
||||
# input_ids can be None for multimodal models.
|
||||
target_token_ids = self.input_ids.gpu[:
|
||||
num_scheduled_tokens]
|
||||
target_positions = positions[:num_scheduled_tokens]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat([
|
||||
h[:num_scheduled_tokens]
|
||||
for h in aux_hidden_states
|
||||
],
|
||||
dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[:
|
||||
num_scheduled_tokens]
|
||||
else:
|
||||
if self.pcp_size > 1:
|
||||
assert common_attn_metadata is not None
|
||||
common_attn_metadata.query_start_loc_cpu[:num_reqs + 1] = \
|
||||
query_start_loc_pcp_full_cpu[:num_reqs + 1]
|
||||
assert common_attn_metadata is not None
|
||||
common_attn_metadata.query_start_loc[:num_reqs + 1] = \
|
||||
query_start_loc_pcp_full[:num_reqs + 1]
|
||||
if self.vllm_config.speculative_config.disable_padded_drafter_batch:
|
||||
# NOTE: Currently, MTP-fullgraph is incompatibility with pcp
|
||||
token_indices_to_sample = None
|
||||
assert self.drafter is not None
|
||||
common_attn_metadata, token_indices =\
|
||||
self.drafter.prepare_inputs(
|
||||
common_attn_metadata,
|
||||
sampled_token_ids,
|
||||
spec_decode_metadata.num_draft_tokens)
|
||||
else:
|
||||
assert self.drafter is not None
|
||||
common_attn_metadata, token_indices, \
|
||||
token_indices_to_sample =\
|
||||
self.drafter.prepare_inputs_padded(
|
||||
common_attn_metadata,
|
||||
spec_decode_metadata,
|
||||
valid_sampled_tokens_count)
|
||||
if self.pcp_size > 1:
|
||||
target_token_ids = input_ids_pcp_full[token_indices]
|
||||
target_positions = positions
|
||||
target_hidden_states = hidden_states
|
||||
else:
|
||||
target_token_ids = self.input_ids.gpu[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
if self.use_aux_hidden_state_outputs:
|
||||
target_hidden_states = torch.cat(
|
||||
[h[token_indices] for h in aux_hidden_states],
|
||||
dim=-1)
|
||||
else:
|
||||
target_hidden_states = hidden_states[token_indices]
|
||||
assert self.drafter is not None
|
||||
draft_token_ids = self.drafter._propose(
|
||||
target_token_ids=target_token_ids,
|
||||
target_positions=target_positions,
|
||||
target_hidden_states=target_hidden_states,
|
||||
next_token_ids=next_token_ids,
|
||||
last_token_indices=token_indices_to_sample,
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
sampling_metadata=sampling_metadata,
|
||||
req_scheduled_tokens=req_scheduled_tokens,
|
||||
long_seq_metadata=long_seq_metadata,
|
||||
num_prefill_reqs=num_prefill_reqs,
|
||||
num_decode_reqs=num_decode_reqs,
|
||||
scheduler_output=scheduler_output,
|
||||
num_scheduled_tokens=num_scheduled_tokens,
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError("Unknown speculative decoding method: "
|
||||
f"{self.speculative_config.method}")
|
||||
|
||||
return draft_token_ids
|
||||
|
||||
@staticmethod
|
||||
|
||||
Reference in New Issue
Block a user