Fuse gate_proj and up_proj in Qwen 2.5 VL's vision MLP (#9661)
Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com> Co-authored-by: Xiang (Kevin) Li <lik@nvidia.com> Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
This commit is contained in:
@@ -43,7 +43,11 @@ from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import (
|
|||||||
from sglang.srt.hf_transformers_utils import get_processor
|
from sglang.srt.hf_transformers_utils import get_processor
|
||||||
from sglang.srt.layers.attention.vision import VisionAttention
|
from sglang.srt.layers.attention.vision import VisionAttention
|
||||||
from sglang.srt.layers.layernorm import RMSNorm
|
from sglang.srt.layers.layernorm import RMSNorm
|
||||||
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
from sglang.srt.layers.linear import (
|
||||||
|
ColumnParallelLinear,
|
||||||
|
MergedColumnParallelLinear,
|
||||||
|
RowParallelLinear,
|
||||||
|
)
|
||||||
from sglang.srt.layers.logits_processor import LogitsProcessor
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
||||||
from sglang.srt.layers.pooler import Pooler, PoolingType
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
||||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||||
@@ -62,7 +66,6 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VLMLP(nn.Module):
|
class Qwen2_5_VLMLP(nn.Module):
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
in_features: int,
|
in_features: int,
|
||||||
@@ -73,19 +76,12 @@ class Qwen2_5_VLMLP(nn.Module):
|
|||||||
prefix: str = "",
|
prefix: str = "",
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.gate_proj = ColumnParallelLinear(
|
self.gate_up_proj = MergedColumnParallelLinear(
|
||||||
in_features,
|
input_size=in_features,
|
||||||
hidden_features,
|
output_sizes=[hidden_features] * 2, # [gate_proj, up_proj]
|
||||||
bias=bias,
|
bias=bias,
|
||||||
quant_config=quant_config,
|
quant_config=quant_config,
|
||||||
prefix=add_prefix("gate_proj", prefix),
|
prefix=add_prefix("gate_up_proj", prefix),
|
||||||
)
|
|
||||||
self.up_proj = ColumnParallelLinear(
|
|
||||||
in_features,
|
|
||||||
hidden_features,
|
|
||||||
bias=bias,
|
|
||||||
quant_config=quant_config,
|
|
||||||
prefix=add_prefix("up_proj", prefix),
|
|
||||||
)
|
)
|
||||||
self.down_proj = RowParallelLinear(
|
self.down_proj = RowParallelLinear(
|
||||||
hidden_features,
|
hidden_features,
|
||||||
@@ -97,12 +93,11 @@ class Qwen2_5_VLMLP(nn.Module):
|
|||||||
self.act = ACT2FN[hidden_act]
|
self.act = ACT2FN[hidden_act]
|
||||||
|
|
||||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||||
x_parallel_gate, _ = self.gate_proj(x)
|
gate_up, _ = self.gate_up_proj(x)
|
||||||
x_parallel_gate = self.act(x_parallel_gate)
|
gate, up = gate_up.chunk(2, dim=-1)
|
||||||
x_parallel_up, _ = self.up_proj(x)
|
x = self.act(gate) * up
|
||||||
x_parallel = x_parallel_gate * x_parallel_up
|
x_down, _ = self.down_proj(x)
|
||||||
x, _ = self.down_proj(x_parallel)
|
return x_down
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
class Qwen2_5_VisionBlock(nn.Module):
|
class Qwen2_5_VisionBlock(nn.Module):
|
||||||
@@ -353,7 +348,7 @@ class Qwen2_5_VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
return self.blocks[0].mlp.gate_proj.weight.device
|
return self.patch_embed.proj.weight.device
|
||||||
|
|
||||||
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
|
||||||
pos_ids = []
|
pos_ids = []
|
||||||
@@ -468,9 +463,8 @@ cached_get_processor = lru_cache(get_processor)
|
|||||||
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
||||||
# BitandBytes specific attributes
|
# BitandBytes specific attributes
|
||||||
default_bitsandbytes_target_modules = [
|
default_bitsandbytes_target_modules = [
|
||||||
".gate_proj.",
|
".gate_up_proj.",
|
||||||
".down_proj.",
|
".down_proj.",
|
||||||
".up_proj.",
|
|
||||||
".q_proj.",
|
".q_proj.",
|
||||||
".k_proj.",
|
".k_proj.",
|
||||||
".v_proj.",
|
".v_proj.",
|
||||||
@@ -617,7 +611,11 @@ class Qwen2_5_VLForConditionalGeneration(nn.Module):
|
|||||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||||
if weight_name not in name:
|
if weight_name not in name:
|
||||||
continue
|
continue
|
||||||
if "visual" in name:
|
if (
|
||||||
|
"visual" in name
|
||||||
|
and "up_proj" not in name
|
||||||
|
and "gate_proj" not in name
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
name = name.replace(weight_name, param_name)
|
name = name.replace(weight_name, param_name)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user