[1/N][UT][v1 MTP] add basic v1 mtp features (#890)

### What this PR does / why we need it?
add basic v1 mtp features
please merge it after
https://github.com/vllm-project/vllm-ascend/pull/874 and
https://github.com/vllm-project/vllm-ascend/pull/844.

### Does this PR introduce _any_ user-facing change?
now, we supported basic v1 mtp, only supported tp only、eager mode and
k=1
we will continue to expand more scenarios.

### How was this patch tested?
local tested

Signed-off-by: XWFAlone <xuewenfei2@huawei.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Co-authored-by: JC-ut0 <xuyexiong@huawei.com>
This commit is contained in:
XWFAlone
2025-05-30 08:59:58 +08:00
committed by GitHub
parent 5903547d09
commit 3442fbdb23
7 changed files with 477 additions and 13 deletions

View File

@@ -16,13 +16,26 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm_ascend.attention.attention_v1 import AscendAttentionState
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
if TYPE_CHECKING:
from vllm.v1.core.sched.output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
@dataclass
class CommonAttentionMetadata:
"""
Attention metadata attributes that can be shared by layers in different KV
cache groups and thus having different block table.
"""
query_start_loc: torch.Tensor
"""(batch_size + 1,), the start location of each request in query Tensor"""
seq_lens: torch.Tensor
"""(batch_size,), the length of each request including both computed tokens
and newly scheduled tokens"""
class AscendMLABackend(AttentionBackend):
accept_output_buffer: bool = True
@@ -57,6 +70,7 @@ class AscendMLAPrefillMetadata:
seq_lens: list[int]
context_lens: torch.Tensor
input_positions: torch.Tensor
query_start_loc: torch.Tensor
block_table: torch.Tensor
max_query_len: int
max_seq_lens: int
@@ -90,6 +104,9 @@ class AscendMLAMetadata:
num_actual_tokens: int # Number of tokens excluding padding.
slot_mapping: torch.Tensor
query_start_loc: torch.Tensor
seq_lens: torch.Tensor
block_tables: torch.Tensor
# New for MLA (compared to FlashAttention)
# For handling prefill decode split
@@ -130,7 +147,7 @@ class AscendMLAMetadataBuilder:
# _attn_mask_builder = None
def __init__(self,
runner: "NPUModelRunner",
runner,
metadata_cls: Optional[AscendMLAMetadata] = None):
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
if metadata_cls is not None else AscendMLAMetadata # type: ignore
@@ -230,6 +247,7 @@ class AscendMLAMetadataBuilder:
num_reqs: int,
num_actual_tokens: int,
max_query_len: int,
common_attn_metadata: CommonAttentionMetadata,
common_prefix_len: Optional[int] = None,
graph_pad_size: int = -1) -> AscendMLAMetadata:
assert self._num_decodes + self._num_prefills == num_reqs
@@ -252,6 +270,7 @@ class AscendMLAMetadataBuilder:
seq_lens = seq_lens_cpu
max_query_len = query_lens.max().item()
max_seq_lens = seq_lens.max().item()
query_start_loc = None
prefill_metadata = None
if self._num_prefills > 0:
@@ -259,6 +278,9 @@ class AscendMLAMetadataBuilder:
tokens_start = self._num_decode_tokens
max_query_len = query_lens[tokens_start:].max().item()
max_seq_lens = seq_lens[tokens_start:].max().item()
query_start_loc = common_attn_metadata.query_start_loc
prefill_query_start_loc = query_start_loc[
reqs_start:] - query_start_loc[reqs_start]
prefill_metadata = AscendMLAPrefillMetadata(
attn_mask=self.runner.attn_mask,
@@ -269,6 +291,7 @@ class AscendMLAMetadataBuilder:
block_table=block_table[reqs_start:, ...],
max_query_len=max_query_len,
max_seq_lens=max_seq_lens,
query_start_loc=prefill_query_start_loc,
)
decode_metadata = None
@@ -325,6 +348,9 @@ class AscendMLAMetadataBuilder:
attn_state=self.runner.attn_state,
prefill=prefill_metadata,
decode=decode_metadata,
query_start_loc=query_start_loc,
block_tables=block_table,
seq_lens=seq_lens,
)