Fuse gate_proj and up_proj in Qwen 2.5 VL's vision MLP (#9661)

Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Xiang (Kevin) Li <lik@nvidia.com>
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
Kevin Xiang Li
2025-08-31 04:08:28 -07:00
committed by GitHub
parent 25c7395934
commit a391f73adc

View File

@@ -43,7 +43,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
from sglang.srt.hf_transformers_utils import get_processor
from sglang.srt.layers.attention.vision import VisionAttention
from sglang.srt.layers.layernorm import RMSNorm
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
from sglang.srt.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from sglang.srt.layers.logits_processor import LogitsProcessor
from sglang.srt.layers.pooler import Pooler, PoolingType
from sglang.srt.layers.quantization.base_config import QuantizationConfig
@@ -62,7 +66,6 @@ logger = logging.getLogger(__name__)
class Qwen2_5_VLMLP(nn.Module):
def __init__(
self,
in_features: int,
@@ -73,19 +76,12 @@ class Qwen2_5_VLMLP(nn.Module):
prefix: str = "",
):
super().__init__()
self.gate_proj = ColumnParallelLinear(
in_features,
hidden_features,
self.gate_up_proj = MergedColumnParallelLinear(
input_size=in_features,
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
bias=bias,
quant_config=quant_config,
prefix=add_prefix("gate_proj", prefix),
)
self.up_proj = ColumnParallelLinear(
in_features,
hidden_features,
bias=bias,
quant_config=quant_config,
prefix=add_prefix("up_proj", prefix),
prefix=add_prefix("gate_up_proj", prefix),
)
self.down_proj = RowParallelLinear(
hidden_features,
@@ -97,12 +93,11 @@ class Qwen2_5_VLMLP(nn.Module):
self.act = ACT2FN[hidden_act]
def forward(self, x: torch.Tensor) -> torch.Tensor:
x_parallel_gate, _ = self.gate_proj(x)
x_parallel_gate = self.act(x_parallel_gate)
x_parallel_up, _ = self.up_proj(x)
x_parallel = x_parallel_gate * x_parallel_up
x, _ = self.down_proj(x_parallel)
return x
gate_up, _ = self.gate_up_proj(x)
gate, up = gate_up.chunk(2, dim=-1)
x = self.act(gate) * up
x_down, _ = self.down_proj(x)
return x_down
class Qwen2_5_VisionBlock(nn.Module):
@@ -353,7 +348,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
@property
def device(self) -> torch.device:
return self.blocks[0].mlp.gate_proj.weight.device
return self.patch_embed.proj.weight.device
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
pos_ids = []
@@ -468,9 +463,8 @@ cached_get_processor = lru_cache(get_processor)
class Qwen2_5_VLForConditionalGeneration(nn.Module):
# BitandBytes specific attributes
default_bitsandbytes_target_modules = [
".gate_proj.",
".gate_up_proj.",
".down_proj.",
".up_proj.",
".q_proj.",
".k_proj.",
".v_proj.",
@@ -617,7 +611,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
if "visual" in name:
if (
"visual" in name
and "up_proj" not in name
and "gate_proj" not in name
):
continue
name = name.replace(weight_name, param_name)