Files
xc-llm-kunlun/vllm_kunlun/v1/sample/spec_decode/eagle.py

342 lines
14 KiB
Python
Raw Normal View History

2026-01-06 21:37:21 +08:00
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Optional
import numpy as np
import torch
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.tree_attn import TreeAttentionMetadata
2026-01-06 21:37:21 +08:00
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
2026-01-06 21:37:21 +08:00
logger = init_logger(__name__)
PADDING_SLOT_ID = -1
def propose(
self,
# [num_tokens]
target_token_ids: torch.Tensor,
# [num_tokens]
target_positions: torch.Tensor,
# [num_tokens, hidden_size]
target_hidden_states: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
last_token_indices: Optional[torch.Tensor],
2026-01-06 21:37:21 +08:00
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
2026-01-06 21:37:21 +08:00
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(target_hidden_states)
2026-01-06 21:37:21 +08:00
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
2026-01-06 21:37:21 +08:00
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
assert self.runner is not None
ubatch_id = dbo_current_ubatch_id()
attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
if (
hasattr(attn_metadata, "decode")
and attn_metadata.decode is not None
and hasattr(attn_metadata.decode, "spec_num_seq_len")
and attn_metadata.decode.spec_num_seq_len is not None
):
attn_metadata.decode.spec_num_seq_len = -1
if self.draft_indexer_metadata_builder:
draft_indexer_metadata = self.draft_indexer_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=0,
)
else:
draft_indexer_metadata = None
2026-01-06 21:37:21 +08:00
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
for layer_name in self.indexer_layer_names:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
2026-01-06 21:37:21 +08:00
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
if self.is_multimodal_model:
input_ids = self.input_ids[:num_tokens]
inputs_embeds = self.model.get_input_embeddings(
input_ids,
multimodal_embeddings=mm_embeds or None,
)
self.inputs_embeds[:num_tokens] = inputs_embeds
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
else:
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]
with set_forward_context(
per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
):
2026-01-06 21:37:21 +08:00
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
if self.method == "mtp":
2026-01-06 21:37:21 +08:00
last_hidden_states = ret_hidden_states
hidden_states = self.hidden_states[:num_input_tokens]
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
if self.method == "mtp":
logits = self.model.compute_logits(sample_hidden_states, 0)
else:
logits = self.model.compute_logits(sample_hidden_states, None)
2026-01-06 21:37:21 +08:00
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
if isinstance(attn_metadata, TreeAttentionMetadata):
# Draft using tree attention.
draft_token_ids_list = self.propose_tree(
batch_size=batch_size,
logits=logits,
positions=positions,
hidden_states=hidden_states,
common_attn_metadata=common_attn_metadata,
)
# [batch_size, num_tree_tokens]
return torch.cat(draft_token_ids_list, dim=1)
draft_token_ids = logits.argmax(dim=-1)
# Early exit if there is only one draft token to be generated.
if self.num_speculative_tokens == 1:
# [batch_size, 1]
return draft_token_ids.view(-1, 1)
# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
2026-01-06 21:37:21 +08:00
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1].to(torch.int32)
common_attn_metadata.query_start_loc_cpu = (
torch.from_numpy(self.token_arange_np[: batch_size + 1]).clone().to(torch.int32)
)
attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builder
for token_index in range(self.num_speculative_tokens - 1):
2026-01-06 21:37:21 +08:00
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
positions += 1
# NOTE(woosuk): We should handle the case where the draft model
# generates tokens beyond the max model length. Since it is complex
# to remove such requests from the batch, we keep them in the batch
# but adjust the position ids and slot mappings to avoid the
# out-of-range access during the model execution. The draft tokens
# generated with this adjustment should be ignored.
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
2026-01-06 21:37:21 +08:00
# Increment the sequence lengths.
common_attn_metadata.seq_lens += 1
common_attn_metadata.seq_lens_cpu += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
common_attn_metadata.num_computed_tokens_cpu = (
common_attn_metadata.seq_lens_cpu - 1
)
2026-01-06 21:37:21 +08:00
# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = common_attn_metadata.block_table_tensor.gather(
dim=1, index=block_numbers.view(-1, 1)
)
2026-01-06 21:37:21 +08:00
block_ids = block_ids.view(-1)
common_attn_metadata.slot_mapping = (
block_ids * self.block_size + clamped_positions % self.block_size
)
2026-01-06 21:37:21 +08:00
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
common_attn_metadata.slot_mapping.masked_fill_(
exceeds_max_model_len, PADDING_SLOT_ID
)
2026-01-06 21:37:21 +08:00
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
)
2026-01-06 21:37:21 +08:00
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
2026-01-06 21:37:21 +08:00
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
self.hidden_states[:batch_size] = hidden_states
if self.is_multimodal_model:
inputs_embeds = self.model.get_input_embeddings(input_ids)
self.inputs_embeds[:batch_size] = inputs_embeds
inputs_embeds = self.inputs_embeds[:input_batch_size]
input_ids = None
else:
inputs_embeds = None
input_ids = self.input_ids[:input_batch_size]
# Run the model.
with set_forward_context(
per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size
):
ret_hidden_states = self.model(
2026-01-06 21:37:21 +08:00
input_ids=input_ids,
positions=self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds,
)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size])
2026-01-06 21:37:21 +08:00
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
# [batch_size, num_speculative_tokens]
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_indices: torch.Tensor,
num_discarded_requests: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids and the number of valid sampled tokens
for each request, considering the "discarded" requests whose next token
is not sampled and comes from `request.get_token_id()` instead.
It also accounts for the rejected tokens in `sampled_token_ids`.
This function must use device functions to operate on the inputs, and
should not introduce any blocking CPU-GPU synchronization.
"""
# TODO(Ben): Combine this into a custom fused kernel
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
self.backup_next_token_ids.np[:num_reqs] = np.array(
[
2026-01-06 21:37:21 +08:00
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item()
)
2026-01-06 21:37:21 +08:00
for i in range(num_reqs)
]
)
self.backup_next_token_ids.copy_to_gpu(num_reqs)
# Mask out the sampled tokens indices that should not be sampled.
discard_sampled_tokens_req_indices = discard_request_indices[
:num_discarded_requests
]
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
# valid_sampled_token_ids_gpu.index_fill_(
# 0, discard_sampled_tokens_req_indices, -1)
# ---- FIX START ----
# XPU/XMLIR index_fill_ does NOT accept empty index tensor.
if num_discarded_requests > 0:
# make sure index is on same device and is int64
idx = discard_sampled_tokens_req_indices
if idx.device != valid_sampled_token_ids_gpu.device:
idx = idx.to(valid_sampled_token_ids_gpu.device, non_blocking=True)
if idx.dtype != torch.long:
idx = idx.to(torch.long)
if idx.numel() > 0:
valid_sampled_token_ids_gpu.index_fill_(0, idx, -1)
# ---- FIX END ----
# Generate a mask for all valid tokens within those requests
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool)
else:
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
)
# Count the number of valid tokens in each request
valid_sampled_tokens_count = valid_mask.sum(dim=1)
2026-01-06 21:37:21 +08:00
# Get the rightmost valid index per row
last_valid_indices = valid_sampled_tokens_count - 1
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
2026-01-06 21:37:21 +08:00
# Get last valid token from each row
# (assume undefined state where there is no valid token)
selected_tokens = torch.gather(
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
).squeeze(1)
2026-01-06 21:37:21 +08:00
# Use last token if valid, pre-computed backup if not
batch_size = valid_sampled_token_ids_gpu.shape[0]
next_token_ids = torch.where(
last_valid_indices != -1,
selected_tokens,
self.backup_next_token_ids.gpu[:batch_size],
)
2026-01-06 21:37:21 +08:00
return next_token_ids, valid_sampled_tokens_count
2026-01-06 21:37:21 +08:00
EagleProposer.propose = propose
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded