【A5】【Qwen VL】Qwen VL adapt for A5 (#7046)

### What this PR does / why we need it?
Replace the '_npu_flash_attention_unpad' operator with the
'npu_fusion_attention' operator to ensure that the Qwen VL model can run
in the A5 environment and remove the 'mrope' operator call restriction
for A5.
### Does this PR introduce _any_ user-facing change?

### How was this patch tested?

- vLLM version: v0.16.0
- vLLM main:
4034c3d32e

Signed-off-by: 汪越 <wangyue361@h-partners.com>
This commit is contained in:
yesyue-w
2026-03-20 16:56:12 +08:00
committed by GitHub
parent f39f566e22
commit c860535246
2 changed files with 10 additions and 11 deletions

View File

@@ -118,19 +118,18 @@ class AscendMMEncoderAttention(MMEncoderAttention):
k = F.pad(k, (0, pad_len), mode="constant", value=0)
v = F.pad(v, (0, pad_len), mode="constant", value=0)
context_layer = torch.empty_like(q)
seq_lens_cpu = list(seq_lens_cpu.cumsum(0))
# operator requires pta version >= 2.5.1
torch_npu._npu_flash_attention_unpad(
context_layer = torch_npu.npu_fusion_attention(
query=q,
key=k,
value=v,
seq_len=seq_lens_cpu,
scale_value=self.scale_value,
num_heads=self.num_heads,
num_kv_heads=self.num_kv_heads,
out=context_layer,
)
actual_seq_qlen=seq_lens_cpu,
actual_seq_kvlen=seq_lens_cpu,
head_num=self.num_heads,
scale=self.scale_value,
input_layout="TND",
)[0]
if self.enable_pad:
context_layer = context_layer[..., :origin_shape]

View File

@@ -32,7 +32,7 @@ from vllm.triton_utils import HAS_TRITON
from vllm_ascend.ascend_forward_context import _EXTRA_CTX
from vllm_ascend.platform import NPUPlatform
from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model
from vllm_ascend.utils import has_rope, is_vl_model
if HAS_TRITON:
from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope
@@ -519,7 +519,7 @@ class AscendMRotaryEmbedding(MRotaryEmbedding):
# todo: need cann update in 8.5.0
return self.forward_triton(positions, query, key)
if self.mrope_section != [16, 24, 24] or get_ascend_device_type() == AscendDeviceType.A5:
if self.mrope_section != [16, 24, 24]:
return super().forward_oot(positions, query, key)
import torch_npu