From 122505208ff6284f409846ca7294f4a4b9883285 Mon Sep 17 00:00:00 2001 From: BAI Fan <36036759+baifanxxx@users.noreply.github.com> Date: Wed, 26 Mar 2025 14:28:20 +0800 Subject: [PATCH] FastPatch: Optimized Patch Embedding for Qwen2VL (#345) ### What this PR does / why we need it? We proposed the FastPatch method, which optimized patch embedding (Conv3D) for Qwen2VL. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? We've tested it on benchmark, it meets our satisfaction and is better than original patch_embed layer. --------- Signed-off-by: baifanxxx Signed-off-by: zouyida Co-authored-by: zouyida --- vllm_ascend/models/qwen2_vl.py | 24 ++++++++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/vllm_ascend/models/qwen2_vl.py b/vllm_ascend/models/qwen2_vl.py index 7879220..a3b4dc1 100644 --- a/vllm_ascend/models/qwen2_vl.py +++ b/vllm_ascend/models/qwen2_vl.py @@ -30,10 +30,10 @@ from vllm.config import VllmConfig from vllm.model_executor.layers.activation import QuickGELU from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.models.qwen2_vl import ( - Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionTransformer, - Qwen2VLDummyInputsBuilder, Qwen2VLForConditionalGeneration, - Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo, - apply_rotary_pos_emb_vision) + Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionPatchEmbed, + Qwen2VisionTransformer, Qwen2VLDummyInputsBuilder, + Qwen2VLForConditionalGeneration, Qwen2VLMultiModalProcessor, + Qwen2VLProcessingInfo, apply_rotary_pos_emb_vision) from vllm.model_executor.models.utils import maybe_prefix from vllm.multimodal import MULTIMODAL_REGISTRY @@ -125,6 +125,14 @@ class CustomQwen2VisionBlock(Qwen2VisionBlock): prefix=f"{prefix}.attn") +class CustomQwen2VisionPatchEmbed(Qwen2VisionPatchEmbed): + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = x.matmul( + self.proj.weight.data.view(self.embed_dim, -1).transpose(0, 1)) + return x + + class CustomQwen2VisionTransformer(Qwen2VisionTransformer): def __init__( @@ -135,6 +143,14 @@ class CustomQwen2VisionTransformer(Qwen2VisionTransformer): prefix: str = "", ) -> None: super().__init__(vision_config, norm_eps, quant_config, prefix) + + self.patch_embed = CustomQwen2VisionPatchEmbed( + patch_size=vision_config.patch_size, + temporal_patch_size=vision_config.temporal_patch_size, + in_channels=vision_config.in_channels, + embed_dim=vision_config.embed_dim, + ) + self.blocks = nn.ModuleList([ CustomQwen2VisionBlock(dim=self.embed_dim, num_heads=self.num_heads,