[Refactor] optimize _prepare_inputs method in eagle_proposer (#3296)

### What this PR does / why we need it?

We optimized the _prepare_input method in eagle_proposer and no longer
use the _prepare_eagle_input_sequential method, improving the
performance of eagle-3.

### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```
python3 -m vllm.entrypoints.openai.api_server  
--host 0.0.0.0 
--port 13963
--dtype bfloat16 
--model meta-llama/Llama-3.1-8B-Instruct
--served-model-name Llama-3.1-8B-Instruct 
--tensor-parallel-size 1 
--gpu-memory-utilization 0.85   
--max-model-len  32768 
--trust-remote-code  
--seed 42  
--no-enable-prefix-caching 
--speculative_config '{"method":"eagle3","model":"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B","num_speculative_tokens":2,"draft_tensor_parallel_size":1}'
```

Co-authored-by: QilaiZhang (245706640@qq.com )


- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0

Signed-off-by: lio <1983142975@qq.com>
This commit is contained in:
lio
2025-10-25 09:49:42 +08:00
committed by GitHub
parent d30bb95b90
commit 9e150e5009

View File

@@ -12,13 +12,15 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils import is_pin_memory_available
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
AscendMetadata)
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
from vllm_ascend.utils import vllm_version_is
@@ -78,6 +80,9 @@ class EagleProposer(Proposer):
self.hidden_size),
dtype=self.vllm_config.model_config.dtype,
device=device)
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
self.token_arange_np = np.arange(self.max_num_tokens)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
@@ -197,10 +202,8 @@ class EagleProposer(Proposer):
dtype=torch.int32,
device=self.device,
)
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
cu_num_tokens, token_indices = self._prepare_inputs(
eagle_attn_metadata.query_start_loc, num_rejected_tokens,
num_tokens)
cu_num_tokens, token_indices =\
self._prepare_inputs(eagle_attn_metadata, num_rejected_tokens)
target_token_ids = self.runner.input_ids[token_indices]
target_positions = positions[token_indices]
if self.name == SpecDcodeType.EAGLE3:
@@ -605,72 +608,88 @@ class EagleProposer(Proposer):
def _prepare_inputs(
self,
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
eagle_attn_metadata: AscendMetadata,
# [batch_size]
num_rejected_tokens: torch.Tensor,
num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
"""
This function is used to prepare the inputs for the spec decode.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
# E.g.
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1, q1 + q2, q1 + q2 + q3]
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
# num_rejected_tokens: [n1, n2, n3]
# This function computes the intermediate values:
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
# And returns:
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
# common_attn_metadata.seq_lens{_cpu}:
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
# token_indices: [0, 1, ..., q1 - n1 - 1,
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
num_rejected_tokens_cpu = num_rejected_tokens.to("cpu")
cu_target_query_lens = eagle_attn_metadata.query_start_loc
device = eagle_attn_metadata.query_start_loc.device
query_start_loc_cpu = cu_target_query_lens.to("cpu")
# [0, a, a + b, a + b + c] -> [a, b, c]
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
new_query_len_per_req = (query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens_cpu
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
# [q1 - n1, q2 - n2, q3 - n3] ->
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
new_query_start_loc_cpu = torch.zeros(
query_start_loc_cpu.shape,
dtype=torch.int32,
pin_memory=is_pin_memory_available())
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
total_num_tokens = new_query_start_loc_np[-1]
# Example assuming num_tokens_per_req_np = [2, 4, 3]
# this implies that `new_query_start_locs` is:
# [0, 2, 6, 9] ->
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
# _r1_ ____r2____ ___r3__
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
new_num_tokens_per_req_np)
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
# _r1_ ____r2____ ___r3__
token_offests = self.token_arange_np[:total_num_tokens] \
- new_query_start_locs_expanded
# Expand starting positions to match token pattern
# [0, q1, q1 + q2] ->
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
# _r1_ _____r2_______ ___________r3____________
old_query_start_locs_expanded = np.repeat(
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
# Final token indices are:
# [0, 1, // req 1
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
token_indices_np = token_offests + old_query_start_locs_expanded
token_indices = torch.from_numpy(token_indices_np).to(
device, non_blocking=True)
# need use npu
query_len_per_req = (cu_target_query_lens[1:] -
cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
num_tokens_per_req = query_len_per_req - num_rejected_tokens
# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
token_indices = torch.empty(
num_tokens,
dtype=torch.int32,
device=cu_target_query_lens.device,
)
BLOCK_SIZE = 1024
self._prepare_eagle_input_sequential(
token_indices,
cu_target_query_lens,
cu_num_tokens,
block_size=BLOCK_SIZE,
)
return cu_num_tokens, token_indices
def _prepare_eagle_input_sequential(self, out_tensor: torch.Tensor,
cu_query_lens: torch.Tensor,
cu_num_tokens: torch.Tensor,
block_size: int):
num_programs = len(cu_num_tokens) - 1
for pid in range(num_programs):
start_pos = cu_num_tokens[pid].item()
end_pos = cu_num_tokens[pid + 1].item()
num_tokens = end_pos - start_pos
index_start = cu_query_lens[pid].item()
num_blocks = int(
torch.ceil(torch.tensor(num_tokens / block_size)).item())
for i in range(num_blocks):
offset_tensor = torch.arange(0,
block_size,
dtype=torch.int32,
device=out_tensor.device)
global_start_offset = i * block_size
target_indices = torch.tensor(
start_pos + global_start_offset,
dtype=torch.int32,
device=out_tensor.device) + offset_tensor
values_to_store = torch.tensor(
index_start + global_start_offset,
dtype=torch.int32,
device=out_tensor.device) + offset_tensor
mask = (target_indices >= start_pos) & \
(target_indices < end_pos) & \
(offset_tensor < num_tokens)
out_tensor[target_indices[mask]] = values_to_store[mask]