diff --git a/vllm_ascend/models/qwen2_vl.py b/vllm_ascend/models/qwen2_vl.py index 7879220..a3b4dc1 100644 --- a/vllm_ascend/models/qwen2_vl.py +++ b/vllm_ascend/models/qwen2_vl.py @@ -30,10 +30,10 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.qwen2_vl import ( - Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionTransformer, - Qwen2VLDummyInputsBuilder, Qwen2VLForConditionalGeneration, - Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, - apply_rotary_pos_emb_vision) + Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionPatchEmbed, + Qwen2VisionTransformer, Qwen2VLDummyInputsBuilder, + Qwen2VLForConditionalGeneration, Qwen2VLMultiModalProcessor, + Qwen2VLProcessingInfo, apply_rotary_pos_emb_vision) from vllm.model_executor.models.utils import maybe_prefix from vllm.multimodal import MULTIMODAL_REGISTRY @@ -125,6 +125,14 @@ class CustomQwen2VisionBlock(Qwen2VisionBlock): prefix=f"{prefix}.attn") +class CustomQwen2VisionPatchEmbed(Qwen2VisionPatchEmbed): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.matmul( + self.proj.weight.data.view(self.embed_dim, -1).transpose(0, 1)) + return x + + class CustomQwen2VisionTransformer(Qwen2VisionTransformer): def __init__( @@ -135,6 +143,14 @@ class CustomQwen2VisionTransformer(Qwen2VisionTransformer): prefix: str = "", ) -> None: super().__init__(vision_config, norm_eps, quant_config, prefix) + + self.patch_embed = CustomQwen2VisionPatchEmbed( + patch_size=vision_config.patch_size, + temporal_patch_size=vision_config.temporal_patch_size, + in_channels=vision_config.in_channels, + embed_dim=vision_config.embed_dim, + ) + self.blocks = nn.ModuleList([ CustomQwen2VisionBlock(dim=self.embed_dim, num_heads=self.num_heads,