[Refactor][EAGLE] 3/N delete redundant methods in mtp_proposer (#5420)
### What this PR does / why we need it?
This PR aims to delete redundant methods in mtp_proposer. All the
deleted methods now can be found in eagle_proposer. We also remove some
methods in eagle_proposer since they are identical to those in
vllm-eagle.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM version: release/v0.13.0
- vLLM main:
81786c8774
---------
Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -262,6 +262,7 @@ class TestMtpProposer:
|
||||
device=torch.device("cpu"))
|
||||
assert torch.equal(next_token_ids, expected_next_tokens)
|
||||
|
||||
@patch("vllm_ascend.spec_decode.eagle_proposer.HAS_TRITON", False)
|
||||
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||
def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer):
|
||||
mock_buffer_instance = MagicMock()
|
||||
|
||||
@@ -4,6 +4,7 @@ from typing import Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
||||
get_layers_from_vllm_config)
|
||||
from vllm.distributed.parallel_state import get_pp_group
|
||||
@@ -15,6 +16,7 @@ from vllm.model_executor.models import supports_multimodal
|
||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.triton_utils import HAS_TRITON, triton
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
@@ -28,7 +30,10 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
update_attn_params)
|
||||
update_attn_dcp_pcp_params,
|
||||
update_attn_params,
|
||||
update_mla_attn_dcp_pcp_params,
|
||||
update_mla_attn_params)
|
||||
from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
||||
from vllm_ascend.ops.triton.spec_decode.utils import \
|
||||
prepare_inputs_padded_kernel
|
||||
@@ -37,10 +42,6 @@ from vllm_ascend.utils import shared_expert_dp_enabled
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
|
||||
|
||||
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
||||
|
||||
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
||||
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
||||
|
||||
@@ -93,27 +94,6 @@ class EagleProposer(VllmEagleProposer):
|
||||
self.use_sparse = hasattr(vllm_config.model_config.hf_text_config,
|
||||
"index_topk")
|
||||
|
||||
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
|
||||
"""
|
||||
NOTE(2025-12-18): This is an explicit copy from vLLM EagleProposer, only added
|
||||
to align with its logics.
|
||||
|
||||
Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
|
||||
hidden states and directly uses the last layer output just like eagle1.
|
||||
They might indicate this by setting "use_aux_hidden_state" to False
|
||||
inside the "eagle_config" dict of their hf_config.
|
||||
"""
|
||||
if self.method != "eagle3":
|
||||
return False
|
||||
# Assume that eagle3 heads use aux hidden states by default
|
||||
use_aux_hidden_state = True
|
||||
eagle_config = getattr(self.draft_model_config.hf_config,
|
||||
"eagle_config", None)
|
||||
if eagle_config is not None:
|
||||
use_aux_hidden_state = eagle_config.get("use_aux_hidden_state",
|
||||
True)
|
||||
return use_aux_hidden_state
|
||||
|
||||
def load_model(self, model: nn.Module) -> None:
|
||||
target_attn_layer_names = set(
|
||||
get_layers_from_vllm_config(self.vllm_config,
|
||||
@@ -512,48 +492,6 @@ class EagleProposer(VllmEagleProposer):
|
||||
draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1)
|
||||
return draft_token_ids
|
||||
|
||||
def _get_attn_metadata(self, attn_metadata):
|
||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
||||
architecture = self.vllm_config.model_config.architecture
|
||||
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
|
||||
attn_metadata = attn_metadata[layer_name]
|
||||
|
||||
return attn_metadata
|
||||
|
||||
def prepare_next_token_ids_cpu(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It calculates the next token ids for each request based on the sampled
|
||||
token ids from the CPU. If a request has no sampled token ids (e.g.,
|
||||
during the initial decoding steps), it falls back to using the request
|
||||
state to get the next token id.
|
||||
"""
|
||||
req_ids = gpu_input_batch.req_ids
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(sampled_token_ids):
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
req_id = req_ids[i]
|
||||
req_state = requests[req_id]
|
||||
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[
|
||||
req_id]
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device)
|
||||
return next_token_ids
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
@@ -829,3 +767,92 @@ class EagleProposer(VllmEagleProposer):
|
||||
max_seq_len=0)
|
||||
|
||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||
|
||||
def _split_pcp_input(self, req_scheduled_tokens, input_ids,
|
||||
target_hidden_states):
|
||||
"""
|
||||
Split prefill input_ids and target_hidden_states in pcp group.
|
||||
1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad]
|
||||
2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5]
|
||||
3. split target_hidden_states (already include pcp padding):
|
||||
[h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5]
|
||||
4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split.
|
||||
"""
|
||||
if len(req_scheduled_tokens) == 0:
|
||||
# no prefill inputs to split, return empty result
|
||||
return (
|
||||
0,
|
||||
torch.zeros([0], device='npu'),
|
||||
torch.zeros([0, target_hidden_states.size(1)], device='npu'),
|
||||
0,
|
||||
torch.zeros([0]),
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
)
|
||||
|
||||
def _pcp_pad_and_split(num_tokens):
|
||||
num_pcp_padded_scheduled_tokens = cdiv(
|
||||
num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size
|
||||
pcp_pad = num_pcp_padded_scheduled_tokens - num_tokens
|
||||
chunk_size = num_pcp_padded_scheduled_tokens // (2 * self.pcp_size)
|
||||
|
||||
# split position_ids (and use split position_ids to split input_ids afterwards)
|
||||
req_position_cp: list[int] = []
|
||||
req_position_cp.extend(
|
||||
self.full_indices[self.pcp_rank *
|
||||
chunk_size:(self.pcp_rank + 1) * chunk_size])
|
||||
req_position_cp.extend(
|
||||
self.full_indices[num_pcp_padded_scheduled_tokens -
|
||||
(self.pcp_rank + 1) *
|
||||
chunk_size:num_pcp_padded_scheduled_tokens -
|
||||
self.pcp_rank * chunk_size])
|
||||
|
||||
return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad
|
||||
|
||||
num_pcp_scheduled_tokens = []
|
||||
ori_start_index = 0
|
||||
pad_start_index = 0
|
||||
pcp_split_input_ids_list = []
|
||||
pcp_split_hidden_states_list = []
|
||||
for ori_num_tokens in req_scheduled_tokens.values():
|
||||
req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = \
|
||||
_pcp_pad_and_split(ori_num_tokens)
|
||||
actual_num_tokens = len(req_position_pcp)
|
||||
num_pcp_scheduled_tokens.append(actual_num_tokens)
|
||||
pad_input_ids = F.pad(
|
||||
input_ids[ori_start_index:ori_start_index + ori_num_tokens],
|
||||
(0, num_pcp_pad))
|
||||
ori_start_index += ori_num_tokens
|
||||
pcp_chunk_indices = [
|
||||
pad_start_index + pos for pos in req_position_pcp
|
||||
]
|
||||
pcp_split_input_ids = pad_input_ids[req_position_pcp]
|
||||
pcp_split_hidden_states = target_hidden_states[pcp_chunk_indices]
|
||||
pcp_split_input_ids_list.append(pcp_split_input_ids)
|
||||
pcp_split_hidden_states_list.append(pcp_split_hidden_states)
|
||||
pad_start_index += num_pcp_padded_scheduled_tokens
|
||||
num_tokens = sum(num_pcp_scheduled_tokens)
|
||||
input_ids = torch.cat(pcp_split_input_ids_list)
|
||||
target_hidden_states = torch.cat(pcp_split_hidden_states_list, dim=0)
|
||||
max_query_len = max(num_pcp_scheduled_tokens)
|
||||
seq_lens = torch.tensor(num_pcp_scheduled_tokens, dtype=torch.int32)
|
||||
cu_num_tokens = torch.tensor(
|
||||
np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0))
|
||||
return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens
|
||||
|
||||
# update full-graph params for one spec token
|
||||
def _update_full_graph_params(self, forward_context, num_tokens):
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_mla_attn_dcp_pcp_params(self.update_stream,
|
||||
forward_context, num_tokens)
|
||||
else:
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
num_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
else:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_attn_dcp_pcp_params(self.update_stream, forward_context,
|
||||
num_tokens)
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
num_tokens, self.vllm_config)
|
||||
|
||||
@@ -1,73 +1,31 @@
|
||||
from typing import Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from vllm.config import CUDAGraphMode
|
||||
from vllm.distributed import get_pcp_group
|
||||
from vllm.forward_context import get_forward_context
|
||||
from vllm.logger import init_logger
|
||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||
from vllm.utils.math_utils import cdiv
|
||||
from vllm.utils.platform_utils import is_pin_memory_available
|
||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
||||
|
||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||
update_attn_dcp_pcp_params,
|
||||
update_attn_params,
|
||||
update_mla_attn_dcp_pcp_params,
|
||||
update_mla_attn_params)
|
||||
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
||||
|
||||
logger = init_logger(__name__)
|
||||
|
||||
PADDING_SLOT_ID = -1
|
||||
|
||||
_MTP_MODELS = {
|
||||
"DeepseekV3ForCausalLM":
|
||||
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
|
||||
"PanguUltraMoEForCausalLM":
|
||||
("vllm.model_executor.models.openpangu_mtp", "OpenPanguMTP"),
|
||||
"DeepseekV32ForCausalLM":
|
||||
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
|
||||
"Qwen3NextForCausalLM":
|
||||
("vllm.model_executor.models.qwen3_next_mtp", "Qwen3NextMTP")
|
||||
}
|
||||
|
||||
|
||||
class MtpProposer(EagleProposer):
|
||||
|
||||
# TODO: Find out why ModelRunner does not this explicit typing?
|
||||
model: Union[nn.Module, ACLGraphWrapper]
|
||||
|
||||
# update full-graph params for one spec token
|
||||
def _update_full_graph_params(self, forward_context, num_tokens):
|
||||
if self.vllm_config.model_config.use_mla:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_mla_attn_dcp_pcp_params(self.update_stream,
|
||||
forward_context, num_tokens)
|
||||
else:
|
||||
update_mla_attn_params(self.update_stream, forward_context,
|
||||
num_tokens,
|
||||
self.vllm_config.speculative_config)
|
||||
else:
|
||||
if self.pcp_size * self.dcp_size > 1:
|
||||
update_attn_dcp_pcp_params(self.update_stream, forward_context,
|
||||
num_tokens)
|
||||
else:
|
||||
update_attn_params(self.update_stream, forward_context,
|
||||
num_tokens, self.vllm_config)
|
||||
|
||||
@torch.inference_mode()
|
||||
def dummy_run(self,
|
||||
num_tokens: int,
|
||||
@@ -180,124 +138,6 @@ class MtpProposer(EagleProposer):
|
||||
if with_prefill:
|
||||
break
|
||||
|
||||
def _prepare_inputs(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: list[list[int]],
|
||||
num_draft_tokens: list[int],
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It updates to the common_attn_metadata to account for the rejected
|
||||
tokens (and newly sampled tokens). It also returns the token indices
|
||||
of the tokens that should be fed to the speculator.
|
||||
"""
|
||||
# E.g.
|
||||
# common_attn_metadata.query_start_loc{_cpu}:
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3]
|
||||
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
|
||||
# num_rejected_tokens: [n1, n2, n3]
|
||||
# This function computes the intermediate values:
|
||||
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
|
||||
# And returns:
|
||||
# common_attn_metadata.query_start_loc{_cpu}:
|
||||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||||
# common_attn_metadata.seq_lens{_cpu}:
|
||||
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
|
||||
# token_indices: [0, 1, ..., q1 - n1 - 1,
|
||||
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
||||
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
||||
|
||||
num_actual_reqs = len(num_draft_tokens)
|
||||
num_rejected_tokens = [
|
||||
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
||||
for i, n in enumerate(num_draft_tokens)
|
||||
]
|
||||
num_rejected_tokens = torch.tensor(num_rejected_tokens,
|
||||
dtype=torch.int32)
|
||||
|
||||
device = common_attn_metadata.query_start_loc.device
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
||||
num_actual_reqs
|
||||
+ 1]
|
||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_actual_reqs]
|
||||
new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens
|
||||
|
||||
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
||||
new_query_len_per_req = query_start_loc_cpu[
|
||||
1:] - query_start_loc_cpu[:-1]
|
||||
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
|
||||
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
|
||||
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
|
||||
|
||||
# [q1 - n1, q2 - n2, q3 - n3] ->
|
||||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
||||
new_query_start_loc_cpu = torch.zeros(
|
||||
query_start_loc_cpu.shape,
|
||||
dtype=torch.int32,
|
||||
pin_memory=is_pin_memory_available(),
|
||||
)
|
||||
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
|
||||
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
|
||||
|
||||
total_num_tokens = new_query_start_loc_np[-1]
|
||||
# Example assuming num_tokens_per_req_np = [2, 4, 3]
|
||||
# this implies that `new_query_start_locs` is:
|
||||
# [0, 2, 6, 9] ->
|
||||
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
|
||||
# _r1_ ____r2____ ___r3__
|
||||
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
|
||||
new_num_tokens_per_req_np)
|
||||
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
|
||||
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
|
||||
# _r1_ ____r2____ ___r3__
|
||||
token_offests = (self.token_arange_np[:total_num_tokens] -
|
||||
new_query_start_locs_expanded)
|
||||
|
||||
# Expand starting positions to match token pattern
|
||||
# [0, q1, q1 + q2] ->
|
||||
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
|
||||
# _r1_ _____r2_______ ___________r3____________
|
||||
old_query_start_locs_expanded = np.repeat(
|
||||
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
|
||||
# Final token indices are:
|
||||
# [0, 1, // req 1
|
||||
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
|
||||
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
|
||||
token_indices_np = token_offests + old_query_start_locs_expanded
|
||||
token_indices = torch.from_numpy(token_indices_np).to(
|
||||
device, non_blocking=True)
|
||||
|
||||
common_attn_metadata.slot_mapping[:token_indices.shape[0]].copy_(
|
||||
common_attn_metadata.slot_mapping[token_indices])
|
||||
common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1)
|
||||
|
||||
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
|
||||
# so we do not need to fixed them. But if they are used in the future,
|
||||
# we should fixed them.
|
||||
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=new_query_start_loc_cpu.to(device,
|
||||
non_blocking=True),
|
||||
query_start_loc_cpu=new_query_start_loc_cpu,
|
||||
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
|
||||
seq_lens_cpu=new_seq_lens_cpu,
|
||||
num_computed_tokens_cpu=common_attn_metadata.
|
||||
num_computed_tokens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
num_actual_tokens=total_num_tokens,
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
positions=common_attn_metadata.positions[token_indices],
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
max_seq_len=0)
|
||||
return spec_common_attn_metadata, token_indices
|
||||
|
||||
def _propose(
|
||||
self,
|
||||
# [num_tokens]
|
||||
@@ -731,257 +571,3 @@ class MtpProposer(EagleProposer):
|
||||
# mtp>1: [batch_size, k]
|
||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||
return draft_token_ids
|
||||
|
||||
# TODO Using torch instead of triton may result in poor performance
|
||||
def _prepare_input_kernel(self, out_ptr: torch.Tensor,
|
||||
cu_query_lens: torch.Tensor,
|
||||
cu_num_tokens: torch.Tensor, block_size: int):
|
||||
device = cu_query_lens.device
|
||||
dtype = out_ptr.dtype
|
||||
|
||||
offsets = torch.arange(block_size, device=device, dtype=dtype)
|
||||
start_pos = cu_num_tokens[:-1]
|
||||
end_pos = cu_num_tokens[1:]
|
||||
num_tokens = end_pos - start_pos
|
||||
|
||||
global_indices = (start_pos.view(-1, 1) + offsets.view(1, -1))
|
||||
values = (cu_query_lens[:-1].view(-1, 1) + offsets.view(1, -1))
|
||||
|
||||
mask = (offsets.view(1, -1) < num_tokens.view(-1, 1))
|
||||
|
||||
global_indices_flat = global_indices[mask]
|
||||
values_flat = values[mask]
|
||||
out_ptr[global_indices_flat] = values_flat
|
||||
|
||||
def prepare_next_token_ids_cpu(
|
||||
self,
|
||||
sampled_token_ids: list[list[int]],
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
num_scheduled_tokens: dict[str, int],
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It calculates the next token ids for each request based on the sampled
|
||||
token ids from the CPU. If a request has no sampled token ids (e.g.,
|
||||
during the initial decoding steps), it falls back to using the request
|
||||
state to get the next token id.
|
||||
"""
|
||||
req_ids = gpu_input_batch.req_ids
|
||||
next_token_ids: list[int] = []
|
||||
for i, token_ids in enumerate(sampled_token_ids):
|
||||
if token_ids:
|
||||
# Common case.
|
||||
next_token_id = token_ids[-1]
|
||||
else:
|
||||
# Partial prefill (rare case).
|
||||
# Get the next token id from the request state.
|
||||
req_id = req_ids[i]
|
||||
req_state = requests[req_id]
|
||||
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[
|
||||
req_id]
|
||||
next_token_id = req_state.get_token_id(seq_len)
|
||||
next_token_ids.append(next_token_id)
|
||||
next_token_ids = torch.tensor(next_token_ids,
|
||||
dtype=torch.int32,
|
||||
device=self.input_ids.device)
|
||||
return next_token_ids
|
||||
|
||||
def prepare_next_token_ids_padded(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
sampled_token_ids: torch.Tensor,
|
||||
requests: dict[str, CachedRequestState],
|
||||
gpu_input_batch: InputBatch,
|
||||
discard_request_indices: torch.Tensor,
|
||||
num_discarded_requests: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding.
|
||||
It calculates the next token ids and the number of valid sampled tokens
|
||||
for each request, considering the "discarded" requests whose next token
|
||||
is not sampled and comes from `request.get_token_id()` instead.
|
||||
It also accounts for the rejected tokens in `sampled_token_ids`.
|
||||
This function must use device functions to operate on the inputs, and
|
||||
should not introduce any blocking CPU-GPU synchronization.
|
||||
"""
|
||||
# TODO(Ben): Combine this into a custom fused kernel
|
||||
|
||||
# Precompute get_token_id for when there is no valid next token
|
||||
num_reqs = gpu_input_batch.num_reqs
|
||||
self.backup_next_token_ids.np[:num_reqs] = np.array([
|
||||
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
||||
common_attn_metadata.seq_lens_cpu[i].item())
|
||||
for i in range(num_reqs)
|
||||
])
|
||||
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
||||
|
||||
# Mask out the sampled tokens indices that should not be sampled.
|
||||
discard_sampled_tokens_req_indices = discard_request_indices[:
|
||||
num_discarded_requests]
|
||||
|
||||
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
|
||||
valid_sampled_token_ids_gpu.index_fill_(
|
||||
0, discard_sampled_tokens_req_indices, -1)
|
||||
|
||||
# Generate a mask for all valid tokens within those requests
|
||||
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
|
||||
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)
|
||||
|
||||
# Count the number of valid tokens in each request
|
||||
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
||||
|
||||
# Get the rightmost valid index per row
|
||||
last_valid_indices = valid_sampled_tokens_count - 1
|
||||
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
|
||||
|
||||
# Get last valid token from each row
|
||||
# (assume undefined state where there is no valid token)
|
||||
selected_tokens = torch.gather(
|
||||
valid_sampled_token_ids_gpu, 1,
|
||||
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
|
||||
|
||||
# Use last token if valid, pre-computed backup if not
|
||||
batch_size = valid_sampled_token_ids_gpu.shape[0]
|
||||
next_token_ids = torch.where(
|
||||
last_valid_indices != -1,
|
||||
selected_tokens,
|
||||
self.backup_next_token_ids.gpu[:batch_size],
|
||||
)
|
||||
|
||||
return next_token_ids, valid_sampled_tokens_count
|
||||
|
||||
def prepare_inputs_padded(
|
||||
self,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
spec_decode_metadata: SpecDecodeMetadata,
|
||||
valid_sampled_tokens_count: torch.Tensor,
|
||||
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
This function is used to prepare the inputs for speculative decoding
|
||||
It updates the common_attn_metadata for speculative decoding,
|
||||
but does not consider the rejected tokens. Instead, all tokens
|
||||
are included as inputs to the speculator, with the rejected tokens
|
||||
used as padding and filtered out later by `token_indices_to_sample`.
|
||||
No blocking CPU operations should be introduced in this function.
|
||||
"""
|
||||
num_draft_tokens_gpu = torch.cat([
|
||||
spec_decode_metadata.cu_num_draft_tokens[0:1],
|
||||
spec_decode_metadata.cu_num_draft_tokens[1:] -
|
||||
spec_decode_metadata.cu_num_draft_tokens[:-1],
|
||||
])
|
||||
|
||||
num_rejected_tokens_gpu = torch.where(
|
||||
num_draft_tokens_gpu > 0,
|
||||
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
|
||||
torch.zeros_like(num_draft_tokens_gpu),
|
||||
)
|
||||
|
||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
||||
|
||||
new_query_len_per_req = query_start_loc_cpu[
|
||||
1:] - query_start_loc_cpu[:-1]
|
||||
|
||||
total_num_tokens = query_start_loc_cpu[-1].item()
|
||||
token_indices = self.arange[:total_num_tokens]
|
||||
|
||||
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
|
||||
# so we do not need to fixed them. But if they are used in the future,
|
||||
# we should fixed them.
|
||||
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
||||
query_start_loc=common_attn_metadata.query_start_loc,
|
||||
query_start_loc_cpu=query_start_loc_cpu,
|
||||
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
|
||||
num_reqs=common_attn_metadata.num_reqs,
|
||||
num_actual_tokens=total_num_tokens,
|
||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
||||
max_query_len=new_query_len_per_req.max().item(),
|
||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
||||
slot_mapping=common_attn_metadata.slot_mapping,
|
||||
positions=common_attn_metadata.positions,
|
||||
attn_mask=self.runner.attn_mask,
|
||||
spec_attn_mask=self.runner.spec_attn_mask,
|
||||
attn_state=self.runner.attn_state,
|
||||
decode_token_per_req=self.runner.decode_token_per_req,
|
||||
num_computed_tokens_cpu=common_attn_metadata.
|
||||
num_computed_tokens_cpu,
|
||||
seq_lens=common_attn_metadata.seq_lens,
|
||||
max_seq_len=0)
|
||||
|
||||
query_start_loc = common_attn_metadata.query_start_loc[
|
||||
1:1 + num_rejected_tokens_gpu.shape[0]]
|
||||
token_indices_to_sample = query_start_loc - 1 - num_rejected_tokens_gpu
|
||||
|
||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||
|
||||
def _split_pcp_input(self, req_scheduled_tokens, input_ids,
|
||||
target_hidden_states):
|
||||
"""
|
||||
Split prefill input_ids and target_hidden_states in pcp group.
|
||||
1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad]
|
||||
2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5]
|
||||
3. split target_hidden_states (already include pcp padding):
|
||||
[h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5]
|
||||
4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split.
|
||||
"""
|
||||
if len(req_scheduled_tokens) == 0:
|
||||
# no prefill inputs to split, return empty result
|
||||
return (
|
||||
0,
|
||||
torch.zeros([0], device='npu'),
|
||||
torch.zeros([0, target_hidden_states.size(1)], device='npu'),
|
||||
0,
|
||||
torch.zeros([0]),
|
||||
torch.tensor([0], dtype=torch.int32),
|
||||
)
|
||||
|
||||
def _pcp_pad_and_split(num_tokens):
|
||||
num_pcp_padded_scheduled_tokens = cdiv(
|
||||
num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size
|
||||
pcp_pad = num_pcp_padded_scheduled_tokens - num_tokens
|
||||
chunk_size = num_pcp_padded_scheduled_tokens // (2 * self.pcp_size)
|
||||
|
||||
# split position_ids (and use split position_ids to split input_ids afterwards)
|
||||
req_position_cp: list[int] = []
|
||||
req_position_cp.extend(
|
||||
self.full_indices[self.pcp_rank *
|
||||
chunk_size:(self.pcp_rank + 1) * chunk_size])
|
||||
req_position_cp.extend(
|
||||
self.full_indices[num_pcp_padded_scheduled_tokens -
|
||||
(self.pcp_rank + 1) *
|
||||
chunk_size:num_pcp_padded_scheduled_tokens -
|
||||
self.pcp_rank * chunk_size])
|
||||
|
||||
return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad
|
||||
|
||||
num_pcp_scheduled_tokens = []
|
||||
ori_start_index = 0
|
||||
pad_start_index = 0
|
||||
pcp_split_input_ids_list = []
|
||||
pcp_split_hidden_states_list = []
|
||||
for ori_num_tokens in req_scheduled_tokens.values():
|
||||
req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = \
|
||||
_pcp_pad_and_split(ori_num_tokens)
|
||||
actual_num_tokens = len(req_position_pcp)
|
||||
num_pcp_scheduled_tokens.append(actual_num_tokens)
|
||||
pad_input_ids = F.pad(
|
||||
input_ids[ori_start_index:ori_start_index + ori_num_tokens],
|
||||
(0, num_pcp_pad))
|
||||
ori_start_index += ori_num_tokens
|
||||
pcp_chunk_indices = [
|
||||
pad_start_index + pos for pos in req_position_pcp
|
||||
]
|
||||
pcp_split_input_ids = pad_input_ids[req_position_pcp]
|
||||
pcp_split_hidden_states = target_hidden_states[pcp_chunk_indices]
|
||||
pcp_split_input_ids_list.append(pcp_split_input_ids)
|
||||
pcp_split_hidden_states_list.append(pcp_split_hidden_states)
|
||||
pad_start_index += num_pcp_padded_scheduled_tokens
|
||||
num_tokens = sum(num_pcp_scheduled_tokens)
|
||||
input_ids = torch.cat(pcp_split_input_ids_list)
|
||||
target_hidden_states = torch.cat(pcp_split_hidden_states_list, dim=0)
|
||||
max_query_len = max(num_pcp_scheduled_tokens)
|
||||
seq_lens = torch.tensor(num_pcp_scheduled_tokens, dtype=torch.int32)
|
||||
cu_num_tokens = torch.tensor(
|
||||
np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0))
|
||||
return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens
|
||||
|
||||
Reference in New Issue
Block a user