clean pr for ds.2 mtp support (#164)
* Add MTP support in eagle.py Signed-off-by: wanghao129 <wanghao129@baidu.com> * new pr for mtp Signed-off-by: wanghao129 <wanghao129@baidu.com> * Revert formatting changes in deepseek_v2.py Signed-off-by: wanghao129 <wanghao129@baidu.com> --------- Signed-off-by: wanghao129 <wanghao129@baidu.com> Co-authored-by: wanghao129 <wanghao129@baidu.com>
This commit is contained in:
@@ -15,20 +15,22 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
# embedding
|
||||
import vllm_kunlun.ops.rotary_embedding
|
||||
import vllm_kunlun.ops.vocab_parallel_embedding
|
||||
|
||||
# quantization
|
||||
import vllm_kunlun.ops.quantization.awq
|
||||
import vllm_kunlun.ops.quantization.gptq
|
||||
import vllm_kunlun.ops.quantization.moe_wna16
|
||||
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors
|
||||
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe
|
||||
import vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm
|
||||
import vllm_kunlun.ops.quantization.kernels.kunlun_exllama_linear
|
||||
import vllm_kunlun.ops.fused_moe.layer
|
||||
|
||||
# base layers
|
||||
import vllm_kunlun.ops.layernorm
|
||||
import vllm_kunlun.ops.linear
|
||||
import vllm_kunlun.ops.fused_moe.layer
|
||||
|
||||
# quantization
|
||||
import vllm_kunlun.ops.quantization.awq
|
||||
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors
|
||||
import vllm_kunlun.ops.quantization.compressed_tensors.compressed_tensors_moe
|
||||
import vllm_kunlun.ops.quantization.gptq
|
||||
import vllm_kunlun.ops.quantization.kernels.kunlun_exllama_linear
|
||||
import vllm_kunlun.ops.quantization.kernels.kunlun_scale_mm
|
||||
import vllm_kunlun.ops.quantization.moe_wna16
|
||||
|
||||
# embedding
|
||||
import vllm_kunlun.ops.rotary_embedding
|
||||
import vllm_kunlun.ops.vocab_parallel_embedding
|
||||
import vllm_kunlun.v1.sample.spec_decode.eagle
|
||||
|
||||
@@ -1,17 +1,23 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import torch
|
||||
from dataclasses import dataclass
|
||||
from typing import ClassVar, Optional
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from vllm.logger import init_logger
|
||||
from vllm.v1.attention.backends.utils import ( CommonAttentionMetadata,
|
||||
split_decodes_and_prefills)
|
||||
from vllm.v1.attention.backends.mla.indexer import (split_prefill_chunks,
|
||||
DeepseekV32IndexerMetadataBuilder,
|
||||
kv_spans_from_batches)
|
||||
from vllm.v1.attention.backends.mla.indexer import (
|
||||
DeepseekV32IndexerMetadataBuilder,
|
||||
kv_spans_from_batches,
|
||||
split_prefill_chunks,
|
||||
)
|
||||
from vllm.v1.attention.backends.utils import (
|
||||
CommonAttentionMetadata,
|
||||
split_decodes_and_prefills,
|
||||
)
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
block_table: torch.Tensor
|
||||
@@ -32,6 +38,7 @@ class DeepseekV32IndexerPrefillChunkMetadata:
|
||||
class DeepseekV32IndexerPrefillMetadata:
|
||||
chunks: list[DeepseekV32IndexerPrefillChunkMetadata]
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeepSeekV32IndexerDecodeMetadata:
|
||||
block_table: torch.Tensor
|
||||
@@ -70,26 +77,36 @@ class DeepseekV32IndexerMetadata:
|
||||
decode: Optional[DeepSeekV32IndexerDecodeMetadata] = None
|
||||
prefill: Optional[DeepseekV32IndexerPrefillMetadata] = None
|
||||
|
||||
def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end,
|
||||
query_start_loc_cpu, seq_lens_cpu,
|
||||
block_table):
|
||||
prefill_query_start_loc = query_start_loc_cpu[
|
||||
reqs_start:reqs_end + 1] - query_start_loc_cpu[reqs_start]
|
||||
|
||||
def kunlun_build_one_prefill_chunk(
|
||||
self, reqs_start, reqs_end, query_start_loc_cpu, seq_lens_cpu, block_table
|
||||
):
|
||||
prefill_query_start_loc = (
|
||||
query_start_loc_cpu[reqs_start : reqs_end + 1] - query_start_loc_cpu[reqs_start]
|
||||
)
|
||||
cu_seqlen_ks, cu_seqlen_ke = kv_spans_from_batches(
|
||||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end],
|
||||
self.device)
|
||||
prefill_query_start_loc, seq_lens_cpu[reqs_start:reqs_end], self.device
|
||||
)
|
||||
token_start = query_start_loc_cpu[reqs_start].item()
|
||||
token_end = query_start_loc_cpu[reqs_end].item()
|
||||
total_seq_lens = seq_lens_cpu[reqs_start:reqs_end].sum()
|
||||
assert total_seq_lens <= self.max_prefill_buffer_size
|
||||
cu_seq_lens = torch.cat([
|
||||
torch.zeros(1, dtype=torch.int32),
|
||||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0)
|
||||
]).to(torch.int32).to(self.device)
|
||||
cu_seq_lens = (
|
||||
torch.cat(
|
||||
[
|
||||
torch.zeros(1, dtype=torch.int32),
|
||||
seq_lens_cpu[reqs_start:reqs_end].cumsum(dim=0),
|
||||
]
|
||||
)
|
||||
.to(torch.int32)
|
||||
.to(self.device)
|
||||
)
|
||||
seq_len_q = token_end - token_start
|
||||
seq_len_kv = total_seq_lens
|
||||
context_q_lens = torch.tensor([0, seq_len_q], dtype=torch.int32, device=self.device)
|
||||
context_k_lens = torch.tensor([0, seq_len_kv], dtype=torch.int32, device=self.device)
|
||||
context_k_lens = torch.tensor(
|
||||
[0, seq_len_kv], dtype=torch.int32, device=self.device
|
||||
)
|
||||
context_q_lens_cpu = torch.tensor([0, seq_len_q], dtype=torch.int32, device="cpu")
|
||||
context_k_lens_cpu = torch.tensor([0, seq_len_kv], dtype=torch.int32, device="cpu")
|
||||
|
||||
@@ -107,85 +124,103 @@ def kunlun_build_one_prefill_chunk(self, reqs_start, reqs_end,
|
||||
context_k_lens=context_k_lens,
|
||||
context_k_lens_cpu=context_k_lens_cpu,
|
||||
)
|
||||
def kunlun_build(self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False) -> DeepseekV32IndexerMetadata:
|
||||
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = \
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata,
|
||||
decode_threshold=self.reorder_batch_threshold)
|
||||
def kunlun_build(
|
||||
self,
|
||||
common_prefix_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
fast_build: bool = False,
|
||||
) -> DeepseekV32IndexerMetadata:
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
num_reqs = common_attn_metadata.num_reqs
|
||||
num_tokens = common_attn_metadata.num_actual_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
|
||||
split_decodes_and_prefills(
|
||||
common_attn_metadata, decode_threshold=self.reorder_batch_threshold
|
||||
)
|
||||
)
|
||||
|
||||
assert num_decodes + num_prefills == num_reqs
|
||||
assert num_decode_tokens + num_prefill_tokens == num_tokens
|
||||
|
||||
prefill_metadata = None
|
||||
if num_prefills > 0:
|
||||
chunk_seq_ids = split_prefill_chunks(
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
self.max_prefill_buffer_size,
|
||||
num_decodes,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start,
|
||||
reqs_end,
|
||||
query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
self.max_prefill_buffer_size,
|
||||
num_decodes,
|
||||
common_attn_metadata.block_table_tensor,
|
||||
)
|
||||
chunks = [
|
||||
self.build_one_prefill_chunk(
|
||||
reqs_start, reqs_end, query_start_loc_cpu,
|
||||
common_attn_metadata.seq_lens_cpu,
|
||||
common_attn_metadata.block_table_tensor)
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks, )
|
||||
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
torch.diff(common_attn_metadata.query_start_loc[:num_decodes + 1],
|
||||
out=self.decode_lens_buffer[:num_decodes])
|
||||
decode_lens = self.decode_lens_buffer[:num_decodes]
|
||||
decode_lens_cpu = torch.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[:num_decodes + 1])
|
||||
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max()
|
||||
> decode_lens_cpu.min()).item()
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.
|
||||
block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(),
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
schedule_metadata=self.scheduler_metadata_buffer,
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens.cpu(),
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
head_dim=128,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
for reqs_start, reqs_end in chunk_seq_ids
|
||||
]
|
||||
prefill_metadata = DeepseekV32IndexerPrefillMetadata(
|
||||
chunks=chunks,
|
||||
)
|
||||
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||
return attn_metadata
|
||||
decode_metadata = None
|
||||
if num_decodes > 0:
|
||||
torch.diff(
|
||||
common_attn_metadata.query_start_loc[: num_decodes + 1],
|
||||
out=self.decode_lens_buffer[:num_decodes],
|
||||
)
|
||||
decode_lens = self.decode_lens_buffer[:num_decodes]
|
||||
decode_lens_cpu = torch.diff(
|
||||
common_attn_metadata.query_start_loc_cpu[: num_decodes + 1]
|
||||
)
|
||||
|
||||
DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk= kunlun_build_one_prefill_chunk
|
||||
DeepseekV32IndexerMetadataBuilder.build = kunlun_build
|
||||
# Use CPU to avoid GPU sync; breaking async scheduling
|
||||
requires_padding = (decode_lens_cpu.max() > decode_lens_cpu.min()).item()
|
||||
|
||||
seq_lens = common_attn_metadata.seq_lens[:num_decodes]
|
||||
|
||||
decode_metadata = DeepSeekV32IndexerDecodeMetadata(
|
||||
block_table=common_attn_metadata.block_table_tensor[:num_decodes, ...],
|
||||
seq_lens=common_attn_metadata.seq_lens[:num_decodes],
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens[:num_decodes].cpu(),
|
||||
decode_lens=decode_lens,
|
||||
requires_padding=requires_padding,
|
||||
schedule_metadata=self.scheduler_metadata_buffer,
|
||||
)
|
||||
|
||||
attn_metadata = DeepseekV32IndexerMetadata(
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens.cpu(),
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
max_query_len=common_attn_metadata.max_query_len,
|
||||
max_seq_len=common_attn_metadata.max_seq_len,
|
||||
num_actual_tokens=common_attn_metadata.num_actual_tokens,
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
head_dim=128,
|
||||
num_decodes=num_decodes,
|
||||
num_decode_tokens=num_decode_tokens,
|
||||
num_prefills=num_prefills,
|
||||
num_prefill_tokens=num_prefill_tokens,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
)
|
||||
|
||||
# if get_tensor_model_parallel_rank() == 0:
|
||||
# logger.info(f"attn_metadata: {attn_metadata}")
|
||||
return attn_metadata
|
||||
|
||||
|
||||
DeepseekV32IndexerMetadataBuilder.build_one_prefill_chunk = (
|
||||
kunlun_build_one_prefill_chunk
|
||||
)
|
||||
DeepseekV32IndexerMetadataBuilder.build = kunlun_build
|
||||
|
||||
# Monkey patch: Upgrade cudagraph_support to UNIFORM_BATCH for spec-decode compatibility
|
||||
from vllm.v1.attention.backends.utils import AttentionCGSupport
|
||||
|
||||
DeepseekV32IndexerMetadataBuilder.cudagraph_support = AttentionCGSupport.UNIFORM_BATCH
|
||||
|
||||
@@ -1,26 +1,18 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
||||
import ast
|
||||
from typing import Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from vllm.attention.layer import Attention
|
||||
from vllm.forward_context import set_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.platforms import current_platform
|
||||
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
|
||||
from vllm.v1.attention.backends.rocm_aiter_fa import (
|
||||
AiterFlashAttentionMetadata)
|
||||
from vllm.v1.attention.backends.tree_attn import (TreeAttentionMetadata,
|
||||
TreeAttentionMetadataBuilder)
|
||||
from vllm.v1.attention.backends.triton_attn import TritonAttentionMetadata
|
||||
from vllm.v1.attention.backends.tree_attn import TreeAttentionMetadata
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.eagle import EagleProposer
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
from vllm.v1.worker.ubatching import dbo_current_ubatch_id
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
@@ -37,47 +29,65 @@ def propose(
|
||||
target_hidden_states: torch.Tensor,
|
||||
# [batch_size]
|
||||
next_token_ids: torch.Tensor,
|
||||
last_token_indices: Optional[torch.Tensor],
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampling_metadata: SamplingMetadata,
|
||||
mm_embeds: Optional[list[torch.Tensor]] = None,
|
||||
) -> torch.Tensor:
|
||||
num_tokens = target_token_ids.shape[0]
|
||||
batch_size = next_token_ids.shape[0]
|
||||
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||
|
||||
if last_token_indices is None:
|
||||
last_token_indices = common_attn_metadata.query_start_loc[1:] - 1
|
||||
|
||||
if self.method == "eagle3":
|
||||
assert isinstance(self.model, Eagle3LlamaForCausalLM)
|
||||
target_hidden_states = self.model.combine_hidden_states(
|
||||
target_hidden_states)
|
||||
target_hidden_states = self.model.combine_hidden_states(target_hidden_states)
|
||||
assert target_hidden_states.shape[-1] == self.hidden_size
|
||||
|
||||
# Shift the input ids by one token.
|
||||
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
|
||||
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
|
||||
self.input_ids[: num_tokens - 1] = target_token_ids[1:]
|
||||
# Replace the last token with the next token.
|
||||
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
|
||||
self.input_ids[last_token_indices] = next_token_ids
|
||||
|
||||
assert self.runner is not None
|
||||
|
||||
# FIXME: need to consider multiple kv_cache_groups
|
||||
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
|
||||
.build_for_drafting(common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0)
|
||||
if attn_metadata.decode is not None and attn_metadata.decode.spec_num_seq_len is not None:
|
||||
attn_metadata.decode.spec_num_seq_len = -1
|
||||
ubatch_id = dbo_current_ubatch_id()
|
||||
attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builders[ubatch_id]
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=0
|
||||
)
|
||||
if (
|
||||
hasattr(attn_metadata, "decode")
|
||||
and attn_metadata.decode is not None
|
||||
and hasattr(attn_metadata.decode, "spec_num_seq_len")
|
||||
and attn_metadata.decode.spec_num_seq_len is not None
|
||||
):
|
||||
attn_metadata.decode.spec_num_seq_len = -1
|
||||
|
||||
if self.draft_indexer_metadata_builder:
|
||||
draft_indexer_metadata = self.draft_indexer_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0,
|
||||
)
|
||||
else:
|
||||
draft_indexer_metadata = None
|
||||
|
||||
# At this moment, we assume all eagle layers belong to the same KV
|
||||
# cache group, thus using the same attention metadata.
|
||||
per_layer_attn_metadata = {}
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
if self.use_cuda_graph and \
|
||||
num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
for layer_name in self.indexer_layer_names:
|
||||
assert draft_indexer_metadata is not None
|
||||
per_layer_attn_metadata[layer_name] = draft_indexer_metadata
|
||||
if self.use_cuda_graph and num_tokens <= self.cudagraph_batch_sizes[-1]:
|
||||
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
|
||||
else:
|
||||
num_input_tokens = num_tokens
|
||||
|
||||
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.positions[:num_tokens] = target_positions
|
||||
self.hidden_states[:num_tokens] = target_hidden_states
|
||||
@@ -94,22 +104,25 @@ def propose(
|
||||
inputs_embeds = None
|
||||
input_ids = self.input_ids[:num_input_tokens]
|
||||
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=num_input_tokens):
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata, self.vllm_config, num_tokens=num_input_tokens
|
||||
):
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.positions[:num_input_tokens],
|
||||
hidden_states=self.hidden_states[:num_input_tokens],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
if self.method == "deepseek_mtp":
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = self.hidden_states[:num_input_tokens]
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
sample_hidden_states = last_hidden_states[last_token_indices]
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
if self.method == "mtp":
|
||||
logits = self.model.compute_logits(sample_hidden_states, 0)
|
||||
else:
|
||||
logits = self.model.compute_logits(sample_hidden_states, None)
|
||||
positions = target_positions[last_token_indices]
|
||||
hidden_states = hidden_states[last_token_indices]
|
||||
|
||||
@@ -136,21 +149,22 @@ def propose(
|
||||
# one layer. Adapt this code to support multiple layers once
|
||||
# there's a multi-layer MTP module.
|
||||
|
||||
|
||||
# Generate the remaining draft tokens.
|
||||
draft_token_ids_list = [draft_token_ids]
|
||||
if self.use_cuda_graph and \
|
||||
batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||
if self.use_cuda_graph and batch_size <= self.cudagraph_batch_sizes[-1]:
|
||||
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
|
||||
else:
|
||||
input_batch_size = batch_size
|
||||
|
||||
common_attn_metadata.num_actual_tokens = batch_size
|
||||
common_attn_metadata.max_query_len = 1
|
||||
common_attn_metadata.query_start_loc = self.arange[:batch_size + 1].to(torch.int32)
|
||||
common_attn_metadata.query_start_loc_cpu = torch.from_numpy(
|
||||
self.token_arange_np[:batch_size + 1]).clone().to(torch.int32)
|
||||
for _ in range(self.num_speculative_tokens - 1):
|
||||
common_attn_metadata.query_start_loc = self.arange[: batch_size + 1].to(torch.int32)
|
||||
common_attn_metadata.query_start_loc_cpu = (
|
||||
torch.from_numpy(self.token_arange_np[: batch_size + 1]).clone().to(torch.int32)
|
||||
)
|
||||
|
||||
attn_metadata_builder = self.runner.attn_groups[0][0].metadata_builder
|
||||
for token_index in range(self.num_speculative_tokens - 1):
|
||||
# Update the inputs.
|
||||
# cast to int32 is crucial when eagle model is compiled.
|
||||
# tensor.argmax() returns int64 by default.
|
||||
@@ -166,38 +180,39 @@ def propose(
|
||||
exceeds_max_model_len = positions >= self.max_model_len
|
||||
# Mask out the position ids that exceed the max model length.
|
||||
# Otherwise, we may get out-of-range error in RoPE.
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0,
|
||||
positions)
|
||||
clamped_positions = torch.where(exceeds_max_model_len, 0, positions)
|
||||
|
||||
# Increment the sequence lengths.
|
||||
common_attn_metadata.seq_lens += 1
|
||||
common_attn_metadata.seq_lens_cpu += 1
|
||||
# For the requests that exceed the max model length, we set the
|
||||
# sequence length to 1 to minimize their overheads in attention.
|
||||
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len,
|
||||
1)
|
||||
common_attn_metadata.num_computed_tokens_cpu = \
|
||||
common_attn_metadata.seq_lens_cpu - 1
|
||||
common_attn_metadata.seq_lens.masked_fill_(exceeds_max_model_len, 1)
|
||||
common_attn_metadata.num_computed_tokens_cpu = (
|
||||
common_attn_metadata.seq_lens_cpu - 1
|
||||
)
|
||||
|
||||
# Compute the slot mapping.
|
||||
block_numbers = clamped_positions // self.block_size
|
||||
block_ids = common_attn_metadata.block_table_tensor.gather(
|
||||
dim=1, index=block_numbers.view(-1, 1))
|
||||
dim=1, index=block_numbers.view(-1, 1)
|
||||
)
|
||||
block_ids = block_ids.view(-1)
|
||||
common_attn_metadata.slot_mapping = (
|
||||
block_ids * self.block_size +
|
||||
clamped_positions % self.block_size)
|
||||
block_ids * self.block_size + clamped_positions % self.block_size
|
||||
)
|
||||
# Mask out the slot mappings that exceed the max model length.
|
||||
# Otherwise, the KV cache will be inadvertently updated with the
|
||||
# padding tokens.
|
||||
common_attn_metadata.slot_mapping.masked_fill_(
|
||||
exceeds_max_model_len, PADDING_SLOT_ID)
|
||||
exceeds_max_model_len, PADDING_SLOT_ID
|
||||
)
|
||||
|
||||
attn_metadata = self.runner.attn_groups[0][0].metadata_builder\
|
||||
.build_for_drafting(common_attn_metadata=common_attn_metadata,
|
||||
draft_index=0)
|
||||
attn_metadata = attn_metadata_builder.build_for_drafting(
|
||||
common_attn_metadata=common_attn_metadata, draft_index=token_index + 1
|
||||
)
|
||||
for layer_name in self.attn_layer_names:
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
per_layer_attn_metadata[layer_name] = attn_metadata
|
||||
# copy inputs to buffer for cudagraph
|
||||
self.input_ids[:batch_size] = input_ids
|
||||
self.positions[:batch_size] = clamped_positions
|
||||
@@ -212,17 +227,23 @@ def propose(
|
||||
input_ids = self.input_ids[:input_batch_size]
|
||||
|
||||
# Run the model.
|
||||
with set_forward_context(per_layer_attn_metadata,
|
||||
self.vllm_config,
|
||||
num_tokens=input_batch_size):
|
||||
last_hidden_states = self.model(
|
||||
with set_forward_context(
|
||||
per_layer_attn_metadata, self.vllm_config, num_tokens=input_batch_size
|
||||
):
|
||||
ret_hidden_states = self.model(
|
||||
input_ids=input_ids,
|
||||
positions=self.positions[:input_batch_size],
|
||||
hidden_states=self.hidden_states[:input_batch_size],
|
||||
inputs_embeds=inputs_embeds,
|
||||
)
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size],
|
||||
None)
|
||||
if self.method == "mtp":
|
||||
last_hidden_states = ret_hidden_states
|
||||
hidden_states = ret_hidden_states
|
||||
else:
|
||||
last_hidden_states, hidden_states = ret_hidden_states
|
||||
|
||||
hidden_states = hidden_states[:batch_size]
|
||||
logits = self.model.compute_logits(last_hidden_states[:batch_size])
|
||||
draft_token_ids = logits.argmax(dim=-1)
|
||||
draft_token_ids_list.append(draft_token_ids)
|
||||
|
||||
@@ -230,83 +251,91 @@ def propose(
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
|
||||
def prepare_next_token_ids_padded(self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
discard_request_indices: torch.Tensor,
|
||||
num_discarded_requests: int) -> \
|
||||
tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It calculates the next token ids and the number of valid sampled tokens
|
||||
for each request, considering the "discarded" requests whose next token
|
||||
is not sampled and comes from `request.get_token_id()` instead.
|
||||
It also accounts for the rejected tokens in `sampled_token_ids`.
|
||||
This function must use device functions to operate on the inputs, and
|
||||
should not introduce any blocking CPU-GPU synchronization.
|
||||
"""
|
||||
# TODO(Ben): Combine this into a custom fused kernel
|
||||
|
||||
# Precompute get_token_id for when there is no valid next token
|
||||
num_reqs = gpu_input_batch.num_reqs
|
||||
self.backup_next_token_ids.np[:num_reqs] = np.array([
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
discard_request_indices: torch.Tensor,
|
||||
num_discarded_requests: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It calculates the next token ids and the number of valid sampled tokens
|
||||
for each request, considering the "discarded" requests whose next token
|
||||
is not sampled and comes from `request.get_token_id()` instead.
|
||||
It also accounts for the rejected tokens in `sampled_token_ids`.
|
||||
This function must use device functions to operate on the inputs, and
|
||||
should not introduce any blocking CPU-GPU synchronization.
|
||||
"""
|
||||
# TODO(Ben): Combine this into a custom fused kernel
|
||||
|
||||
# Precompute get_token_id for when there is no valid next token
|
||||
num_reqs = gpu_input_batch.num_reqs
|
||||
self.backup_next_token_ids.np[:num_reqs] = np.array(
|
||||
[
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
||||
common_attn_metadata.seq_lens_cpu[i].item())
|
||||
common_attn_metadata.seq_lens_cpu[i].item()
|
||||
)
|
||||
for i in range(num_reqs)
|
||||
])
|
||||
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
||||
]
|
||||
)
|
||||
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
||||
|
||||
# Mask out the sampled tokens indices that should not be sampled.
|
||||
discard_sampled_tokens_req_indices = \
|
||||
discard_request_indices[:num_discarded_requests]
|
||||
# Mask out the sampled tokens indices that should not be sampled.
|
||||
discard_sampled_tokens_req_indices = discard_request_indices[
|
||||
:num_discarded_requests
|
||||
]
|
||||
|
||||
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
|
||||
# valid_sampled_token_ids_gpu.index_fill_(
|
||||
# 0, discard_sampled_tokens_req_indices, -1)
|
||||
# ---- FIX START ----
|
||||
# XPU/XMLIR index_fill_ does NOT accept empty index tensor.
|
||||
if num_discarded_requests > 0:
|
||||
# make sure index is on same device and is int64
|
||||
idx = discard_sampled_tokens_req_indices
|
||||
if idx.device != valid_sampled_token_ids_gpu.device:
|
||||
idx = idx.to(valid_sampled_token_ids_gpu.device, non_blocking=True)
|
||||
if idx.dtype != torch.long:
|
||||
idx = idx.to(torch.long)
|
||||
if idx.numel() > 0:
|
||||
valid_sampled_token_ids_gpu.index_fill_(0, idx, -1)
|
||||
# ---- FIX END ----
|
||||
# Generate a mask for all valid tokens within those requests
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu,
|
||||
dtype=torch.bool)
|
||||
else:
|
||||
valid_mask = (
|
||||
(valid_sampled_token_ids_gpu != -1) &
|
||||
(valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size))
|
||||
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
|
||||
# valid_sampled_token_ids_gpu.index_fill_(
|
||||
# 0, discard_sampled_tokens_req_indices, -1)
|
||||
# ---- FIX START ----
|
||||
# XPU/XMLIR index_fill_ does NOT accept empty index tensor.
|
||||
if num_discarded_requests > 0:
|
||||
# make sure index is on same device and is int64
|
||||
idx = discard_sampled_tokens_req_indices
|
||||
if idx.device != valid_sampled_token_ids_gpu.device:
|
||||
idx = idx.to(valid_sampled_token_ids_gpu.device, non_blocking=True)
|
||||
if idx.dtype != torch.long:
|
||||
idx = idx.to(torch.long)
|
||||
if idx.numel() > 0:
|
||||
valid_sampled_token_ids_gpu.index_fill_(0, idx, -1)
|
||||
# ---- FIX END ----
|
||||
# Generate a mask for all valid tokens within those requests
|
||||
max_gen_len = sampled_token_ids.shape[-1]
|
||||
if max_gen_len == 1:
|
||||
valid_mask = torch.ones_like(valid_sampled_token_ids_gpu, dtype=torch.bool)
|
||||
else:
|
||||
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
|
||||
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size
|
||||
)
|
||||
|
||||
# Count the number of valid tokens in each request
|
||||
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
||||
# Count the number of valid tokens in each request
|
||||
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
||||
|
||||
# Get the rightmost valid index per row
|
||||
last_valid_indices = valid_sampled_tokens_count - 1
|
||||
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
|
||||
# Get the rightmost valid index per row
|
||||
last_valid_indices = valid_sampled_tokens_count - 1
|
||||
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
|
||||
|
||||
# Get last valid token from each row
|
||||
# (assume undefined state where there is no valid token)
|
||||
selected_tokens = torch.gather(
|
||||
valid_sampled_token_ids_gpu, 1,
|
||||
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
|
||||
# Get last valid token from each row
|
||||
# (assume undefined state where there is no valid token)
|
||||
selected_tokens = torch.gather(
|
||||
valid_sampled_token_ids_gpu, 1, last_valid_indices_safe.unsqueeze(1)
|
||||
).squeeze(1)
|
||||
|
||||
# Use last token if valid, pre-computed backup if not
|
||||
batch_size = valid_sampled_token_ids_gpu.shape[0]
|
||||
next_token_ids = torch.where(
|
||||
last_valid_indices != -1, selected_tokens,
|
||||
self.backup_next_token_ids.gpu[:batch_size])
|
||||
# Use last token if valid, pre-computed backup if not
|
||||
batch_size = valid_sampled_token_ids_gpu.shape[0]
|
||||
next_token_ids = torch.where(
|
||||
last_valid_indices != -1,
|
||||
selected_tokens,
|
||||
self.backup_next_token_ids.gpu[:batch_size],
|
||||
)
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
|
||||
EagleProposer.propose = propose
|
||||
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded
|
||||
EagleProposer.prepare_next_token_ids_padded = prepare_next_token_ids_padded
|
||||
|
||||
Reference in New Issue
Block a user