clean pr for ds.2 mtp support (#164)

* Add MTP support in eagle.py

Signed-off-by: wanghao129 <wanghao129@baidu.com>

* new pr for mtp

Signed-off-by: wanghao129 <wanghao129@baidu.com>

* Revert formatting changes in deepseek_v2.py

Signed-off-by: wanghao129 <wanghao129@baidu.com>

---------

Signed-off-by: wanghao129 <wanghao129@baidu.com>
Co-authored-by: wanghao129 <wanghao129@baidu.com>
This commit is contained in:
WANG HAO
2026-02-02 15:23:33 +08:00
committed by GitHub
parent 42a2d38f47
commit 6f30bc439d
3 changed files with 295 additions and 229 deletions

View File

@@ -15,20 +15,22 @@
# This file is a part of the vllm-ascend project.
#
# embedding
import vllm_kunlun.ops.rotary_embedding
import vllm_kunlun.ops.vocab_parallel_embedding
# quantization
import vllm_kunlun.ops.quantization.awq
import vllm_kunlun.ops.quantization.gptq
import vllm_kunlun.ops.quantization.moe_wna16
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe
import vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm
import vllm_kunlun.ops.quantization.kernels.kunlun_exllama_linear
import vllm_kunlun.ops.fused_moe.layer
# base layers
import vllm_kunlun.ops.layernorm
import vllm_kunlun.ops.linear
import vllm_kunlun.ops.fused_moe.layer
# quantization
import vllm_kunlun.ops.quantization.awq
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe
import vllm_kunlun.ops.quantization.gptq
import vllm_kunlun.ops.quantization.kernels.kunlun_exllama_linear
import vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm
import vllm_kunlun.ops.quantization.moe_wna16
# embedding
import vllm_kunlun.ops.rotary_embedding
import vllm_kunlun.ops.vocab_parallel_embedding
import vllm_kunlun.v1.sample.spec_decode.eagle

View File

@@ -1,17 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional
import torch
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
DeepseekV32IndexerMetadataBuilder,
kv_spans_from_batches)
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadataBuilder,
kv_spans_from_batches,
split_prefill_chunks,
)
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
split_decodes_and_prefills,
)
logger = init_logger(__name__)
@dataclass
class DeepseekV32IndexerPrefillChunkMetadata:
block_table: torch.Tensor
@@ -32,6 +38,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
class DeepseekV32IndexerPrefillMetadata:
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
@dataclass
class DeepSeekV32IndexerDecodeMetadata:
block_table: torch.Tensor
@@ -70,26 +77,36 @@ class DeepseekV32IndexerMetadata:
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end,
query_start_loc_cpu, seq_lens_cpu,
block_table):
prefill_query_start_loc = query_start_loc_cpu[
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
def kunlun_build_one_prefill_chunk(
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
):
prefill_query_start_loc = (
query_start_loc_cpu[reqs_start : reqs_end + 1] - query_start_loc_cpu[reqs_start]
)
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
self.device)
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
)
token_start = query_start_loc_cpu[reqs_start].item()
token_end = query_start_loc_cpu[reqs_end].item()
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = torch.cat([
torch.zeros(1, dtype=torch.int32),
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
]).to(torch.int32).to(self.device)
cu_seq_lens = (
torch.cat(
[
torch.zeros(1, dtype=torch.int32),
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
]
)
.to(torch.int32)
.to(self.device)
)
seq_len_q = token_end - token_start
seq_len_kv = total_seq_lens
context_q_lens = torch.tensor([0, seq_len_q], dtype=torch.int32, device=self.device)
context_k_lens = torch.tensor([0, seq_len_kv], dtype=torch.int32, device=self.device)
context_k_lens = torch.tensor(
[0, seq_len_kv], dtype=torch.int32, device=self.device
)
context_q_lens_cpu = torch.tensor([0, seq_len_q], dtype=torch.int32, device="cpu")
context_k_lens_cpu = torch.tensor([0, seq_len_kv], dtype=torch.int32, device="cpu")
@@ -107,85 +124,103 @@ def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end,
context_k_lens=context_k_lens,
context_k_lens_cpu=context_k_lens_cpu,
)
def kunlun_build(self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
def kunlun_build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False,
) -> DeepseekV32IndexerMetadata:
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
prefill_metadata = None
if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks(
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
prefill_metadata = None
if num_prefills > 0:
chunk_seq_ids = split_prefill_chunks(
common_attn_metadata.seq_lens_cpu,
self.max_prefill_buffer_size,
num_decodes,
)
chunks = [
self.build_one_prefill_chunk(
reqs_start,
reqs_end,
query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
self.max_prefill_buffer_size,
num_decodes,
common_attn_metadata.block_table_tensor,
)
chunks = [
self.build_one_prefill_chunk(
reqs_start, reqs_end, query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.block_table_tensor)
for reqs_start, reqs_end in chunk_seq_ids
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks, )
decode_metadata = None
if num_decodes > 0:
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
out=self.decode_lens_buffer[:num_decodes])
decode_lens = self.decode_lens_buffer[:num_decodes]
decode_lens_cpu = torch.diff(
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max()
> decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.
block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(),
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
)
attn_metadata = DeepseekV32IndexerMetadata(
seq_lens=common_attn_metadata.seq_lens,
seq_lens_cpu=common_attn_metadata.seq_lens.cpu(),
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
prefill=prefill_metadata,
decode=decode_metadata,
for reqs_start, reqs_end in chunk_seq_ids
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks,
)
# if get_tensor_model_parallel_rank() == 0:
# logger.info(f"attn_metadata: {attn_metadata}")
return attn_metadata
decode_metadata = None
if num_decodes > 0:
torch.diff(
common_attn_metadata.query_start_loc[: num_decodes + 1],
out=self.decode_lens_buffer[:num_decodes],
)
decode_lens = self.decode_lens_buffer[:num_decodes]
decode_lens_cpu = torch.diff(
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
)
DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk= kunlun_build_one_prefill_chunk
DeepseekV32IndexerMetadataBuilder.build = kunlun_build
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(),
decode_lens=decode_lens,
requires_padding=requires_padding,
schedule_metadata=self.scheduler_metadata_buffer,
)
attn_metadata = DeepseekV32IndexerMetadata(
seq_lens=common_attn_metadata.seq_lens,
seq_lens_cpu=common_attn_metadata.seq_lens.cpu(),
num_reqs=common_attn_metadata.num_reqs,
max_query_len=common_attn_metadata.max_query_len,
max_seq_len=common_attn_metadata.max_seq_len,
num_actual_tokens=common_attn_metadata.num_actual_tokens,
query_start_loc=common_attn_metadata.query_start_loc,
slot_mapping=common_attn_metadata.slot_mapping,
head_dim=128,
num_decodes=num_decodes,
num_decode_tokens=num_decode_tokens,
num_prefills=num_prefills,
num_prefill_tokens=num_prefill_tokens,
prefill=prefill_metadata,
decode=decode_metadata,
)
# if get_tensor_model_parallel_rank() == 0:
# logger.info(f"attn_metadata: {attn_metadata}")
return attn_metadata
DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk = (
kunlun_build_one_prefill_chunk
)
DeepseekV32IndexerMetadataBuilder.build = kunlun_build
# Monkey patch: Upgrade cudagraph_support to UNIFORM_BATCH for spec-decode compatibility
from vllm.v1.attention.backends.utils import AttentionCGSupport
DeepseekV32IndexerMetadataBuilder.cudagraph_support = AttentionCGSupport.UNIFORM_BATCH

View File

@@ -1,26 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.attention.backends.tree_attn import TreeAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__)
@@ -37,47 +29,65 @@ def propose(
target_hidden_states: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
last_token_indices: Optional[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
target_hidden_states = self.model.combine_hidden_states(target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)
if attn_metadata.decode is not None and attn_metadata.decode.spec_num_seq_len is not None:
attn_metadata.decode.spec_num_seq_len = -1
ubatch_id = dbo_current_ubatch_id()
attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
if (
hasattr(attn_metadata, "decode")
and attn_metadata.decode is not None
and hasattr(attn_metadata.decode, "spec_num_seq_len")
and attn_metadata.decode.spec_num_seq_len is not None
):
attn_metadata.decode.spec_num_seq_len = -1
if self.draft_indexer_metadata_builder:
draft_indexer_metadata = self.draft_indexer_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=0,
)
else:
draft_indexer_metadata = None
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
for layer_name in self.indexer_layer_names:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
@@ -94,22 +104,25 @@ def propose(
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
with set_forward_context(
per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
if self.method == "deepseek_mtp":
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = self.hidden_states[:num_input_tokens]
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
if self.method == "mtp":
logits = self.model.compute_logits(sample_hidden_states, 0)
else:
logits = self.model.compute_logits(sample_hidden_states, None)
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
@@ -136,21 +149,22 @@ def propose(
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1].to(torch.int32)
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[:batch_size + 1]).clone().to(torch.int32)
for _ in range(self.num_speculative_tokens - 1):
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1].to(torch.int32)
common_attn_metadata.query_start_loc_cpu = (
torch.from_numpy(self.token_arange_np[: batch_size + 1]).clone().to(torch.int32)
)
attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builder
for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
@@ -166,38 +180,39 @@ def propose(
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
# Increment the sequence lengths.
common_attn_metadata.seq_lens += 1
common_attn_metadata.seq_lens_cpu += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len,
1)
common_attn_metadata.num_computed_tokens_cpu = \
common_attn_metadata.seq_lens_cpu - 1
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
common_attn_metadata.num_computed_tokens_cpu = (
common_attn_metadata.seq_lens_cpu - 1
)
# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = common_attn_metadata.block_table_tensor.gather(
dim=1, index=block_numbers.view(-1, 1))
dim=1, index=block_numbers.view(-1, 1)
)
block_ids = block_ids.view(-1)
common_attn_metadata.slot_mapping = (
block_ids * self.block_size +
clamped_positions % self.block_size)
block_ids * self.block_size + clamped_positions % self.block_size
)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
common_attn_metadata.slot_mapping.masked_fill_(
exceeds_max_model_len, PADDING_SLOT_ID)
exceeds_max_model_len, PADDING_SLOT_ID
)
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
per_layer_attn_metadata[layer_name] = attn_metadata
# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions
@@ -212,17 +227,23 @@ def propose(
input_ids = self.input_ids[:input_batch_size]
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states = self.model(
with set_forward_context(
per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size
):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds,
)
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size])
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
@@ -230,83 +251,91 @@ def propose(
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
def prepare_next_token_ids_padded(self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_indices: torch.Tensor,
num_discarded_requests: int) -> \
tuple[torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids and the number of valid sampled tokens
for each request, considering the "discarded" requests whose next token
is not sampled and comes from `request.get_token_id()` instead.
It also accounts for the rejected tokens in `sampled_token_ids`.
This function must use device functions to operate on the inputs, and
should not introduce any blocking CPU-GPU synchronization.
"""
# TODO(Ben): Combine this into a custom fused kernel
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
self.backup_next_token_ids.np[:num_reqs] = np.array([
def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_indices: torch.Tensor,
num_discarded_requests: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids and the number of valid sampled tokens
for each request, considering the "discarded" requests whose next token
is not sampled and comes from `request.get_token_id()` instead.
It also accounts for the rejected tokens in `sampled_token_ids`.
This function must use device functions to operate on the inputs, and
should not introduce any blocking CPU-GPU synchronization.
"""
# TODO(Ben): Combine this into a custom fused kernel
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
self.backup_next_token_ids.np[:num_reqs] = np.array(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item())
common_attn_metadata.seq_lens_cpu[i].item()
)
for i in range(num_reqs)
])
self.backup_next_token_ids.copy_to_gpu(num_reqs)
]
)
self.backup_next_token_ids.copy_to_gpu(num_reqs)
# Mask out the sampled tokens indices that should not be sampled.
discard_sampled_tokens_req_indices = \
discard_request_indices[:num_discarded_requests]
# Mask out the sampled tokens indices that should not be sampled.
discard_sampled_tokens_req_indices = discard_request_indices[
:num_discarded_requests
]
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
# valid_sampled_token_ids_gpu.index_fill_(
# 0, discard_sampled_tokens_req_indices, -1)
# ---- FIX START ----
# XPU/XMLIR index_fill_ does NOT accept empty index tensor.
if num_discarded_requests > 0:
# make sure index is on same device and is int64
idx = discard_sampled_tokens_req_indices
if idx.device != valid_sampled_token_ids_gpu.device:
idx = idx.to(valid_sampled_token_ids_gpu.device, non_blocking=True)
if idx.dtype != torch.long:
idx = idx.to(torch.long)
if idx.numel() > 0:
valid_sampled_token_ids_gpu.index_fill_(0, idx, -1)
# ---- FIX END ----
# Generate a mask for all valid tokens within those requests
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
dtype=torch.bool)
else:
valid_mask = (
(valid_sampled_token_ids_gpu != -1) &
(valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size))
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
# valid_sampled_token_ids_gpu.index_fill_(
# 0, discard_sampled_tokens_req_indices, -1)
# ---- FIX START ----
# XPU/XMLIR index_fill_ does NOT accept empty index tensor.
if num_discarded_requests > 0:
# make sure index is on same device and is int64
idx = discard_sampled_tokens_req_indices
if idx.device != valid_sampled_token_ids_gpu.device:
idx = idx.to(valid_sampled_token_ids_gpu.device, non_blocking=True)
if idx.dtype != torch.long:
idx = idx.to(torch.long)
if idx.numel() > 0:
valid_sampled_token_ids_gpu.index_fill_(0, idx, -1)
# ---- FIX END ----
# Generate a mask for all valid tokens within those requests
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool)
else:
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
)
# Count the number of valid tokens in each request
valid_sampled_tokens_count = valid_mask.sum(dim=1)
# Count the number of valid tokens in each request
valid_sampled_tokens_count = valid_mask.sum(dim=1)
# Get the rightmost valid index per row
last_valid_indices = valid_sampled_tokens_count - 1
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
# Get the rightmost valid index per row
last_valid_indices = valid_sampled_tokens_count - 1
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
# Get last valid token from each row
# (assume undefined state where there is no valid token)
selected_tokens = torch.gather(
valid_sampled_token_ids_gpu, 1,
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
# Get last valid token from each row
# (assume undefined state where there is no valid token)
selected_tokens = torch.gather(
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
).squeeze(1)
# Use last token if valid, pre-computed backup if not
batch_size = valid_sampled_token_ids_gpu.shape[0]
next_token_ids = torch.where(
last_valid_indices != -1, selected_tokens,
self.backup_next_token_ids.gpu[:batch_size])
# Use last token if valid, pre-computed backup if not
batch_size = valid_sampled_token_ids_gpu.shape[0]
next_token_ids = torch.where(
last_valid_indices != -1,
selected_tokens,
self.backup_next_token_ids.gpu[:batch_size],
)
return next_token_ids, valid_sampled_tokens_count
return next_token_ids, valid_sampled_tokens_count
EagleProposer.propose = propose
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded