From c2fbf60f398c1d44b675c0fc3fbd3039dca01003 Mon Sep 17 00:00:00 2001 From: Binyao Jiang Date: Mon, 18 Aug 2025 14:40:13 -0700 Subject: [PATCH] [GLM4.1V and GLM4.5V] Add vision transformer num_dummy_head support: max tp=4 -> max tp=8 (#9059) --- benchmark/mmmu/bench_hf.py | 8 ++- benchmark/mmmu/bench_sglang.py | 8 ++- benchmark/mmmu/eval_utils.py | 6 +- .../srt/layers/attention/vision_utils.py | 65 +++++++++++++++++++ python/sglang/srt/models/glm4v.py | 53 ++++++++++++++- python/sglang/srt/models/glm4v_moe.py | 6 ++ python/sglang/srt/models/interns1.py | 51 ++------------- python/sglang/srt/models/internvl.py | 53 ++------------- python/sglang/srt/models/qwen2_5_vl.py | 2 + 9 files changed, 150 insertions(+), 102 deletions(-) create mode 100644 python/sglang/srt/layers/attention/vision_utils.py diff --git a/benchmark/mmmu/bench_hf.py b/benchmark/mmmu/bench_hf.py index 0295bc5dc..949b63b80 100644 --- a/benchmark/mmmu/bench_hf.py +++ b/benchmark/mmmu/bench_hf.py @@ -141,9 +141,13 @@ def eval_mmmu(args): print(f"response: {response}") process_result(response, sample, answer_dict, out_samples) - args.output_path = f"{args.model_path}_val_hf.json" + args.output_path = f"{args.model_path}_answer_hf.json" save_json(args.output_path, out_samples) - eval_result(model_answer_path=args.output_path, answer_dict=answer_dict) + eval_result( + model_answer_path=args.output_path, + answer_dict=answer_dict, + eval_output_path=f"{args.model_path}_val_hf.json", + ) if __name__ == "__main__": diff --git a/benchmark/mmmu/bench_sglang.py b/benchmark/mmmu/bench_sglang.py index 372bfeed8..d8834ea5f 100644 --- a/benchmark/mmmu/bench_sglang.py +++ b/benchmark/mmmu/bench_sglang.py @@ -187,9 +187,13 @@ async def eval_mmmu(args) -> None: print("Profiler stopped") print(f"Benchmark time: {time.perf_counter() - start}") - args.output_path = f"./val_sglang.json" + args.output_path = "./answer_sglang.json" save_json(args.output_path, out_samples) - eval_result(model_answer_path=args.output_path, answer_dict=answer_dict) + eval_result( + model_answer_path=args.output_path, + answer_dict=answer_dict, + eval_output_path="./val_sglang.json", + ) def parse_args(): diff --git a/benchmark/mmmu/eval_utils.py b/benchmark/mmmu/eval_utils.py index 83f6dd7fb..ca0e87c6a 100644 --- a/benchmark/mmmu/eval_utils.py +++ b/benchmark/mmmu/eval_utils.py @@ -544,7 +544,9 @@ def process_result(response, sample, answer_dict, out_samples): } -def eval_result(model_answer_path, answer_dict): +def eval_result(model_answer_path, answer_dict, eval_output_path=None): + if eval_output_path is None: + eval_output_path = model_answer_path print("Evaluating...") output_dict = json.load(open(model_answer_path)) # answer_dict = json.load(open(answer_path)) @@ -639,7 +641,7 @@ def eval_result(model_answer_path, answer_dict): "acc": overall_acc, } pprint.pprint(printable_results) - out = model_answer_path + out = eval_output_path with open(out, "w", encoding="utf-8") as outfile: json.dump(printable_results, outfile) print(f"eval out saved to {out}") diff --git a/python/sglang/srt/layers/attention/vision_utils.py b/python/sglang/srt/layers/attention/vision_utils.py new file mode 100644 index 000000000..ecccb1f85 --- /dev/null +++ b/python/sglang/srt/layers/attention/vision_utils.py @@ -0,0 +1,65 @@ +"""Utility functions for vision attention layers.""" + +import torch + +from sglang.srt.layers.dp_attention import get_attention_tp_size + + +def update_vit_attn_dummy_heads_config(config): + """Update HF config to ensure vision attention num_attention_heads is divisible by tp_size""" + tp_size = get_attention_tp_size() + num_heads = getattr( + config.vision_config, + "num_heads", + getattr(config.vision_config, "num_attention_heads", None), + ) + head_dim = config.vision_config.hidden_size // num_heads + num_dummy_heads = 0 + + if num_heads % tp_size != 0: + num_dummy_heads = ((num_heads + tp_size - 1) // tp_size) * tp_size - num_heads + + setattr(config.vision_config, "head_dim", head_dim) + setattr(config.vision_config, "num_dummy_heads", num_dummy_heads) + + +def pad_vit_attn_dummy_heads(config, name: str, loaded_weight: torch.Tensor): + """Pad attention qkv weights for dummy heads""" + num_dummy_heads = config.vision_config.num_dummy_heads + if num_dummy_heads == 0: + return loaded_weight + head_dim = config.vision_config.head_dim + + if "attn.qkv_proj" in name: + wq, wk, wv = loaded_weight.chunk(3, dim=0) + if name.endswith(".weight"): + dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]] + elif name.endswith(".bias"): + dummy_shape = [num_dummy_heads, head_dim] + else: + raise RuntimeError(f"Unsupported weight with name={name}") + pad_func = lambda x: torch.cat( + [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0 + ).flatten(0, 1) + wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv) + loaded_weight = torch.cat([wq, wk, wv], dim=0) + elif any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]): + if name.endswith(".weight"): + dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]] + elif name.endswith(".bias"): + dummy_shape = [num_dummy_heads, head_dim] + else: + raise RuntimeError(f"Unsupported weight with name={name}") + padded_weight = loaded_weight.new_zeros(dummy_shape) + loaded_weight = torch.cat( + [loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0 + ).flatten(0, 1) + elif "attn.proj.weight" in name: + padded_weight = loaded_weight.new_zeros( + loaded_weight.shape[0], head_dim * num_dummy_heads + ) + loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1) + elif "attn.q_norm.weight" in name or "attn.k_norm.weight" in name: + padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads) + loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0) + return loaded_weight diff --git a/python/sglang/srt/models/glm4v.py b/python/sglang/srt/models/glm4v.py index fbd757849..79eae3946 100644 --- a/python/sglang/srt/models/glm4v.py +++ b/python/sglang/srt/models/glm4v.py @@ -9,6 +9,7 @@ from transformers.models.glm4v.configuration_glm4v import Glm4vConfig, Glm4vVisi from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.layers.activation import SiluAndMul +from sglang.srt.layers.attention import vision_utils from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -91,6 +92,7 @@ class Glm4vVisionBlock(Qwen2_5_VisionBlock): norm_layer=norm_layer, quant_config=quant_config, prefix=prefix, + num_dummy_heads=config.num_dummy_heads, ) self.norm1 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.norm2 = Glm4vRMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -469,7 +471,7 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): nn.Module.__init__(self) self.config = config - + vision_utils.update_vit_attn_dummy_heads_config(self.config) self.model = Glm4Model( config, quant_config, @@ -537,6 +539,51 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): video_embeds = torch.split(video_embeds, split_sizes) return torch.cat(video_embeds) + def _update_hf_config(self): + """update hf config to ensure vision attention num_attention_heads is divisible by tp_size""" + tp_size = get_attention_tp_size() + num_heads = self.config.vision_config.num_heads + head_dim = self.config.vision_config.hidden_size // num_heads + num_dummy_heads = 0 + + if num_heads % tp_size != 0: + num_dummy_heads = ( + (num_heads + tp_size - 1) // tp_size + ) * tp_size - num_heads + + setattr(self.config.vision_config, "head_dim", head_dim) + setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads) + + def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor): + """pad attn qkv weights for dummy heads""" + num_dummy_heads = self.config.vision_config.num_dummy_heads + if num_dummy_heads == 0: + return loaded_weight + head_dim = self.config.vision_config.head_dim + + if "attn.qkv_proj" in name: + wq, wk, wv = loaded_weight.chunk(3, dim=0) + if name.endswith(".weight"): + dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]] + elif name.endswith(".bias"): + dummy_shape = [num_dummy_heads, head_dim] + else: + raise RuntimeError(f"Unsupported weight with name={name}") + pad_func = lambda x: torch.cat( + [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0 + ).flatten(0, 1) + wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv) + loaded_weight = torch.cat([wq, wk, wv], dim=0) + elif "attn.proj.weight" in name: + padded_weight = loaded_weight.new_zeros( + loaded_weight.shape[0], head_dim * num_dummy_heads + ) + loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1) + elif "attn.q_norm.weight" in name or "attn.k_norm.weight" in name: + padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads) + loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0) + return loaded_weight + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -583,6 +630,10 @@ class Glm4vForConditionalGeneration(Qwen2_5_VLForConditionalGeneration): raise weight_loader = getattr(param, "weight_loader", default_weight_loader) + if "visual" in name: + loaded_weight = vision_utils.pad_vit_attn_dummy_heads( + self.config, name, loaded_weight + ) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/glm4v_moe.py b/python/sglang/srt/models/glm4v_moe.py index 576cb3490..86cca4ab2 100644 --- a/python/sglang/srt/models/glm4v_moe.py +++ b/python/sglang/srt/models/glm4v_moe.py @@ -11,6 +11,7 @@ from sglang.srt.distributed import ( get_tensor_model_parallel_world_size, ) from sglang.srt.hf_transformers_utils import get_processor +from sglang.srt.layers.attention import vision_utils from sglang.srt.layers.logits_processor import LogitsProcessor from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.pooler import Pooler, PoolingType @@ -40,6 +41,7 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): config.moe_layer_freq = 1 self.config = config + vision_utils.update_vit_attn_dummy_heads_config(self.config) self.tp_size = get_tensor_model_parallel_world_size() self.quant_config = quant_config self.determine_num_fused_shared_experts("Glm4MoeForCausalLM") @@ -385,6 +387,10 @@ class Glm4vMoeForConditionalGeneration(Glm4vForConditionalGeneration): weight_loader = getattr( param, "weight_loader", default_weight_loader ) + if "visual" in name: + loaded_weight = vision_utils.pad_vit_attn_dummy_heads( + self.config, name, loaded_weight + ) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/interns1.py b/python/sglang/srt/models/interns1.py index d72deca41..267170301 100644 --- a/python/sglang/srt/models/interns1.py +++ b/python/sglang/srt/models/interns1.py @@ -4,7 +4,7 @@ import torch from torch import nn from transformers import PretrainedConfig -from sglang.srt.distributed import parallel_state +from sglang.srt.layers.attention import vision_utils from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -35,7 +35,7 @@ class InternS1ForConditionalGeneration(nn.Module): super().__init__() self.config = config self.quant_config = quant_config - self._update_hf_config() + vision_utils.update_vit_attn_dummy_heads_config(self.config) image_size = ( getattr(config, "force_image_size", None) or config.vision_config.image_size ) @@ -87,21 +87,6 @@ class InternS1ForConditionalGeneration(nn.Module): nn.Linear(llm_hidden_size, llm_hidden_size), ) - def _update_hf_config(self): - """update hf config to support tp""" - world_size = parallel_state.get_tensor_model_parallel_world_size() - num_heads = self.config.vision_config.num_attention_heads - head_dim = self.config.vision_config.hidden_size // num_heads - num_dummy_heads = 0 - - if num_heads % world_size != 0: - num_dummy_heads = ( - (num_heads + world_size) // world_size - ) * world_size - num_heads - - setattr(self.config.vision_config, "head_dim", head_dim) - setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads) - def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale @@ -184,34 +169,6 @@ class InternS1ForConditionalGeneration(nn.Module): return helper.pad_input_tokens(input_ids, mm_inputs) - def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor): - """pad attn qkv weights for dummy heads""" - num_dummy_heads = self.config.vision_config.num_dummy_heads - if num_dummy_heads == 0: - return loaded_weight - head_dim = self.config.vision_config.head_dim - - if any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]): - if name.endswith(".weight"): - dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]] - elif name.endswith(".bias"): - dummy_shape = [num_dummy_heads, head_dim] - else: - raise RuntimeError(f"Unsupported weight with name={name}") - padded_weight = loaded_weight.new_zeros(dummy_shape) - loaded_weight = torch.cat( - [loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0 - ).flatten(0, 1) - if "attn.proj.weight" in name: - padded_weight = loaded_weight.new_zeros( - loaded_weight.shape[0], head_dim * num_dummy_heads - ) - loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1) - if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name: - padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads) - loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0) - return loaded_weight - def _mapping_interns1_name(self, name): names_map = { "lm_head.weight": "language_model.lm_head.weight", @@ -270,7 +227,9 @@ class InternS1ForConditionalGeneration(nn.Module): continue name = self._mapping_interns1_name(name) if "vision_model" in name: - loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight) + loaded_weight = vision_utils.pad_vit_attn_dummy_heads( + self.config, name, loaded_weight + ) for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: diff --git a/python/sglang/srt/models/internvl.py b/python/sglang/srt/models/internvl.py index 94470cc0a..925bef445 100644 --- a/python/sglang/srt/models/internvl.py +++ b/python/sglang/srt/models/internvl.py @@ -10,7 +10,7 @@ from transformers import PretrainedConfig, PreTrainedModel from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling -from sglang.srt.distributed import parallel_state +from sglang.srt.layers.attention import vision_utils from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE from sglang.srt.layers.quantization.base_config import QuantizationConfig @@ -412,7 +412,7 @@ class InternVLChatModel(nn.Module): super().__init__() self.config = config self.quant_config = quant_config - self._update_vision_config() + vision_utils.update_vit_attn_dummy_heads_config(self.config) image_size = config.force_image_size or config.vision_config.image_size patch_size = config.vision_config.patch_size self.patch_size = patch_size @@ -462,21 +462,6 @@ class InternVLChatModel(nn.Module): nn.Linear(llm_hidden_size, llm_hidden_size), ) - def _update_vision_config(self): - """update vision config to support tp""" - world_size = parallel_state.get_tensor_model_parallel_world_size() - num_heads = self.config.vision_config.num_attention_heads - head_dim = self.config.vision_config.hidden_size // num_heads - num_dummy_heads = 0 - - if num_heads % world_size != 0: - num_dummy_heads = ( - (num_heads + world_size) // world_size - ) * world_size - num_heads - - setattr(self.config.vision_config, "head_dim", head_dim) - setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads) - def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale @@ -559,36 +544,6 @@ class InternVLChatModel(nn.Module): return helper.pad_input_tokens(input_ids, mm_inputs) - def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor): - """pad attn qkv weights for dummy heads""" - num_dummy_heads = self.config.vision_config.num_dummy_heads - if num_dummy_heads == 0: - return loaded_weight - head_dim = self.config.vision_config.head_dim - - if "attn.qkv_proj" in name: - wq, wk, wv = loaded_weight.chunk(3, dim=0) - if name.endswith(".weight"): - dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]] - elif name.endswith(".bias"): - dummy_shape = [num_dummy_heads, head_dim] - else: - raise RuntimeError(f"Unsupported weight with name={name}") - pad_func = lambda x: torch.cat( - [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0 - ).flatten(0, 1) - wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv) - loaded_weight = torch.cat([wq, wk, wv], dim=0) - if "attn.proj.weight" in name: - padded_weight = loaded_weight.new_zeros( - loaded_weight.shape[0], head_dim * num_dummy_heads - ) - loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1) - if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name: - padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads) - loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0) - return loaded_weight - def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): expert_params_mapping = [] if "InternLM2ForCausalLM" in self.config.llm_config.architectures: @@ -699,8 +654,8 @@ class InternVLChatModel(nn.Module): param, "weight_loader", default_weight_loader ) if "vision_model" in name: - loaded_weight = self._pad_vit_attn_dummy_heads( - name, loaded_weight + loaded_weight = vision_utils.pad_vit_attn_dummy_heads( + self.config, name, loaded_weight ) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/qwen2_5_vl.py b/python/sglang/srt/models/qwen2_5_vl.py index 3d7567d2c..48270ee21 100644 --- a/python/sglang/srt/models/qwen2_5_vl.py +++ b/python/sglang/srt/models/qwen2_5_vl.py @@ -117,6 +117,7 @@ class Qwen2_5_VisionBlock(nn.Module): attn_implementation: Optional[str] = None, quant_config: Optional[QuantizationConfig] = None, prefix: str = "", + num_dummy_heads: int = 0, ) -> None: super().__init__() if norm_layer is None: @@ -157,6 +158,7 @@ class Qwen2_5_VisionBlock(nn.Module): flatten_batch=flatten_batch, quant_config=quant_config, prefix=add_prefix("attn", prefix), + num_dummy_heads=num_dummy_heads, ) self.mlp = Qwen2_5_VLMLP( dim,