[Refactor][EAGLE] 3/N delete redundant methods in mtp_proposer (#5420)
### What this PR does / why we need it?
This PR aims to delete redundant methods in mtp_proposer. All the
deleted methods now can be found in eagle_proposer. We also remove some
methods in eagle_proposer since they are identical to those in
vllm-eagle.
### Does this PR introduce _any_ user-facing change?
N/A
### How was this patch tested?
by ci
- vLLM version: release/v0.13.0
- vLLM main:
81786c8774
---------
Signed-off-by: Zetong Li <slippersss@126.com>
This commit is contained in:
@@ -262,6 +262,7 @@ class TestMtpProposer:
|
|||||||
device=torch.device("cpu"))
|
device=torch.device("cpu"))
|
||||||
assert torch.equal(next_token_ids, expected_next_tokens)
|
assert torch.equal(next_token_ids, expected_next_tokens)
|
||||||
|
|
||||||
|
@patch("vllm_ascend.spec_decode.eagle_proposer.HAS_TRITON", False)
|
||||||
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
@patch("vllm.v1.spec_decode.eagle.CpuGpuBuffer")
|
||||||
def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer):
|
def test_prepare_inputs_padded(self, mock_cpu_gpu_buffer):
|
||||||
mock_buffer_instance = MagicMock()
|
mock_buffer_instance = MagicMock()
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from typing import Optional
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.nn.functional as F
|
||||||
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
from vllm.config import (CompilationMode, CUDAGraphMode, VllmConfig,
|
||||||
get_layers_from_vllm_config)
|
get_layers_from_vllm_config)
|
||||||
from vllm.distributed.parallel_state import get_pp_group
|
from vllm.distributed.parallel_state import get_pp_group
|
||||||
@@ -15,6 +16,7 @@ from vllm.model_executor.models import supports_multimodal
|
|||||||
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
from vllm.model_executor.models.deepseek_v2 import DeepseekV32IndexerCache
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.triton_utils import HAS_TRITON, triton
|
from vllm.triton_utils import HAS_TRITON, triton
|
||||||
|
from vllm.utils.math_utils import cdiv
|
||||||
from vllm.utils.platform_utils import is_pin_memory_available
|
from vllm.utils.platform_utils import is_pin_memory_available
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
@@ -28,7 +30,10 @@ from vllm_ascend.attention.attention_mask import AttentionMaskBuilder
|
|||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
||||||
update_attn_params)
|
update_attn_dcp_pcp_params,
|
||||||
|
update_attn_params,
|
||||||
|
update_mla_attn_dcp_pcp_params,
|
||||||
|
update_mla_attn_params)
|
||||||
from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
from vllm_ascend.ops.rotary_embedding import update_cos_sin
|
||||||
from vllm_ascend.ops.triton.spec_decode.utils import \
|
from vllm_ascend.ops.triton.spec_decode.utils import \
|
||||||
prepare_inputs_padded_kernel
|
prepare_inputs_padded_kernel
|
||||||
@@ -37,10 +42,6 @@ from vllm_ascend.utils import shared_expert_dp_enabled
|
|||||||
|
|
||||||
PADDING_SLOT_ID = -1
|
PADDING_SLOT_ID = -1
|
||||||
|
|
||||||
_DEFAULT_FIRST_LAYER = 'model.layers.0.self_attn.attn'
|
|
||||||
|
|
||||||
_FIRST_LAYERS = {"Qwen3NextForCausalLM": 'model.layers.3.self_attn.attn'}
|
|
||||||
|
|
||||||
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
# Currently we will fix block size to a small one since `num_reqs` can't be too large
|
||||||
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
_PREPARE_INPUTS_BLOCK_SIZE = 4
|
||||||
|
|
||||||
@@ -93,27 +94,6 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
self.use_sparse = hasattr(vllm_config.model_config.hf_text_config,
|
self.use_sparse = hasattr(vllm_config.model_config.hf_text_config,
|
||||||
"index_topk")
|
"index_topk")
|
||||||
|
|
||||||
def _get_eagle3_use_aux_hidden_state_from_config(self) -> bool:
|
|
||||||
"""
|
|
||||||
NOTE(2025-12-18): This is an explicit copy from vLLM EagleProposer, only added
|
|
||||||
to align with its logics.
|
|
||||||
|
|
||||||
Some eagle3 heads (e.g., nvidia/gpt-oss-120b-Eagle3-v2) do not use auxiliary
|
|
||||||
hidden states and directly uses the last layer output just like eagle1.
|
|
||||||
They might indicate this by setting "use_aux_hidden_state" to False
|
|
||||||
inside the "eagle_config" dict of their hf_config.
|
|
||||||
"""
|
|
||||||
if self.method != "eagle3":
|
|
||||||
return False
|
|
||||||
# Assume that eagle3 heads use aux hidden states by default
|
|
||||||
use_aux_hidden_state = True
|
|
||||||
eagle_config = getattr(self.draft_model_config.hf_config,
|
|
||||||
"eagle_config", None)
|
|
||||||
if eagle_config is not None:
|
|
||||||
use_aux_hidden_state = eagle_config.get("use_aux_hidden_state",
|
|
||||||
True)
|
|
||||||
return use_aux_hidden_state
|
|
||||||
|
|
||||||
def load_model(self, model: nn.Module) -> None:
|
def load_model(self, model: nn.Module) -> None:
|
||||||
target_attn_layer_names = set(
|
target_attn_layer_names = set(
|
||||||
get_layers_from_vllm_config(self.vllm_config,
|
get_layers_from_vllm_config(self.vllm_config,
|
||||||
@@ -512,48 +492,6 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1)
|
draft_token_ids = draft_token_ids_tensor.swapaxes(0, 1)
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
def _get_attn_metadata(self, attn_metadata):
|
|
||||||
if attn_metadata is not None and isinstance(attn_metadata, dict):
|
|
||||||
architecture = self.vllm_config.model_config.architecture
|
|
||||||
layer_name = _FIRST_LAYERS.get(architecture, _DEFAULT_FIRST_LAYER)
|
|
||||||
attn_metadata = attn_metadata[layer_name]
|
|
||||||
|
|
||||||
return attn_metadata
|
|
||||||
|
|
||||||
def prepare_next_token_ids_cpu(
|
|
||||||
self,
|
|
||||||
sampled_token_ids: list[list[int]],
|
|
||||||
requests: dict[str, CachedRequestState],
|
|
||||||
gpu_input_batch: InputBatch,
|
|
||||||
num_scheduled_tokens: dict[str, int],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This function is used to prepare the inputs for speculative decoding.
|
|
||||||
It calculates the next token ids for each request based on the sampled
|
|
||||||
token ids from the CPU. If a request has no sampled token ids (e.g.,
|
|
||||||
during the initial decoding steps), it falls back to using the request
|
|
||||||
state to get the next token id.
|
|
||||||
"""
|
|
||||||
req_ids = gpu_input_batch.req_ids
|
|
||||||
next_token_ids: list[int] = []
|
|
||||||
for i, token_ids in enumerate(sampled_token_ids):
|
|
||||||
if token_ids:
|
|
||||||
# Common case.
|
|
||||||
next_token_id = token_ids[-1]
|
|
||||||
else:
|
|
||||||
# Partial prefill (rare case).
|
|
||||||
# Get the next token id from the request state.
|
|
||||||
req_id = req_ids[i]
|
|
||||||
req_state = requests[req_id]
|
|
||||||
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[
|
|
||||||
req_id]
|
|
||||||
next_token_id = req_state.get_token_id(seq_len)
|
|
||||||
next_token_ids.append(next_token_id)
|
|
||||||
next_token_ids = torch.tensor(next_token_ids,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.input_ids.device)
|
|
||||||
return next_token_ids
|
|
||||||
|
|
||||||
def prepare_next_token_ids_padded(
|
def prepare_next_token_ids_padded(
|
||||||
self,
|
self,
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
common_attn_metadata: CommonAttentionMetadata,
|
||||||
@@ -829,3 +767,92 @@ class EagleProposer(VllmEagleProposer):
|
|||||||
max_seq_len=0)
|
max_seq_len=0)
|
||||||
|
|
||||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
||||||
|
|
||||||
|
def _split_pcp_input(self, req_scheduled_tokens, input_ids,
|
||||||
|
target_hidden_states):
|
||||||
|
"""
|
||||||
|
Split prefill input_ids and target_hidden_states in pcp group.
|
||||||
|
1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad]
|
||||||
|
2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5]
|
||||||
|
3. split target_hidden_states (already include pcp padding):
|
||||||
|
[h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5]
|
||||||
|
4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split.
|
||||||
|
"""
|
||||||
|
if len(req_scheduled_tokens) == 0:
|
||||||
|
# no prefill inputs to split, return empty result
|
||||||
|
return (
|
||||||
|
0,
|
||||||
|
torch.zeros([0], device='npu'),
|
||||||
|
torch.zeros([0, target_hidden_states.size(1)], device='npu'),
|
||||||
|
0,
|
||||||
|
torch.zeros([0]),
|
||||||
|
torch.tensor([0], dtype=torch.int32),
|
||||||
|
)
|
||||||
|
|
||||||
|
def _pcp_pad_and_split(num_tokens):
|
||||||
|
num_pcp_padded_scheduled_tokens = cdiv(
|
||||||
|
num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size
|
||||||
|
pcp_pad = num_pcp_padded_scheduled_tokens - num_tokens
|
||||||
|
chunk_size = num_pcp_padded_scheduled_tokens // (2 * self.pcp_size)
|
||||||
|
|
||||||
|
# split position_ids (and use split position_ids to split input_ids afterwards)
|
||||||
|
req_position_cp: list[int] = []
|
||||||
|
req_position_cp.extend(
|
||||||
|
self.full_indices[self.pcp_rank *
|
||||||
|
chunk_size:(self.pcp_rank + 1) * chunk_size])
|
||||||
|
req_position_cp.extend(
|
||||||
|
self.full_indices[num_pcp_padded_scheduled_tokens -
|
||||||
|
(self.pcp_rank + 1) *
|
||||||
|
chunk_size:num_pcp_padded_scheduled_tokens -
|
||||||
|
self.pcp_rank * chunk_size])
|
||||||
|
|
||||||
|
return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad
|
||||||
|
|
||||||
|
num_pcp_scheduled_tokens = []
|
||||||
|
ori_start_index = 0
|
||||||
|
pad_start_index = 0
|
||||||
|
pcp_split_input_ids_list = []
|
||||||
|
pcp_split_hidden_states_list = []
|
||||||
|
for ori_num_tokens in req_scheduled_tokens.values():
|
||||||
|
req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = \
|
||||||
|
_pcp_pad_and_split(ori_num_tokens)
|
||||||
|
actual_num_tokens = len(req_position_pcp)
|
||||||
|
num_pcp_scheduled_tokens.append(actual_num_tokens)
|
||||||
|
pad_input_ids = F.pad(
|
||||||
|
input_ids[ori_start_index:ori_start_index + ori_num_tokens],
|
||||||
|
(0, num_pcp_pad))
|
||||||
|
ori_start_index += ori_num_tokens
|
||||||
|
pcp_chunk_indices = [
|
||||||
|
pad_start_index + pos for pos in req_position_pcp
|
||||||
|
]
|
||||||
|
pcp_split_input_ids = pad_input_ids[req_position_pcp]
|
||||||
|
pcp_split_hidden_states = target_hidden_states[pcp_chunk_indices]
|
||||||
|
pcp_split_input_ids_list.append(pcp_split_input_ids)
|
||||||
|
pcp_split_hidden_states_list.append(pcp_split_hidden_states)
|
||||||
|
pad_start_index += num_pcp_padded_scheduled_tokens
|
||||||
|
num_tokens = sum(num_pcp_scheduled_tokens)
|
||||||
|
input_ids = torch.cat(pcp_split_input_ids_list)
|
||||||
|
target_hidden_states = torch.cat(pcp_split_hidden_states_list, dim=0)
|
||||||
|
max_query_len = max(num_pcp_scheduled_tokens)
|
||||||
|
seq_lens = torch.tensor(num_pcp_scheduled_tokens, dtype=torch.int32)
|
||||||
|
cu_num_tokens = torch.tensor(
|
||||||
|
np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0))
|
||||||
|
return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens
|
||||||
|
|
||||||
|
# update full-graph params for one spec token
|
||||||
|
def _update_full_graph_params(self, forward_context, num_tokens):
|
||||||
|
if self.vllm_config.model_config.use_mla:
|
||||||
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
|
update_mla_attn_dcp_pcp_params(self.update_stream,
|
||||||
|
forward_context, num_tokens)
|
||||||
|
else:
|
||||||
|
update_mla_attn_params(self.update_stream, forward_context,
|
||||||
|
num_tokens,
|
||||||
|
self.vllm_config.speculative_config)
|
||||||
|
else:
|
||||||
|
if self.pcp_size * self.dcp_size > 1:
|
||||||
|
update_attn_dcp_pcp_params(self.update_stream, forward_context,
|
||||||
|
num_tokens)
|
||||||
|
else:
|
||||||
|
update_attn_params(self.update_stream, forward_context,
|
||||||
|
num_tokens, self.vllm_config)
|
||||||
|
|||||||
@@ -1,73 +1,31 @@
|
|||||||
from typing import Optional, Union
|
from typing import Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
|
||||||
from vllm.config import CUDAGraphMode
|
from vllm.config import CUDAGraphMode
|
||||||
from vllm.distributed import get_pcp_group
|
from vllm.distributed import get_pcp_group
|
||||||
from vllm.forward_context import get_forward_context
|
from vllm.forward_context import get_forward_context
|
||||||
from vllm.logger import init_logger
|
|
||||||
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
|
||||||
from vllm.utils.math_utils import cdiv
|
|
||||||
from vllm.utils.platform_utils import is_pin_memory_available
|
|
||||||
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
from vllm.v1.attention.backends.utils import CommonAttentionMetadata
|
||||||
from vllm.v1.core.sched.output import SchedulerOutput
|
from vllm.v1.core.sched.output import SchedulerOutput
|
||||||
from vllm.v1.sample.metadata import SamplingMetadata
|
from vllm.v1.sample.metadata import SamplingMetadata
|
||||||
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
|
|
||||||
from vllm.v1.worker.gpu_input_batch import CachedRequestState, InputBatch
|
|
||||||
|
|
||||||
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
from vllm_ascend.ascend_forward_context import set_ascend_forward_context
|
||||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||||
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
from vllm_ascend.attention.utils import AscendCommonAttentionMetadata
|
||||||
from vllm_ascend.compilation.acl_graph import (ACLGraphWrapper,
|
from vllm_ascend.compilation.acl_graph import ACLGraphWrapper
|
||||||
update_attn_dcp_pcp_params,
|
|
||||||
update_attn_params,
|
|
||||||
update_mla_attn_dcp_pcp_params,
|
|
||||||
update_mla_attn_params)
|
|
||||||
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
from vllm_ascend.ops.rotary_embedding import get_cos_and_sin_mla
|
||||||
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
from vllm_ascend.spec_decode.eagle_proposer import EagleProposer
|
||||||
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
from vllm_ascend.utils import ProfileExecuteDuration, lmhead_tp_enable
|
||||||
|
|
||||||
logger = init_logger(__name__)
|
|
||||||
|
|
||||||
PADDING_SLOT_ID = -1
|
PADDING_SLOT_ID = -1
|
||||||
|
|
||||||
_MTP_MODELS = {
|
|
||||||
"DeepseekV3ForCausalLM":
|
|
||||||
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
|
|
||||||
"PanguUltraMoEForCausalLM":
|
|
||||||
("vllm.model_executor.models.openpangu_mtp", "OpenPanguMTP"),
|
|
||||||
"DeepseekV32ForCausalLM":
|
|
||||||
("vllm.model_executor.models.deepseek_mtp", "DeepSeekMTP"),
|
|
||||||
"Qwen3NextForCausalLM":
|
|
||||||
("vllm.model_executor.models.qwen3_next_mtp", "Qwen3NextMTP")
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
class MtpProposer(EagleProposer):
|
class MtpProposer(EagleProposer):
|
||||||
|
|
||||||
# TODO: Find out why ModelRunner does not this explicit typing?
|
# TODO: Find out why ModelRunner does not this explicit typing?
|
||||||
model: Union[nn.Module, ACLGraphWrapper]
|
model: Union[nn.Module, ACLGraphWrapper]
|
||||||
|
|
||||||
# update full-graph params for one spec token
|
|
||||||
def _update_full_graph_params(self, forward_context, num_tokens):
|
|
||||||
if self.vllm_config.model_config.use_mla:
|
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
|
||||||
update_mla_attn_dcp_pcp_params(self.update_stream,
|
|
||||||
forward_context, num_tokens)
|
|
||||||
else:
|
|
||||||
update_mla_attn_params(self.update_stream, forward_context,
|
|
||||||
num_tokens,
|
|
||||||
self.vllm_config.speculative_config)
|
|
||||||
else:
|
|
||||||
if self.pcp_size * self.dcp_size > 1:
|
|
||||||
update_attn_dcp_pcp_params(self.update_stream, forward_context,
|
|
||||||
num_tokens)
|
|
||||||
else:
|
|
||||||
update_attn_params(self.update_stream, forward_context,
|
|
||||||
num_tokens, self.vllm_config)
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def dummy_run(self,
|
def dummy_run(self,
|
||||||
num_tokens: int,
|
num_tokens: int,
|
||||||
@@ -180,124 +138,6 @@ class MtpProposer(EagleProposer):
|
|||||||
if with_prefill:
|
if with_prefill:
|
||||||
break
|
break
|
||||||
|
|
||||||
def _prepare_inputs(
|
|
||||||
self,
|
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
|
||||||
sampled_token_ids: list[list[int]],
|
|
||||||
num_draft_tokens: list[int],
|
|
||||||
) -> tuple[CommonAttentionMetadata, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
This function is used to prepare the inputs for speculative decoding.
|
|
||||||
It updates to the common_attn_metadata to account for the rejected
|
|
||||||
tokens (and newly sampled tokens). It also returns the token indices
|
|
||||||
of the tokens that should be fed to the speculator.
|
|
||||||
"""
|
|
||||||
# E.g.
|
|
||||||
# common_attn_metadata.query_start_loc{_cpu}:
|
|
||||||
# [0, q1, q1 + q2, q1 + q2 + q3]
|
|
||||||
# common_attn_metadata.seq_lens{_cpu}: [s1, s2, s3]
|
|
||||||
# num_rejected_tokens: [n1, n2, n3]
|
|
||||||
# This function computes the intermediate values:
|
|
||||||
# num_tokens_per_req: [q1 - n1, q2 - n2, q3 - n3]
|
|
||||||
# And returns:
|
|
||||||
# common_attn_metadata.query_start_loc{_cpu}:
|
|
||||||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
|
||||||
# common_attn_metadata.seq_lens{_cpu}:
|
|
||||||
# [s1 - n1 + 1, s2 - n2 + 1, s3 - n3 + 1]
|
|
||||||
# token_indices: [0, 1, ..., q1 - n1 - 1,
|
|
||||||
# q1, q1 + 1, ..., q1 + q2 - n2 - 1,
|
|
||||||
# q1 + q2, q1 + q2 + 1, ..., q1 + q2 + q3 - n3 - 1]
|
|
||||||
|
|
||||||
num_actual_reqs = len(num_draft_tokens)
|
|
||||||
num_rejected_tokens = [
|
|
||||||
n + 1 - len(sampled_token_ids[i]) if n > 0 else 0
|
|
||||||
for i, n in enumerate(num_draft_tokens)
|
|
||||||
]
|
|
||||||
num_rejected_tokens = torch.tensor(num_rejected_tokens,
|
|
||||||
dtype=torch.int32)
|
|
||||||
|
|
||||||
device = common_attn_metadata.query_start_loc.device
|
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu[:
|
|
||||||
num_actual_reqs
|
|
||||||
+ 1]
|
|
||||||
seq_lens_cpu = common_attn_metadata.seq_lens_cpu[:num_actual_reqs]
|
|
||||||
new_seq_lens_cpu = seq_lens_cpu - num_rejected_tokens
|
|
||||||
|
|
||||||
# [0, q1, q1 + q2, q1 + q2 + q3] -> [q1, q2, q3]
|
|
||||||
new_query_len_per_req = query_start_loc_cpu[
|
|
||||||
1:] - query_start_loc_cpu[:-1]
|
|
||||||
# [q1, q2, q3] -> [q1 - n1, q2 - n2, q3 - n3]
|
|
||||||
new_num_tokens_per_req = new_query_len_per_req - num_rejected_tokens
|
|
||||||
new_num_tokens_per_req_np = new_num_tokens_per_req.numpy()
|
|
||||||
|
|
||||||
# [q1 - n1, q2 - n2, q3 - n3] ->
|
|
||||||
# [0, q1 - n1, q1 + q2 - n1 - n2, q1 + q2 + q3 - n1 - n2 - n3]
|
|
||||||
new_query_start_loc_cpu = torch.zeros(
|
|
||||||
query_start_loc_cpu.shape,
|
|
||||||
dtype=torch.int32,
|
|
||||||
pin_memory=is_pin_memory_available(),
|
|
||||||
)
|
|
||||||
new_query_start_loc_np = new_query_start_loc_cpu.numpy()
|
|
||||||
np.cumsum(new_num_tokens_per_req_np, out=new_query_start_loc_np[1:])
|
|
||||||
|
|
||||||
total_num_tokens = new_query_start_loc_np[-1]
|
|
||||||
# Example assuming num_tokens_per_req_np = [2, 4, 3]
|
|
||||||
# this implies that `new_query_start_locs` is:
|
|
||||||
# [0, 2, 6, 9] ->
|
|
||||||
# [0, 0, 2, 2, 2, 2, 6, 6, 6]
|
|
||||||
# _r1_ ____r2____ ___r3__
|
|
||||||
new_query_start_locs_expanded = np.repeat(new_query_start_loc_np[:-1],
|
|
||||||
new_num_tokens_per_req_np)
|
|
||||||
# [0, 1, 2, 3, 4, 5, 6, 7, 8] ->
|
|
||||||
# [0, 1, 0, 1, 2, 3, 0, 1, 2]
|
|
||||||
# _r1_ ____r2____ ___r3__
|
|
||||||
token_offests = (self.token_arange_np[:total_num_tokens] -
|
|
||||||
new_query_start_locs_expanded)
|
|
||||||
|
|
||||||
# Expand starting positions to match token pattern
|
|
||||||
# [0, q1, q1 + q2] ->
|
|
||||||
# [0, 0, q1, q1, q1, q1, q1 + q2, q1 + q2, q1 + q2]
|
|
||||||
# _r1_ _____r2_______ ___________r3____________
|
|
||||||
old_query_start_locs_expanded = np.repeat(
|
|
||||||
query_start_loc_cpu[:-1].numpy(), new_num_tokens_per_req_np)
|
|
||||||
# Final token indices are:
|
|
||||||
# [0, 1, // req 1
|
|
||||||
# q1 + 0, q1 + 1, q1 + 2, q1 + 3, // req 2
|
|
||||||
# q1 + q2 + 0, q1 + q2 + 1, q1 + q2 + 2] // req 3
|
|
||||||
token_indices_np = token_offests + old_query_start_locs_expanded
|
|
||||||
token_indices = torch.from_numpy(token_indices_np).to(
|
|
||||||
device, non_blocking=True)
|
|
||||||
|
|
||||||
common_attn_metadata.slot_mapping[:token_indices.shape[0]].copy_(
|
|
||||||
common_attn_metadata.slot_mapping[token_indices])
|
|
||||||
common_attn_metadata.slot_mapping[token_indices.shape[0]:].fill_(-1)
|
|
||||||
|
|
||||||
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
|
|
||||||
# so we do not need to fixed them. But if they are used in the future,
|
|
||||||
# we should fixed them.
|
|
||||||
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
|
||||||
query_start_loc=new_query_start_loc_cpu.to(device,
|
|
||||||
non_blocking=True),
|
|
||||||
query_start_loc_cpu=new_query_start_loc_cpu,
|
|
||||||
seq_lens=new_seq_lens_cpu.to(device, non_blocking=True),
|
|
||||||
seq_lens_cpu=new_seq_lens_cpu,
|
|
||||||
num_computed_tokens_cpu=common_attn_metadata.
|
|
||||||
num_computed_tokens_cpu,
|
|
||||||
num_reqs=common_attn_metadata.num_reqs,
|
|
||||||
num_actual_tokens=total_num_tokens,
|
|
||||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
|
||||||
max_query_len=new_query_len_per_req.max().item(),
|
|
||||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
|
||||||
slot_mapping=common_attn_metadata.slot_mapping,
|
|
||||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
|
||||||
positions=common_attn_metadata.positions[token_indices],
|
|
||||||
attn_mask=self.runner.attn_mask,
|
|
||||||
spec_attn_mask=self.runner.spec_attn_mask,
|
|
||||||
attn_state=self.runner.attn_state,
|
|
||||||
decode_token_per_req=self.runner.decode_token_per_req,
|
|
||||||
max_seq_len=0)
|
|
||||||
return spec_common_attn_metadata, token_indices
|
|
||||||
|
|
||||||
def _propose(
|
def _propose(
|
||||||
self,
|
self,
|
||||||
# [num_tokens]
|
# [num_tokens]
|
||||||
@@ -731,257 +571,3 @@ class MtpProposer(EagleProposer):
|
|||||||
# mtp>1: [batch_size, k]
|
# mtp>1: [batch_size, k]
|
||||||
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
draft_token_ids = torch.stack(draft_token_ids_list, dim=1)
|
||||||
return draft_token_ids
|
return draft_token_ids
|
||||||
|
|
||||||
# TODO Using torch instead of triton may result in poor performance
|
|
||||||
def _prepare_input_kernel(self, out_ptr: torch.Tensor,
|
|
||||||
cu_query_lens: torch.Tensor,
|
|
||||||
cu_num_tokens: torch.Tensor, block_size: int):
|
|
||||||
device = cu_query_lens.device
|
|
||||||
dtype = out_ptr.dtype
|
|
||||||
|
|
||||||
offsets = torch.arange(block_size, device=device, dtype=dtype)
|
|
||||||
start_pos = cu_num_tokens[:-1]
|
|
||||||
end_pos = cu_num_tokens[1:]
|
|
||||||
num_tokens = end_pos - start_pos
|
|
||||||
|
|
||||||
global_indices = (start_pos.view(-1, 1) + offsets.view(1, -1))
|
|
||||||
values = (cu_query_lens[:-1].view(-1, 1) + offsets.view(1, -1))
|
|
||||||
|
|
||||||
mask = (offsets.view(1, -1) < num_tokens.view(-1, 1))
|
|
||||||
|
|
||||||
global_indices_flat = global_indices[mask]
|
|
||||||
values_flat = values[mask]
|
|
||||||
out_ptr[global_indices_flat] = values_flat
|
|
||||||
|
|
||||||
def prepare_next_token_ids_cpu(
|
|
||||||
self,
|
|
||||||
sampled_token_ids: list[list[int]],
|
|
||||||
requests: dict[str, CachedRequestState],
|
|
||||||
gpu_input_batch: InputBatch,
|
|
||||||
num_scheduled_tokens: dict[str, int],
|
|
||||||
) -> torch.Tensor:
|
|
||||||
"""
|
|
||||||
This function is used to prepare the inputs for speculative decoding.
|
|
||||||
It calculates the next token ids for each request based on the sampled
|
|
||||||
token ids from the CPU. If a request has no sampled token ids (e.g.,
|
|
||||||
during the initial decoding steps), it falls back to using the request
|
|
||||||
state to get the next token id.
|
|
||||||
"""
|
|
||||||
req_ids = gpu_input_batch.req_ids
|
|
||||||
next_token_ids: list[int] = []
|
|
||||||
for i, token_ids in enumerate(sampled_token_ids):
|
|
||||||
if token_ids:
|
|
||||||
# Common case.
|
|
||||||
next_token_id = token_ids[-1]
|
|
||||||
else:
|
|
||||||
# Partial prefill (rare case).
|
|
||||||
# Get the next token id from the request state.
|
|
||||||
req_id = req_ids[i]
|
|
||||||
req_state = requests[req_id]
|
|
||||||
seq_len = req_state.num_computed_tokens + num_scheduled_tokens[
|
|
||||||
req_id]
|
|
||||||
next_token_id = req_state.get_token_id(seq_len)
|
|
||||||
next_token_ids.append(next_token_id)
|
|
||||||
next_token_ids = torch.tensor(next_token_ids,
|
|
||||||
dtype=torch.int32,
|
|
||||||
device=self.input_ids.device)
|
|
||||||
return next_token_ids
|
|
||||||
|
|
||||||
def prepare_next_token_ids_padded(
|
|
||||||
self,
|
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
|
||||||
sampled_token_ids: torch.Tensor,
|
|
||||||
requests: dict[str, CachedRequestState],
|
|
||||||
gpu_input_batch: InputBatch,
|
|
||||||
discard_request_indices: torch.Tensor,
|
|
||||||
num_discarded_requests: int,
|
|
||||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
This function is used to prepare the inputs for speculative decoding.
|
|
||||||
It calculates the next token ids and the number of valid sampled tokens
|
|
||||||
for each request, considering the "discarded" requests whose next token
|
|
||||||
is not sampled and comes from `request.get_token_id()` instead.
|
|
||||||
It also accounts for the rejected tokens in `sampled_token_ids`.
|
|
||||||
This function must use device functions to operate on the inputs, and
|
|
||||||
should not introduce any blocking CPU-GPU synchronization.
|
|
||||||
"""
|
|
||||||
# TODO(Ben): Combine this into a custom fused kernel
|
|
||||||
|
|
||||||
# Precompute get_token_id for when there is no valid next token
|
|
||||||
num_reqs = gpu_input_batch.num_reqs
|
|
||||||
self.backup_next_token_ids.np[:num_reqs] = np.array([
|
|
||||||
requests[gpu_input_batch.req_ids[i]].get_token_id(
|
|
||||||
common_attn_metadata.seq_lens_cpu[i].item())
|
|
||||||
for i in range(num_reqs)
|
|
||||||
])
|
|
||||||
self.backup_next_token_ids.copy_to_gpu(num_reqs)
|
|
||||||
|
|
||||||
# Mask out the sampled tokens indices that should not be sampled.
|
|
||||||
discard_sampled_tokens_req_indices = discard_request_indices[:
|
|
||||||
num_discarded_requests]
|
|
||||||
|
|
||||||
valid_sampled_token_ids_gpu = sampled_token_ids.clone()
|
|
||||||
valid_sampled_token_ids_gpu.index_fill_(
|
|
||||||
0, discard_sampled_tokens_req_indices, -1)
|
|
||||||
|
|
||||||
# Generate a mask for all valid tokens within those requests
|
|
||||||
valid_mask = (valid_sampled_token_ids_gpu != -1) & (
|
|
||||||
valid_sampled_token_ids_gpu < gpu_input_batch.vocab_size)
|
|
||||||
|
|
||||||
# Count the number of valid tokens in each request
|
|
||||||
valid_sampled_tokens_count = valid_mask.sum(dim=1)
|
|
||||||
|
|
||||||
# Get the rightmost valid index per row
|
|
||||||
last_valid_indices = valid_sampled_tokens_count - 1
|
|
||||||
last_valid_indices_safe = torch.clamp(last_valid_indices, min=0)
|
|
||||||
|
|
||||||
# Get last valid token from each row
|
|
||||||
# (assume undefined state where there is no valid token)
|
|
||||||
selected_tokens = torch.gather(
|
|
||||||
valid_sampled_token_ids_gpu, 1,
|
|
||||||
last_valid_indices_safe.unsqueeze(1)).squeeze(1)
|
|
||||||
|
|
||||||
# Use last token if valid, pre-computed backup if not
|
|
||||||
batch_size = valid_sampled_token_ids_gpu.shape[0]
|
|
||||||
next_token_ids = torch.where(
|
|
||||||
last_valid_indices != -1,
|
|
||||||
selected_tokens,
|
|
||||||
self.backup_next_token_ids.gpu[:batch_size],
|
|
||||||
)
|
|
||||||
|
|
||||||
return next_token_ids, valid_sampled_tokens_count
|
|
||||||
|
|
||||||
def prepare_inputs_padded(
|
|
||||||
self,
|
|
||||||
common_attn_metadata: CommonAttentionMetadata,
|
|
||||||
spec_decode_metadata: SpecDecodeMetadata,
|
|
||||||
valid_sampled_tokens_count: torch.Tensor,
|
|
||||||
) -> tuple[CommonAttentionMetadata, torch.Tensor, torch.Tensor]:
|
|
||||||
"""
|
|
||||||
This function is used to prepare the inputs for speculative decoding
|
|
||||||
It updates the common_attn_metadata for speculative decoding,
|
|
||||||
but does not consider the rejected tokens. Instead, all tokens
|
|
||||||
are included as inputs to the speculator, with the rejected tokens
|
|
||||||
used as padding and filtered out later by `token_indices_to_sample`.
|
|
||||||
No blocking CPU operations should be introduced in this function.
|
|
||||||
"""
|
|
||||||
num_draft_tokens_gpu = torch.cat([
|
|
||||||
spec_decode_metadata.cu_num_draft_tokens[0:1],
|
|
||||||
spec_decode_metadata.cu_num_draft_tokens[1:] -
|
|
||||||
spec_decode_metadata.cu_num_draft_tokens[:-1],
|
|
||||||
])
|
|
||||||
|
|
||||||
num_rejected_tokens_gpu = torch.where(
|
|
||||||
num_draft_tokens_gpu > 0,
|
|
||||||
num_draft_tokens_gpu + 1 - valid_sampled_tokens_count,
|
|
||||||
torch.zeros_like(num_draft_tokens_gpu),
|
|
||||||
)
|
|
||||||
|
|
||||||
query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
|
|
||||||
|
|
||||||
new_query_len_per_req = query_start_loc_cpu[
|
|
||||||
1:] - query_start_loc_cpu[:-1]
|
|
||||||
|
|
||||||
total_num_tokens = query_start_loc_cpu[-1].item()
|
|
||||||
token_indices = self.arange[:total_num_tokens]
|
|
||||||
|
|
||||||
# NOTE: Currently positions and seq_lens are not used in mla_v1 forward
|
|
||||||
# so we do not need to fixed them. But if they are used in the future,
|
|
||||||
# we should fixed them.
|
|
||||||
spec_common_attn_metadata = AscendCommonAttentionMetadata(
|
|
||||||
query_start_loc=common_attn_metadata.query_start_loc,
|
|
||||||
query_start_loc_cpu=query_start_loc_cpu,
|
|
||||||
seq_lens_cpu=common_attn_metadata.seq_lens_cpu,
|
|
||||||
num_reqs=common_attn_metadata.num_reqs,
|
|
||||||
num_actual_tokens=total_num_tokens,
|
|
||||||
num_input_tokens=common_attn_metadata.num_input_tokens,
|
|
||||||
max_query_len=new_query_len_per_req.max().item(),
|
|
||||||
actual_seq_lengths_q=self.runner.actual_seq_lengths_q,
|
|
||||||
block_table_tensor=common_attn_metadata.block_table_tensor,
|
|
||||||
slot_mapping=common_attn_metadata.slot_mapping,
|
|
||||||
positions=common_attn_metadata.positions,
|
|
||||||
attn_mask=self.runner.attn_mask,
|
|
||||||
spec_attn_mask=self.runner.spec_attn_mask,
|
|
||||||
attn_state=self.runner.attn_state,
|
|
||||||
decode_token_per_req=self.runner.decode_token_per_req,
|
|
||||||
num_computed_tokens_cpu=common_attn_metadata.
|
|
||||||
num_computed_tokens_cpu,
|
|
||||||
seq_lens=common_attn_metadata.seq_lens,
|
|
||||||
max_seq_len=0)
|
|
||||||
|
|
||||||
query_start_loc = common_attn_metadata.query_start_loc[
|
|
||||||
1:1 + num_rejected_tokens_gpu.shape[0]]
|
|
||||||
token_indices_to_sample = query_start_loc - 1 - num_rejected_tokens_gpu
|
|
||||||
|
|
||||||
return spec_common_attn_metadata, token_indices, token_indices_to_sample
|
|
||||||
|
|
||||||
def _split_pcp_input(self, req_scheduled_tokens, input_ids,
|
|
||||||
target_hidden_states):
|
|
||||||
"""
|
|
||||||
Split prefill input_ids and target_hidden_states in pcp group.
|
|
||||||
1. input_ids padding: [t0, t1, t2, t3, t4, t5] -> [t0, t1, t2, t3, t4, t5, pad, pad]
|
|
||||||
2. split input_ids: pcp0 [t0, t1, pad, pad], pcp1 [t2, t3, t4, t5]
|
|
||||||
3. split target_hidden_states (already include pcp padding):
|
|
||||||
[h0, h1, h2, h3, h4, h5, pad, pad] -> pcp0 [h0, h1, pad, pad], pcp1 [h2, h3, h4, h5]
|
|
||||||
4. also update max_query_len, seq_lens, cu_num_tokens according to pcp split.
|
|
||||||
"""
|
|
||||||
if len(req_scheduled_tokens) == 0:
|
|
||||||
# no prefill inputs to split, return empty result
|
|
||||||
return (
|
|
||||||
0,
|
|
||||||
torch.zeros([0], device='npu'),
|
|
||||||
torch.zeros([0, target_hidden_states.size(1)], device='npu'),
|
|
||||||
0,
|
|
||||||
torch.zeros([0]),
|
|
||||||
torch.tensor([0], dtype=torch.int32),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _pcp_pad_and_split(num_tokens):
|
|
||||||
num_pcp_padded_scheduled_tokens = cdiv(
|
|
||||||
num_tokens, 2 * self.pcp_size) * 2 * self.pcp_size
|
|
||||||
pcp_pad = num_pcp_padded_scheduled_tokens - num_tokens
|
|
||||||
chunk_size = num_pcp_padded_scheduled_tokens // (2 * self.pcp_size)
|
|
||||||
|
|
||||||
# split position_ids (and use split position_ids to split input_ids afterwards)
|
|
||||||
req_position_cp: list[int] = []
|
|
||||||
req_position_cp.extend(
|
|
||||||
self.full_indices[self.pcp_rank *
|
|
||||||
chunk_size:(self.pcp_rank + 1) * chunk_size])
|
|
||||||
req_position_cp.extend(
|
|
||||||
self.full_indices[num_pcp_padded_scheduled_tokens -
|
|
||||||
(self.pcp_rank + 1) *
|
|
||||||
chunk_size:num_pcp_padded_scheduled_tokens -
|
|
||||||
self.pcp_rank * chunk_size])
|
|
||||||
|
|
||||||
return req_position_cp, num_pcp_padded_scheduled_tokens, pcp_pad
|
|
||||||
|
|
||||||
num_pcp_scheduled_tokens = []
|
|
||||||
ori_start_index = 0
|
|
||||||
pad_start_index = 0
|
|
||||||
pcp_split_input_ids_list = []
|
|
||||||
pcp_split_hidden_states_list = []
|
|
||||||
for ori_num_tokens in req_scheduled_tokens.values():
|
|
||||||
req_position_pcp, num_pcp_padded_scheduled_tokens, num_pcp_pad = \
|
|
||||||
_pcp_pad_and_split(ori_num_tokens)
|
|
||||||
actual_num_tokens = len(req_position_pcp)
|
|
||||||
num_pcp_scheduled_tokens.append(actual_num_tokens)
|
|
||||||
pad_input_ids = F.pad(
|
|
||||||
input_ids[ori_start_index:ori_start_index + ori_num_tokens],
|
|
||||||
(0, num_pcp_pad))
|
|
||||||
ori_start_index += ori_num_tokens
|
|
||||||
pcp_chunk_indices = [
|
|
||||||
pad_start_index + pos for pos in req_position_pcp
|
|
||||||
]
|
|
||||||
pcp_split_input_ids = pad_input_ids[req_position_pcp]
|
|
||||||
pcp_split_hidden_states = target_hidden_states[pcp_chunk_indices]
|
|
||||||
pcp_split_input_ids_list.append(pcp_split_input_ids)
|
|
||||||
pcp_split_hidden_states_list.append(pcp_split_hidden_states)
|
|
||||||
pad_start_index += num_pcp_padded_scheduled_tokens
|
|
||||||
num_tokens = sum(num_pcp_scheduled_tokens)
|
|
||||||
input_ids = torch.cat(pcp_split_input_ids_list)
|
|
||||||
target_hidden_states = torch.cat(pcp_split_hidden_states_list, dim=0)
|
|
||||||
max_query_len = max(num_pcp_scheduled_tokens)
|
|
||||||
seq_lens = torch.tensor(num_pcp_scheduled_tokens, dtype=torch.int32)
|
|
||||||
cu_num_tokens = torch.tensor(
|
|
||||||
np.insert(np.cumsum(np.array(num_pcp_scheduled_tokens)), 0, 0))
|
|
||||||
return num_tokens, input_ids, target_hidden_states, max_query_len, seq_lens, cu_num_tokens
|
|
||||||
|
|||||||
Reference in New Issue
Block a user