From fe3f2c7702603b5894c96ea9bb4a38e839f27547 Mon Sep 17 00:00:00 2001 From: Zetong Li <48438720+slippersss@users.noreply.github.com> Date: Tue, 6 Jan 2026 16:47:39 +0800 Subject: [PATCH] [Refactor][EAGLE] 3/N delete redundant methods in mtp_proposer (#5420) ### What this PR does / why we need it? This PR aims to delete redundant methods in mtp_proposer. All the deleted methods now can be found in eagle_proposer. We also remove some methods in eagle_proposer since they are identical to those in vllm-eagle. ### Does this PR introduce _any_ user-facing change? N/A ### How was this patch tested? by ci - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/81786c87748b0177111dfdc07af5351d8389baa1 --------- Signed-off-by: Zetong Li --- tests/ut/spec_decode/test_mtp_proposer.py | 1 + vllm_ascend/spec_decode/eagle_proposer.py | 163 +++++---- vllm_ascend/spec_decode/mtp_proposer.py | 416 +--------------------- 3 files changed, 97 insertions(+), 483 deletions(-) diff --git a/tests/ut/spec_decode/test_mtp_proposer.py b/tests/ut/spec_decode/test_mtp_proposer.py index 703c1597..918b6efb 100644 --- a/tests/ut/spec_decode/test_mtp_proposer.py +++ b/tests/ut/spec_decode/test_mtp_proposer.py @@ -262,6 +262,7 @@ class TestMtpProposer: device=torch.device("cpu")) assert torch.equal(next_token_ids, expected_next_tokens) + @patch("vllm_ascend.spec_decode.eagle_proposer.HAS_TRITON", False) @patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer") def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer): mock_buffer_instance = MagicMock() diff --git a/vllm_ascend/spec_decode/eagle_proposer.py b/vllm_ascend/spec_decode/eagle_proposer.py index 4fbf8532..b88a5ba9 100644 --- a/vllm_ascend/spec_decode/eagle_proposer.py +++ b/vllm_ascend/spec_decode/eagle_proposer.py @@ -4,6 +4,7 @@ from typing import Optional import numpy as np import torch import torch.nn as nn +import torch.nn.functional as F from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig, get_layers_from_vllm_config) from vllm.distributed.parallel_state import get_pp_group @@ -15,6 +16,7 @@ from vllm.model_executor.models import supports_multimodal from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM from vllm.triton_utils import HAS_TRITON, triton +from vllm.utils.math_utils import cdiv from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput @@ -28,7 +30,10 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, - update_attn_params) + update_attn_dcp_pcp_params, + update_attn_params, + update_mla_attn_dcp_pcp_params, + update_mla_attn_params) from vllm_ascend.ops.rotary_embedding import update_cos_sin from vllm_ascend.ops.triton.spec_decode.utils import \ prepare_inputs_padded_kernel @@ -37,10 +42,6 @@ from vllm_ascend.utils import shared_expert_dp_enabled PADDING_SLOT_ID = -1 -_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn' - -_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'} - # Currently we will fix block size to a small one since `num_reqs` can't be too large _PREPARE_INPUTS_BLOCK_SIZE = 4 @@ -93,27 +94,6 @@ class EagleProposer(VllmEagleProposer): self.use_sparse = hasattr(vllm_config.model_config.hf_text_config, "index_topk") - def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool: - """ - NOTE(2025-12-18): This is an explicit copy from vLLM EagleProposer, only added - to align with its logics. - - Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary - hidden states and directly uses the last layer output just like eagle1. - They might indicate this by setting "use_aux_hidden_state" to False - inside the "eagle_config" dict of their hf_config. - """ - if self.method != "eagle3": - return False - # Assume that eagle3 heads use aux hidden states by default - use_aux_hidden_state = True - eagle_config = getattr(self.draft_model_config.hf_config, - "eagle_config", None) - if eagle_config is not None: - use_aux_hidden_state = eagle_config.get("use_aux_hidden_state", - True) - return use_aux_hidden_state - def load_model(self, model: nn.Module) -> None: target_attn_layer_names = set( get_layers_from_vllm_config(self.vllm_config, @@ -512,48 +492,6 @@ class EagleProposer(VllmEagleProposer): draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1) return draft_token_ids - def _get_attn_metadata(self, attn_metadata): - if attn_metadata is not None and isinstance(attn_metadata, dict): - architecture = self.vllm_config.model_config.architecture - layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER) - attn_metadata = attn_metadata[layer_name] - - return attn_metadata - - def prepare_next_token_ids_cpu( - self, - sampled_token_ids: list[list[int]], - requests: dict[str, CachedRequestState], - gpu_input_batch: InputBatch, - num_scheduled_tokens: dict[str, int], - ) -> torch.Tensor: - """ - This function is used to prepare the inputs for speculative decoding. - It calculates the next token ids for each request based on the sampled - token ids from the CPU. If a request has no sampled token ids (e.g., - during the initial decoding steps), it falls back to using the request - state to get the next token id. - """ - req_ids = gpu_input_batch.req_ids - next_token_ids: list[int] = [] - for i, token_ids in enumerate(sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = req_ids[i] - req_state = requests[req_id] - seq_len = req_state.num_computed_tokens + num_scheduled_tokens[ - req_id] - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.input_ids.device) - return next_token_ids - def prepare_next_token_ids_padded( self, common_attn_metadata: CommonAttentionMetadata, @@ -829,3 +767,92 @@ class EagleProposer(VllmEagleProposer): max_seq_len=0) return spec_common_attn_metadata, token_indices, token_indices_to_sample + + def _split_pcp_input(self, req_scheduled_tokens, input_ids, + target_hidden_states): + """ + Split prefill input_ids and target_hidden_states in pcp group. + 1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad] + 2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5] + 3. split target_hidden_states (already include pcp padding): + [h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5] + 4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split. + """ + if len(req_scheduled_tokens) == 0: + # no prefill inputs to split, return empty result + return ( + 0, + torch.zeros([0], device='npu'), + torch.zeros([0, target_hidden_states.size(1)], device='npu'), + 0, + torch.zeros([0]), + torch.tensor([0], dtype=torch.int32), + ) + + def _pcp_pad_and_split(num_tokens): + num_pcp_padded_scheduled_tokens = cdiv( + num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size + pcp_pad = num_pcp_padded_scheduled_tokens - num_tokens + chunk_size = num_pcp_padded_scheduled_tokens // (2 * self.pcp_size) + + # split position_ids (and use split position_ids to split input_ids afterwards) + req_position_cp: list[int] = [] + req_position_cp.extend( + self.full_indices[self.pcp_rank * + chunk_size:(self.pcp_rank + 1) * chunk_size]) + req_position_cp.extend( + self.full_indices[num_pcp_padded_scheduled_tokens - + (self.pcp_rank + 1) * + chunk_size:num_pcp_padded_scheduled_tokens - + self.pcp_rank * chunk_size]) + + return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad + + num_pcp_scheduled_tokens = [] + ori_start_index = 0 + pad_start_index = 0 + pcp_split_input_ids_list = [] + pcp_split_hidden_states_list = [] + for ori_num_tokens in req_scheduled_tokens.values(): + req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = \ + _pcp_pad_and_split(ori_num_tokens) + actual_num_tokens = len(req_position_pcp) + num_pcp_scheduled_tokens.append(actual_num_tokens) + pad_input_ids = F.pad( + input_ids[ori_start_index:ori_start_index + ori_num_tokens], + (0, num_pcp_pad)) + ori_start_index += ori_num_tokens + pcp_chunk_indices = [ + pad_start_index + pos for pos in req_position_pcp + ] + pcp_split_input_ids = pad_input_ids[req_position_pcp] + pcp_split_hidden_states = target_hidden_states[pcp_chunk_indices] + pcp_split_input_ids_list.append(pcp_split_input_ids) + pcp_split_hidden_states_list.append(pcp_split_hidden_states) + pad_start_index += num_pcp_padded_scheduled_tokens + num_tokens = sum(num_pcp_scheduled_tokens) + input_ids = torch.cat(pcp_split_input_ids_list) + target_hidden_states = torch.cat(pcp_split_hidden_states_list, dim=0) + max_query_len = max(num_pcp_scheduled_tokens) + seq_lens = torch.tensor(num_pcp_scheduled_tokens, dtype=torch.int32) + cu_num_tokens = torch.tensor( + np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0)) + return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens + + # update full-graph params for one spec token + def _update_full_graph_params(self, forward_context, num_tokens): + if self.vllm_config.model_config.use_mla: + if self.pcp_size * self.dcp_size > 1: + update_mla_attn_dcp_pcp_params(self.update_stream, + forward_context, num_tokens) + else: + update_mla_attn_params(self.update_stream, forward_context, + num_tokens, + self.vllm_config.speculative_config) + else: + if self.pcp_size * self.dcp_size > 1: + update_attn_dcp_pcp_params(self.update_stream, forward_context, + num_tokens) + else: + update_attn_params(self.update_stream, forward_context, + num_tokens, self.vllm_config) diff --git a/vllm_ascend/spec_decode/mtp_proposer.py b/vllm_ascend/spec_decode/mtp_proposer.py index 21533d43..6f68b488 100644 --- a/vllm_ascend/spec_decode/mtp_proposer.py +++ b/vllm_ascend/spec_decode/mtp_proposer.py @@ -1,73 +1,31 @@ from typing import Optional, Union -import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F from vllm.config import CUDAGraphMode from vllm.distributed import get_pcp_group from vllm.forward_context import get_forward_context -from vllm.logger import init_logger from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM -from vllm.utils.math_utils import cdiv -from vllm.utils.platform_utils import is_pin_memory_available from vllm.v1.attention.backends.utils import CommonAttentionMetadata from vllm.v1.core.sched.output import SchedulerOutput from vllm.v1.sample.metadata import SamplingMetadata -from vllm.v1.spec_decode.metadata import SpecDecodeMetadata -from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch from vllm_ascend.ascend_forward_context import set_ascend_forward_context from vllm_ascend.attention.attention_v1 import AscendAttentionState from vllm_ascend.attention.utils import AscendCommonAttentionMetadata -from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper, - update_attn_dcp_pcp_params, - update_attn_params, - update_mla_attn_dcp_pcp_params, - update_mla_attn_params) +from vllm_ascend.compilation.acl_graph import ACLGraphWrapper from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla from vllm_ascend.spec_decode.eagle_proposer import EagleProposer from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable -logger = init_logger(__name__) - PADDING_SLOT_ID = -1 -_MTP_MODELS = { - "DeepseekV3ForCausalLM": - ("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"), - "PanguUltraMoEForCausalLM": - ("vllm.model_executor.models.openpangu_mtp", "OpenPanguMTP"), - "DeepseekV32ForCausalLM": - ("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"), - "Qwen3NextForCausalLM": - ("vllm.model_executor.models.qwen3_next_mtp", "Qwen3NextMTP") -} - class MtpProposer(EagleProposer): # TODO: Find out why ModelRunner does not this explicit typing? model: Union[nn.Module, ACLGraphWrapper] - # update full-graph params for one spec token - def _update_full_graph_params(self, forward_context, num_tokens): - if self.vllm_config.model_config.use_mla: - if self.pcp_size * self.dcp_size > 1: - update_mla_attn_dcp_pcp_params(self.update_stream, - forward_context, num_tokens) - else: - update_mla_attn_params(self.update_stream, forward_context, - num_tokens, - self.vllm_config.speculative_config) - else: - if self.pcp_size * self.dcp_size > 1: - update_attn_dcp_pcp_params(self.update_stream, forward_context, - num_tokens) - else: - update_attn_params(self.update_stream, forward_context, - num_tokens, self.vllm_config) - @torch.inference_mode() def dummy_run(self, num_tokens: int, @@ -180,124 +138,6 @@ class MtpProposer(EagleProposer): if with_prefill: break - def _prepare_inputs( - self, - common_attn_metadata: CommonAttentionMetadata, - sampled_token_ids: list[list[int]], - num_draft_tokens: list[int], - ) -> tuple[CommonAttentionMetadata, torch.Tensor]: - """ - This function is used to prepare the inputs for speculative decoding. - It updates to the common_attn_metadata to account for the rejected - tokens (and newly sampled tokens). It also returns the token indices - of the tokens that should be fed to the speculator. - """ - # E.g. - # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1, q1 + q2, q1 + q2 + q3] - # common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3] - # num_rejected_tokens: [n1, n2, n3] - # This function computes the intermediate values: - # num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3] - # And returns: - # common_attn_metadata.query_start_loc{_cpu}: - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] - # common_attn_metadata.seq_lens{_cpu}: - # [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1] - # token_indices: [0, 1, ..., q1 - n1 - 1, - # q1, q1 + 1, ..., q1 + q2 - n2 - 1, - # q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1] - - num_actual_reqs = len(num_draft_tokens) - num_rejected_tokens = [ - n + 1 - len(sampled_token_ids[i]) if n > 0 else 0 - for i, n in enumerate(num_draft_tokens) - ] - num_rejected_tokens = torch.tensor(num_rejected_tokens, - dtype=torch.int32) - - device = common_attn_metadata.query_start_loc.device - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[: - num_actual_reqs - + 1] - seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_actual_reqs] - new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens - - # [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3] - new_query_len_per_req = query_start_loc_cpu[ - 1:] - query_start_loc_cpu[:-1] - # [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3] - new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens - new_num_tokens_per_req_np = new_num_tokens_per_req.numpy() - - # [q1 - n1, q2 - n2, q3 - n3] -> - # [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3] - new_query_start_loc_cpu = torch.zeros( - query_start_loc_cpu.shape, - dtype=torch.int32, - pin_memory=is_pin_memory_available(), - ) - new_query_start_loc_np = new_query_start_loc_cpu.numpy() - np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:]) - - total_num_tokens = new_query_start_loc_np[-1] - # Example assuming num_tokens_per_req_np = [2, 4, 3] - # this implies that `new_query_start_locs` is: - # [0, 2, 6, 9] -> - # [0, 0, 2, 2, 2, 2, 6, 6, 6] - # _r1_ ____r2____ ___r3__ - new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1], - new_num_tokens_per_req_np) - # [0, 1, 2, 3, 4, 5, 6, 7, 8] -> - # [0, 1, 0, 1, 2, 3, 0, 1, 2] - # _r1_ ____r2____ ___r3__ - token_offests = (self.token_arange_np[:total_num_tokens] - - new_query_start_locs_expanded) - - # Expand starting positions to match token pattern - # [0, q1, q1 + q2] -> - # [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2] - # _r1_ _____r2_______ ___________r3____________ - old_query_start_locs_expanded = np.repeat( - query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np) - # Final token indices are: - # [0, 1, // req 1 - # q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2 - # q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3 - token_indices_np = token_offests + old_query_start_locs_expanded - token_indices = torch.from_numpy(token_indices_np).to( - device, non_blocking=True) - - common_attn_metadata.slot_mapping[:token_indices.shape[0]].copy_( - common_attn_metadata.slot_mapping[token_indices]) - common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1) - - # NOTE: Currently positions and seq_lens are not used in mla_v1 forward - # so we do not need to fixed them. But if they are used in the future, - # we should fixed them. - spec_common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=new_query_start_loc_cpu.to(device, - non_blocking=True), - query_start_loc_cpu=new_query_start_loc_cpu, - seq_lens=new_seq_lens_cpu.to(device, non_blocking=True), - seq_lens_cpu=new_seq_lens_cpu, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, - num_reqs=common_attn_metadata.num_reqs, - num_actual_tokens=total_num_tokens, - num_input_tokens=common_attn_metadata.num_input_tokens, - max_query_len=new_query_len_per_req.max().item(), - block_table_tensor=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping, - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - positions=common_attn_metadata.positions[token_indices], - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, - attn_state=self.runner.attn_state, - decode_token_per_req=self.runner.decode_token_per_req, - max_seq_len=0) - return spec_common_attn_metadata, token_indices - def _propose( self, # [num_tokens] @@ -731,257 +571,3 @@ class MtpProposer(EagleProposer): # mtp>1: [batch_size, k] draft_token_ids = torch.stack(draft_token_ids_list, dim=1) return draft_token_ids - - # TODO Using torch instead of triton may result in poor performance - def _prepare_input_kernel(self, out_ptr: torch.Tensor, - cu_query_lens: torch.Tensor, - cu_num_tokens: torch.Tensor, block_size: int): - device = cu_query_lens.device - dtype = out_ptr.dtype - - offsets = torch.arange(block_size, device=device, dtype=dtype) - start_pos = cu_num_tokens[:-1] - end_pos = cu_num_tokens[1:] - num_tokens = end_pos - start_pos - - global_indices = (start_pos.view(-1, 1) + offsets.view(1, -1)) - values = (cu_query_lens[:-1].view(-1, 1) + offsets.view(1, -1)) - - mask = (offsets.view(1, -1) < num_tokens.view(-1, 1)) - - global_indices_flat = global_indices[mask] - values_flat = values[mask] - out_ptr[global_indices_flat] = values_flat - - def prepare_next_token_ids_cpu( - self, - sampled_token_ids: list[list[int]], - requests: dict[str, CachedRequestState], - gpu_input_batch: InputBatch, - num_scheduled_tokens: dict[str, int], - ) -> torch.Tensor: - """ - This function is used to prepare the inputs for speculative decoding. - It calculates the next token ids for each request based on the sampled - token ids from the CPU. If a request has no sampled token ids (e.g., - during the initial decoding steps), it falls back to using the request - state to get the next token id. - """ - req_ids = gpu_input_batch.req_ids - next_token_ids: list[int] = [] - for i, token_ids in enumerate(sampled_token_ids): - if token_ids: - # Common case. - next_token_id = token_ids[-1] - else: - # Partial prefill (rare case). - # Get the next token id from the request state. - req_id = req_ids[i] - req_state = requests[req_id] - seq_len = req_state.num_computed_tokens + num_scheduled_tokens[ - req_id] - next_token_id = req_state.get_token_id(seq_len) - next_token_ids.append(next_token_id) - next_token_ids = torch.tensor(next_token_ids, - dtype=torch.int32, - device=self.input_ids.device) - return next_token_ids - - def prepare_next_token_ids_padded( - self, - common_attn_metadata: CommonAttentionMetadata, - sampled_token_ids: torch.Tensor, - requests: dict[str, CachedRequestState], - gpu_input_batch: InputBatch, - discard_request_indices: torch.Tensor, - num_discarded_requests: int, - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - This function is used to prepare the inputs for speculative decoding. - It calculates the next token ids and the number of valid sampled tokens - for each request, considering the "discarded" requests whose next token - is not sampled and comes from `request.get_token_id()` instead. - It also accounts for the rejected tokens in `sampled_token_ids`. - This function must use device functions to operate on the inputs, and - should not introduce any blocking CPU-GPU synchronization. - """ - # TODO(Ben): Combine this into a custom fused kernel - - # Precompute get_token_id for when there is no valid next token - num_reqs = gpu_input_batch.num_reqs - self.backup_next_token_ids.np[:num_reqs] = np.array([ - requests[gpu_input_batch.req_ids[i]].get_token_id( - common_attn_metadata.seq_lens_cpu[i].item()) - for i in range(num_reqs) - ]) - self.backup_next_token_ids.copy_to_gpu(num_reqs) - - # Mask out the sampled tokens indices that should not be sampled. - discard_sampled_tokens_req_indices = discard_request_indices[: - num_discarded_requests] - - valid_sampled_token_ids_gpu = sampled_token_ids.clone() - valid_sampled_token_ids_gpu.index_fill_( - 0, discard_sampled_tokens_req_indices, -1) - - # Generate a mask for all valid tokens within those requests - valid_mask = (valid_sampled_token_ids_gpu != -1) & ( - valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size) - - # Count the number of valid tokens in each request - valid_sampled_tokens_count = valid_mask.sum(dim=1) - - # Get the rightmost valid index per row - last_valid_indices = valid_sampled_tokens_count - 1 - last_valid_indices_safe = torch.clamp(last_valid_indices, min=0) - - # Get last valid token from each row - # (assume undefined state where there is no valid token) - selected_tokens = torch.gather( - valid_sampled_token_ids_gpu, 1, - last_valid_indices_safe.unsqueeze(1)).squeeze(1) - - # Use last token if valid, pre-computed backup if not - batch_size = valid_sampled_token_ids_gpu.shape[0] - next_token_ids = torch.where( - last_valid_indices != -1, - selected_tokens, - self.backup_next_token_ids.gpu[:batch_size], - ) - - return next_token_ids, valid_sampled_tokens_count - - def prepare_inputs_padded( - self, - common_attn_metadata: CommonAttentionMetadata, - spec_decode_metadata: SpecDecodeMetadata, - valid_sampled_tokens_count: torch.Tensor, - ) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]: - """ - This function is used to prepare the inputs for speculative decoding - It updates the common_attn_metadata for speculative decoding, - but does not consider the rejected tokens. Instead, all tokens - are included as inputs to the speculator, with the rejected tokens - used as padding and filtered out later by `token_indices_to_sample`. - No blocking CPU operations should be introduced in this function. - """ - num_draft_tokens_gpu = torch.cat([ - spec_decode_metadata.cu_num_draft_tokens[0:1], - spec_decode_metadata.cu_num_draft_tokens[1:] - - spec_decode_metadata.cu_num_draft_tokens[:-1], - ]) - - num_rejected_tokens_gpu = torch.where( - num_draft_tokens_gpu > 0, - num_draft_tokens_gpu + 1 - valid_sampled_tokens_count, - torch.zeros_like(num_draft_tokens_gpu), - ) - - query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu - - new_query_len_per_req = query_start_loc_cpu[ - 1:] - query_start_loc_cpu[:-1] - - total_num_tokens = query_start_loc_cpu[-1].item() - token_indices = self.arange[:total_num_tokens] - - # NOTE: Currently positions and seq_lens are not used in mla_v1 forward - # so we do not need to fixed them. But if they are used in the future, - # we should fixed them. - spec_common_attn_metadata = AscendCommonAttentionMetadata( - query_start_loc=common_attn_metadata.query_start_loc, - query_start_loc_cpu=query_start_loc_cpu, - seq_lens_cpu=common_attn_metadata.seq_lens_cpu, - num_reqs=common_attn_metadata.num_reqs, - num_actual_tokens=total_num_tokens, - num_input_tokens=common_attn_metadata.num_input_tokens, - max_query_len=new_query_len_per_req.max().item(), - actual_seq_lengths_q=self.runner.actual_seq_lengths_q, - block_table_tensor=common_attn_metadata.block_table_tensor, - slot_mapping=common_attn_metadata.slot_mapping, - positions=common_attn_metadata.positions, - attn_mask=self.runner.attn_mask, - spec_attn_mask=self.runner.spec_attn_mask, - attn_state=self.runner.attn_state, - decode_token_per_req=self.runner.decode_token_per_req, - num_computed_tokens_cpu=common_attn_metadata. - num_computed_tokens_cpu, - seq_lens=common_attn_metadata.seq_lens, - max_seq_len=0) - - query_start_loc = common_attn_metadata.query_start_loc[ - 1:1 + num_rejected_tokens_gpu.shape[0]] - token_indices_to_sample = query_start_loc - 1 - num_rejected_tokens_gpu - - return spec_common_attn_metadata, token_indices, token_indices_to_sample - - def _split_pcp_input(self, req_scheduled_tokens, input_ids, - target_hidden_states): - """ - Split prefill input_ids and target_hidden_states in pcp group. - 1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad] - 2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5] - 3. split target_hidden_states (already include pcp padding): - [h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5] - 4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split. - """ - if len(req_scheduled_tokens) == 0: - # no prefill inputs to split, return empty result - return ( - 0, - torch.zeros([0], device='npu'), - torch.zeros([0, target_hidden_states.size(1)], device='npu'), - 0, - torch.zeros([0]), - torch.tensor([0], dtype=torch.int32), - ) - - def _pcp_pad_and_split(num_tokens): - num_pcp_padded_scheduled_tokens = cdiv( - num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size - pcp_pad = num_pcp_padded_scheduled_tokens - num_tokens - chunk_size = num_pcp_padded_scheduled_tokens // (2 * self.pcp_size) - - # split position_ids (and use split position_ids to split input_ids afterwards) - req_position_cp: list[int] = [] - req_position_cp.extend( - self.full_indices[self.pcp_rank * - chunk_size:(self.pcp_rank + 1) * chunk_size]) - req_position_cp.extend( - self.full_indices[num_pcp_padded_scheduled_tokens - - (self.pcp_rank + 1) * - chunk_size:num_pcp_padded_scheduled_tokens - - self.pcp_rank * chunk_size]) - - return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad - - num_pcp_scheduled_tokens = [] - ori_start_index = 0 - pad_start_index = 0 - pcp_split_input_ids_list = [] - pcp_split_hidden_states_list = [] - for ori_num_tokens in req_scheduled_tokens.values(): - req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = \ - _pcp_pad_and_split(ori_num_tokens) - actual_num_tokens = len(req_position_pcp) - num_pcp_scheduled_tokens.append(actual_num_tokens) - pad_input_ids = F.pad( - input_ids[ori_start_index:ori_start_index + ori_num_tokens], - (0, num_pcp_pad)) - ori_start_index += ori_num_tokens - pcp_chunk_indices = [ - pad_start_index + pos for pos in req_position_pcp - ] - pcp_split_input_ids = pad_input_ids[req_position_pcp] - pcp_split_hidden_states = target_hidden_states[pcp_chunk_indices] - pcp_split_input_ids_list.append(pcp_split_input_ids) - pcp_split_hidden_states_list.append(pcp_split_hidden_states) - pad_start_index += num_pcp_padded_scheduled_tokens - num_tokens = sum(num_pcp_scheduled_tokens) - input_ids = torch.cat(pcp_split_input_ids_list) - target_hidden_states = torch.cat(pcp_split_hidden_states_list, dim=0) - max_query_len = max(num_pcp_scheduled_tokens) - seq_lens = torch.tensor(num_pcp_scheduled_tokens, dtype=torch.int32) - cu_num_tokens = torch.tensor( - np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0)) - return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens