diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 7ffb2e89b..20165c3c7 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -43,7 +43,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import ( from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.layernorm import RMSNorm -from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear +from sglang.srt.layers.linear import ( + ColumnParallelLinear, + MergedColumnParallelLinear, + RowParallelLinear, +) from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.pooler import Pooler, PoolingType from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -62,7 +66,6 @@ logger = logging.getLogger(__name__) class Qwen2_5_VLMLP(nn.Module): - def __init__( self, in_features: int, @@ -73,19 +76,12 @@ class Qwen2_5_VLMLP(nn.Module): prefix: str = "", ): super().__init__() - self.gate_proj = ColumnParallelLinear( - in_features, - hidden_features, + self.gate_up_proj = MergedColumnParallelLinear( + input_size=in_features, + output_sizes=[hidden_features] * 2, # [gate_proj, up_proj] bias=bias, quant_config=quant_config, - prefix=add_prefix("gate_proj", prefix), - ) - self.up_proj = ColumnParallelLinear( - in_features, - hidden_features, - bias=bias, - quant_config=quant_config, - prefix=add_prefix("up_proj", prefix), + prefix=add_prefix("gate_up_proj", prefix), ) self.down_proj = RowParallelLinear( hidden_features, @@ -97,12 +93,11 @@ class Qwen2_5_VLMLP(nn.Module): self.act = ACT2FN[hidden_act] def forward(self, x: torch.Tensor) -> torch.Tensor: - x_parallel_gate, _ = self.gate_proj(x) - x_parallel_gate = self.act(x_parallel_gate) - x_parallel_up, _ = self.up_proj(x) - x_parallel = x_parallel_gate * x_parallel_up - x, _ = self.down_proj(x_parallel) - return x + gate_up, _ = self.gate_up_proj(x) + gate, up = gate_up.chunk(2, dim=-1) + x = self.act(gate) * up + x_down, _ = self.down_proj(x) + return x_down class Qwen2_5_VisionBlock(nn.Module): @@ -353,7 +348,7 @@ class Qwen2_5_VisionTransformer(nn.Module): @property def device(self) -> torch.device: - return self.blocks[0].mlp.gate_proj.weight.device + return self.patch_embed.proj.weight.device def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: pos_ids = [] @@ -468,9 +463,8 @@ cached_get_processor = lru_cache(get_processor) class Qwen2_5_VLForConditionalGeneration(nn.Module): # BitandBytes specific attributes default_bitsandbytes_target_modules = [ - ".gate_proj.", + ".gate_up_proj.", ".down_proj.", - ".up_proj.", ".q_proj.", ".k_proj.", ".v_proj.", @@ -617,7 +611,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module): for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue - if "visual" in name: + if ( + "visual" in name + and "up_proj" not in name + and "gate_proj" not in name + ): continue name = name.replace(weight_name, param_name)