From 58adf7c8ac239d10fa1b9851c3d38af8930a17f3 Mon Sep 17 00:00:00 2001 From: Li Wang Date: Sat, 27 Dec 2025 18:42:46 +0800 Subject: [PATCH] [Bugfix] Correctly handle the output shape in multimodal attention (#5443) ### What this PR does / why we need it? Fix https://github.com/vllm-project/vllm-ascend/issues/5297, for `AscendMMEncoderAttention` forward, we should keep the output shape consistence with the input - vLLM version: release/v0.13.0 - vLLM main: https://github.com/vllm-project/vllm/commit/81786c87748b0177111dfdc07af5351d8389baa1 --------- Signed-off-by: wangli --- tests/e2e/conftest.py | 5 +++++ vllm_ascend/ops/mm_encoder_attention.py | 12 +++++++++--- 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 1d993c1c..9f861a48 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -781,6 +781,11 @@ PROMPT_CONFIGS = { "fps": 1, }, }, + "hunyuan-vl": { + "model": "Tencent-Hunyuan/HunyuanOCR", + "prompt_fn": hunyuan_prompt, + "mm_processor_kwargs": {}, + }, } diff --git a/vllm_ascend/ops/mm_encoder_attention.py b/vllm_ascend/ops/mm_encoder_attention.py index 6f21a5ce..38f97b29 100644 --- a/vllm_ascend/ops/mm_encoder_attention.py +++ b/vllm_ascend/ops/mm_encoder_attention.py @@ -93,6 +93,7 @@ class AscendMMEncoderAttention(MMEncoderAttention): ): bsz, q_len = query.size()[:2] kv_len = key.size(1) + is_reshaped = query.dim() == 4 # q, k, v: [b, s, head, head_dim] -> [b * s, head, head_dim] q, k, v = self.reshape_qkv_to_3d(query, key, value, bsz, q_len, kv_len) @@ -134,7 +135,12 @@ class AscendMMEncoderAttention(MMEncoderAttention): if enable_pad: context_layer = context_layer[..., :origin_shape] - context_layer = einops.rearrange(context_layer, - "(b s) h d -> b s h d", - b=bsz).contiguous() + if is_reshaped: + context_layer = einops.rearrange(context_layer, + "(b s) h d -> b s h d", + b=bsz).contiguous() + else: + context_layer = einops.rearrange(context_layer, + "(b s) h d -> b s (h d)", + b=bsz).contiguous() return context_layer