Fuse gate_proj and up_proj in Qwen 2.5 VL's vision MLP (#9661)
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com> Co-authored-by: Xiang (Kevin) Li <lik@nvidia.com> Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -43,7 +43,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.layers.attention.vision import VisionAttention
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
MergedColumnParallelLinear,
|
||||
RowParallelLinear,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
@@ -62,7 +66,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Qwen2_5_VLMLP(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
@@ -73,19 +76,12 @@ class Qwen2_5_VLMLP(nn.Module):
|
||||
prefix: str = "",
|
||||
):
|
||||
super().__init__()
|
||||
self.gate_proj = ColumnParallelLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
self.gate_up_proj = MergedColumnParallelLinear(
|
||||
input_size=in_features,
|
||||
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("gate_proj", prefix),
|
||||
)
|
||||
self.up_proj = ColumnParallelLinear(
|
||||
in_features,
|
||||
hidden_features,
|
||||
bias=bias,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("up_proj", prefix),
|
||||
prefix=add_prefix("gate_up_proj", prefix),
|
||||
)
|
||||
self.down_proj = RowParallelLinear(
|
||||
hidden_features,
|
||||
@@ -97,12 +93,11 @@ class Qwen2_5_VLMLP(nn.Module):
|
||||
self.act = ACT2FN[hidden_act]
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x_parallel_gate, _ = self.gate_proj(x)
|
||||
x_parallel_gate = self.act(x_parallel_gate)
|
||||
x_parallel_up, _ = self.up_proj(x)
|
||||
x_parallel = x_parallel_gate * x_parallel_up
|
||||
x, _ = self.down_proj(x_parallel)
|
||||
return x
|
||||
gate_up, _ = self.gate_up_proj(x)
|
||||
gate, up = gate_up.chunk(2, dim=-1)
|
||||
x = self.act(gate) * up
|
||||
x_down, _ = self.down_proj(x)
|
||||
return x_down
|
||||
|
||||
|
||||
class Qwen2_5_VisionBlock(nn.Module):
|
||||
@@ -353,7 +348,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
return self.blocks[0].mlp.gate_proj.weight.device
|
||||
return self.patch_embed.proj.weight.device
|
||||
|
||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||
pos_ids = []
|
||||
@@ -468,9 +463,8 @@ cached_get_processor = lru_cache(get_processor)
|
||||
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
# BitandBytes specific attributes
|
||||
default_bitsandbytes_target_modules = [
|
||||
".gate_proj.",
|
||||
".gate_up_proj.",
|
||||
".down_proj.",
|
||||
".up_proj.",
|
||||
".q_proj.",
|
||||
".k_proj.",
|
||||
".v_proj.",
|
||||
@@ -617,7 +611,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
if "visual" in name:
|
||||
if (
|
||||
"visual" in name
|
||||
and "up_proj" not in name
|
||||
and "gate_proj" not in name
|
||||
):
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user