Files
enginex-vastai-va16-vllm/vllm_vacc/vllm/v1/spec_decode/eagle.py
2026-04-02 04:55:00 +00:00

789 lines
35 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import ast
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
from vllm.distributed.parallel_state import get_pp_group
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models import supports_multimodal
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.utils import is_pin_memory_available
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
CommonAttentionMetadata)
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
from vllm.v1.utils import CpuGpuBuffer
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.distributed import get_tensor_model_parallel_rank
PADDING_SLOT_ID = -1
from vacc_tools.trace_logger import get_trace_api
trace_time, register_module_trace, trace_autograd_function, register_optimizer_trace = (
get_trace_api("deepseek")
)
# @trace_time('prepare_eagle_input_python')
def prepare_eagle_input_python(
out_ptr,
cu_query_lens_ptr,
cu_num_tokens_ptr,
# BLOCK_SIZE
):
"""
Python实现版本的prepare_eagle_input_kernel
参数:
out_ptr: 输出张量
cu_query_lens_ptr: 每个查询的起始索引张量
cu_num_tokens_ptr: 每个查询的token数量累计张量
BLOCK_SIZE: 块大小
"""
cu_query_lens_ptr_list = cu_query_lens_ptr
cu_num_tokens_ptr_list = cu_num_tokens_ptr
num_queries = len(cu_num_tokens_ptr) - 1
# out_ptr_list = np.zeros(cu_query_lens_ptr_list.shape, cu_query_lens_ptr_list.dtype)
for pid in range(num_queries):
start_pos = cu_num_tokens_ptr_list[pid]#.item()
end_pos = cu_num_tokens_ptr_list[pid + 1]#.item()
num_tokens = end_pos - start_pos
index_start = cu_query_lens_ptr_list[pid]#.item()
# offset = np.array([i for i in range(num_tokens)], dtype=cu_num_tokens_ptr_list.dtype)
# values = index_start + offset
# 存储到输出张量
# out_ptr[start_pos + offset] = values
for i in range(num_tokens):
out_ptr[start_pos + i] = index_start + i
return
import torch
num_queries = len(cu_num_tokens_ptr) - 1
for pid in range(num_queries):
# [start_pos, end_pos)
start_pos = cu_num_tokens_ptr[pid].item()
end_pos = cu_num_tokens_ptr[pid + 1].item()
num_tokens = end_pos - start_pos
index_start = cu_query_lens_ptr[pid].item()
num_blocks = (num_tokens + BLOCK_SIZE - 1) // BLOCK_SIZE
for i in range(num_blocks):
offset_start = i * BLOCK_SIZE
offset_end = min(offset_start + BLOCK_SIZE, num_tokens)
# 创建当前块的偏移量
offset = torch.arange(offset_start, offset_end, device=out_ptr.device, dtype=out_ptr.dtype)
# 计算要存储的值
values = index_start + offset
# 存储到输出张量
out_ptr[start_pos + offset] = values
class EagleProposer:
# @trace_time('EagleProposer_propose')
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
last_token_indices: Optional[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
# self.input_ids[last_token_indices] = next_token_ids
if isinstance(last_token_indices, list) and len(last_token_indices) == 1:
self.input_ids[last_token_indices[0] : last_token_indices[0]+1] = next_token_ids
else:
self.input_ids[last_token_indices] = next_token_ids
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
ubatch_id = dbo_current_ubatch_id()
attn_metadata_builder = \
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0)
# FIXME: support hybrid kv for draft model (remove separate indexer)
if self.draft_indexer_metadata_builder:
draft_indexer_metadata = (
self.draft_indexer_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=0,
))
else:
draft_indexer_metadata = None
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for layer_name in self.indexer_layer_names:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
# self.positions[:num_tokens] = target_positions
# self.hidden_states[:num_tokens] = target_hidden_states
if self.is_multimodal_model:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = self.model.get_input_embeddings(
input_ids,
multimodal_embeddings=mm_embeds or None,
)
self.inputs_embeds[:num_tokens] = inputs_embeds
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
else:
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=target_positions,
hidden_states=target_hidden_states, #self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = last_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
# sample_hidden_states = last_hidden_states[last_token_indices]
if isinstance(last_token_indices, list):
if len(last_token_indices) == last_hidden_states.shape[0]:
sample_hidden_states = last_hidden_states
elif len(last_token_indices) == 1:
sample_hidden_states = last_hidden_states[last_token_indices[0] : last_token_indices[0] + 1]
else:
sample_hidden_states = last_hidden_states.new_empty(last_token_indices.shape + last_hidden_states.shape[1:])
torch.ops.aten.index(hidden_states, [torch.tensor(last_token_indices, dtype=torch.int32)], out=sample_hidden_states)
else:
assert isinstance(last_token_indices, torch.Tensor)
if last_token_indices.shape[0] == last_hidden_states.shape[0]:
sample_hidden_states = last_hidden_states
else:
sample_hidden_states = last_hidden_states.new_empty(last_token_indices.shape + last_hidden_states.shape[1:])
torch.ops.aten.index(hidden_states, [last_token_indices], out=sample_hidden_states)
logits = self.model.compute_logits(sample_hidden_states)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
draft_token_ids = logits.argmax(dim=-1)
return draft_token_ids.view(-1, 1)
else:
raise ValueError(f'not support self.num_speculative_tokens > 1, but get {self.num_speculative_tokens}')
'''
positions = target_positions[last_token_indices]
if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"):
hidden_states = None #self.hidden_states[last_token_indices]
else:
hidden_states = None #hidden_states[last_token_indices]
if isinstance(attn_metadata, TreeAttentionMetadata):
# Draft using tree attention.
draft_token_ids_list = self.propose_tree(
batch_size=batch_size,
logits=logits,
positions=positions,
hidden_states=hidden_states,
common_attn_metadata=common_attn_metadata,
)
# [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1)
draft_token_ids = logits.argmax(dim=-1)
if self.allowed_attn_types is not None and \
not isinstance(attn_metadata, self.allowed_attn_types):
raise ValueError(
f"Unsupported attention metadata type for speculative "
"decoding with num_speculative_tokens > 1: "
f"{type(attn_metadata)}. Supported types are: "
f"{self.allowed_attn_types}")
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1]
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[:batch_size + 1]).clone()
for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# Increment the sequence lengths.
common_attn_metadata.seq_lens += 1
common_attn_metadata.seq_lens_cpu += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len,
1)
common_attn_metadata.num_computed_tokens_cpu = \
common_attn_metadata.seq_lens_cpu - 1
# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = common_attn_metadata.block_table_tensor.gather(
dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
common_attn_metadata.slot_mapping = (
block_ids * self.block_size +
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
common_attn_metadata.slot_mapping.masked_fill_(
exceeds_max_model_len, PADDING_SLOT_ID)
# Rebuild attention metadata
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
common_attn_metadata=common_attn_metadata,
draft_index=token_index + 1)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
# self.hidden_states[:batch_size] = hidden_states
if self.is_multimodal_model:
inputs_embeds = self.model.get_input_embeddings(input_ids)
self.inputs_embeds[:batch_size] = inputs_embeds
inputs_embeds = self.inputs_embeds[:input_batch_size]
input_ids = None
else:
inputs_embeds = None
input_ids = self.input_ids[:input_batch_size]
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:input_batch_size],
# hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds,
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size])
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
'''
def prepare_next_token_ids_padded(self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_indices: torch.Tensor,
num_discarded_requests: int) -> \
tuple[torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids and the number of valid sampled tokens
for each request, considering the "discarded" requests whose next token
is not sampled and comes from `request.get_token_id()` instead.
It also accounts for the rejected tokens in `sampled_token_ids`.
This function must use device functions to operate on the inputs, and
should not introduce any blocking CPU-GPU synchronization.
"""
# TODO(Ben): Combine this into a custom fused kernel
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
self.backup_next_token_ids.np[:num_reqs] = np.array([
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i])
for i in range(num_reqs)
])
self.backup_next_token_ids.copy_to_gpu(num_reqs)
# Mask out the sampled tokens indices that should not be sampled.
discard_sampled_tokens_req_indices = \
discard_request_indices[:num_discarded_requests]
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
valid_sampled_token_ids_gpu.index_fill_(
0, discard_sampled_tokens_req_indices, -1)
# Generate a mask for all valid tokens within those requests
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
dtype=torch.bool)
else:
valid_mask = (
(valid_sampled_token_ids_gpu != -1) &
(valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size))
# Count the number of valid tokens in each request
valid_sampled_tokens_count = valid_mask.sum(dim=1)
# Get the rightmost valid index per row
last_valid_indices = valid_sampled_tokens_count - 1
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
# Get last valid token from each row
# (assume undefined state where there is no valid token)
selected_tokens = torch.gather(
valid_sampled_token_ids_gpu, 1,
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
# Use last token if valid, pre-computed backup if not
batch_size = valid_sampled_token_ids_gpu.shape[0]
next_token_ids = torch.where(
last_valid_indices != -1, selected_tokens,
self.backup_next_token_ids.gpu[:batch_size])
return next_token_ids, valid_sampled_tokens_count
# @trace_time('prepare_inputs_padded')
def prepare_inputs_padded(self,
common_attn_metadata: CommonAttentionMetadata,
spec_decode_metadata: SpecDecodeMetadata,
valid_sampled_tokens_count: torch.Tensor) -> \
tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding
It updates the common_attn_metadata for speculative decoding,
but does not consider the rejected tokens. Instead, all tokens
are included as inputs to the speculator, with the rejected tokens
used as padding and filtered out later by `token_indices_to_sample`.
No blocking CPU operations should be introduced in this function.
"""
num_draft_tokens_gpu = torch.cat([
spec_decode_metadata.cu_num_draft_tokens[0:1],
spec_decode_metadata.cu_num_draft_tokens[1:] -
spec_decode_metadata.cu_num_draft_tokens[:-1]
])
num_rejected_tokens_gpu = torch.where(
num_draft_tokens_gpu > 0,
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
torch.zeros_like(num_draft_tokens_gpu))
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
new_query_len_per_req = (query_start_loc_cpu[1:] -
query_start_loc_cpu[:-1])
total_num_tokens = query_start_loc_cpu[-1].item()
token_indices = self.arange[:total_num_tokens]
spec_common_attn_metadata = CommonAttentionMetadata(
query_start_loc=common_attn_metadata.query_start_loc,
seq_lens=common_attn_metadata.seq_lens,
query_start_loc_cpu=query_start_loc_cpu,
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.
num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=new_query_len_per_req.max().item(),
max_seq_len=max(common_attn_metadata.seq_lens_cpu),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
causal=True,
)
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \
- num_rejected_tokens_gpu
return spec_common_attn_metadata, token_indices, token_indices_to_sample
# @trace_time('prepare_inputs')
def prepare_inputs(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: list[list[int]],
num_draft_tokens: list[int],
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It updates to the common_attn_metadata to account for the rejected
tokens (and newly sampled tokens). It also returns the token indices
of the tokens that should be fed to the speculator.
"""
# E.g.
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1, q1 + q2, q1 + q2 + q3]
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
# num_rejected_tokens: [n1, n2, n3]
# This function computes the intermediate values:
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
# And returns:
# common_attn_metadata.query_start_loc{_cpu}:
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
# common_attn_metadata.seq_lens{_cpu}:
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
# token_indices: [0, 1, ..., q1 - n1 - 1,
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
num_rejected_tokens = [
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
for i, n in enumerate(num_draft_tokens)
]
# num_rejected_tokens = torch.tensor(num_rejected_tokens,
# dtype=torch.int32)
# device = common_attn_metadata.query_start_loc.device
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
# new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \
# - num_rejected_tokens
# new_seq_lens_cpu = [i-j for i,j in zip(common_attn_metadata.seq_lens_cpu, num_rejected_tokens)]
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
# new_query_len_per_req = (query_start_loc_cpu[1:] -
# query_start_loc_cpu[:-1])
# new_query_len_per_req = [query_start_loc_cpu[i+1] - query_start_loc_cpu[i] for i in range(len(query_start_loc_cpu)-1)]
new_query_len_per_req = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]).tolist() # [2] *bs+1
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
# new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
new_num_tokens_per_req = [i-j for i,j in zip(new_query_len_per_req, num_rejected_tokens)]
new_num_tokens_per_req_np = np.array(new_num_tokens_per_req)
# common_attn_metadata.seq_lens_cpu is list[int], length is max_seq_num,
# seq_lens_cpu come from VACCModelRunner _prepare_inputs, 预先加了 k + 1
# if seq=[31,63] bs=2, max_seq_num=4, seq_lens_cpu=[31,63,0,0]
# new_seq_lens_cpu need all real seq,
# if num_rejected_tokens=[0,1] new_seq_lens_cpu = [30,31,62] means bs=1 接受了, bs=2拒绝了 只有一个recover_token
# if num_rejected_tokens=[1,0] new_seq_lens_cpu = [30,62,63] means bs=2 接受了, bs=1拒绝了 只有一个recover_token
# if num_rejected_tokens=[0,0] new_seq_lens_cpu = [30,31,62,63] means 都接受了
# if num_rejected_tokens=[1,1] new_seq_lens_cpu = [30,62] means 都拒绝了
new_seq_lens_cpu = []
for i in range(len(num_rejected_tokens)):
for j in range(new_num_tokens_per_req[i]):
new_seq_lens_cpu.append(common_attn_metadata.seq_lens_cpu[i] - new_num_tokens_per_req[i] + 1 - num_rejected_tokens[i] + j)
# [q1 - n1, q2 - n2, q3 - n3] ->
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
new_query_start_loc_cpu = torch.zeros(
query_start_loc_cpu.shape,
dtype=torch.int32,
pin_memory=is_pin_memory_available())
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
total_num_tokens = new_query_start_loc_np[-1]
# Example assuming num_tokens_per_req_np = [2, 4, 3]
# this implies that `new_query_start_locs` is:
# [0, 2, 6, 9] ->
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
# _r1_ ____r2____ ___r3__
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
new_num_tokens_per_req_np)
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
# _r1_ ____r2____ ___r3__
token_offests = self.token_arange_np[:total_num_tokens] \
- new_query_start_locs_expanded
# Expand starting positions to match token pattern
# [0, q1, q1 + q2] ->
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
# _r1_ _____r2_______ ___________r3____________
old_query_start_locs_expanded = np.repeat(
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
# Final token indices are:
# [0, 1, // req 1
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
token_indices_np = token_offests + old_query_start_locs_expanded
token_indices = token_indices_np.tolist()
# token_indices = torch.from_numpy(token_indices_np).to(
# device, non_blocking=True)
# if get_tensor_model_parallel_rank() == 0:
# print('token_indices', token_indices, common_attn_metadata.slot_mapping.shape)
# [0] or [0,1] bs1
# [0, 2, 4, 6] ... or [0, 1, 2, 3, 4, 6, 7] for bs4
# opt slot_mapping slice
#common_attn_metadata.slot_mapping[token_indices] : copy + copy + index_out
if len(token_indices) == common_attn_metadata.slot_mapping.shape[0]: # no need slice
slot_mapping = common_attn_metadata.slot_mapping
elif len(token_indices) == 1:
slot_mapping = common_attn_metadata.slot_mapping[token_indices[0] : token_indices[0] + 1]
else:
slot_mapping = common_attn_metadata.slot_mapping.new_empty(len(token_indices))
torch.ops.aten.index(common_attn_metadata.slot_mapping, [torch.tensor(token_indices, dtype=torch.int32)], out=slot_mapping)
spec_common_attn_metadata = CommonAttentionMetadata(
query_start_loc=new_query_start_loc_cpu,
# query_start_loc=new_query_start_loc_cpu.to(device,
# non_blocking=True),
seq_lens=new_seq_lens_cpu, #.to(device, non_blocking=True),
query_start_loc_cpu=new_query_start_loc_cpu,
seq_lens_cpu=new_seq_lens_cpu,
num_computed_tokens_cpu=common_attn_metadata.
num_computed_tokens_cpu,
num_reqs=common_attn_metadata.num_reqs,
num_actual_tokens=total_num_tokens,
max_query_len=max(new_query_len_per_req),#new_query_len_per_req.max().item(),
max_seq_len=None, #max(new_seq_lens_cpu),#new_seq_lens_cpu.max().item(),
block_table_tensor=common_attn_metadata.block_table_tensor,
slot_mapping=slot_mapping, #common_attn_metadata.slot_mapping[token_indices],
causal=True,
)
return spec_common_attn_metadata, token_indices
# @trace_time('EagleProposer_prepare_inputs')
@staticmethod
def prepare_inputs_9_2(
self,
# [batch_size + 1]
cu_target_query_lens: torch.Tensor,
# [batch_size]
num_rejected_tokens: torch.Tensor,
num_tokens: int,
) -> tuple[torch.Tensor, torch.Tensor]:
# cu_target_query_lens: [0, a, a + b, a + b + c]
# num_rejected_tokens: [n1, n2, n3]
# num_tokens_per_req: [a - n1, b - n2, c - n3]
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# token_indices: [0, 1, ..., a - n1 - 1,
# a, a + 1, ..., a + b - n2 - 1,
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
# [0, a, a + b, a + b + c] -> [a, b, c]
# torch
# query_len_per_req = (cu_target_query_lens[1:] -
# cu_target_query_lens[:-1])
# [a, b, c] -> [a - n1, b - n2, c - n3]
# num_tokens_per_req = query_len_per_req - num_rejected_tokens
# list
num_tokens_per_req = [cu_target_query_lens[i+1] - cu_target_query_lens[i] - num_rejected_tokens[i] for i in range(len(cu_target_query_lens)-1)]
# [a - n1, b - n2, c - n3] ->
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
# torch style
# cu_num_tokens = torch.zeros_like(cu_target_query_lens)
# torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
# list style
cu_num_tokens = [0] * len(cu_target_query_lens)
for i in range(len(cu_target_query_lens)-1):
cu_num_tokens[i+1] = cu_num_tokens[i] + num_tokens_per_req[i]
# token_indices = torch.empty(
# num_tokens,
# dtype=torch.int32,
# device=cu_target_query_lens.device,
# )
# batch_size = num_rejected_tokens.shape[0]
# BLOCK_SIZE = 1024
# prepare_eagle_input_kernel[(batch_size, )](
# token_indices,
# cu_target_query_lens,
# cu_num_tokens,
# BLOCK_SIZE=BLOCK_SIZE,
# )
token_indices = [0] * num_tokens
prepare_eagle_input_python(
token_indices,
cu_target_query_lens,
cu_num_tokens
)
return cu_num_tokens, token_indices
def EagleProposer_init_(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
self.vllm_config = vllm_config
self.speculative_config = vllm_config.speculative_config
assert self.speculative_config is not None
self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method
self.runner = runner
self.device = device
self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
self.num_speculative_tokens = (
self.speculative_config.num_speculative_tokens)
self.max_num_tokens = (
vllm_config.scheduler_config.max_num_batched_tokens)
self.token_arange_np = np.arange(self.max_num_tokens)
# We need to get the hidden size from the draft model config because
# the draft model's hidden size can be different from the target model's
# hidden size (e.g., Llama 3.3 70B).
self.hidden_size = self.draft_model_config.get_hidden_size()
self.is_multimodal_model = vllm_config.model_config \
.is_multimodal_model
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
self.draft_indexer_metadata_builder: Optional[
AttentionMetadataBuilder] = None
self.attn_layer_names: list[str] = []
self.indexer_layer_names: list[str] = []
self.use_cuda_graph = False
self.cudagraph_batch_sizes = []
# persistent buffers for cuda graph
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=device)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=device)
# self.hidden_states = torch.zeros(
# (self.max_num_tokens, self.hidden_size),
# dtype=self.dtype,
# device=device)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
max_batch_size = vllm_config.scheduler_config.max_num_seqs
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
self.arange = torch.arange(max_num_slots_for_arange,
device=device,
dtype=torch.int32)
if self.is_multimodal_model:
self.inputs_embeds = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
else:
self.inputs_embeds = None
self.backup_next_token_ids = CpuGpuBuffer(
max_batch_size,
dtype=torch.int32,
pin_memory=is_pin_memory_available(),
device=device,
with_numpy=True)
# Determine allowed attention backends once during initialization.
self.allowed_attn_types: Optional[tuple] = None
# Parse the speculative token tree.
spec_token_tree = self.speculative_config.speculative_token_tree
self.tree_choices: list[tuple[int,
...]] = ast.literal_eval(spec_token_tree)
tree_depth = len(self.tree_choices[-1])
# Precompute per-level properties of the tree.
num_drafts_per_level = [0] * tree_depth
for node in self.tree_choices:
num_drafts_per_level[len(node) - 1] += 1
self.cu_drafts_per_level = [num_drafts_per_level[0]]
self.child_drafts_per_level = [num_drafts_per_level[0]]
for level in range(1, tree_depth):
self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] +
num_drafts_per_level[level])
self.child_drafts_per_level.append(num_drafts_per_level[level] //
num_drafts_per_level[level - 1])
# Precompute draft position offsets in flattened tree.
self.tree_draft_pos_offsets = torch.arange(
1,
len(self.tree_choices) + 1,
device=device,
dtype=torch.int32,
).repeat(max_batch_size, 1)