diff --git a/vllm_ascend/models/qwen2_vl.py b/vllm_ascend/models/qwen2_vl.py index d4b108b..7879220 100644 --- a/vllm_ascend/models/qwen2_vl.py +++ b/vllm_ascend/models/qwen2_vl.py @@ -40,6 +40,23 @@ from vllm.multimodal import MULTIMODAL_REGISTRY class CustomQwen2VisionAttention(Qwen2VisionAttention): + def __init__( + self, + embed_dim: int, + num_heads: int, + projection_size: int, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = "", + ) -> None: + super().__init__( + embed_dim, + num_heads, + projection_size, + quant_config, + prefix, + ) + self.cu_seqlens = None + def forward( self, x: torch.Tensor, @@ -47,6 +64,8 @@ class CustomQwen2VisionAttention(Qwen2VisionAttention): rotary_pos_emb: torch.Tensor, ) -> torch.Tensor: + self.cu_seqlens = cu_seqlens + # [s, b, c] --> [s, b, 3 * head * head_dim] x, _ = self.qkv(x) @@ -72,7 +91,7 @@ class CustomQwen2VisionAttention(Qwen2VisionAttention): query=q, key=k, value=v, - seq_len=cu_seqlens, + seq_len=self.cu_seqlens, scale_value=self.hidden_size_per_attention_head**-0.5, num_heads=self.num_attention_heads_per_partition, num_kv_heads=self.num_attention_heads_per_partition,