From a391f73adc627fce5ae265a76e8ee77b223313c0 Mon Sep 17 00:00:00 2001 From: Kevin Xiang Li Date: Sun, 31 Aug 2025 04:08:28 -0700 Subject: [PATCH] Fuse gate_proj and up_proj in Qwen 2.5 VL's vision MLP (#9661) Signed-off-by: Xinyuan Tong Co-authored-by: Xiang (Kevin) Li Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Xinyuan Tong --- python/sglang/srt/models/qwen2_5_vl.py | 44 ++++++++++++-------------- 1 file changed, 21 insertions(+), 23 deletions(-) diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 7ffb2e89b..20165c3c7 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -43,7 +43,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -62,7 +66,6 @@ logger = logging.getLogger(__name__) class Qwen2_5_VLMLP(nn.Module): - def __init__( self, in_features: int, @@ -73,19 +76,12 @@ class Qwen2_5_VLMLP(nn.Module): prefix: str = "", ): super().__init__() - self.gate_proj = ColumnParallelLinear( - in_features, - hidden_features, + self.gate_up_proj = MergedColumnParallelLinear( + input_size=in_features, + output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] bias=bias, quant_config=quant_config, - prefix=add_prefix("gate_proj", prefix), - ) - self.up_proj = ColumnParallelLinear( - in_features, - hidden_features, - bias=bias, - quant_config=quant_config, - prefix=add_prefix("up_proj", prefix), + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( hidden_features, @@ -97,12 +93,11 @@ class Qwen2_5_VLMLP(nn.Module): self.act = ACT2FN[hidden_act] def forward(self, x: torch.Tensor) -> torch.Tensor: - x_parallel_gate, _ = self.gate_proj(x) - x_parallel_gate = self.act(x_parallel_gate) - x_parallel_up, _ = self.up_proj(x) - x_parallel = x_parallel_gate * x_parallel_up - x, _ = self.down_proj(x_parallel) - return x + gate_up, _ = self.gate_up_proj(x) + gate, up = gate_up.chunk(2, dim=-1) + x = self.act(gate) * up + x_down, _ = self.down_proj(x) + return x_down class Qwen2_5_VisionBlock(nn.Module): @@ -353,7 +348,7 @@ class Qwen2_5_VisionTransformer(nn.Module): @property def device(self) -> torch.device: - return self.blocks[0].mlp.gate_proj.weight.device + return self.patch_embed.proj.weight.device def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: pos_ids = [] @@ -468,9 +463,8 @@ cached_get_processor = lru_cache(get_processor) class Qwen2_5_VLForConditionalGeneration(nn.Module): # BitandBytes specific attributes default_bitsandbytes_target_modules = [ - ".gate_proj.", + ".gate_up_proj.", ".down_proj.", - ".up_proj.", ".q_proj.", ".k_proj.", ".v_proj.", @@ -617,7 +611,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - if "visual" in name: + if ( + "visual" in name + and "up_proj" not in name + and "gate_proj" not in name + ): continue name = name.replace(weight_name, param_name)