[1/N][UT][v1 MTP] add basic v1 mtp features (#890)
### What this PR does / why we need it? add basic v1 mtp features please merge it after https://github.com/vllm-project/vllm-ascend/pull/874 and https://github.com/vllm-project/vllm-ascend/pull/844. ### Does this PR introduce _any_ user-facing change? now, we supported basic v1 mtp, only supported tp only、eager mode and k=1 we will continue to expand more scenarios. ### How was this patch tested? local tested Signed-off-by: XWFAlone <xuewenfei2@huawei.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: JC-ut0 <xuyexiong@huawei.com>
This commit is contained in:
@@ -16,13 +16,26 @@ from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
||||
|
||||
from vllm_ascend.attention.attention_v1 import AscendAttentionState
|
||||
from vllm_ascend.ops.attention import vanilla_chunked_prefill_mla
|
||||
from vllm_ascend.worker.model_runner_v1 import NPUModelRunner
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from vllm.v1.core.sched.output import SchedulerOutput
|
||||
from vllm.v1.worker.gpu_input_batch import InputBatch
|
||||
|
||||
|
||||
@dataclass
|
||||
class CommonAttentionMetadata:
|
||||
"""
|
||||
Attention metadata attributes that can be shared by layers in different KV
|
||||
cache groups and thus having different block table.
|
||||
"""
|
||||
|
||||
query_start_loc: torch.Tensor
|
||||
"""(batch_size + 1,), the start location of each request in query Tensor"""
|
||||
seq_lens: torch.Tensor
|
||||
"""(batch_size,), the length of each request including both computed tokens
|
||||
and newly scheduled tokens"""
|
||||
|
||||
|
||||
class AscendMLABackend(AttentionBackend):
|
||||
|
||||
accept_output_buffer: bool = True
|
||||
@@ -57,6 +70,7 @@ class AscendMLAPrefillMetadata:
|
||||
seq_lens: list[int]
|
||||
context_lens: torch.Tensor
|
||||
input_positions: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
block_table: torch.Tensor
|
||||
max_query_len: int
|
||||
max_seq_lens: int
|
||||
@@ -90,6 +104,9 @@ class AscendMLAMetadata:
|
||||
|
||||
num_actual_tokens: int # Number of tokens excluding padding.
|
||||
slot_mapping: torch.Tensor
|
||||
query_start_loc: torch.Tensor
|
||||
seq_lens: torch.Tensor
|
||||
block_tables: torch.Tensor
|
||||
|
||||
# New for MLA (compared to FlashAttention)
|
||||
# For handling prefill decode split
|
||||
@@ -130,7 +147,7 @@ class AscendMLAMetadataBuilder:
|
||||
|
||||
# _attn_mask_builder = None
|
||||
def __init__(self,
|
||||
runner: "NPUModelRunner",
|
||||
runner,
|
||||
metadata_cls: Optional[AscendMLAMetadata] = None):
|
||||
self.metadata_cls: Optional[AscendMLAMetadata] = metadata_cls \
|
||||
if metadata_cls is not None else AscendMLAMetadata # type: ignore
|
||||
@@ -230,6 +247,7 @@ class AscendMLAMetadataBuilder:
|
||||
num_reqs: int,
|
||||
num_actual_tokens: int,
|
||||
max_query_len: int,
|
||||
common_attn_metadata: CommonAttentionMetadata,
|
||||
common_prefix_len: Optional[int] = None,
|
||||
graph_pad_size: int = -1) -> AscendMLAMetadata:
|
||||
assert self._num_decodes + self._num_prefills == num_reqs
|
||||
@@ -252,6 +270,7 @@ class AscendMLAMetadataBuilder:
|
||||
seq_lens = seq_lens_cpu
|
||||
max_query_len = query_lens.max().item()
|
||||
max_seq_lens = seq_lens.max().item()
|
||||
query_start_loc = None
|
||||
|
||||
prefill_metadata = None
|
||||
if self._num_prefills > 0:
|
||||
@@ -259,6 +278,9 @@ class AscendMLAMetadataBuilder:
|
||||
tokens_start = self._num_decode_tokens
|
||||
max_query_len = query_lens[tokens_start:].max().item()
|
||||
max_seq_lens = seq_lens[tokens_start:].max().item()
|
||||
query_start_loc = common_attn_metadata.query_start_loc
|
||||
prefill_query_start_loc = query_start_loc[
|
||||
reqs_start:] - query_start_loc[reqs_start]
|
||||
|
||||
prefill_metadata = AscendMLAPrefillMetadata(
|
||||
attn_mask=self.runner.attn_mask,
|
||||
@@ -269,6 +291,7 @@ class AscendMLAMetadataBuilder:
|
||||
block_table=block_table[reqs_start:, ...],
|
||||
max_query_len=max_query_len,
|
||||
max_seq_lens=max_seq_lens,
|
||||
query_start_loc=prefill_query_start_loc,
|
||||
)
|
||||
|
||||
decode_metadata = None
|
||||
@@ -325,6 +348,9 @@ class AscendMLAMetadataBuilder:
|
||||
attn_state=self.runner.attn_state,
|
||||
prefill=prefill_metadata,
|
||||
decode=decode_metadata,
|
||||
query_start_loc=query_start_loc,
|
||||
block_tables=block_table,
|
||||
seq_lens=seq_lens,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user