diff --git a/vllm_ascend/ops/mm_encoder_attention.py b/vllm_ascend/ops/mm_encoder_attention.py index 9d497059..1ffd8679 100644 --- a/vllm_ascend/ops/mm_encoder_attention.py +++ b/vllm_ascend/ops/mm_encoder_attention.py @@ -118,19 +118,18 @@ class AscendMMEncoderAttention(MMEncoderAttention): k = F.pad(k, (0, pad_len), mode="constant", value=0) v = F.pad(v, (0, pad_len), mode="constant", value=0) - context_layer = torch.empty_like(q) + seq_lens_cpu = list(seq_lens_cpu.cumsum(0)) - # operator requires pta version >= 2.5.1 - torch_npu._npu_flash_attention_unpad( + context_layer = torch_npu.npu_fusion_attention( query=q, key=k, value=v, - seq_len=seq_lens_cpu, - scale_value=self.scale_value, - num_heads=self.num_heads, - num_kv_heads=self.num_kv_heads, - out=context_layer, - ) + actual_seq_qlen=seq_lens_cpu, + actual_seq_kvlen=seq_lens_cpu, + head_num=self.num_heads, + scale=self.scale_value, + input_layout="TND", + )[0] if self.enable_pad: context_layer = context_layer[..., :origin_shape] diff --git a/vllm_ascend/ops/rotary_embedding.py b/vllm_ascend/ops/rotary_embedding.py index 9e81fa3c..3df4cad3 100644 --- a/vllm_ascend/ops/rotary_embedding.py +++ b/vllm_ascend/ops/rotary_embedding.py @@ -32,7 +32,7 @@ from vllm.triton_utils import HAS_TRITON from vllm_ascend.ascend_forward_context import _EXTRA_CTX from vllm_ascend.platform import NPUPlatform -from vllm_ascend.utils import AscendDeviceType, get_ascend_device_type, has_rope, is_vl_model +from vllm_ascend.utils import has_rope, is_vl_model if HAS_TRITON: from vllm.model_executor.layers.rotary_embedding.mrope import triton_mrope @@ -519,7 +519,7 @@ class AscendMRotaryEmbedding(MRotaryEmbedding): # todo: need cann update in 8.5.0 return self.forward_triton(positions, query, key) - if self.mrope_section != [16, 24, 24] or get_ascend_device_type() == AscendDeviceType.A5: + if self.mrope_section != [16, 24, 24]: return super().forward_oot(positions, query, key) import torch_npu