[GLM4.1V and GLM4.5V] Add vision transformer num_dummy_head support: max tp=4 -> max tp=8 (#9059)
This commit is contained in:
@@ -9,6 +9,7 @@ from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisi
|
||||
|
||||
from sglang.srt.hf_transformers_utils import get_processor
|
||||
from sglang.srt.layers.activation import SiluAndMul
|
||||
from sglang.srt.layers.attention import vision_utils
|
||||
from sglang.srt.layers.layernorm import RMSNorm
|
||||
from sglang.srt.layers.linear import (
|
||||
ColumnParallelLinear,
|
||||
@@ -91,6 +92,7 @@ class Glm4vVisionBlock(Qwen2_5_VisionBlock):
|
||||
norm_layer=norm_layer,
|
||||
quant_config=quant_config,
|
||||
prefix=prefix,
|
||||
num_dummy_heads=config.num_dummy_heads,
|
||||
)
|
||||
self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
||||
@@ -469,7 +471,7 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
||||
nn.Module.__init__(self)
|
||||
|
||||
self.config = config
|
||||
|
||||
vision_utils.update_vit_attn_dummy_heads_config(self.config)
|
||||
self.model = Glm4Model(
|
||||
config,
|
||||
quant_config,
|
||||
@@ -537,6 +539,51 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
||||
video_embeds = torch.split(video_embeds, split_sizes)
|
||||
return torch.cat(video_embeds)
|
||||
|
||||
def _update_hf_config(self):
|
||||
"""update hf config to ensure vision attention num_attention_heads is divisible by tp_size"""
|
||||
tp_size = get_attention_tp_size()
|
||||
num_heads = self.config.vision_config.num_heads
|
||||
head_dim = self.config.vision_config.hidden_size // num_heads
|
||||
num_dummy_heads = 0
|
||||
|
||||
if num_heads % tp_size != 0:
|
||||
num_dummy_heads = (
|
||||
(num_heads + tp_size - 1) // tp_size
|
||||
) * tp_size - num_heads
|
||||
|
||||
setattr(self.config.vision_config, "head_dim", head_dim)
|
||||
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
|
||||
|
||||
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
||||
"""pad attn qkv weights for dummy heads"""
|
||||
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
||||
if num_dummy_heads == 0:
|
||||
return loaded_weight
|
||||
head_dim = self.config.vision_config.head_dim
|
||||
|
||||
if "attn.qkv_proj" in name:
|
||||
wq, wk, wv = loaded_weight.chunk(3, dim=0)
|
||||
if name.endswith(".weight"):
|
||||
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
|
||||
elif name.endswith(".bias"):
|
||||
dummy_shape = [num_dummy_heads, head_dim]
|
||||
else:
|
||||
raise RuntimeError(f"Unsupported weight with name={name}")
|
||||
pad_func = lambda x: torch.cat(
|
||||
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
|
||||
).flatten(0, 1)
|
||||
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
|
||||
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
||||
elif "attn.proj.weight" in name:
|
||||
padded_weight = loaded_weight.new_zeros(
|
||||
loaded_weight.shape[0], head_dim * num_dummy_heads
|
||||
)
|
||||
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
||||
elif "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
||||
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
||||
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
||||
return loaded_weight
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
@@ -583,6 +630,10 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration):
|
||||
raise
|
||||
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
if "visual" in name:
|
||||
loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
|
||||
self.config, name, loaded_weight
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user