diff --git a/vllm_ascend/ops/mm_encoder_attention.py b/vllm_ascend/ops/mm_encoder_attention.py index 081bc45f..19f44066 100644 --- a/vllm_ascend/ops/mm_encoder_attention.py +++ b/vllm_ascend/ops/mm_encoder_attention.py @@ -106,10 +106,12 @@ class AscendMMEncoderAttention(MMEncoderAttention): if enable_pad: origin_shape = q.shape[-1] pad_len = MAX_PAD_SIZE - origin_shape - # q, k, v: [b * s, head, head_dim] -> [b * s, head, MAX_PAD_SIZE] - q = F.pad(q, (0, pad_len), mode="constant", value=0) - k = F.pad(k, (0, pad_len), mode="constant", value=0) - v = F.pad(v, (0, pad_len), mode="constant", value=0) + # Merge qkv to reduce the overhead of launching npu pad operation. + # [3, b*s, head, head_dim] + qkv = torch.stack([q, k, v], dim=0) + # qkv: [3, b * s, head, head_dim] -> [3, b * s, head, MAX_PAD_SIZE] + qkv = F.pad(qkv, (0, pad_len), mode="constant", value=0) + q, k, v = qkv.unbind(dim=0) context_layer = torch.empty_like(q)