FastPatch: Optimized Patch Embedding for Qwen2VL (#345)

### What this PR does / why we need it?
We proposed the FastPatch method, which optimized patch embedding
(Conv3D) for Qwen2VL.


### Does this PR introduce _any_ user-facing change?
No.

### How was this patch tested?
We've tested it on benchmark, it meets our satisfaction and is better
than original patch_embed layer.


---------

Signed-off-by: baifanxxx <baifanxxx@gmail.com>
Signed-off-by: zouyida <zouyida@huawei.com>
Co-authored-by: zouyida <zouyida@huawei.com>
This commit is contained in:
BAI Fan
2025-03-26 14:28:20 +08:00
committed by GitHub
parent d4accf4ec2
commit 122505208f

View File

@@ -30,10 +30,10 @@ from vllm.config import VllmConfig
from vllm.model_executor.layers.activation import QuickGELU
from vllm.model_executor.layers.quantization import QuantizationConfig
from vllm.model_executor.models.qwen2_vl import (
Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionTransformer,
Qwen2VLDummyInputsBuilder, Qwen2VLForConditionalGeneration,
Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
apply_rotary_pos_emb_vision)
Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionPatchEmbed,
Qwen2VisionTransformer, Qwen2VLDummyInputsBuilder,
Qwen2VLForConditionalGeneration, Qwen2VLMultiModalProcessor,
Qwen2VLProcessingInfo, apply_rotary_pos_emb_vision)
from vllm.model_executor.models.utils import maybe_prefix
from vllm.multimodal import MULTIMODAL_REGISTRY
@@ -125,6 +125,14 @@ class CustomQwen2VisionBlock(Qwen2VisionBlock):
prefix=f"{prefix}.attn")
class CustomQwen2VisionPatchEmbed(Qwen2VisionPatchEmbed):
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = x.matmul(
self.proj.weight.data.view(self.embed_dim, -1).transpose(0, 1))
return x
class CustomQwen2VisionTransformer(Qwen2VisionTransformer):
def __init__(
@@ -135,6 +143,14 @@ class CustomQwen2VisionTransformer(Qwen2VisionTransformer):
prefix: str = "",
) -> None:
super().__init__(vision_config, norm_eps, quant_config, prefix)
self.patch_embed = CustomQwen2VisionPatchEmbed(
patch_size=vision_config.patch_size,
temporal_patch_size=vision_config.temporal_patch_size,
in_channels=vision_config.in_channels,
embed_dim=vision_config.embed_dim,
)
self.blocks = nn.ModuleList([
CustomQwen2VisionBlock(dim=self.embed_dim,
num_heads=self.num_heads,