Files
xc-llm-ascend/vllm_ascend/ops/rel_pos_attention.py
Wangbei25 4f259d4fd8 [Performance]Optimize DeepSeekOCR2 RelPosAttention and CustomQwen2Decoder (#7737)
### What this PR does / why we need it?
Optimize DeepSeekOCR2 RelPosAttention and CustomQwen2Decoder and add doc
for DeepSeekOCR2.md

### Does this PR introduce _any_ user-facing change?

### How was this patch tested?
- vllm 0.18.0
- vllm-ascend main

1. _create_custom_4d_mask during 141ms49us620ns -->
_create_npu_optimized_mask during 1ms227us780ns
2. convd2d : 27ms --> matmul <1ms
3. relposattention:sdpa->prompt_flash_attention

---------

Signed-off-by: Wangbei25 <wangbei41@huawie.com>
Signed-off-by: Wangbei25 <wangbei41@huawei.com>
Co-authored-by: Wangbei25 <wangbei41@huawie.com>
2026-03-31 14:49:29 +08:00

69 lines
2.6 KiB
Python

import torch
import torch_npu
from vllm.model_executor.models.deepencoder import RelPosAttention, add_decomposed_rel_pos
class AscendRelPosAttention(RelPosAttention):
def __init__(
self,
dim: int,
num_heads: int = 8,
qkv_bias: bool = True,
use_rel_pos: bool = False,
rel_pos_zero_init: bool = True,
input_size: tuple[int, int] | None = None,
) -> None:
"""
Args:
dim (int): Number of input channels.
num_heads (int): Number of attention heads.
qkv_bias (bool): If True, add a learnable bias to query, key, value.
rel_pos_zero_init (bool): If True, zero initialize relative positional parameters.
input_size (tuple(int, int) or None): Input resolution for calculating the relative
positional parameter size.
"""
super().__init__(dim, num_heads, qkv_bias, use_rel_pos, rel_pos_zero_init, input_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, H, W, _ = x.shape
# qkv with shape (3, B, nHead, H * W, C)
qkv = self.qkv(x).reshape(B, H * W, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
# q, k, v with shape (B * nHead, H * W, C)
q, k, v = qkv.reshape(3, B * self.num_heads, H * W, -1).unbind(0)
rel_h, rel_w = None, None
if self.use_rel_pos:
rel_h, rel_w = add_decomposed_rel_pos(q, self.rel_pos_h, self.rel_pos_w, (H, W), (H, W))
q = q.view(B, self.num_heads, H * W, -1)
k = k.view(B, self.num_heads, H * W, -1)
v = v.view(B, self.num_heads, H * W, -1)
if self.use_rel_pos:
assert rel_h is not None and rel_w is not None
rel_h = rel_h.view(B, self.num_heads, rel_h.size(1), rel_h.size(2), rel_h.size(3))
rel_w = rel_w.view(B, self.num_heads, rel_w.size(1), rel_w.size(2), rel_w.size(3))
attn_bias = (rel_h + rel_w).view(B, self.num_heads, rel_h.size(2), rel_h.size(3) * rel_w.size(4))
x = torch_npu.npu_prompt_flash_attention(
q,
k,
v,
pse_shift=attn_bias,
input_layout="BNSD",
scale_value=self.scale,
num_heads=self.num_heads,
)
else:
x = torch_npu.npu_prompt_flash_attention(
q,
k,
v,
input_layout="BNSD",
scale_value=self.scale,
num_heads=self.num_heads,
)
x = x.view(B, self.num_heads, H, W, -1).permute(0, 2, 3, 1, 4).reshape(B, H, W, -1)
x = self.proj(x)
return x