FastPatch: Optimized Patch Embedding for Qwen2VL (#345)
### What this PR does / why we need it? We proposed the FastPatch method, which optimized patch embedding (Conv3D) for Qwen2VL. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? We've tested it on benchmark, it meets our satisfaction and is better than original patch_embed layer. --------- Signed-off-by: baifanxxx <baifanxxx@gmail.com> Signed-off-by: zouyida <zouyida@huawei.com> Co-authored-by: zouyida <zouyida@huawei.com>
This commit is contained in:
@@ -30,10 +30,10 @@ from vllm.config import VllmConfig
|
||||
from vllm.model_executor.layers.activation import QuickGELU
|
||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||
from vllm.model_executor.models.qwen2_vl import (
|
||||
Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionTransformer,
|
||||
Qwen2VLDummyInputsBuilder, Qwen2VLForConditionalGeneration,
|
||||
Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
|
||||
apply_rotary_pos_emb_vision)
|
||||
Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionPatchEmbed,
|
||||
Qwen2VisionTransformer, Qwen2VLDummyInputsBuilder,
|
||||
Qwen2VLForConditionalGeneration, Qwen2VLMultiModalProcessor,
|
||||
Qwen2VLProcessingInfo, apply_rotary_pos_emb_vision)
|
||||
from vllm.model_executor.models.utils import maybe_prefix
|
||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||
|
||||
@@ -125,6 +125,14 @@ class CustomQwen2VisionBlock(Qwen2VisionBlock):
|
||||
prefix=f"{prefix}.attn")
|
||||
|
||||
|
||||
class CustomQwen2VisionPatchEmbed(Qwen2VisionPatchEmbed):
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = x.matmul(
|
||||
self.proj.weight.data.view(self.embed_dim, -1).transpose(0, 1))
|
||||
return x
|
||||
|
||||
|
||||
class CustomQwen2VisionTransformer(Qwen2VisionTransformer):
|
||||
|
||||
def __init__(
|
||||
@@ -135,6 +143,14 @@ class CustomQwen2VisionTransformer(Qwen2VisionTransformer):
|
||||
prefix: str = "",
|
||||
) -> None:
|
||||
super().__init__(vision_config, norm_eps, quant_config, prefix)
|
||||
|
||||
self.patch_embed = CustomQwen2VisionPatchEmbed(
|
||||
patch_size=vision_config.patch_size,
|
||||
temporal_patch_size=vision_config.temporal_patch_size,
|
||||
in_channels=vision_config.in_channels,
|
||||
embed_dim=vision_config.embed_dim,
|
||||
)
|
||||
|
||||
self.blocks = nn.ModuleList([
|
||||
CustomQwen2VisionBlock(dim=self.embed_dim,
|
||||
num_heads=self.num_heads,
|
||||
|
||||
Reference in New Issue
Block a user