[1/N][UT][v1 MTP] add basic v1 mtp features (#890)
### What this PR does / why we need it? add basic v1 mtp features please merge it after https://github.com/vllm-project/vllm-ascend/pull/874 and https://github.com/vllm-project/vllm-ascend/pull/844. ### Does this PR introduce _any_ user-facing change? now, we supported basic v1 mtp, only supported tp only、eager mode and k=1 we will continue to expand more scenarios. ### How was this patch tested? local tested Signed-off-by: XWFAlone <xuewenfei2@huawei.com> Co-authored-by: mengwei805 <mengwei25@huawei.com> Co-authored-by: JC-ut0 <xuyexiong@huawei.com>
This commit is contained in:
@@ -222,6 +222,17 @@ def vanilla_chunked_prefill_mla(
|
||||
device="npu",
|
||||
dtype=value.dtype,
|
||||
)
|
||||
num_query = torch.sum(q_mask).item()
|
||||
num_add_query = num_query - query.size(0)
|
||||
# mtp will come in
|
||||
if num_add_query > 0:
|
||||
add_query_size = query.size()
|
||||
add_query_size = list(add_query_size)
|
||||
add_query_size[0] = num_add_query
|
||||
pad_tensor = torch.zeros(add_query_size,
|
||||
dtype=query.dtype,
|
||||
device=query.device)
|
||||
query = torch.cat([query, pad_tensor], dim=0)
|
||||
pad_q[q_mask] = query
|
||||
pad_k[kv_c_mask] = key[kv_c_mask]
|
||||
pad_v[kv_c_mask] = value[kv_c_mask]
|
||||
@@ -247,8 +258,8 @@ def vanilla_chunked_prefill_mla(
|
||||
|
||||
attn_output = (attn_output[q_mask].view([-1, num_heads,
|
||||
v_head_dim]).to(output.dtype))
|
||||
output = output.view_as(attn_output)
|
||||
output.copy_(attn_output)
|
||||
output = output.view([-1, num_heads, v_head_dim])
|
||||
output.copy_(attn_output[:query.size(0) - num_add_query])
|
||||
return attn_output
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user