FastPatch: Optimized Patch Embedding for Qwen2VL (#345)
### What this PR does / why we need it? We proposed the FastPatch method, which optimized patch embedding (Conv3D) for Qwen2VL. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? We've tested it on benchmark, it meets our satisfaction and is better than original patch_embed layer. --------- Signed-off-by: baifanxxx <baifanxxx@gmail.com> Signed-off-by: zouyida <zouyida@huawei.com> Co-authored-by: zouyida <zouyida@huawei.com>
This commit is contained in:
@@ -30,10 +30,10 @@ from vllm.config import VllmConfig
|
|||||||
from vllm.model_executor.layers.activation import QuickGELU
|
from vllm.model_executor.layers.activation import QuickGELU
|
||||||
from vllm.model_executor.layers.quantization import QuantizationConfig
|
from vllm.model_executor.layers.quantization import QuantizationConfig
|
||||||
from vllm.model_executor.models.qwen2_vl import (
|
from vllm.model_executor.models.qwen2_vl import (
|
||||||
Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionTransformer,
|
Qwen2VisionAttention, Qwen2VisionBlock, Qwen2VisionPatchEmbed,
|
||||||
Qwen2VLDummyInputsBuilder, Qwen2VLForConditionalGeneration,
|
Qwen2VisionTransformer, Qwen2VLDummyInputsBuilder,
|
||||||
Qwen2VLMultiModalProcessor, Qwen2VLProcessingInfo,
|
Qwen2VLForConditionalGeneration, Qwen2VLMultiModalProcessor,
|
||||||
apply_rotary_pos_emb_vision)
|
Qwen2VLProcessingInfo, apply_rotary_pos_emb_vision)
|
||||||
from vllm.model_executor.models.utils import maybe_prefix
|
from vllm.model_executor.models.utils import maybe_prefix
|
||||||
from vllm.multimodal import MULTIMODAL_REGISTRY
|
from vllm.multimodal import MULTIMODAL_REGISTRY
|
||||||
|
|
||||||
@@ -125,6 +125,14 @@ class CustomQwen2VisionBlock(Qwen2VisionBlock):
|
|||||||
prefix=f"{prefix}.attn")
|
prefix=f"{prefix}.attn")
|
||||||
|
|
||||||
|
|
||||||
|
class CustomQwen2VisionPatchEmbed(Qwen2VisionPatchEmbed):
|
||||||
|
|
||||||
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
|
x = x.matmul(
|
||||||
|
self.proj.weight.data.view(self.embed_dim, -1).transpose(0, 1))
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
class CustomQwen2VisionTransformer(Qwen2VisionTransformer):
|
class CustomQwen2VisionTransformer(Qwen2VisionTransformer):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -135,6 +143,14 @@ class CustomQwen2VisionTransformer(Qwen2VisionTransformer):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(vision_config, norm_eps, quant_config, prefix)
|
super().__init__(vision_config, norm_eps, quant_config, prefix)
|
||||||
|
|
||||||
|
self.patch_embed = CustomQwen2VisionPatchEmbed(
|
||||||
|
patch_size=vision_config.patch_size,
|
||||||
|
temporal_patch_size=vision_config.temporal_patch_size,
|
||||||
|
in_channels=vision_config.in_channels,
|
||||||
|
embed_dim=vision_config.embed_dim,
|
||||||
|
)
|
||||||
|
|
||||||
self.blocks = nn.ModuleList([
|
self.blocks = nn.ModuleList([
|
||||||
CustomQwen2VisionBlock(dim=self.embed_dim,
|
CustomQwen2VisionBlock(dim=self.embed_dim,
|
||||||
num_heads=self.num_heads,
|
num_heads=self.num_heads,
|
||||||
|
|||||||
Reference in New Issue
Block a user