[Refactor]refactor 310p ops and add ut (#6591)
### What this PR does / why we need it?
This pull request focuses on a significant refactoring effort within the
vllm-ascend project, specifically targeting operations optimized for the
Ascend 310P hardware. The changes aim to streamline the implementation
of core components like quantization and multi-head attention, making
the codebase more maintainable and robust. Concurrently, new unit tests
have been introduced to ensure the correctness and reliability of these
refactored modules.
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
E2E test with qwen3-32b w8a8
- vLLM version: v0.15.0
- vLLM main:
d7e17aaacd
---------
Signed-off-by: pu-zhe <zpuaa@outlook.com>
This commit is contained in:
@@ -15,7 +15,6 @@
|
||||
# This file is a part of the vllm-ascend project.
|
||||
#
|
||||
|
||||
import einops
|
||||
import torch
|
||||
import torch_npu
|
||||
|
||||
@@ -37,31 +36,26 @@ class AscendMMEncoderAttention310(AscendMMEncoderAttention):
|
||||
):
|
||||
bsz, q_len = query.size()[:2]
|
||||
kv_len = key.size(1)
|
||||
|
||||
q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len)
|
||||
query = query.view(bsz * q_len, self.num_heads, self.head_size)
|
||||
key = key.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
||||
value = value.view(bsz * kv_len, self.num_kv_heads, self.head_size)
|
||||
|
||||
if cu_seqlens is None:
|
||||
cu_seqlens = torch.arange(
|
||||
0,
|
||||
(bsz + 1) * q_len,
|
||||
step=q_len,
|
||||
dtype=torch.int32,
|
||||
device=query.device,
|
||||
)
|
||||
seq_len = torch.tensor([q_len] * bsz, device="cpu", dtype=torch.int32)
|
||||
else:
|
||||
seq_len = torch.diff(cu_seqlens.to("cpu", dtype=torch.int32))
|
||||
|
||||
seq_len = torch.diff(cu_seqlens).to("cpu", dtype=torch.int32)
|
||||
|
||||
context_layer = torch.empty_like(q)
|
||||
output = torch.empty_like(query)
|
||||
torch_npu._npu_flash_attention_unpad(
|
||||
query=q,
|
||||
key=k,
|
||||
value=v,
|
||||
query=query,
|
||||
key=key,
|
||||
value=value,
|
||||
seq_len=seq_len,
|
||||
scale_value=self.head_size**-0.5,
|
||||
num_heads=self.num_heads,
|
||||
num_kv_heads=self.num_kv_heads,
|
||||
out=context_layer,
|
||||
out=output,
|
||||
)
|
||||
|
||||
context_layer = einops.rearrange(context_layer, "(b s) h d -> b s h d", b=bsz).contiguous()
|
||||
return context_layer
|
||||
output = output.view(bsz, -1, self.num_heads, self.head_size)
|
||||
return output
|
||||
|
||||
Reference in New Issue
Block a user