clean pr for ds.2 mtp support (#164)

* Add MTP support in eagle.py

Signed-off-by: wanghao129 <wanghao129@baidu.com>

* new pr for mtp

Signed-off-by: wanghao129 <wanghao129@baidu.com>

* Revert formatting changes in deepseek_v2.py

Signed-off-by: wanghao129 <wanghao129@baidu.com>

---------

Signed-off-by: wanghao129 <wanghao129@baidu.com>
Co-authored-by: wanghao129 <wanghao129@baidu.com>
This commit is contained in:
WANG HAO
2026-02-02 15:23:33 +08:00
committed by GitHub
parent 42a2d38f47
commit 6f30bc439d
3 changed files with 295 additions and 229 deletions

View File

@@ -15,20 +15,22 @@
# This file is a part of the vllm-ascend project.
#
# embedding
import vllm_kunlun.ops.rotary_embedding
import vllm_kunlun.ops.vocab_parallel_embedding
# quantization
import vllm_kunlun.ops.quantization.awq
import vllm_kunlun.ops.quantization.gptq
import vllm_kunlun.ops.quantization.moe_wna16
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe
import vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm
import vllm_kunlun.ops.quantization.kernels.kunlun_exllama_linear
import vllm_kunlun.ops.fused_moe.layer
# base layers
import vllm_kunlun.ops.layernorm
import vllm_kunlun.ops.linear
import vllm_kunlun.ops.fused_moe.layer
# quantization
import vllm_kunlun.ops.quantization.awq
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe
import vllm_kunlun.ops.quantization.gptq
import vllm_kunlun.ops.quantization.kernels.kunlun_exllama_linear
import vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm
import vllm_kunlun.ops.quantization.moe_wna16
# embedding
import vllm_kunlun.ops.rotary_embedding
import vllm_kunlun.ops.vocab_parallel_embedding
import vllm_kunlun.v1.sample.spec_decode.eagle

View File

@@ -1,17 +1,23 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import torch
from dataclasses import dataclass
from typing import ClassVar, Optional
from typing import Optional
import torch
from vllm.logger import init_logger
from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata,
split_decodes_and_prefills)
from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
from vllm.v1.attention.backends.mla.indexer import (
DeepseekV32IndexerMetadataBuilder,
kv_spans_from_batches)
kv_spans_from_batches,
split_prefill_chunks,
)
from vllm.v1.attention.backends.utils import (
CommonAttentionMetadata,
split_decodes_and_prefills,
)
logger = init_logger(__name__)
@dataclass
class DeepseekV32IndexerPrefillChunkMetadata:
block_table: torch.Tensor
@@ -32,6 +38,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
class DeepseekV32IndexerPrefillMetadata:
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
@dataclass
class DeepSeekV32IndexerDecodeMetadata:
block_table: torch.Tensor
@@ -70,26 +77,36 @@ class DeepseekV32IndexerMetadata:
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end,
query_start_loc_cpu, seq_lens_cpu,
block_table):
prefill_query_start_loc = query_start_loc_cpu[
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
def kunlun_build_one_prefill_chunk(
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
):
prefill_query_start_loc = (
query_start_loc_cpu[reqs_start : reqs_end + 1] - query_start_loc_cpu[reqs_start]
)
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
self.device)
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
)
token_start = query_start_loc_cpu[reqs_start].item()
token_end = query_start_loc_cpu[reqs_end].item()
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
assert total_seq_lens <= self.max_prefill_buffer_size
cu_seq_lens = torch.cat([
cu_seq_lens = (
torch.cat(
[
torch.zeros(1, dtype=torch.int32),
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
]).to(torch.int32).to(self.device)
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
]
)
.to(torch.int32)
.to(self.device)
)
seq_len_q = token_end - token_start
seq_len_kv = total_seq_lens
context_q_lens = torch.tensor([0, seq_len_q], dtype=torch.int32, device=self.device)
context_k_lens = torch.tensor([0, seq_len_kv], dtype=torch.int32, device=self.device)
context_k_lens = torch.tensor(
[0, seq_len_kv], dtype=torch.int32, device=self.device
)
context_q_lens_cpu = torch.tensor([0, seq_len_q], dtype=torch.int32, device="cpu")
context_k_lens_cpu = torch.tensor([0, seq_len_kv], dtype=torch.int32, device="cpu")
@@ -107,19 +124,24 @@ def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end,
context_k_lens=context_k_lens,
context_k_lens_cpu=context_k_lens_cpu,
)
def kunlun_build(self,
def kunlun_build(
self,
common_prefix_len: int,
common_attn_metadata: CommonAttentionMetadata,
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
fast_build: bool = False,
) -> DeepseekV32IndexerMetadata:
num_reqs = common_attn_metadata.num_reqs
num_tokens = common_attn_metadata.num_actual_tokens
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
split_decodes_and_prefills(
common_attn_metadata,
decode_threshold=self.reorder_batch_threshold)
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
)
)
assert num_decodes + num_prefills == num_reqs
assert num_decode_tokens + num_prefill_tokens == num_tokens
@@ -133,31 +155,36 @@ def kunlun_build(self,
)
chunks = [
self.build_one_prefill_chunk(
reqs_start, reqs_end, query_start_loc_cpu,
reqs_start,
reqs_end,
query_start_loc_cpu,
common_attn_metadata.seq_lens_cpu,
common_attn_metadata.block_table_tensor)
common_attn_metadata.block_table_tensor,
)
for reqs_start, reqs_end in chunk_seq_ids
]
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
chunks=chunks, )
chunks=chunks,
)
decode_metadata = None
if num_decodes > 0:
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
out=self.decode_lens_buffer[:num_decodes])
torch.diff(
common_attn_metadata.query_start_loc[: num_decodes + 1],
out=self.decode_lens_buffer[:num_decodes],
)
decode_lens = self.decode_lens_buffer[:num_decodes]
decode_lens_cpu = torch.diff(
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
)
# Use CPU to avoid GPU sync; breaking async scheduling
requires_padding = (decode_lens_cpu.max()
> decode_lens_cpu.min()).item()
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
block_table=common_attn_metadata.
block_table_tensor[:num_decodes, ...],
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(),
decode_lens=decode_lens,
@@ -187,5 +214,13 @@ def kunlun_build(self,
# logger.info(f"attn_metadata: {attn_metadata}")
return attn_metadata
DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk= kunlun_build_one_prefill_chunk
DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk = (
kunlun_build_one_prefill_chunk
)
DeepseekV32IndexerMetadataBuilder.build = kunlun_build
# Monkey patch: Upgrade cudagraph_support to UNIFORM_BATCH for spec-decode compatibility
from vllm.v1.attention.backends.utils import AttentionCGSupport
DeepseekV32IndexerMetadataBuilder.cudagraph_support = AttentionCGSupport.UNIFORM_BATCH

View File

@@ -1,26 +1,18 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import ast
from typing import Optional
import numpy as np
import torch
import torch.nn as nn
from vllm.attention.layer import Attention
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.platforms import current_platform
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.rocm_aiter_fa import (
AiterFlashAttentionMetadata)
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
TreeAttentionMetadataBuilder)
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
from vllm.v1.attention.backends.tree_attn import TreeAttentionMetadata
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.eagle import EagleProposer
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
logger = init_logger(__name__)
@@ -37,47 +29,65 @@ def propose(
target_hidden_states: torch.Tensor,
# [batch_size]
next_token_ids: torch.Tensor,
last_token_indices: Optional[torch.Tensor],
common_attn_metadata: CommonAttentionMetadata,
sampling_metadata: SamplingMetadata,
mm_embeds: Optional[list[torch.Tensor]] = None,
) -> torch.Tensor:
num_tokens = target_token_ids.shape[0]
batch_size = next_token_ids.shape[0]
if last_token_indices is None:
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
target_hidden_states = self.model.combine_hidden_states(target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
self.input_ids[last_token_indices] = next_token_ids
assert self.runner is not None
# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)
if attn_metadata.decode is not None and attn_metadata.decode.spec_num_seq_len is not None:
ubatch_id = dbo_current_ubatch_id()
attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=0
)
if (
hasattr(attn_metadata, "decode")
and attn_metadata.decode is not None
and hasattr(attn_metadata.decode, "spec_num_seq_len")
and attn_metadata.decode.spec_num_seq_len is not None
):
attn_metadata.decode.spec_num_seq_len = -1
if self.draft_indexer_metadata_builder:
draft_indexer_metadata = self.draft_indexer_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata,
draft_index=0,
)
else:
draft_indexer_metadata = None
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
for layer_name in self.indexer_layer_names:
assert draft_indexer_metadata is not None
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states
@@ -94,21 +104,24 @@ def propose(
inputs_embeds = None
input_ids = self.input_ids[:num_input_tokens]
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
with set_forward_context(
per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
inputs_embeds=inputs_embeds,
)
if self.method == "deepseek_mtp":
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = self.hidden_states[:num_input_tokens]
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
if self.method == "mtp":
logits = self.model.compute_logits(sample_hidden_states, 0)
else:
logits = self.model.compute_logits(sample_hidden_states, None)
positions = target_positions[last_token_indices]
hidden_states = hidden_states[last_token_indices]
@@ -136,21 +149,22 @@ def propose(
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.
# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
common_attn_metadata.num_actual_tokens = batch_size
common_attn_metadata.max_query_len = 1
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1].to(torch.int32)
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
self.token_arange_np[:batch_size + 1]).clone().to(torch.int32)
for _ in range(self.num_speculative_tokens - 1):
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1].to(torch.int32)
common_attn_metadata.query_start_loc_cpu = (
torch.from_numpy(self.token_arange_np[: batch_size + 1]).clone().to(torch.int32)
)
attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builder
for token_index in range(self.num_speculative_tokens - 1):
# Update the inputs.
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
@@ -166,36 +180,37 @@ def propose(
exceeds_max_model_len = positions >= self.max_model_len
# Mask out the position ids that exceed the max model length.
# Otherwise, we may get out-of-range error in RoPE.
clamped_positions = torch.where(exceeds_max_model_len, 0,
positions)
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
# Increment the sequence lengths.
common_attn_metadata.seq_lens += 1
common_attn_metadata.seq_lens_cpu += 1
# For the requests that exceed the max model length, we set the
# sequence length to 1 to minimize their overheads in attention.
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len,
1)
common_attn_metadata.num_computed_tokens_cpu = \
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
common_attn_metadata.num_computed_tokens_cpu = (
common_attn_metadata.seq_lens_cpu - 1
)
# Compute the slot mapping.
block_numbers = clamped_positions // self.block_size
block_ids = common_attn_metadata.block_table_tensor.gather(
dim=1, index=block_numbers.view(-1, 1))
dim=1, index=block_numbers.view(-1, 1)
)
block_ids = block_ids.view(-1)
common_attn_metadata.slot_mapping = (
block_ids * self.block_size +
clamped_positions % self.block_size)
block_ids * self.block_size + clamped_positions % self.block_size
)
# Mask out the slot mappings that exceed the max model length.
# Otherwise, the KV cache will be inadvertently updated with the
# padding tokens.
common_attn_metadata.slot_mapping.masked_fill_(
exceeds_max_model_len, PADDING_SLOT_ID)
exceeds_max_model_len, PADDING_SLOT_ID
)
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
.build_for_drafting(common_attn_metadata=common_attn_metadata,
draft_index=0)
attn_metadata = attn_metadata_builder.build_for_drafting(
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
)
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
# copy inputs to buffer for cudagraph
@@ -212,17 +227,23 @@ def propose(
input_ids = self.input_ids[:input_batch_size]
# Run the model.
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states = self.model(
with set_forward_context(
per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size
):
ret_hidden_states = self.model(
input_ids=input_ids,
positions=self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size],
inputs_embeds=inputs_embeds,
)
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
if self.method == "mtp":
last_hidden_states = ret_hidden_states
hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size])
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)
@@ -230,14 +251,16 @@ def propose(
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
return draft_token_ids
def prepare_next_token_ids_padded(self,
def prepare_next_token_ids_padded(
self,
common_attn_metadata: CommonAttentionMetadata,
sampled_token_ids: torch.Tensor,
requests: dict[str, CachedRequestState],
gpu_input_batch: InputBatch,
discard_request_indices: torch.Tensor,
num_discarded_requests: int) -> \
tuple[torch.Tensor, torch.Tensor]:
num_discarded_requests: int,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
This function is used to prepare the inputs for speculative decoding.
It calculates the next token ids and the number of valid sampled tokens
@@ -251,16 +274,20 @@ def prepare_next_token_ids_padded(self,
# Precompute get_token_id for when there is no valid next token
num_reqs = gpu_input_batch.num_reqs
self.backup_next_token_ids.np[:num_reqs] = np.array([
self.backup_next_token_ids.np[:num_reqs] = np.array(
[
requests[gpu_input_batch.req_ids[i]].get_token_id(
common_attn_metadata.seq_lens_cpu[i].item())
common_attn_metadata.seq_lens_cpu[i].item()
)
for i in range(num_reqs)
])
]
)
self.backup_next_token_ids.copy_to_gpu(num_reqs)
# Mask out the sampled tokens indices that should not be sampled.
discard_sampled_tokens_req_indices = \
discard_request_indices[:num_discarded_requests]
discard_sampled_tokens_req_indices = discard_request_indices[
:num_discarded_requests
]
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
# valid_sampled_token_ids_gpu.index_fill_(
@@ -280,12 +307,11 @@ def prepare_next_token_ids_padded(self,
# Generate a mask for all valid tokens within those requests
max_gen_len = sampled_token_ids.shape[-1]
if max_gen_len == 1:
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
dtype=torch.bool)
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool)
else:
valid_mask = (
(valid_sampled_token_ids_gpu != -1) &
(valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size))
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
)
# Count the number of valid tokens in each request
valid_sampled_tokens_count = valid_mask.sum(dim=1)
@@ -297,16 +323,19 @@ def prepare_next_token_ids_padded(self,
# Get last valid token from each row
# (assume undefined state where there is no valid token)
selected_tokens = torch.gather(
valid_sampled_token_ids_gpu, 1,
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
).squeeze(1)
# Use last token if valid, pre-computed backup if not
batch_size = valid_sampled_token_ids_gpu.shape[0]
next_token_ids = torch.where(
last_valid_indices != -1, selected_tokens,
self.backup_next_token_ids.gpu[:batch_size])
last_valid_indices != -1,
selected_tokens,
self.backup_next_token_ids.gpu[:batch_size],
)
return next_token_ids, valid_sampled_tokens_count
EagleProposer.propose = propose
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded