clean pr for ds.2 mtp support (#164)
* Add MTP support in eagle.py Signed-off-by: wanghao129 <wanghao129@baidu.com> * new pr for mtp Signed-off-by: wanghao129 <wanghao129@baidu.com> * Revert formatting changes in deepseek_v2.py Signed-off-by: wanghao129 <wanghao129@baidu.com> --------- Signed-off-by: wanghao129 <wanghao129@baidu.com> Co-authored-by: wanghao129 <wanghao129@baidu.com>
This commit is contained in:
@@ -15,20 +15,22 @@
|
|||||||
# This file is a part of the vllm-ascend project.
|
# This file is a part of the vllm-ascend project.
|
||||||
#
|
#
|
||||||
|
|
||||||
# embedding
|
import vllm_kunlun.ops.fused_moe.layer
|
||||||
import vllm_kunlun.ops.rotary_embedding
|
|
||||||
import vllm_kunlun.ops.vocab_parallel_embedding
|
|
||||||
|
|
||||||
# quantization
|
|
||||||
import vllm_kunlun.ops.quantization.awq
|
|
||||||
import vllm_kunlun.ops.quantization.gptq
|
|
||||||
import vllm_kunlun.ops.quantization.moe_wna16
|
|
||||||
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors
|
|
||||||
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe
|
|
||||||
import vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm
|
|
||||||
import vllm_kunlun.ops.quantization.kernels.kunlun_exllama_linear
|
|
||||||
|
|
||||||
# base layers
|
# base layers
|
||||||
import vllm_kunlun.ops.layernorm
|
import vllm_kunlun.ops.layernorm
|
||||||
import vllm_kunlun.ops.linear
|
import vllm_kunlun.ops.linear
|
||||||
import vllm_kunlun.ops.fused_moe.layer
|
|
||||||
|
# quantization
|
||||||
|
import vllm_kunlun.ops.quantization.awq
|
||||||
|
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors
|
||||||
|
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe
|
||||||
|
import vllm_kunlun.ops.quantization.gptq
|
||||||
|
import vllm_kunlun.ops.quantization.kernels.kunlun_exllama_linear
|
||||||
|
import vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm
|
||||||
|
import vllm_kunlun.ops.quantization.moe_wna16
|
||||||
|
|
||||||
|
# embedding
|
||||||
|
import vllm_kunlun.ops.rotary_embedding
|
||||||
|
import vllm_kunlun.ops.vocab_parallel_embedding
|
||||||
|
import vllm_kunlun.v1.sample.spec_decode.eagle
|
||||||
|
|||||||
@@ -1,17 +1,23 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import torch
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import ClassVar, Optional
|
from typing import Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata,
|
from vllm.v1.attention.backends.mla.indexer import (
|
||||||
split_decodes_and_prefills)
|
DeepseekV32IndexerMetadataBuilder,
|
||||||
from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
|
kv_spans_from_batches,
|
||||||
DeepseekV32IndexerMetadataBuilder,
|
split_prefill_chunks,
|
||||||
kv_spans_from_batches)
|
)
|
||||||
|
from vllm.v1.attention.backends.utils import (
|
||||||
|
CommonAttentionMetadata,
|
||||||
|
split_decodes_and_prefills,
|
||||||
|
)
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DeepseekV32IndexerPrefillChunkMetadata:
|
class DeepseekV32IndexerPrefillChunkMetadata:
|
||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
@@ -32,6 +38,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
|
|||||||
class DeepseekV32IndexerPrefillMetadata:
|
class DeepseekV32IndexerPrefillMetadata:
|
||||||
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
|
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DeepSeekV32IndexerDecodeMetadata:
|
class DeepSeekV32IndexerDecodeMetadata:
|
||||||
block_table: torch.Tensor
|
block_table: torch.Tensor
|
||||||
@@ -70,26 +77,36 @@ class DeepseekV32IndexerMetadata:
|
|||||||
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
|
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
|
||||||
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
|
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
|
||||||
|
|
||||||
def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end,
|
|
||||||
query_start_loc_cpu, seq_lens_cpu,
|
def kunlun_build_one_prefill_chunk(
|
||||||
block_table):
|
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
|
||||||
prefill_query_start_loc = query_start_loc_cpu[
|
):
|
||||||
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
|
prefill_query_start_loc = (
|
||||||
|
query_start_loc_cpu[reqs_start : reqs_end + 1] - query_start_loc_cpu[reqs_start]
|
||||||
|
)
|
||||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
|
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
|
||||||
self.device)
|
)
|
||||||
token_start = query_start_loc_cpu[reqs_start].item()
|
token_start = query_start_loc_cpu[reqs_start].item()
|
||||||
token_end = query_start_loc_cpu[reqs_end].item()
|
token_end = query_start_loc_cpu[reqs_end].item()
|
||||||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||||
assert total_seq_lens <= self.max_prefill_buffer_size
|
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||||
cu_seq_lens = torch.cat([
|
cu_seq_lens = (
|
||||||
torch.zeros(1, dtype=torch.int32),
|
torch.cat(
|
||||||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
|
[
|
||||||
]).to(torch.int32).to(self.device)
|
torch.zeros(1, dtype=torch.int32),
|
||||||
|
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
.to(torch.int32)
|
||||||
|
.to(self.device)
|
||||||
|
)
|
||||||
seq_len_q = token_end - token_start
|
seq_len_q = token_end - token_start
|
||||||
seq_len_kv = total_seq_lens
|
seq_len_kv = total_seq_lens
|
||||||
context_q_lens = torch.tensor([0, seq_len_q], dtype=torch.int32, device=self.device)
|
context_q_lens = torch.tensor([0, seq_len_q], dtype=torch.int32, device=self.device)
|
||||||
context_k_lens = torch.tensor([0, seq_len_kv], dtype=torch.int32, device=self.device)
|
context_k_lens = torch.tensor(
|
||||||
|
[0, seq_len_kv], dtype=torch.int32, device=self.device
|
||||||
|
)
|
||||||
context_q_lens_cpu = torch.tensor([0, seq_len_q], dtype=torch.int32, device="cpu")
|
context_q_lens_cpu = torch.tensor([0, seq_len_q], dtype=torch.int32, device="cpu")
|
||||||
context_k_lens_cpu = torch.tensor([0, seq_len_kv], dtype=torch.int32, device="cpu")
|
context_k_lens_cpu = torch.tensor([0, seq_len_kv], dtype=torch.int32, device="cpu")
|
||||||
|
|
||||||
@@ -107,85 +124,103 @@ def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end,
|
|||||||
context_k_lens=context_k_lens,
|
context_k_lens=context_k_lens,
|
||||||
context_k_lens_cpu=context_k_lens_cpu,
|
context_k_lens_cpu=context_k_lens_cpu,
|
||||||
)
|
)
|
||||||
def kunlun_build(self,
|
|
||||||
common_prefix_len: int,
|
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
|
||||||
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
|
|
||||||
|
|
||||||
num_reqs = common_attn_metadata.num_reqs
|
|
||||||
num_tokens = common_attn_metadata.num_actual_tokens
|
|
||||||
|
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
def kunlun_build(
|
||||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
self,
|
||||||
split_decodes_and_prefills(
|
common_prefix_len: int,
|
||||||
common_attn_metadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
decode_threshold=self.reorder_batch_threshold)
|
fast_build: bool = False,
|
||||||
|
) -> DeepseekV32IndexerMetadata:
|
||||||
|
|
||||||
assert num_decodes + num_prefills == num_reqs
|
num_reqs = common_attn_metadata.num_reqs
|
||||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
num_tokens = common_attn_metadata.num_actual_tokens
|
||||||
|
|
||||||
prefill_metadata = None
|
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||||
if num_prefills > 0:
|
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||||
chunk_seq_ids = split_prefill_chunks(
|
split_decodes_and_prefills(
|
||||||
|
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert num_decodes + num_prefills == num_reqs
|
||||||
|
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||||
|
|
||||||
|
prefill_metadata = None
|
||||||
|
if num_prefills > 0:
|
||||||
|
chunk_seq_ids = split_prefill_chunks(
|
||||||
|
common_attn_metadata.seq_lens_cpu,
|
||||||
|
self.max_prefill_buffer_size,
|
||||||
|
num_decodes,
|
||||||
|
)
|
||||||
|
chunks = [
|
||||||
|
self.build_one_prefill_chunk(
|
||||||
|
reqs_start,
|
||||||
|
reqs_end,
|
||||||
|
query_start_loc_cpu,
|
||||||
common_attn_metadata.seq_lens_cpu,
|
common_attn_metadata.seq_lens_cpu,
|
||||||
self.max_prefill_buffer_size,
|
common_attn_metadata.block_table_tensor,
|
||||||
num_decodes,
|
|
||||||
)
|
)
|
||||||
chunks = [
|
for reqs_start, reqs_end in chunk_seq_ids
|
||||||
self.build_one_prefill_chunk(
|
]
|
||||||
reqs_start, reqs_end, query_start_loc_cpu,
|
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||||
common_attn_metadata.seq_lens_cpu,
|
chunks=chunks,
|
||||||
common_attn_metadata.block_table_tensor)
|
|
||||||
for reqs_start, reqs_end in chunk_seq_ids
|
|
||||||
]
|
|
||||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
|
||||||
chunks=chunks, )
|
|
||||||
|
|
||||||
decode_metadata = None
|
|
||||||
if num_decodes > 0:
|
|
||||||
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
|
|
||||||
out=self.decode_lens_buffer[:num_decodes])
|
|
||||||
decode_lens = self.decode_lens_buffer[:num_decodes]
|
|
||||||
decode_lens_cpu = torch.diff(
|
|
||||||
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
|
|
||||||
|
|
||||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
|
||||||
requires_padding = (decode_lens_cpu.max()
|
|
||||||
> decode_lens_cpu.min()).item()
|
|
||||||
|
|
||||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
|
||||||
|
|
||||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
|
||||||
block_table=common_attn_metadata.
|
|
||||||
block_table_tensor[:num_decodes, ...],
|
|
||||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
|
||||||
seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(),
|
|
||||||
decode_lens=decode_lens,
|
|
||||||
requires_padding=requires_padding,
|
|
||||||
schedule_metadata=self.scheduler_metadata_buffer,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_metadata = DeepseekV32IndexerMetadata(
|
|
||||||
seq_lens=common_attn_metadata.seq_lens,
|
|
||||||
seq_lens_cpu=common_attn_metadata.seq_lens.cpu(),
|
|
||||||
num_reqs=common_attn_metadata.num_reqs,
|
|
||||||
max_query_len=common_attn_metadata.max_query_len,
|
|
||||||
max_seq_len=common_attn_metadata.max_seq_len,
|
|
||||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
|
||||||
query_start_loc=common_attn_metadata.query_start_loc,
|
|
||||||
slot_mapping=common_attn_metadata.slot_mapping,
|
|
||||||
head_dim=128,
|
|
||||||
num_decodes=num_decodes,
|
|
||||||
num_decode_tokens=num_decode_tokens,
|
|
||||||
num_prefills=num_prefills,
|
|
||||||
num_prefill_tokens=num_prefill_tokens,
|
|
||||||
prefill=prefill_metadata,
|
|
||||||
decode=decode_metadata,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# if get_tensor_model_parallel_rank() == 0:
|
decode_metadata = None
|
||||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
if num_decodes > 0:
|
||||||
return attn_metadata
|
torch.diff(
|
||||||
|
common_attn_metadata.query_start_loc[: num_decodes + 1],
|
||||||
|
out=self.decode_lens_buffer[:num_decodes],
|
||||||
|
)
|
||||||
|
decode_lens = self.decode_lens_buffer[:num_decodes]
|
||||||
|
decode_lens_cpu = torch.diff(
|
||||||
|
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
||||||
|
)
|
||||||
|
|
||||||
DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk= kunlun_build_one_prefill_chunk
|
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||||
DeepseekV32IndexerMetadataBuilder.build = kunlun_build
|
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
|
||||||
|
|
||||||
|
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||||
|
|
||||||
|
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||||
|
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
|
||||||
|
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||||
|
seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(),
|
||||||
|
decode_lens=decode_lens,
|
||||||
|
requires_padding=requires_padding,
|
||||||
|
schedule_metadata=self.scheduler_metadata_buffer,
|
||||||
|
)
|
||||||
|
|
||||||
|
attn_metadata = DeepseekV32IndexerMetadata(
|
||||||
|
seq_lens=common_attn_metadata.seq_lens,
|
||||||
|
seq_lens_cpu=common_attn_metadata.seq_lens.cpu(),
|
||||||
|
num_reqs=common_attn_metadata.num_reqs,
|
||||||
|
max_query_len=common_attn_metadata.max_query_len,
|
||||||
|
max_seq_len=common_attn_metadata.max_seq_len,
|
||||||
|
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||||
|
query_start_loc=common_attn_metadata.query_start_loc,
|
||||||
|
slot_mapping=common_attn_metadata.slot_mapping,
|
||||||
|
head_dim=128,
|
||||||
|
num_decodes=num_decodes,
|
||||||
|
num_decode_tokens=num_decode_tokens,
|
||||||
|
num_prefills=num_prefills,
|
||||||
|
num_prefill_tokens=num_prefill_tokens,
|
||||||
|
prefill=prefill_metadata,
|
||||||
|
decode=decode_metadata,
|
||||||
|
)
|
||||||
|
|
||||||
|
# if get_tensor_model_parallel_rank() == 0:
|
||||||
|
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||||
|
return attn_metadata
|
||||||
|
|
||||||
|
|
||||||
|
DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk = (
|
||||||
|
kunlun_build_one_prefill_chunk
|
||||||
|
)
|
||||||
|
DeepseekV32IndexerMetadataBuilder.build = kunlun_build
|
||||||
|
|
||||||
|
# Monkey patch: Upgrade cudagraph_support to UNIFORM_BATCH for spec-decode compatibility
|
||||||
|
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||||
|
|
||||||
|
DeepseekV32IndexerMetadataBuilder.cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||||
|
|||||||
@@ -1,26 +1,18 @@
|
|||||||
# SPDX-License-Identifier: Apache-2.0
|
# SPDX-License-Identifier: Apache-2.0
|
||||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||||
import ast
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
|
||||||
|
|
||||||
from vllm.attention.layer import Attention
|
|
||||||
from vllm.forward_context import set_forward_context
|
from vllm.forward_context import set_forward_context
|
||||||
from vllm.logger import init_logger
|
from vllm.logger import init_logger
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.platforms import current_platform
|
from vllm.v1.attention.backends.tree_attn import TreeAttentionMetadata
|
||||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
|
||||||
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
|
||||||
AiterFlashAttentionMetadata)
|
|
||||||
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
|
||||||
TreeAttentionMetadataBuilder)
|
|
||||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||||
|
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||||
|
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
logger = init_logger(__name__)
|
||||||
|
|
||||||
@@ -37,47 +29,65 @@ def propose(
|
|||||||
target_hidden_states: torch.Tensor,
|
target_hidden_states: torch.Tensor,
|
||||||
# [batch_size]
|
# [batch_size]
|
||||||
next_token_ids: torch.Tensor,
|
next_token_ids: torch.Tensor,
|
||||||
|
last_token_indices: Optional[torch.Tensor],
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
sampling_metadata: SamplingMetadata,
|
sampling_metadata: SamplingMetadata,
|
||||||
mm_embeds: Optional[list[torch.Tensor]] = None,
|
mm_embeds: Optional[list[torch.Tensor]] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
num_tokens = target_token_ids.shape[0]
|
num_tokens = target_token_ids.shape[0]
|
||||||
batch_size = next_token_ids.shape[0]
|
batch_size = next_token_ids.shape[0]
|
||||||
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
|
||||||
|
if last_token_indices is None:
|
||||||
|
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||||
|
|
||||||
if self.method == "eagle3":
|
if self.method == "eagle3":
|
||||||
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
||||||
target_hidden_states = self.model.combine_hidden_states(
|
target_hidden_states = self.model.combine_hidden_states(target_hidden_states)
|
||||||
target_hidden_states)
|
|
||||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||||
|
|
||||||
# Shift the input ids by one token.
|
# Shift the input ids by one token.
|
||||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
|
||||||
# Replace the last token with the next token.
|
# Replace the last token with the next token.
|
||||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||||
self.input_ids[last_token_indices] = next_token_ids
|
self.input_ids[last_token_indices] = next_token_ids
|
||||||
|
|
||||||
assert self.runner is not None
|
assert self.runner is not None
|
||||||
|
|
||||||
# FIXME: need to consider multiple kv_cache_groups
|
ubatch_id = dbo_current_ubatch_id()
|
||||||
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
|
attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||||
.build_for_drafting(common_attn_metadata=common_attn_metadata,
|
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||||
draft_index=0)
|
common_attn_metadata=common_attn_metadata, draft_index=0
|
||||||
if attn_metadata.decode is not None and attn_metadata.decode.spec_num_seq_len is not None:
|
)
|
||||||
attn_metadata.decode.spec_num_seq_len = -1
|
if (
|
||||||
|
hasattr(attn_metadata, "decode")
|
||||||
|
and attn_metadata.decode is not None
|
||||||
|
and hasattr(attn_metadata.decode, "spec_num_seq_len")
|
||||||
|
and attn_metadata.decode.spec_num_seq_len is not None
|
||||||
|
):
|
||||||
|
attn_metadata.decode.spec_num_seq_len = -1
|
||||||
|
|
||||||
|
if self.draft_indexer_metadata_builder:
|
||||||
|
draft_indexer_metadata = self.draft_indexer_metadata_builder.build_for_drafting(
|
||||||
|
common_attn_metadata=common_attn_metadata,
|
||||||
|
draft_index=0,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
draft_indexer_metadata = None
|
||||||
|
|
||||||
# At this moment, we assume all eagle layers belong to the same KV
|
# At this moment, we assume all eagle layers belong to the same KV
|
||||||
# cache group, thus using the same attention metadata.
|
# cache group, thus using the same attention metadata.
|
||||||
per_layer_attn_metadata = {}
|
per_layer_attn_metadata = {}
|
||||||
for layer_name in self.attn_layer_names:
|
for layer_name in self.attn_layer_names:
|
||||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||||
if self.use_cuda_graph and \
|
for layer_name in self.indexer_layer_names:
|
||||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
assert draft_indexer_metadata is not None
|
||||||
|
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
||||||
|
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||||
else:
|
else:
|
||||||
num_input_tokens = num_tokens
|
num_input_tokens = num_tokens
|
||||||
|
|
||||||
|
|
||||||
# copy inputs to buffer for cudagraph
|
# copy inputs to buffer for cudagraph
|
||||||
self.positions[:num_tokens] = target_positions
|
self.positions[:num_tokens] = target_positions
|
||||||
self.hidden_states[:num_tokens] = target_hidden_states
|
self.hidden_states[:num_tokens] = target_hidden_states
|
||||||
@@ -94,22 +104,25 @@ def propose(
|
|||||||
inputs_embeds = None
|
inputs_embeds = None
|
||||||
input_ids = self.input_ids[:num_input_tokens]
|
input_ids = self.input_ids[:num_input_tokens]
|
||||||
|
|
||||||
with set_forward_context(per_layer_attn_metadata,
|
with set_forward_context(
|
||||||
self.vllm_config,
|
per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
|
||||||
num_tokens=num_input_tokens):
|
):
|
||||||
ret_hidden_states = self.model(
|
ret_hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=self.positions[:num_input_tokens],
|
positions=self.positions[:num_input_tokens],
|
||||||
hidden_states=self.hidden_states[:num_input_tokens],
|
hidden_states=self.hidden_states[:num_input_tokens],
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
if self.method == "deepseek_mtp":
|
if self.method == "mtp":
|
||||||
last_hidden_states = ret_hidden_states
|
last_hidden_states = ret_hidden_states
|
||||||
hidden_states = self.hidden_states[:num_input_tokens]
|
hidden_states = self.hidden_states[:num_input_tokens]
|
||||||
else:
|
else:
|
||||||
last_hidden_states, hidden_states = ret_hidden_states
|
last_hidden_states, hidden_states = ret_hidden_states
|
||||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
if self.method == "mtp":
|
||||||
|
logits = self.model.compute_logits(sample_hidden_states, 0)
|
||||||
|
else:
|
||||||
|
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||||
positions = target_positions[last_token_indices]
|
positions = target_positions[last_token_indices]
|
||||||
hidden_states = hidden_states[last_token_indices]
|
hidden_states = hidden_states[last_token_indices]
|
||||||
|
|
||||||
@@ -136,21 +149,22 @@ def propose(
|
|||||||
# one layer. Adapt this code to support multiple layers once
|
# one layer. Adapt this code to support multiple layers once
|
||||||
# there's a multi-layer MTP module.
|
# there's a multi-layer MTP module.
|
||||||
|
|
||||||
|
|
||||||
# Generate the remaining draft tokens.
|
# Generate the remaining draft tokens.
|
||||||
draft_token_ids_list = [draft_token_ids]
|
draft_token_ids_list = [draft_token_ids]
|
||||||
if self.use_cuda_graph and \
|
if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
|
||||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||||
else:
|
else:
|
||||||
input_batch_size = batch_size
|
input_batch_size = batch_size
|
||||||
|
|
||||||
common_attn_metadata.num_actual_tokens = batch_size
|
common_attn_metadata.num_actual_tokens = batch_size
|
||||||
common_attn_metadata.max_query_len = 1
|
common_attn_metadata.max_query_len = 1
|
||||||
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1].to(torch.int32)
|
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1].to(torch.int32)
|
||||||
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
common_attn_metadata.query_start_loc_cpu = (
|
||||||
self.token_arange_np[:batch_size + 1]).clone().to(torch.int32)
|
torch.from_numpy(self.token_arange_np[: batch_size + 1]).clone().to(torch.int32)
|
||||||
for _ in range(self.num_speculative_tokens - 1):
|
)
|
||||||
|
|
||||||
|
attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builder
|
||||||
|
for token_index in range(self.num_speculative_tokens - 1):
|
||||||
# Update the inputs.
|
# Update the inputs.
|
||||||
# cast to int32 is crucial when eagle model is compiled.
|
# cast to int32 is crucial when eagle model is compiled.
|
||||||
# tensor.argmax() returns int64 by default.
|
# tensor.argmax() returns int64 by default.
|
||||||
@@ -166,38 +180,39 @@ def propose(
|
|||||||
exceeds_max_model_len = positions >= self.max_model_len
|
exceeds_max_model_len = positions >= self.max_model_len
|
||||||
# Mask out the position ids that exceed the max model length.
|
# Mask out the position ids that exceed the max model length.
|
||||||
# Otherwise, we may get out-of-range error in RoPE.
|
# Otherwise, we may get out-of-range error in RoPE.
|
||||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
|
||||||
positions)
|
|
||||||
|
|
||||||
# Increment the sequence lengths.
|
# Increment the sequence lengths.
|
||||||
common_attn_metadata.seq_lens += 1
|
common_attn_metadata.seq_lens += 1
|
||||||
common_attn_metadata.seq_lens_cpu += 1
|
common_attn_metadata.seq_lens_cpu += 1
|
||||||
# For the requests that exceed the max model length, we set the
|
# For the requests that exceed the max model length, we set the
|
||||||
# sequence length to 1 to minimize their overheads in attention.
|
# sequence length to 1 to minimize their overheads in attention.
|
||||||
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len,
|
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
||||||
1)
|
common_attn_metadata.num_computed_tokens_cpu = (
|
||||||
common_attn_metadata.num_computed_tokens_cpu = \
|
common_attn_metadata.seq_lens_cpu - 1
|
||||||
common_attn_metadata.seq_lens_cpu - 1
|
)
|
||||||
|
|
||||||
# Compute the slot mapping.
|
# Compute the slot mapping.
|
||||||
block_numbers = clamped_positions // self.block_size
|
block_numbers = clamped_positions // self.block_size
|
||||||
block_ids = common_attn_metadata.block_table_tensor.gather(
|
block_ids = common_attn_metadata.block_table_tensor.gather(
|
||||||
dim=1, index=block_numbers.view(-1, 1))
|
dim=1, index=block_numbers.view(-1, 1)
|
||||||
|
)
|
||||||
block_ids = block_ids.view(-1)
|
block_ids = block_ids.view(-1)
|
||||||
common_attn_metadata.slot_mapping = (
|
common_attn_metadata.slot_mapping = (
|
||||||
block_ids * self.block_size +
|
block_ids * self.block_size + clamped_positions % self.block_size
|
||||||
clamped_positions % self.block_size)
|
)
|
||||||
# Mask out the slot mappings that exceed the max model length.
|
# Mask out the slot mappings that exceed the max model length.
|
||||||
# Otherwise, the KV cache will be inadvertently updated with the
|
# Otherwise, the KV cache will be inadvertently updated with the
|
||||||
# padding tokens.
|
# padding tokens.
|
||||||
common_attn_metadata.slot_mapping.masked_fill_(
|
common_attn_metadata.slot_mapping.masked_fill_(
|
||||||
exceeds_max_model_len, PADDING_SLOT_ID)
|
exceeds_max_model_len, PADDING_SLOT_ID
|
||||||
|
)
|
||||||
|
|
||||||
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
|
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||||
.build_for_drafting(common_attn_metadata=common_attn_metadata,
|
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
|
||||||
draft_index=0)
|
)
|
||||||
for layer_name in self.attn_layer_names:
|
for layer_name in self.attn_layer_names:
|
||||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||||
# copy inputs to buffer for cudagraph
|
# copy inputs to buffer for cudagraph
|
||||||
self.input_ids[:batch_size] = input_ids
|
self.input_ids[:batch_size] = input_ids
|
||||||
self.positions[:batch_size] = clamped_positions
|
self.positions[:batch_size] = clamped_positions
|
||||||
@@ -212,17 +227,23 @@ def propose(
|
|||||||
input_ids = self.input_ids[:input_batch_size]
|
input_ids = self.input_ids[:input_batch_size]
|
||||||
|
|
||||||
# Run the model.
|
# Run the model.
|
||||||
with set_forward_context(per_layer_attn_metadata,
|
with set_forward_context(
|
||||||
self.vllm_config,
|
per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size
|
||||||
num_tokens=input_batch_size):
|
):
|
||||||
last_hidden_states = self.model(
|
ret_hidden_states = self.model(
|
||||||
input_ids=input_ids,
|
input_ids=input_ids,
|
||||||
positions=self.positions[:input_batch_size],
|
positions=self.positions[:input_batch_size],
|
||||||
hidden_states=self.hidden_states[:input_batch_size],
|
hidden_states=self.hidden_states[:input_batch_size],
|
||||||
inputs_embeds=inputs_embeds,
|
inputs_embeds=inputs_embeds,
|
||||||
)
|
)
|
||||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
if self.method == "mtp":
|
||||||
None)
|
last_hidden_states = ret_hidden_states
|
||||||
|
hidden_states = ret_hidden_states
|
||||||
|
else:
|
||||||
|
last_hidden_states, hidden_states = ret_hidden_states
|
||||||
|
|
||||||
|
hidden_states = hidden_states[:batch_size]
|
||||||
|
logits = self.model.compute_logits(last_hidden_states[:batch_size])
|
||||||
draft_token_ids = logits.argmax(dim=-1)
|
draft_token_ids = logits.argmax(dim=-1)
|
||||||
draft_token_ids_list.append(draft_token_ids)
|
draft_token_ids_list.append(draft_token_ids)
|
||||||
|
|
||||||
@@ -230,83 +251,91 @@ def propose(
|
|||||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
def prepare_next_token_ids_padded(self,
|
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
|
||||||
sampled_token_ids: torch.Tensor,
|
|
||||||
requests: dict[str, CachedRequestState],
|
|
||||||
gpu_input_batch: InputBatch,
|
|
||||||
discard_request_indices: torch.Tensor,
|
|
||||||
num_discarded_requests: int) -> \
|
|
||||||
tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
This function is used to prepare the inputs for speculative decoding.
|
|
||||||
It calculates the next token ids and the number of valid sampled tokens
|
|
||||||
for each request, considering the "discarded" requests whose next token
|
|
||||||
is not sampled and comes from `request.get_token_id()` instead.
|
|
||||||
It also accounts for the rejected tokens in `sampled_token_ids`.
|
|
||||||
This function must use device functions to operate on the inputs, and
|
|
||||||
should not introduce any blocking CPU-GPU synchronization.
|
|
||||||
"""
|
|
||||||
# TODO(Ben): Combine this into a custom fused kernel
|
|
||||||
|
|
||||||
# Precompute get_token_id for when there is no valid next token
|
def prepare_next_token_ids_padded(
|
||||||
num_reqs = gpu_input_batch.num_reqs
|
self,
|
||||||
self.backup_next_token_ids.np[:num_reqs] = np.array([
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
|
sampled_token_ids: torch.Tensor,
|
||||||
|
requests: dict[str, CachedRequestState],
|
||||||
|
gpu_input_batch: InputBatch,
|
||||||
|
discard_request_indices: torch.Tensor,
|
||||||
|
num_discarded_requests: int,
|
||||||
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
This function is used to prepare the inputs for speculative decoding.
|
||||||
|
It calculates the next token ids and the number of valid sampled tokens
|
||||||
|
for each request, considering the "discarded" requests whose next token
|
||||||
|
is not sampled and comes from `request.get_token_id()` instead.
|
||||||
|
It also accounts for the rejected tokens in `sampled_token_ids`.
|
||||||
|
This function must use device functions to operate on the inputs, and
|
||||||
|
should not introduce any blocking CPU-GPU synchronization.
|
||||||
|
"""
|
||||||
|
# TODO(Ben): Combine this into a custom fused kernel
|
||||||
|
|
||||||
|
# Precompute get_token_id for when there is no valid next token
|
||||||
|
num_reqs = gpu_input_batch.num_reqs
|
||||||
|
self.backup_next_token_ids.np[:num_reqs] = np.array(
|
||||||
|
[
|
||||||
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
||||||
common_attn_metadata.seq_lens_cpu[i].item())
|
common_attn_metadata.seq_lens_cpu[i].item()
|
||||||
|
)
|
||||||
for i in range(num_reqs)
|
for i in range(num_reqs)
|
||||||
])
|
]
|
||||||
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
)
|
||||||
|
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
||||||
|
|
||||||
# Mask out the sampled tokens indices that should not be sampled.
|
# Mask out the sampled tokens indices that should not be sampled.
|
||||||
discard_sampled_tokens_req_indices = \
|
discard_sampled_tokens_req_indices = discard_request_indices[
|
||||||
discard_request_indices[:num_discarded_requests]
|
:num_discarded_requests
|
||||||
|
]
|
||||||
|
|
||||||
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
|
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
|
||||||
# valid_sampled_token_ids_gpu.index_fill_(
|
# valid_sampled_token_ids_gpu.index_fill_(
|
||||||
# 0, discard_sampled_tokens_req_indices, -1)
|
# 0, discard_sampled_tokens_req_indices, -1)
|
||||||
# ---- FIX START ----
|
# ---- FIX START ----
|
||||||
# XPU/XMLIR index_fill_ does NOT accept empty index tensor.
|
# XPU/XMLIR index_fill_ does NOT accept empty index tensor.
|
||||||
if num_discarded_requests > 0:
|
if num_discarded_requests > 0:
|
||||||
# make sure index is on same device and is int64
|
# make sure index is on same device and is int64
|
||||||
idx = discard_sampled_tokens_req_indices
|
idx = discard_sampled_tokens_req_indices
|
||||||
if idx.device != valid_sampled_token_ids_gpu.device:
|
if idx.device != valid_sampled_token_ids_gpu.device:
|
||||||
idx = idx.to(valid_sampled_token_ids_gpu.device, non_blocking=True)
|
idx = idx.to(valid_sampled_token_ids_gpu.device, non_blocking=True)
|
||||||
if idx.dtype != torch.long:
|
if idx.dtype != torch.long:
|
||||||
idx = idx.to(torch.long)
|
idx = idx.to(torch.long)
|
||||||
if idx.numel() > 0:
|
if idx.numel() > 0:
|
||||||
valid_sampled_token_ids_gpu.index_fill_(0, idx, -1)
|
valid_sampled_token_ids_gpu.index_fill_(0, idx, -1)
|
||||||
# ---- FIX END ----
|
# ---- FIX END ----
|
||||||
# Generate a mask for all valid tokens within those requests
|
# Generate a mask for all valid tokens within those requests
|
||||||
max_gen_len = sampled_token_ids.shape[-1]
|
max_gen_len = sampled_token_ids.shape[-1]
|
||||||
if max_gen_len == 1:
|
if max_gen_len == 1:
|
||||||
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
|
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool)
|
||||||
dtype=torch.bool)
|
else:
|
||||||
else:
|
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
|
||||||
valid_mask = (
|
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
|
||||||
(valid_sampled_token_ids_gpu != -1) &
|
)
|
||||||
(valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size))
|
|
||||||
|
|
||||||
# Count the number of valid tokens in each request
|
# Count the number of valid tokens in each request
|
||||||
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
||||||
|
|
||||||
# Get the rightmost valid index per row
|
# Get the rightmost valid index per row
|
||||||
last_valid_indices = valid_sampled_tokens_count - 1
|
last_valid_indices = valid_sampled_tokens_count - 1
|
||||||
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
|
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
|
||||||
|
|
||||||
# Get last valid token from each row
|
# Get last valid token from each row
|
||||||
# (assume undefined state where there is no valid token)
|
# (assume undefined state where there is no valid token)
|
||||||
selected_tokens = torch.gather(
|
selected_tokens = torch.gather(
|
||||||
valid_sampled_token_ids_gpu, 1,
|
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
|
||||||
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
|
).squeeze(1)
|
||||||
|
|
||||||
# Use last token if valid, pre-computed backup if not
|
# Use last token if valid, pre-computed backup if not
|
||||||
batch_size = valid_sampled_token_ids_gpu.shape[0]
|
batch_size = valid_sampled_token_ids_gpu.shape[0]
|
||||||
next_token_ids = torch.where(
|
next_token_ids = torch.where(
|
||||||
last_valid_indices != -1, selected_tokens,
|
last_valid_indices != -1,
|
||||||
self.backup_next_token_ids.gpu[:batch_size])
|
selected_tokens,
|
||||||
|
self.backup_next_token_ids.gpu[:batch_size],
|
||||||
|
)
|
||||||
|
|
||||||
|
return next_token_ids, valid_sampled_tokens_count
|
||||||
|
|
||||||
return next_token_ids, valid_sampled_tokens_count
|
|
||||||
|
|
||||||
EagleProposer.propose = propose
|
EagleProposer.propose = propose
|
||||||
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded
|
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded
|
||||||
|
|||||||
Reference in New Issue
Block a user