Files
xc-llm-kunlun/vllm_kunlun/v1/sample/spec_decode/eagle.py
2026-01-06 21:37:21 +08:00

312 lines
14 KiB
Python

# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import EagleProposer
logger = init_logger(__name__)
PADDING_SLOT_ID = -1
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)
if attn_metadata.decode is not None and attn_metadata.decode.spec_num_seq_len is not None:
attn_metadata.decode.spec_num_seq_len = -1
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
if self.is_multimodal_model:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = self.model.get_input_embeddings(
input_ids,
multimodal_embeddings=mm_embeds or None,
)
self.inputs_embeds[:num_tokens] = inputs_embeds
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
else:
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
hidden_states = self.hidden_states[:num_input_tokens]
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
if isinstance(attn_metadata, TreeAttentionMetadata):
# Draft using tree attention.
draft_token_ids_list = self.propose_tree(
batch_size=batch_size,
logits=logits,
positions=positions,
hidden_states=hidden_states,
common_attn_metadata=common_attn_metadata,
)
# [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1)
draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1].to(torch.int32)
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[:batch_size + 1]).clone().to(torch.int32)
for _ in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
# Increment the sequence lengths.
common_attn_metadata.seq_lens += 1
common_attn_metadata.seq_lens_cpu += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len,
1)
common_attn_metadata.num_computed_tokens_cpu = \
common_attn_metadata.seq_lens_cpu - 1
# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = common_attn_metadata.block_table_tensor.gather(
dim=1, index=block_numbers.view(-1, 1))
block_ids = block_ids.view(-1)
common_attn_metadata.slot_mapping = (
block_ids * self.block_size +
clamped_positions % self.block_size)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
common_attn_metadata.slot_mapping.masked_fill_(
exceeds_max_model_len, PADDING_SLOT_ID)
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
if self.is_multimodal_model:
inputs_embeds = self.model.get_input_embeddings(input_ids)
self.inputs_embeds[:batch_size] = inputs_embeds
inputs_embeds = self.inputs_embeds[:input_batch_size]
input_ids = None
else:
inputs_embeds = None
input_ids = self.input_ids[:input_batch_size]
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds,
)
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
def prepare_next_token_ids_padded(self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_indices: torch.Tensor,
num_discarded_requests: int) -> \
tuple[torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids and the number of valid sampled tokens
for each request, considering the "discarded" requests whose next token
is not sampled and comes from `request.get_token_id()` instead.
It also accounts for the rejected tokens in `sampled_token_ids`.
This function must use device functions to operate on the inputs, and
should not introduce any blocking CPU-GPU synchronization.
"""
# TODO(Ben): Combine this into a custom fused kernel
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
self.backup_next_token_ids.np[:num_reqs] = np.array([
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item())
for i in range(num_reqs)
])
self.backup_next_token_ids.copy_to_gpu(num_reqs)
# Mask out the sampled tokens indices that should not be sampled.
discard_sampled_tokens_req_indices = \
discard_request_indices[:num_discarded_requests]
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
# valid_sampled_token_ids_gpu.index_fill_(
# 0, discard_sampled_tokens_req_indices, -1)
# ---- FIX START ----
# XPU/XMLIR index_fill_ does NOT accept empty index tensor.
if num_discarded_requests > 0:
# make sure index is on same device and is int64
idx = discard_sampled_tokens_req_indices
if idx.device != valid_sampled_token_ids_gpu.device:
idx = idx.to(valid_sampled_token_ids_gpu.device, non_blocking=True)
if idx.dtype != torch.long:
idx = idx.to(torch.long)
if idx.numel() > 0:
valid_sampled_token_ids_gpu.index_fill_(0, idx, -1)
# ---- FIX END ----
# Generate a mask for all valid tokens within those requests
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
dtype=torch.bool)
else:
valid_mask = (
(valid_sampled_token_ids_gpu != -1) &
(valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size))
# Count the number of valid tokens in each request
valid_sampled_tokens_count = valid_mask.sum(dim=1)
# Get the rightmost valid index per row
last_valid_indices = valid_sampled_tokens_count - 1
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
# Get last valid token from each row
# (assume undefined state where there is no valid token)
selected_tokens = torch.gather(
valid_sampled_token_ids_gpu, 1,
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
# Use last token if valid, pre-computed backup if not
batch_size = valid_sampled_token_ids_gpu.shape[0]
next_token_ids = torch.where(
last_valid_indices != -1, selected_tokens,
self.backup_next_token_ids.gpu[:batch_size])
return next_token_ids, valid_sampled_tokens_count
EagleProposer.propose = propose
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded