[1/N][UT][v1 MTP] add basic v1 mtp features (#890)

### What this PR does / why we need it?
add basic v1 mtp features
please merge it after
https://github.com/vllm-project/vllm-ascend/pull/874 and
https://github.com/vllm-project/vllm-ascend/pull/844.

### Does this PR introduce _any_ user-facing change?
now, we supported basic v1 mtp, only supported tp only、eager mode and
k=1
we will continue to expand more scenarios.

### How was this patch tested?
local tested

Signed-off-by: XWFAlone <xuewenfei2@huawei.com>
Co-authored-by: mengwei805 <mengwei25@huawei.com>
Co-authored-by: JC-ut0 <xuyexiong@huawei.com>
This commit is contained in:
XWFAlone
2025-05-30 08:59:58 +08:00
committed by GitHub
parent 5903547d09
commit 3442fbdb23
7 changed files with 477 additions and 13 deletions

View File

@@ -222,6 +222,17 @@ def vanilla_chunked_prefill_mla(
device="npu",
dtype=value.dtype,
)
num_query = torch.sum(q_mask).item()
num_add_query = num_query - query.size(0)
# mtp will come in
if num_add_query > 0:
add_query_size = query.size()
add_query_size = list(add_query_size)
add_query_size[0] = num_add_query
pad_tensor = torch.zeros(add_query_size,
dtype=query.dtype,
device=query.device)
query = torch.cat([query, pad_tensor], dim=0)
pad_q[q_mask] = query
pad_k[kv_c_mask] = key[kv_c_mask]
pad_v[kv_c_mask] = value[kv_c_mask]
@@ -247,8 +258,8 @@ def vanilla_chunked_prefill_mla(
attn_output = (attn_output[q_mask].view([-1, num_heads,
v_head_dim]).to(output.dtype))
output = output.view_as(attn_output)
output.copy_(attn_output)
output = output.view([-1, num_heads, v_head_dim])
output.copy_(attn_output[:query.size(0) - num_add_query])
return attn_output