From 12aa7115b58e6def5603e4eae6744f0af8e05634 Mon Sep 17 00:00:00 2001 From: zouyida2002 Date: Wed, 12 Mar 2025 08:39:50 +0800 Subject: [PATCH] bugfix for qwen2_vl (#301) ### What this PR does / why we need it? this pr fixes the error while inferring Qwen2_VL. ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? We've tested it on benchmark, it meets our satisfaction and is equal to gpu. --------- Signed-off-by: zouyida --- vllm_ascend/models/qwen2_vl.py | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/vllm_ascend/models/qwen2_vl.py b/vllm_ascend/models/qwen2_vl.py index d4b108b..7879220 100644 --- a/vllm_ascend/models/qwen2_vl.py +++ b/vllm_ascend/models/qwen2_vl.py @@ -40,6 +40,23 @@ from vllm.multimodal import MULTIMODAL_REGISTRY class CustomQwen2VisionAttention(Qwen2VisionAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + embed_dim, + num_heads, + projection_size, + quant_config, + prefix, + ) + self.cu_seqlens = None + def forward( self, x: torch.Tensor, @@ -47,6 +64,8 @@ class CustomQwen2VisionAttention(Qwen2VisionAttention): rotary_pos_emb: torch.Tensor, ) -> torch.Tensor: + self.cu_seqlens = cu_seqlens + # [s, b, c] --> [s, b, 3 * head * head_dim] x, _ = self.qkv(x) @@ -72,7 +91,7 @@ class CustomQwen2VisionAttention(Qwen2VisionAttention): query=q, key=k, value=v, - seq_len=cu_seqlens, + seq_len=self.cu_seqlens, scale_value=self.hidden_size_per_attention_head**-0.5, num_heads=self.num_attention_heads_per_partition, num_kv_heads=self.num_attention_heads_per_partition,