789 lines
35 KiB
Python
789 lines
35 KiB
Python
|
||
import ast
|
||
from typing import Optional
|
||
|
||
import numpy as np
|
||
import torch
|
||
import torch.nn as nn
|
||
|
||
from vllm.attention.layer import Attention
|
||
from vllm.config import (CompilationLevel, VllmConfig,
|
||
get_layers_from_vllm_config)
|
||
from vllm.distributed.parallel_state import get_pp_group
|
||
from vllm.forward_context import set_forward_context
|
||
from vllm.logger import init_logger
|
||
from vllm.model_executor.model_loader import get_model
|
||
from vllm.model_executor.models import supports_multimodal
|
||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||
from vllm.utils import is_pin_memory_available
|
||
|
||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||
from vllm.v1.attention.backends.utils import (AttentionMetadataBuilder,
|
||
CommonAttentionMetadata)
|
||
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
||
TreeAttentionMetadataBuilder)
|
||
from vllm.v1.sample.metadata import SamplingMetadata
|
||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||
from vllm.v1.utils import CpuGpuBuffer
|
||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||
|
||
from vllm.distributed import get_tensor_model_parallel_rank
|
||
|
||
PADDING_SLOT_ID = -1
|
||
|
||
from vacc_tools.trace_logger import get_trace_api
|
||
trace_time, register_module_trace, trace_autograd_function, register_optimizer_trace = (
|
||
get_trace_api("deepseek")
|
||
)
|
||
|
||
# @trace_time('prepare_eagle_input_python')
|
||
def prepare_eagle_input_python(
|
||
out_ptr,
|
||
cu_query_lens_ptr,
|
||
cu_num_tokens_ptr,
|
||
# BLOCK_SIZE
|
||
):
|
||
"""
|
||
Python实现版本的prepare_eagle_input_kernel
|
||
|
||
参数:
|
||
out_ptr: 输出张量
|
||
cu_query_lens_ptr: 每个查询的起始索引张量
|
||
cu_num_tokens_ptr: 每个查询的token数量累计张量
|
||
BLOCK_SIZE: 块大小
|
||
"""
|
||
cu_query_lens_ptr_list = cu_query_lens_ptr
|
||
cu_num_tokens_ptr_list = cu_num_tokens_ptr
|
||
num_queries = len(cu_num_tokens_ptr) - 1
|
||
|
||
# out_ptr_list = np.zeros(cu_query_lens_ptr_list.shape, cu_query_lens_ptr_list.dtype)
|
||
for pid in range(num_queries):
|
||
start_pos = cu_num_tokens_ptr_list[pid]#.item()
|
||
end_pos = cu_num_tokens_ptr_list[pid + 1]#.item()
|
||
num_tokens = end_pos - start_pos
|
||
|
||
index_start = cu_query_lens_ptr_list[pid]#.item()
|
||
|
||
# offset = np.array([i for i in range(num_tokens)], dtype=cu_num_tokens_ptr_list.dtype)
|
||
# values = index_start + offset
|
||
# 存储到输出张量
|
||
# out_ptr[start_pos + offset] = values
|
||
|
||
for i in range(num_tokens):
|
||
out_ptr[start_pos + i] = index_start + i
|
||
|
||
return
|
||
import torch
|
||
num_queries = len(cu_num_tokens_ptr) - 1
|
||
|
||
for pid in range(num_queries):
|
||
# [start_pos, end_pos)
|
||
start_pos = cu_num_tokens_ptr[pid].item()
|
||
end_pos = cu_num_tokens_ptr[pid + 1].item()
|
||
num_tokens = end_pos - start_pos
|
||
|
||
index_start = cu_query_lens_ptr[pid].item()
|
||
|
||
num_blocks = (num_tokens + BLOCK_SIZE - 1) // BLOCK_SIZE
|
||
|
||
for i in range(num_blocks):
|
||
offset_start = i * BLOCK_SIZE
|
||
offset_end = min(offset_start + BLOCK_SIZE, num_tokens)
|
||
|
||
# 创建当前块的偏移量
|
||
offset = torch.arange(offset_start, offset_end, device=out_ptr.device, dtype=out_ptr.dtype)
|
||
|
||
# 计算要存储的值
|
||
values = index_start + offset
|
||
|
||
# 存储到输出张量
|
||
out_ptr[start_pos + offset] = values
|
||
|
||
class EagleProposer:
|
||
|
||
# @trace_time('EagleProposer_propose')
|
||
def propose(
|
||
self,
|
||
# [num_tokens]
|
||
target_token_ids: torch.Tensor,
|
||
# [num_tokens]
|
||
target_positions: torch.Tensor,
|
||
# [num_tokens, hidden_size]
|
||
target_hidden_states: torch.Tensor,
|
||
# [batch_size]
|
||
next_token_ids: torch.Tensor,
|
||
last_token_indices: Optional[torch.Tensor],
|
||
common_attn_metadata: CommonAttentionMetadata,
|
||
sampling_metadata: SamplingMetadata,
|
||
mm_embeds: Optional[list[torch.Tensor]] = None,
|
||
) -> torch.Tensor:
|
||
num_tokens = target_token_ids.shape[0]
|
||
batch_size = next_token_ids.shape[0]
|
||
|
||
if last_token_indices is None:
|
||
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||
|
||
if self.method == "eagle3":
|
||
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
||
target_hidden_states = self.model.combine_hidden_states(
|
||
target_hidden_states)
|
||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||
# Shift the input ids by one token.
|
||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||
# Replace the last token with the next token.
|
||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||
# self.input_ids[last_token_indices] = next_token_ids
|
||
if isinstance(last_token_indices, list) and len(last_token_indices) == 1:
|
||
self.input_ids[last_token_indices[0] : last_token_indices[0]+1] = next_token_ids
|
||
else:
|
||
self.input_ids[last_token_indices] = next_token_ids
|
||
|
||
|
||
assert self.runner is not None
|
||
|
||
# FIXME: need to consider multiple kv_cache_groups
|
||
ubatch_id = dbo_current_ubatch_id()
|
||
attn_metadata_builder = \
|
||
self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||
common_attn_metadata=common_attn_metadata, draft_index=0)
|
||
# FIXME: support hybrid kv for draft model (remove separate indexer)
|
||
if self.draft_indexer_metadata_builder:
|
||
draft_indexer_metadata = (
|
||
self.draft_indexer_metadata_builder.build_for_drafting(
|
||
common_attn_metadata=common_attn_metadata,
|
||
draft_index=0,
|
||
))
|
||
else:
|
||
draft_indexer_metadata = None
|
||
# At this moment, we assume all eagle layers belong to the same KV
|
||
# cache group, thus using the same attention metadata.
|
||
per_layer_attn_metadata = {}
|
||
for layer_name in self.attn_layer_names:
|
||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||
for layer_name in self.indexer_layer_names:
|
||
assert draft_indexer_metadata is not None
|
||
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
||
|
||
num_input_tokens = num_tokens
|
||
# copy inputs to buffer for cudagraph
|
||
# self.positions[:num_tokens] = target_positions
|
||
# self.hidden_states[:num_tokens] = target_hidden_states
|
||
if self.is_multimodal_model:
|
||
input_ids = self.input_ids[:num_tokens]
|
||
inputs_embeds = self.model.get_input_embeddings(
|
||
input_ids,
|
||
multimodal_embeddings=mm_embeds or None,
|
||
)
|
||
self.inputs_embeds[:num_tokens] = inputs_embeds
|
||
inputs_embeds = self.inputs_embeds[:num_input_tokens]
|
||
input_ids = None
|
||
else:
|
||
inputs_embeds = None
|
||
input_ids = self.input_ids[:num_input_tokens]
|
||
|
||
with set_forward_context(per_layer_attn_metadata,
|
||
self.vllm_config,
|
||
num_tokens=num_input_tokens):
|
||
ret_hidden_states = self.model(
|
||
input_ids=input_ids,
|
||
positions=target_positions,
|
||
hidden_states=target_hidden_states, #self.hidden_states[:num_input_tokens],
|
||
inputs_embeds=inputs_embeds,
|
||
)
|
||
if self.method == "mtp":
|
||
last_hidden_states = ret_hidden_states
|
||
hidden_states = last_hidden_states
|
||
else:
|
||
last_hidden_states, hidden_states = ret_hidden_states
|
||
# sample_hidden_states = last_hidden_states[last_token_indices]
|
||
|
||
if isinstance(last_token_indices, list):
|
||
if len(last_token_indices) == last_hidden_states.shape[0]:
|
||
sample_hidden_states = last_hidden_states
|
||
elif len(last_token_indices) == 1:
|
||
sample_hidden_states = last_hidden_states[last_token_indices[0] : last_token_indices[0] + 1]
|
||
else:
|
||
sample_hidden_states = last_hidden_states.new_empty(last_token_indices.shape + last_hidden_states.shape[1:])
|
||
torch.ops.aten.index(hidden_states, [torch.tensor(last_token_indices, dtype=torch.int32)], out=sample_hidden_states)
|
||
else:
|
||
assert isinstance(last_token_indices, torch.Tensor)
|
||
if last_token_indices.shape[0] == last_hidden_states.shape[0]:
|
||
sample_hidden_states = last_hidden_states
|
||
else:
|
||
sample_hidden_states = last_hidden_states.new_empty(last_token_indices.shape + last_hidden_states.shape[1:])
|
||
torch.ops.aten.index(hidden_states, [last_token_indices], out=sample_hidden_states)
|
||
|
||
|
||
|
||
logits = self.model.compute_logits(sample_hidden_states)
|
||
|
||
# Early exit if there is only one draft token to be generated.
|
||
if self.num_speculative_tokens == 1:
|
||
draft_token_ids = logits.argmax(dim=-1)
|
||
return draft_token_ids.view(-1, 1)
|
||
else:
|
||
raise ValueError(f'not support self.num_speculative_tokens > 1, but get {self.num_speculative_tokens}')
|
||
|
||
'''
|
||
positions = target_positions[last_token_indices]
|
||
if self.method in ("deepseek_mtp", "ernie_mtp", "longcat_flash_mtp"):
|
||
hidden_states = None #self.hidden_states[last_token_indices]
|
||
else:
|
||
hidden_states = None #hidden_states[last_token_indices]
|
||
|
||
if isinstance(attn_metadata, TreeAttentionMetadata):
|
||
# Draft using tree attention.
|
||
draft_token_ids_list = self.propose_tree(
|
||
batch_size=batch_size,
|
||
logits=logits,
|
||
positions=positions,
|
||
hidden_states=hidden_states,
|
||
common_attn_metadata=common_attn_metadata,
|
||
)
|
||
# [batch_size, num_tree_tokens]
|
||
return torch.cat(draft_token_ids_list, dim=1)
|
||
|
||
draft_token_ids = logits.argmax(dim=-1)
|
||
|
||
if self.allowed_attn_types is not None and \
|
||
not isinstance(attn_metadata, self.allowed_attn_types):
|
||
raise ValueError(
|
||
f"Unsupported attention metadata type for speculative "
|
||
"decoding with num_speculative_tokens > 1: "
|
||
f"{type(attn_metadata)}. Supported types are: "
|
||
f"{self.allowed_attn_types}")
|
||
|
||
# Generate the remaining draft tokens.
|
||
draft_token_ids_list = [draft_token_ids]
|
||
|
||
if self.use_cuda_graph and \
|
||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||
else:
|
||
input_batch_size = batch_size
|
||
|
||
common_attn_metadata.num_actual_tokens = batch_size
|
||
common_attn_metadata.max_query_len = 1
|
||
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1]
|
||
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
||
self.token_arange_np[:batch_size + 1]).clone()
|
||
for token_index in range(self.num_speculative_tokens - 1):
|
||
# Update the inputs.
|
||
# cast to int32 is crucial when eagle model is compiled.
|
||
# tensor.argmax() returns int64 by default.
|
||
input_ids = draft_token_ids_list[-1].int()
|
||
positions += 1
|
||
|
||
# NOTE(woosuk): We should handle the case where the draft model
|
||
# generates tokens beyond the max model length. Since it is complex
|
||
# to remove such requests from the batch, we keep them in the batch
|
||
# but adjust the position ids and slot mappings to avoid the
|
||
# out-of-range access during the model execution. The draft tokens
|
||
# generated with this adjustment should be ignored.
|
||
exceeds_max_model_len = positions >= self.max_model_len
|
||
# Mask out the position ids that exceed the max model length.
|
||
# Otherwise, we may get out-of-range error in RoPE.
|
||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||
positions)
|
||
|
||
# Increment the sequence lengths.
|
||
common_attn_metadata.seq_lens += 1
|
||
common_attn_metadata.seq_lens_cpu += 1
|
||
# For the requests that exceed the max model length, we set the
|
||
# sequence length to 1 to minimize their overheads in attention.
|
||
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len,
|
||
1)
|
||
|
||
common_attn_metadata.num_computed_tokens_cpu = \
|
||
common_attn_metadata.seq_lens_cpu - 1
|
||
|
||
# Compute the slot mapping.
|
||
block_numbers = clamped_positions // self.block_size
|
||
block_ids = common_attn_metadata.block_table_tensor.gather(
|
||
dim=1, index=block_numbers.view(-1, 1))
|
||
block_ids = block_ids.view(-1)
|
||
common_attn_metadata.slot_mapping = (
|
||
block_ids * self.block_size +
|
||
clamped_positions % self.block_size)
|
||
# Mask out the slot mappings that exceed the max model length.
|
||
# Otherwise, the KV cache will be inadvertently updated with the
|
||
# padding tokens.
|
||
common_attn_metadata.slot_mapping.masked_fill_(
|
||
exceeds_max_model_len, PADDING_SLOT_ID)
|
||
|
||
# Rebuild attention metadata
|
||
attn_metadata = attn_metadata_builder.build_for_drafting( # type: ignore
|
||
common_attn_metadata=common_attn_metadata,
|
||
draft_index=token_index + 1)
|
||
for layer_name in self.attn_layer_names:
|
||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||
|
||
# copy inputs to buffer for cudagraph
|
||
self.input_ids[:batch_size] = input_ids
|
||
self.positions[:batch_size] = clamped_positions
|
||
# self.hidden_states[:batch_size] = hidden_states
|
||
if self.is_multimodal_model:
|
||
inputs_embeds = self.model.get_input_embeddings(input_ids)
|
||
self.inputs_embeds[:batch_size] = inputs_embeds
|
||
inputs_embeds = self.inputs_embeds[:input_batch_size]
|
||
input_ids = None
|
||
else:
|
||
inputs_embeds = None
|
||
input_ids = self.input_ids[:input_batch_size]
|
||
|
||
# Run the model.
|
||
with set_forward_context(per_layer_attn_metadata,
|
||
self.vllm_config,
|
||
num_tokens=input_batch_size):
|
||
ret_hidden_states = self.model(
|
||
input_ids=input_ids,
|
||
positions=self.positions[:input_batch_size],
|
||
# hidden_states=self.hidden_states[:input_batch_size],
|
||
inputs_embeds=inputs_embeds,
|
||
)
|
||
if self.method == "mtp":
|
||
last_hidden_states = ret_hidden_states
|
||
hidden_states = ret_hidden_states
|
||
else:
|
||
last_hidden_states, hidden_states = ret_hidden_states
|
||
hidden_states = hidden_states[:batch_size]
|
||
logits = self.model.compute_logits(last_hidden_states[:batch_size])
|
||
draft_token_ids = logits.argmax(dim=-1)
|
||
draft_token_ids_list.append(draft_token_ids)
|
||
|
||
# [batch_size, num_speculative_tokens]
|
||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||
return draft_token_ids
|
||
'''
|
||
|
||
def prepare_next_token_ids_padded(self,
|
||
common_attn_metadata: CommonAttentionMetadata,
|
||
sampled_token_ids: torch.Tensor,
|
||
requests: dict[str, CachedRequestState],
|
||
gpu_input_batch: InputBatch,
|
||
discard_request_indices: torch.Tensor,
|
||
num_discarded_requests: int) -> \
|
||
tuple[torch.Tensor, torch.Tensor]:
|
||
"""
|
||
This function is used to prepare the inputs for speculative decoding.
|
||
It calculates the next token ids and the number of valid sampled tokens
|
||
for each request, considering the "discarded" requests whose next token
|
||
is not sampled and comes from `request.get_token_id()` instead.
|
||
It also accounts for the rejected tokens in `sampled_token_ids`.
|
||
This function must use device functions to operate on the inputs, and
|
||
should not introduce any blocking CPU-GPU synchronization.
|
||
"""
|
||
# TODO(Ben): Combine this into a custom fused kernel
|
||
|
||
# Precompute get_token_id for when there is no valid next token
|
||
num_reqs = gpu_input_batch.num_reqs
|
||
self.backup_next_token_ids.np[:num_reqs] = np.array([
|
||
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
||
common_attn_metadata.seq_lens_cpu[i])
|
||
for i in range(num_reqs)
|
||
])
|
||
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
||
|
||
# Mask out the sampled tokens indices that should not be sampled.
|
||
discard_sampled_tokens_req_indices = \
|
||
discard_request_indices[:num_discarded_requests]
|
||
|
||
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
|
||
valid_sampled_token_ids_gpu.index_fill_(
|
||
0, discard_sampled_tokens_req_indices, -1)
|
||
|
||
# Generate a mask for all valid tokens within those requests
|
||
max_gen_len = sampled_token_ids.shape[-1]
|
||
if max_gen_len == 1:
|
||
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
|
||
dtype=torch.bool)
|
||
else:
|
||
valid_mask = (
|
||
(valid_sampled_token_ids_gpu != -1) &
|
||
(valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size))
|
||
|
||
# Count the number of valid tokens in each request
|
||
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
||
|
||
# Get the rightmost valid index per row
|
||
last_valid_indices = valid_sampled_tokens_count - 1
|
||
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
|
||
|
||
# Get last valid token from each row
|
||
# (assume undefined state where there is no valid token)
|
||
selected_tokens = torch.gather(
|
||
valid_sampled_token_ids_gpu, 1,
|
||
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
|
||
|
||
# Use last token if valid, pre-computed backup if not
|
||
batch_size = valid_sampled_token_ids_gpu.shape[0]
|
||
next_token_ids = torch.where(
|
||
last_valid_indices != -1, selected_tokens,
|
||
self.backup_next_token_ids.gpu[:batch_size])
|
||
|
||
return next_token_ids, valid_sampled_tokens_count
|
||
|
||
# @trace_time('prepare_inputs_padded')
|
||
def prepare_inputs_padded(self,
|
||
common_attn_metadata: CommonAttentionMetadata,
|
||
spec_decode_metadata: SpecDecodeMetadata,
|
||
valid_sampled_tokens_count: torch.Tensor) -> \
|
||
tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
|
||
"""
|
||
This function is used to prepare the inputs for speculative decoding
|
||
It updates the common_attn_metadata for speculative decoding,
|
||
but does not consider the rejected tokens. Instead, all tokens
|
||
are included as inputs to the speculator, with the rejected tokens
|
||
used as padding and filtered out later by `token_indices_to_sample`.
|
||
No blocking CPU operations should be introduced in this function.
|
||
"""
|
||
num_draft_tokens_gpu = torch.cat([
|
||
spec_decode_metadata.cu_num_draft_tokens[0:1],
|
||
spec_decode_metadata.cu_num_draft_tokens[1:] -
|
||
spec_decode_metadata.cu_num_draft_tokens[:-1]
|
||
])
|
||
|
||
num_rejected_tokens_gpu = torch.where(
|
||
num_draft_tokens_gpu > 0,
|
||
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
|
||
torch.zeros_like(num_draft_tokens_gpu))
|
||
|
||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||
|
||
new_query_len_per_req = (query_start_loc_cpu[1:] -
|
||
query_start_loc_cpu[:-1])
|
||
|
||
total_num_tokens = query_start_loc_cpu[-1].item()
|
||
token_indices = self.arange[:total_num_tokens]
|
||
|
||
spec_common_attn_metadata = CommonAttentionMetadata(
|
||
query_start_loc=common_attn_metadata.query_start_loc,
|
||
seq_lens=common_attn_metadata.seq_lens,
|
||
query_start_loc_cpu=query_start_loc_cpu,
|
||
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
|
||
num_computed_tokens_cpu=common_attn_metadata.
|
||
num_computed_tokens_cpu,
|
||
num_reqs=common_attn_metadata.num_reqs,
|
||
num_actual_tokens=total_num_tokens,
|
||
max_query_len=new_query_len_per_req.max().item(),
|
||
max_seq_len=max(common_attn_metadata.seq_lens_cpu),
|
||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||
slot_mapping=common_attn_metadata.slot_mapping[token_indices],
|
||
causal=True,
|
||
)
|
||
|
||
token_indices_to_sample = common_attn_metadata.query_start_loc[1:] - 1 \
|
||
- num_rejected_tokens_gpu
|
||
|
||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||
|
||
# @trace_time('prepare_inputs')
|
||
def prepare_inputs(
|
||
self,
|
||
common_attn_metadata: CommonAttentionMetadata,
|
||
sampled_token_ids: list[list[int]],
|
||
num_draft_tokens: list[int],
|
||
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
|
||
"""
|
||
This function is used to prepare the inputs for speculative decoding.
|
||
It updates to the common_attn_metadata to account for the rejected
|
||
tokens (and newly sampled tokens). It also returns the token indices
|
||
of the tokens that should be fed to the speculator.
|
||
"""
|
||
# E.g.
|
||
# common_attn_metadata.query_start_loc{_cpu}:
|
||
# [0, q1, q1 + q2, q1 + q2 + q3]
|
||
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
|
||
# num_rejected_tokens: [n1, n2, n3]
|
||
# This function computes the intermediate values:
|
||
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
|
||
# And returns:
|
||
# common_attn_metadata.query_start_loc{_cpu}:
|
||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||
# common_attn_metadata.seq_lens{_cpu}:
|
||
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
|
||
# token_indices: [0, 1, ..., q1 - n1 - 1,
|
||
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
||
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
||
|
||
num_rejected_tokens = [
|
||
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
||
for i, n in enumerate(num_draft_tokens)
|
||
]
|
||
# num_rejected_tokens = torch.tensor(num_rejected_tokens,
|
||
# dtype=torch.int32)
|
||
|
||
# device = common_attn_metadata.query_start_loc.device
|
||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||
|
||
# new_seq_lens_cpu = common_attn_metadata.seq_lens_cpu \
|
||
# - num_rejected_tokens
|
||
# new_seq_lens_cpu = [i-j for i,j in zip(common_attn_metadata.seq_lens_cpu, num_rejected_tokens)]
|
||
|
||
|
||
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
||
# new_query_len_per_req = (query_start_loc_cpu[1:] -
|
||
# query_start_loc_cpu[:-1])
|
||
# new_query_len_per_req = [query_start_loc_cpu[i+1] - query_start_loc_cpu[i] for i in range(len(query_start_loc_cpu)-1)]
|
||
new_query_len_per_req = (query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]).tolist() # [2] *(bs+1)
|
||
|
||
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
|
||
# new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
|
||
new_num_tokens_per_req = [i-j for i,j in zip(new_query_len_per_req, num_rejected_tokens)]
|
||
|
||
new_num_tokens_per_req_np = np.array(new_num_tokens_per_req)
|
||
|
||
# common_attn_metadata.seq_lens_cpu is list[int], length is max_seq_num,
|
||
# seq_lens_cpu come from VACCModelRunner _prepare_inputs, 预先加了 k + 1
|
||
# if seq=[31,63] bs=2, max_seq_num=4, seq_lens_cpu=[31,63,0,0]
|
||
# new_seq_lens_cpu need all real seq,
|
||
# if num_rejected_tokens=[0,1] new_seq_lens_cpu = [30,31,62] means bs=1 接受了, bs=2拒绝了 只有一个recover_token
|
||
# if num_rejected_tokens=[1,0] new_seq_lens_cpu = [30,62,63] means bs=2 接受了, bs=1拒绝了 只有一个recover_token
|
||
# if num_rejected_tokens=[0,0] new_seq_lens_cpu = [30,31,62,63] means 都接受了
|
||
# if num_rejected_tokens=[1,1] new_seq_lens_cpu = [30,62] means 都拒绝了
|
||
new_seq_lens_cpu = []
|
||
for i in range(len(num_rejected_tokens)):
|
||
for j in range(new_num_tokens_per_req[i]):
|
||
new_seq_lens_cpu.append(common_attn_metadata.seq_lens_cpu[i] - new_num_tokens_per_req[i] + 1 - num_rejected_tokens[i] + j)
|
||
|
||
# [q1 - n1, q2 - n2, q3 - n3] ->
|
||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||
new_query_start_loc_cpu = torch.zeros(
|
||
query_start_loc_cpu.shape,
|
||
dtype=torch.int32,
|
||
pin_memory=is_pin_memory_available())
|
||
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
|
||
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
|
||
|
||
total_num_tokens = new_query_start_loc_np[-1]
|
||
# Example assuming num_tokens_per_req_np = [2, 4, 3]
|
||
# this implies that `new_query_start_locs` is:
|
||
# [0, 2, 6, 9] ->
|
||
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
|
||
# _r1_ ____r2____ ___r3__
|
||
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
|
||
new_num_tokens_per_req_np)
|
||
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
|
||
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
|
||
# _r1_ ____r2____ ___r3__
|
||
token_offests = self.token_arange_np[:total_num_tokens] \
|
||
- new_query_start_locs_expanded
|
||
|
||
# Expand starting positions to match token pattern
|
||
# [0, q1, q1 + q2] ->
|
||
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
|
||
# _r1_ _____r2_______ ___________r3____________
|
||
old_query_start_locs_expanded = np.repeat(
|
||
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
|
||
# Final token indices are:
|
||
# [0, 1, // req 1
|
||
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
|
||
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
|
||
token_indices_np = token_offests + old_query_start_locs_expanded
|
||
token_indices = token_indices_np.tolist()
|
||
# token_indices = torch.from_numpy(token_indices_np).to(
|
||
# device, non_blocking=True)
|
||
# if get_tensor_model_parallel_rank() == 0:
|
||
# print('token_indices', token_indices, common_attn_metadata.slot_mapping.shape)
|
||
# [0] or [0,1] bs1
|
||
# [0, 2, 4, 6] ... or [0, 1, 2, 3, 4, 6, 7] for bs4
|
||
|
||
# opt slot_mapping slice
|
||
#common_attn_metadata.slot_mapping[token_indices] : copy + copy + index_out
|
||
if len(token_indices) == common_attn_metadata.slot_mapping.shape[0]: # no need slice
|
||
slot_mapping = common_attn_metadata.slot_mapping
|
||
elif len(token_indices) == 1:
|
||
slot_mapping = common_attn_metadata.slot_mapping[token_indices[0] : token_indices[0] + 1]
|
||
else:
|
||
slot_mapping = common_attn_metadata.slot_mapping.new_empty(len(token_indices))
|
||
torch.ops.aten.index(common_attn_metadata.slot_mapping, [torch.tensor(token_indices, dtype=torch.int32)], out=slot_mapping)
|
||
|
||
spec_common_attn_metadata = CommonAttentionMetadata(
|
||
query_start_loc=new_query_start_loc_cpu,
|
||
# query_start_loc=new_query_start_loc_cpu.to(device,
|
||
# non_blocking=True),
|
||
seq_lens=new_seq_lens_cpu, #.to(device, non_blocking=True),
|
||
query_start_loc_cpu=new_query_start_loc_cpu,
|
||
seq_lens_cpu=new_seq_lens_cpu,
|
||
num_computed_tokens_cpu=common_attn_metadata.
|
||
num_computed_tokens_cpu,
|
||
num_reqs=common_attn_metadata.num_reqs,
|
||
num_actual_tokens=total_num_tokens,
|
||
max_query_len=max(new_query_len_per_req),#new_query_len_per_req.max().item(),
|
||
max_seq_len=None, #max(new_seq_lens_cpu),#new_seq_lens_cpu.max().item(),
|
||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||
slot_mapping=slot_mapping, #common_attn_metadata.slot_mapping[token_indices],
|
||
causal=True,
|
||
)
|
||
|
||
return spec_common_attn_metadata, token_indices
|
||
|
||
|
||
|
||
# @trace_time('EagleProposer_prepare_inputs')
|
||
@staticmethod
|
||
def prepare_inputs_9_2(
|
||
self,
|
||
# [batch_size + 1]
|
||
cu_target_query_lens: torch.Tensor,
|
||
# [batch_size]
|
||
num_rejected_tokens: torch.Tensor,
|
||
num_tokens: int,
|
||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||
# cu_target_query_lens: [0, a, a + b, a + b + c]
|
||
# num_rejected_tokens: [n1, n2, n3]
|
||
# num_tokens_per_req: [a - n1, b - n2, c - n3]
|
||
# cu_num_tokens: [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||
# token_indices: [0, 1, ..., a - n1 - 1,
|
||
# a, a + 1, ..., a + b - n2 - 1,
|
||
# a + b, a + b + 1, ..., a + b + c - n3 - 1]
|
||
|
||
# [0, a, a + b, a + b + c] -> [a, b, c]
|
||
# torch
|
||
# query_len_per_req = (cu_target_query_lens[1:] -
|
||
# cu_target_query_lens[:-1])
|
||
# [a, b, c] -> [a - n1, b - n2, c - n3]
|
||
# num_tokens_per_req = query_len_per_req - num_rejected_tokens
|
||
|
||
# list
|
||
num_tokens_per_req = [cu_target_query_lens[i+1] - cu_target_query_lens[i] - num_rejected_tokens[i] for i in range(len(cu_target_query_lens)-1)]
|
||
|
||
|
||
# [a - n1, b - n2, c - n3] ->
|
||
# [0, a - n1, a + b - n1 - n2, a + b + c - n1 - n2 - n3]
|
||
# torch style
|
||
# cu_num_tokens = torch.zeros_like(cu_target_query_lens)
|
||
# torch.cumsum(num_tokens_per_req, dim=0, out=cu_num_tokens[1:])
|
||
|
||
# list style
|
||
cu_num_tokens = [0] * len(cu_target_query_lens)
|
||
for i in range(len(cu_target_query_lens)-1):
|
||
cu_num_tokens[i+1] = cu_num_tokens[i] + num_tokens_per_req[i]
|
||
|
||
# token_indices = torch.empty(
|
||
# num_tokens,
|
||
# dtype=torch.int32,
|
||
# device=cu_target_query_lens.device,
|
||
# )
|
||
# batch_size = num_rejected_tokens.shape[0]
|
||
# BLOCK_SIZE = 1024
|
||
# prepare_eagle_input_kernel[(batch_size, )](
|
||
# token_indices,
|
||
# cu_target_query_lens,
|
||
# cu_num_tokens,
|
||
# BLOCK_SIZE=BLOCK_SIZE,
|
||
# )
|
||
token_indices = [0] * num_tokens
|
||
prepare_eagle_input_python(
|
||
token_indices,
|
||
cu_target_query_lens,
|
||
cu_num_tokens
|
||
)
|
||
|
||
|
||
return cu_num_tokens, token_indices
|
||
|
||
def EagleProposer_init_(
|
||
self,
|
||
vllm_config: VllmConfig,
|
||
device: torch.device,
|
||
runner=None,
|
||
):
|
||
self.vllm_config = vllm_config
|
||
self.speculative_config = vllm_config.speculative_config
|
||
assert self.speculative_config is not None
|
||
self.draft_model_config = self.speculative_config.draft_model_config
|
||
self.method = self.speculative_config.method
|
||
|
||
self.runner = runner
|
||
self.device = device
|
||
self.dtype = vllm_config.model_config.dtype
|
||
self.max_model_len = vllm_config.model_config.max_model_len
|
||
self.block_size = vllm_config.cache_config.block_size
|
||
self.num_speculative_tokens = (
|
||
self.speculative_config.num_speculative_tokens)
|
||
self.max_num_tokens = (
|
||
vllm_config.scheduler_config.max_num_batched_tokens)
|
||
self.token_arange_np = np.arange(self.max_num_tokens)
|
||
# We need to get the hidden size from the draft model config because
|
||
# the draft model's hidden size can be different from the target model's
|
||
# hidden size (e.g., Llama 3.3 70B).
|
||
self.hidden_size = self.draft_model_config.get_hidden_size()
|
||
|
||
self.is_multimodal_model = vllm_config.model_config \
|
||
.is_multimodal_model
|
||
|
||
self.attn_metadata_builder: Optional[AttentionMetadataBuilder] = None
|
||
self.draft_indexer_metadata_builder: Optional[
|
||
AttentionMetadataBuilder] = None
|
||
self.attn_layer_names: list[str] = []
|
||
self.indexer_layer_names: list[str] = []
|
||
|
||
self.use_cuda_graph = False
|
||
self.cudagraph_batch_sizes = []
|
||
|
||
|
||
# persistent buffers for cuda graph
|
||
self.input_ids = torch.zeros(self.max_num_tokens,
|
||
dtype=torch.int32,
|
||
device=device)
|
||
self.positions = torch.zeros(self.max_num_tokens,
|
||
dtype=torch.int64,
|
||
device=device)
|
||
# self.hidden_states = torch.zeros(
|
||
# (self.max_num_tokens, self.hidden_size),
|
||
# dtype=self.dtype,
|
||
# device=device)
|
||
# We need +1 here because the arange is used to set query_start_loc,
|
||
# which has one more element than batch_size.
|
||
max_batch_size = vllm_config.scheduler_config.max_num_seqs
|
||
max_num_slots_for_arange = max(max_batch_size + 1, self.max_num_tokens)
|
||
self.arange = torch.arange(max_num_slots_for_arange,
|
||
device=device,
|
||
dtype=torch.int32)
|
||
|
||
if self.is_multimodal_model:
|
||
self.inputs_embeds = torch.zeros(
|
||
(self.max_num_tokens, self.hidden_size),
|
||
dtype=self.dtype,
|
||
device=device)
|
||
else:
|
||
self.inputs_embeds = None
|
||
|
||
self.backup_next_token_ids = CpuGpuBuffer(
|
||
max_batch_size,
|
||
dtype=torch.int32,
|
||
pin_memory=is_pin_memory_available(),
|
||
device=device,
|
||
with_numpy=True)
|
||
|
||
# Determine allowed attention backends once during initialization.
|
||
self.allowed_attn_types: Optional[tuple] = None
|
||
|
||
# Parse the speculative token tree.
|
||
spec_token_tree = self.speculative_config.speculative_token_tree
|
||
self.tree_choices: list[tuple[int,
|
||
...]] = ast.literal_eval(spec_token_tree)
|
||
tree_depth = len(self.tree_choices[-1])
|
||
# Precompute per-level properties of the tree.
|
||
num_drafts_per_level = [0] * tree_depth
|
||
for node in self.tree_choices:
|
||
num_drafts_per_level[len(node) - 1] += 1
|
||
self.cu_drafts_per_level = [num_drafts_per_level[0]]
|
||
self.child_drafts_per_level = [num_drafts_per_level[0]]
|
||
for level in range(1, tree_depth):
|
||
self.cu_drafts_per_level.append(self.cu_drafts_per_level[-1] +
|
||
num_drafts_per_level[level])
|
||
self.child_drafts_per_level.append(num_drafts_per_level[level] //
|
||
num_drafts_per_level[level - 1])
|
||
# Precompute draft position offsets in flattened tree.
|
||
self.tree_draft_pos_offsets = torch.arange(
|
||
1,
|
||
len(self.tree_choices) + 1,
|
||
device=device,
|
||
dtype=torch.int32,
|
||
).repeat(max_batch_size, 1)
|
||
|