[Refactor] optimize _prepare_inputs method in eagle_proposer (#3296)
### What this PR does / why we need it?
We optimized the _prepare_input method in eagle_proposer and no longer
use the _prepare_eagle_input_sequential method, improving the
performance of eagle-3.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
```
python3 -m vllm.entrypoints.openai.api_server
--host 0.0.0.0
--port 13963
--dtype bfloat16
--model meta-llama/Llama-3.1-8B-Instruct
--served-model-name Llama-3.1-8B-Instruct
--tensor-parallel-size 1
--gpu-memory-utilization 0.85
--max-model-len 32768
--trust-remote-code
--seed 42
--no-enable-prefix-caching
--speculative_config '{"method":"eagle3","model":"yuhuili/EAGLE3-LLaMA3.1-Instruct-8B","num_speculative_tokens":2,"draft_tensor_parallel_size":1}'
```
Co-authored-by: QilaiZhang (245706640@qq.com )
- vLLM version: v0.11.0rc3
- vLLM main: https://github.com/vllm-project/vllm/commit/v0.11.0
Signed-off-by: lio <1983142975@qq.com>
This commit is contained in:
@@ -12,13 +12,15 @@ from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase
|
||||
from vllm.model_executor.model_loader import get_model
|
||||
from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.utils import is_pin_memory_available
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.attention_v1 import (AscendAttentionState,
|
||||
AscendMetadata)
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.spec_decode.interface import Proposer, SpecDcodeType
|
||||
from vllm_ascend.utils import vllm_version_is
|
||||
@@ -78,6 +80,9 @@ class EagleProposer(Proposer):
|
||||
self.hidden_size),
|
||||
dtype=self.vllm_config.model_config.dtype,
|
||||
device=device)
|
||||
self.max_num_tokens = (
|
||||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||||
self.token_arange_np = np.arange(self.max_num_tokens)
|
||||
# We need +1 here because the arange is used to set query_start_loc,
|
||||
# which has one more element than batch_size.
|
||||
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
|
||||
@@ -197,10 +202,8 @@ class EagleProposer(Proposer):
|
||||
dtype=torch.int32,
|
||||
device=self.device,
|
||||
)
|
||||
num_tokens = num_scheduled_tokens - sum(num_rejected_tokens)
|
||||
cu_num_tokens, token_indices = self._prepare_inputs(
|
||||
eagle_attn_metadata.query_start_loc, num_rejected_tokens,
|
||||
num_tokens)
|
||||
cu_num_tokens, token_indices =\
|
||||
self._prepare_inputs(eagle_attn_metadata, num_rejected_tokens)
|
||||
target_token_ids = self.runner.input_ids[token_indices]
|
||||
target_positions = positions[token_indices]
|
||||
if self.name == SpecDcodeType.EAGLE3:
|
||||
@@ -605,72 +608,88 @@ class EagleProposer(Proposer):
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
# [batch_size + 1]
|
||||
cu_target_query_lens: torch.Tensor,
|
||||
eagle_attn_metadata: AscendMetadata,
|
||||
# [batch_size]
|
||||
num_rejected_tokens: torch.Tensor,
|
||||
num_tokens: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
# cu_target_query_lens: [0, a, a + b, a + b + c]
|
||||
# num_rejected_tokens: [n1, n2, n3]
|
||||
# num_tokens_per_req: [a - n1, b - n2, c - n3]
|
||||
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||||
# token_indices: [0, 1, ..., a - n1 - 1,
|
||||
# a, a + 1, ..., a + b - n2 - 1,
|
||||
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
|
||||
"""
|
||||
This function is used to prepare the inputs for the spec decode.
|
||||
It updates to the common_attn_metadata to account for the rejected
|
||||
tokens (and newly sampled tokens). It also returns the token indices
|
||||
of the tokens that should be fed to the speculator.
|
||||
"""
|
||||
# E.g.
|
||||
# common_attn_metadata.query_start_loc{_cpu}:
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3]
|
||||
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
|
||||
# num_rejected_tokens: [n1, n2, n3]
|
||||
# This function computes the intermediate values:
|
||||
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
|
||||
# And returns:
|
||||
# common_attn_metadata.query_start_loc{_cpu}:
|
||||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||||
# common_attn_metadata.seq_lens{_cpu}:
|
||||
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
|
||||
# token_indices: [0, 1, ..., q1 - n1 - 1,
|
||||
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
||||
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
||||
num_rejected_tokens_cpu = num_rejected_tokens.to("cpu")
|
||||
cu_target_query_lens = eagle_attn_metadata.query_start_loc
|
||||
device = eagle_attn_metadata.query_start_loc.device
|
||||
query_start_loc_cpu = cu_target_query_lens.to("cpu")
|
||||
|
||||
# [0, a, a + b, a + b + c] -> [a, b, c]
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
||||
new_query_len_per_req = (query_start_loc_cpu[1:] -
|
||||
query_start_loc_cpu[:-1])
|
||||
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
|
||||
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens_cpu
|
||||
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
|
||||
|
||||
# [q1 - n1, q2 - n2, q3 - n3] ->
|
||||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||||
new_query_start_loc_cpu = torch.zeros(
|
||||
query_start_loc_cpu.shape,
|
||||
dtype=torch.int32,
|
||||
pin_memory=is_pin_memory_available())
|
||||
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
|
||||
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
|
||||
|
||||
total_num_tokens = new_query_start_loc_np[-1]
|
||||
# Example assuming num_tokens_per_req_np = [2, 4, 3]
|
||||
# this implies that `new_query_start_locs` is:
|
||||
# [0, 2, 6, 9] ->
|
||||
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
|
||||
# _r1_ ____r2____ ___r3__
|
||||
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
|
||||
new_num_tokens_per_req_np)
|
||||
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
|
||||
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
|
||||
# _r1_ ____r2____ ___r3__
|
||||
token_offests = self.token_arange_np[:total_num_tokens] \
|
||||
- new_query_start_locs_expanded
|
||||
|
||||
# Expand starting positions to match token pattern
|
||||
# [0, q1, q1 + q2] ->
|
||||
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
|
||||
# _r1_ _____r2_______ ___________r3____________
|
||||
old_query_start_locs_expanded = np.repeat(
|
||||
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
|
||||
# Final token indices are:
|
||||
# [0, 1, // req 1
|
||||
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
|
||||
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
|
||||
token_indices_np = token_offests + old_query_start_locs_expanded
|
||||
token_indices = torch.from_numpy(token_indices_np).to(
|
||||
device, non_blocking=True)
|
||||
|
||||
# need use npu
|
||||
query_len_per_req = (cu_target_query_lens[1:] -
|
||||
cu_target_query_lens[:-1])
|
||||
# [a, b, c] -> [a - n1, b - n2, c - n3]
|
||||
num_tokens_per_req = query_len_per_req - num_rejected_tokens
|
||||
|
||||
# [a - n1, b - n2, c - n3] ->
|
||||
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||||
cu_num_tokens = torch.zeros_like(cu_target_query_lens)
|
||||
torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
|
||||
token_indices = torch.empty(
|
||||
num_tokens,
|
||||
dtype=torch.int32,
|
||||
device=cu_target_query_lens.device,
|
||||
)
|
||||
BLOCK_SIZE = 1024
|
||||
self._prepare_eagle_input_sequential(
|
||||
token_indices,
|
||||
cu_target_query_lens,
|
||||
cu_num_tokens,
|
||||
block_size=BLOCK_SIZE,
|
||||
)
|
||||
|
||||
return cu_num_tokens, token_indices
|
||||
|
||||
def _prepare_eagle_input_sequential(self, out_tensor: torch.Tensor,
|
||||
cu_query_lens: torch.Tensor,
|
||||
cu_num_tokens: torch.Tensor,
|
||||
block_size: int):
|
||||
num_programs = len(cu_num_tokens) - 1
|
||||
for pid in range(num_programs):
|
||||
start_pos = cu_num_tokens[pid].item()
|
||||
end_pos = cu_num_tokens[pid + 1].item()
|
||||
num_tokens = end_pos - start_pos
|
||||
index_start = cu_query_lens[pid].item()
|
||||
num_blocks = int(
|
||||
torch.ceil(torch.tensor(num_tokens / block_size)).item())
|
||||
|
||||
for i in range(num_blocks):
|
||||
offset_tensor = torch.arange(0,
|
||||
block_size,
|
||||
dtype=torch.int32,
|
||||
device=out_tensor.device)
|
||||
global_start_offset = i * block_size
|
||||
target_indices = torch.tensor(
|
||||
start_pos + global_start_offset,
|
||||
dtype=torch.int32,
|
||||
device=out_tensor.device) + offset_tensor
|
||||
values_to_store = torch.tensor(
|
||||
index_start + global_start_offset,
|
||||
dtype=torch.int32,
|
||||
device=out_tensor.device) + offset_tensor
|
||||
mask = (target_indices >= start_pos) & \
|
||||
(target_indices < end_pos) & \
|
||||
(offset_tensor < num_tokens)
|
||||
out_tensor[target_indices[mask]] = values_to_store[mask]
|
||||
|
||||
Reference in New Issue
Block a user