### What this PR does / why we need it? Optimize DeepSeekOCR2 RelPosAttention and CustomQwen2Decoder and add doc for DeepSeekOCR2.md ### Does this PR introduce _any_ user-facing change? ### How was this patch tested? - vllm 0.18.0 - vllm-ascend main 1. _create_custom_4d_mask during 141ms49us620ns --> _create_npu_optimized_mask during 1ms227us780ns 2. convd2d : 27ms --> matmul <1ms 3. relposattention:sdpa->prompt_flash_attention --------- Signed-off-by: Wangbei25 <wangbei41@huawie.com> Signed-off-by: Wangbei25 <wangbei41@huawei.com> Co-authored-by: Wangbei25 <wangbei41@huawie.com>
69 lines
2.6 KiB
Python
69 lines
2.6 KiB
Python
import torch
|
|
import torch_npu
|
|
from vllm.model_executor.models.deepencoder import RelPosAttention, add_decomposed_rel_pos
|
|
|
|
|
|
class AscendRelPosAttention(RelPosAttention):
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
num_heads: int = 8,
|
|
qkv_bias: bool = True,
|
|
use_rel_pos: bool = False,
|
|
rel_pos_zero_init: bool = True,
|
|
input_size: tuple[int, int] | None = None,
|
|
) -> None:
|
|
"""
|
|
Args:
|
|
dim (int): Number of input channels.
|
|
num_heads (int): Number of attention heads.
|
|
qkv_bias (bool): If True, add a learnable bias to query, key, value.
|
|
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
|
|
input_size (tuple(int, int) or None): Input resolution for calculating the relative
|
|
positional parameter size.
|
|
"""
|
|
super().__init__(dim, num_heads, qkv_bias, use_rel_pos, rel_pos_zero_init, input_size)
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
B, H, W, _ = x.shape
|
|
# qkv with shape (3, B, nHead, H * W, C)
|
|
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
|
# q, k, v with shape (B * nHead, H * W, C)
|
|
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
|
|
|
|
rel_h, rel_w = None, None
|
|
if self.use_rel_pos:
|
|
rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
|
|
|
|
q = q.view(B, self.num_heads, H * W, -1)
|
|
k = k.view(B, self.num_heads, H * W, -1)
|
|
v = v.view(B, self.num_heads, H * W, -1)
|
|
|
|
if self.use_rel_pos:
|
|
assert rel_h is not None and rel_w is not None
|
|
rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
|
|
rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
|
|
attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
|
|
x = torch_npu.npu_prompt_flash_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
pse_shift=attn_bias,
|
|
input_layout="BNSD",
|
|
scale_value=self.scale,
|
|
num_heads=self.num_heads,
|
|
)
|
|
else:
|
|
x = torch_npu.npu_prompt_flash_attention(
|
|
q,
|
|
k,
|
|
v,
|
|
input_layout="BNSD",
|
|
scale_value=self.scale,
|
|
num_heads=self.num_heads,
|
|
)
|
|
|
|
x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
|
|
x = self.proj(x)
|
|
return x
|