From b7094a5ef197743d9fb5540feac06f8f2814444b Mon Sep 17 00:00:00 2001 From: RunningLeon Date: Sun, 27 Jul 2025 04:48:51 +0800 Subject: [PATCH] model: support intern-s1 (#8350) Signed-off-by: Xinyuan Tong Co-authored-by: zxy Co-authored-by: Xinyuan Tong Co-authored-by: Mick Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com> --- python/sglang/lang/chat_template.py | 21 ++ python/sglang/srt/configs/internvl.py | 3 + python/sglang/srt/configs/model_config.py | 1 + python/sglang/srt/conversation.py | 17 +- python/sglang/srt/layers/attention/vision.py | 64 +++- python/sglang/srt/layers/layernorm.py | 27 +- python/sglang/srt/models/interns1.py | 328 ++++++++++++++++++ python/sglang/srt/models/internvl.py | 190 +++++++--- python/sglang/srt/models/qwen3_moe.py | 3 + .../srt/multimodal/processors/internvl.py | 25 +- 10 files changed, 616 insertions(+), 63 deletions(-) create mode 100644 python/sglang/srt/models/interns1.py diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index f309d053d..ef348d27e 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -448,6 +448,19 @@ register_chat_template( ) ) +register_chat_template( + ChatTemplate( + name="interns1", + default_system_prompt="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within ... tags.", + role_prefix_and_suffix={ + "system": ("<|im_start|>system\n", "<|im_end|>\n"), + "user": ("<|im_start|>user\n", "<|im_end|>\n"), + "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"), + }, + stop_str=["<|im_end|>", "<|action_end|>"], + ) +) + register_chat_template( ChatTemplate( name="granite-3-instruct", @@ -609,6 +622,14 @@ def match_internvl_chat(model_path: str): return "internvl-2-5" +@register_chat_template_matching_function +def match_interns1_chat(model_path: str): + if re.search(r"intern-s1", model_path, re.IGNORECASE): + return "interns1" + if re.search(r"interns1", model_path, re.IGNORECASE): + return "interns1" + + if __name__ == "__main__": messages = [ {"role": "system", "content": None}, # None means default diff --git a/python/sglang/srt/configs/internvl.py b/python/sglang/srt/configs/internvl.py index b4ddda227..7033ef359 100644 --- a/python/sglang/srt/configs/internvl.py +++ b/python/sglang/srt/configs/internvl.py @@ -10,6 +10,7 @@ from transformers import ( PretrainedConfig, PreTrainedTokenizer, Qwen2Config, + Qwen3Config, ) from sglang.utils import logger @@ -314,6 +315,8 @@ class InternVLChatConfig(PretrainedConfig): self.llm_config = InternLM2Config(**llm_config) elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM": self.llm_config = Qwen2Config(**llm_config) + elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM": + self.llm_config = Qwen3Config(**llm_config) else: raise ValueError( "Unsupported architecture: {}".format( diff --git a/python/sglang/srt/configs/model_config.py b/python/sglang/srt/configs/model_config.py index cea455a24..c2d1d1415 100644 --- a/python/sglang/srt/configs/model_config.py +++ b/python/sglang/srt/configs/model_config.py @@ -635,6 +635,7 @@ multimodal_model_archs = [ "Qwen2_5_VLForConditionalGeneration", "KimiVLForConditionalGeneration", "InternVLChatModel", + "InternS1ForConditionalGeneration", "Phi4MMForCausalLM", "VILAForConditionalGeneration", ] diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 80b706430..cc0071628 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -623,7 +623,7 @@ def generate_chat_conv( real_content += content.text elif content.type == "image_url": # NOTE: works for llava and intervl2_5 - if conv.name == "internvl-2-5": + if conv.name in ["internvl-2-5", "interns1"]: real_content = image_token + real_content else: real_content += image_token @@ -817,6 +817,19 @@ register_conv_template( ) ) +register_conv_template( + Conversation( + name="interns1", + system_template="<|im_start|>system\n{system_message}", + system_message="You are an AI assistant whose name is Intern-S1 (书生大模型).\n- Intern-S1 (书生大模型) is a vision-language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). It is designed to be helpful, honest, and harmless.\n- Intern-S1 (书生大模型) can understand and communicate fluently in the language chosen by the user such as English and 中文.\nYou are an expert reasoner with extensive experience in all areas. You approach problems through systematic thinking and rigorous reasoning. Your response should reflect deep understanding and precise logical thinking, making your solution path and reasoning clear to others. Please put your thinking process within ... tags.", + roles=("<|im_start|>user\n", "<|im_start|>assistant\n"), + sep_style=SeparatorStyle.MPT, + sep="<|im_end|>\n", + stop_str=["<|im_end|>", "<|action_end|>"], + image_token="", + ) +) + # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example register_conv_template( Conversation( @@ -986,6 +999,8 @@ register_conv_template( def match_internvl(model_path: str): if re.search(r"internvl", model_path, re.IGNORECASE): return "internvl-2-5" + if re.search(r"interns1", model_path, re.IGNORECASE): + return "interns1" @register_conv_template_matching_function diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 41f3110cd..c7bbd3ea6 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -3,7 +3,7 @@ from __future__ import annotations import dataclasses import functools import math -from functools import lru_cache +from functools import lru_cache, partial from typing import Any, Optional, Tuple, Union import torch @@ -18,11 +18,16 @@ _is_cuda = is_cuda() if _is_cuda: from sgl_kernel.flash_attn import flash_attn_varlen_func -from sglang.srt.distributed import parallel_state +from sglang.srt.distributed import ( + parallel_state, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather, +) from sglang.srt.distributed import utils as dist_utils from sglang.srt.layers.attention.triton_ops.prefill_attention import ( context_attention_fwd, ) +from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, QKVParallelLinear, @@ -349,25 +354,44 @@ class VisionAttention(nn.Module): flatten_batch: bool = False, prefix: str = "", proj_bias: bool = True, + num_dummy_heads: int = 0, + qkv_bias: bool = True, + qk_normalization: bool = False, + layer_norm_eps: float = 1e-06, **kwargs, ): super().__init__() world_size = parallel_state.get_tensor_model_parallel_world_size() + self.tp_size = world_size + self.tp_rank = parallel_state.get_tensor_model_parallel_rank() self.dropout = dropout self.head_size = embed_dim // num_heads self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads ) self.num_attention_heads_per_partition = dist_utils.divide( - num_heads, world_size + num_dummy_heads + num_heads, world_size ) self.num_attention_kv_heads_per_partition = dist_utils.divide( - num_heads, world_size + num_dummy_heads + num_heads, world_size ) self.q_size = self.num_attention_heads_per_partition * self.head_size self.kv_size = self.num_attention_kv_heads_per_partition * self.head_size + self.qk_normalization = qk_normalization + + # Additional dummy heads are used to enable TP for common GPU counts. + self.dummy_dim = (num_dummy_heads + num_heads) * self.head_size + + if self.qk_normalization: + self.q_norm = RMSNorm( + self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim + ) + self.k_norm = RMSNorm( + self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim + ) + if global_server_args_dict["mm_attention_backend"] is None: if qkv_backend is None: qkv_backend = "sdpa" @@ -391,26 +415,46 @@ class VisionAttention(nn.Module): self.qkv_proj = QKVParallelLinear( hidden_size=embed_dim, head_size=self.head_size, - total_num_heads=num_heads, - total_num_kv_heads=num_heads, + total_num_heads=num_dummy_heads + num_heads, + total_num_kv_heads=num_dummy_heads + num_heads, + bias=qkv_bias, quant_config=quant_config, prefix=add_prefix("qkv_proj", prefix), ) else: self.qkv_proj = ColumnParallelLinear( input_size=embed_dim, - output_size=3 * projection_size, + output_size=3 * self.dummy_dim, + bias=qkv_bias, quant_config=quant_config, prefix=add_prefix("qkv_proj", prefix), ) self.proj = RowParallelLinear( - input_size=embed_dim, + input_size=self.dummy_dim, output_size=embed_dim, bias=proj_bias, quant_config=quant_config, prefix=add_prefix("proj", prefix), ) + def _apply_qk_norm(self, q: torch.Tensor, k: torch.Tensor): + """apply qk norm for internvl vit attn""" + q = q.flatten(1, 2) + k = k.flatten(1, 2) + + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm(q) + k = self.k_norm(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + q = q.unflatten(-1, (-1, self.head_size)) + k = k.unflatten(-1, (-1, self.head_size)) + return q, k + def forward( self, x: torch.Tensor, @@ -489,6 +533,10 @@ class VisionAttention(nn.Module): assert k.dim() == 3, k.dim() assert v.dim() == 3, v.dim() + # internvl + if self.qk_normalization: + q, k = self._apply_qk_norm(q, k) + output = self.qkv_backend.forward( q=q, k=k, diff --git a/python/sglang/srt/layers/layernorm.py b/python/sglang/srt/layers/layernorm.py index 0ad32a380..4c1f2268b 100644 --- a/python/sglang/srt/layers/layernorm.py +++ b/python/sglang/srt/layers/layernorm.py @@ -61,10 +61,15 @@ class RMSNorm(CustomOp): self, hidden_size: int, eps: float = 1e-6, + var_hidden_size: Optional[int] = None, ) -> None: super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps + self.hidden_size = hidden_size + self.variance_size_override = ( + None if var_hidden_size == hidden_size else var_hidden_size + ) if _use_aiter: self._forward_method = self.forward_aiter @@ -73,6 +78,8 @@ class RMSNorm(CustomOp): x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + if self.variance_size_override is not None: + return self.forward_native(x, residual) if residual is not None: fused_add_rmsnorm(x, residual, self.weight.data, self.variance_epsilon) return x, residual @@ -138,7 +145,25 @@ class RMSNorm(CustomOp): x = x + residual.to(torch.float32) residual = x.to(orig_dtype) - variance = x.pow(2).mean(dim=-1, keepdim=True) + hidden_size = x.shape[-1] + if hidden_size != self.hidden_size: + raise ValueError( + "Expected hidden_size to be " + f"{self.hidden_size}, but found: {hidden_size}" + ) + + if self.variance_size_override is None: + x_var = x + else: + if hidden_size < self.variance_size_override: + raise ValueError( + "Expected hidden_size to be at least " + f"{self.variance_size_override}, but found: {hidden_size}" + ) + + x_var = x[..., : self.variance_size_override] + + variance = x_var.pow(2).mean(dim=-1, keepdim=True) x = x * torch.rsqrt(variance + self.variance_epsilon) x = (x * self.weight).to(orig_dtype) if residual is None: diff --git a/python/sglang/srt/models/interns1.py b/python/sglang/srt/models/interns1.py new file mode 100644 index 000000000..75f2cb775 --- /dev/null +++ b/python/sglang/srt/models/interns1.py @@ -0,0 +1,328 @@ +from typing import Iterable, List, Optional, Set, Tuple + +import torch +from torch import nn +from transformers import PretrainedConfig + +from sglang.srt.distributed import parallel_state +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class +from sglang.srt.layers.quantization.base_config import QuantizationConfig +from sglang.srt.managers.mm_utils import ( + MultiModalityDataPaddingPatternTokenPairs, + general_mm_embed_routine, +) +from sglang.srt.managers.schedule_batch import ( + Modality, + MultimodalDataItem, + MultimodalInputs, +) +from sglang.srt.model_executor.forward_batch_info import ForwardBatch +from sglang.srt.model_loader.weight_utils import default_weight_loader +from sglang.srt.models.internvl import InternVisionModel +from sglang.srt.models.qwen2 import Qwen2ForCausalLM +from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM +from sglang.utils import logger + + +class InternS1ForConditionalGeneration(nn.Module): + def __init__( + self, + config: PretrainedConfig, + quant_config: Optional[QuantizationConfig] = None, + use_flash_attn=True, + ) -> None: + super().__init__() + self.config = config + self.quant_config = quant_config + self._update_hf_config() + image_size = ( + getattr(config, "force_image_size", None) or config.vision_config.image_size + ) + patch_size = config.vision_config.patch_size + if isinstance(image_size, list): + image_size = image_size[0] + if isinstance(patch_size, list): + patch_size = patch_size[0] + self.patch_size = patch_size + self.select_layer = config.vision_feature_layer + self.num_image_token = int( + (image_size // patch_size) ** 2 * (config.downsample_ratio**2) + ) + self.downsample_ratio = config.downsample_ratio + self.ps_version = getattr(config, "ps_version", "v1") + # self.template = getattr(config, 'template', 'internvl2_5') + + config.vision_config.use_flash_attn = True if use_flash_attn else False + config.text_config._attn_implementation = ( + "flash_attention_2" if use_flash_attn else "eager" + ) + + logger.info(f"num_image_token: {self.num_image_token}") + logger.info(f"ps_version: {self.ps_version}") + + self.vision_model = InternVisionModel(config.vision_config) + if config.text_config.architectures[0] == "Qwen2ForCausalLM": + self.language_model = Qwen2ForCausalLM( + config=config.text_config, quant_config=quant_config + ) + elif config.text_config.architectures[0] == "Qwen3MoeForCausalLM": + self.language_model = Qwen3MoeForCausalLM( + config=config.text_config, quant_config=quant_config + ) + else: + raise NotImplementedError( + f"{config.text_config.architectures[0]} is not implemented." + ) + + vit_hidden_size = config.vision_config.hidden_size + llm_hidden_size = config.text_config.hidden_size + + self.mlp1 = nn.Sequential( + nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2), + nn.Linear( + vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size + ), + nn.GELU(), + nn.Linear(llm_hidden_size, llm_hidden_size), + ) + + def _update_hf_config(self): + """update hf config to support tp""" + world_size = parallel_state.get_tensor_model_parallel_world_size() + num_heads = self.config.vision_config.num_attention_heads + head_dim = self.config.vision_config.hidden_size // num_heads + num_dummy_heads = 0 + + if num_heads % world_size != 0: + num_dummy_heads = ( + (num_heads + world_size) // world_size + ) * world_size - num_heads + + setattr(self.config.vision_config, "head_dim", head_dim) + setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads) + + def pixel_shuffle(self, x, scale_factor=0.5): + n, w, h, c = x.size() + # N, W, H, C --> N, W, H * scale, C // scale + x = x.view(n, w, int(h * scale_factor), int(c / scale_factor)) + # N, W, H * scale, C // scale --> N, H * scale, W, C // scale + x = x.permute(0, 2, 1, 3).contiguous() + # N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2) + x = x.view( + n, + int(h * scale_factor), + int(w * scale_factor), + int(c / (scale_factor * scale_factor)), + ) + if self.ps_version == "v1": + logger.warn( + "In ps_version 'v1', the height and width have not been swapped back, " + "which results in a transposed image." + ) + else: + x = x.permute(0, 2, 1, 3).contiguous() + return x + + def extract_feature(self, pixel_values): + if self.select_layer == -1: + vit_embeds = self.vision_model( + pixel_values=pixel_values, output_hidden_states=False, return_dict=True + ).last_hidden_state + else: + vit_embeds = self.vision_model( + pixel_values=pixel_values, output_hidden_states=True, return_dict=True + ).hidden_states[self.select_layer] + vit_embeds = vit_embeds[:, 1:, :] + + h = w = int(vit_embeds.shape[1] ** 0.5) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1) + vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio) + vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1]) + vit_embeds = self.mlp1(vit_embeds) + return vit_embeds + + def get_image_feature(self, items: List[MultimodalDataItem]): + """ + Projects the last hidden state from the vision model into language model space. + + Returns: + image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`). + """ + pixel_values = torch.cat([item.feature for item in items]) + image_features = self.extract_feature(pixel_values) + return image_features + + @torch.no_grad() + def forward( + self, + input_ids: torch.Tensor, + positions: torch.Tensor, + forward_batch: ForwardBatch, + input_embeds: torch.Tensor = None, + ) -> torch.Tensor: + + hs = general_mm_embed_routine( + input_ids=input_ids, + forward_batch=forward_batch, + language_model=self.language_model, + data_embedding_funcs={ + Modality.IMAGE: self.get_image_feature, + }, + positions=positions, + ) + + return hs + + def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs): + # Get all special token IDs + im_start_id: int = mm_inputs.im_start_id + im_end_id: int = mm_inputs.im_end_id + + media_token_pairs = [(im_start_id, im_end_id)] + helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs) + + return helper.pad_input_tokens(input_ids, mm_inputs) + + def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor): + """pad attn qkv weights for dummy heads""" + num_dummy_heads = self.config.vision_config.num_dummy_heads + if num_dummy_heads == 0: + return loaded_weight + head_dim = self.config.vision_config.head_dim + + if any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]): + if name.endswith(".weight"): + dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]] + elif name.endswith(".bias"): + dummy_shape = [num_dummy_heads, head_dim] + else: + raise RuntimeError(f"Unsupported weight with name={name}") + padded_weight = loaded_weight.new_zeros(dummy_shape) + loaded_weight = torch.cat( + [loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0 + ).flatten(0, 1) + if "attn.proj.weight" in name: + padded_weight = loaded_weight.new_zeros( + loaded_weight.shape[0], head_dim * num_dummy_heads + ) + loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1) + if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name: + padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads) + loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0) + return loaded_weight + + def _mapping_interns1_name(self, name): + names_map = { + "lm_head.weight": "language_model.lm_head.weight", + "model.multi_modal_projector.layer_norm.bias": "mlp1.0.bias", + "model.multi_modal_projector.layer_norm.weight": "mlp1.0.weight", + "model.multi_modal_projector.linear_1.bias": "mlp1.1.bias", + "model.multi_modal_projector.linear_1.weight": "mlp1.1.weight", + "model.multi_modal_projector.linear_2.bias": "mlp1.3.bias", + "model.multi_modal_projector.linear_2.weight": "mlp1.3.weight", + "model.vision_tower.embeddings.cls_token": "vision_model.embeddings.class_embedding", + "model.vision_tower.embeddings.patch_embeddings.projection.bias": "vision_model.embeddings.patch_embedding.bias", + "model.vision_tower.embeddings.patch_embeddings.projection.weight": "vision_model.embeddings.patch_embedding.weight", + "model.vision_tower.embeddings.position_embeddings": "vision_model.embeddings.position_embedding", + } + if name in names_map: + name = names_map[name] + elif name.startswith("model.language_model."): + name = "language_model.model." + name[len("model.language_model.") :] + elif name.startswith("model.vision_tower."): + name = "vision_model." + name[len("model.vision_tower.") :] + + if name.startswith("vision_model.encoder.layer"): + + name = name.replace(r".layer.", r".layers.") + name = name.replace(r".attention.", r".attn.attn.") + name = name.replace(r".projection_layer.", r".proj.") + name = name.replace(r".lambda_1", r".ls1") + name = name.replace(r".lambda_2", r".ls2") + name = name.replace(r".layernorm_before.", r".norm1.") + name = name.replace(r".layernorm_after.", r".norm2.") + return name + + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + expert_params_mapping = [] + if "Qwen3MoeForCausalLM" in self.config.text_config.architectures: + expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + + params_dict = dict(self.named_parameters()) + loaded_params: Set[str] = set() + + for name, loaded_weight in weights: + if "rotary_emb.inv_freq" in name: + continue + name = self._mapping_interns1_name(name) + if "vision_model" in name: + loaded_weight = self._pad_vit_attn_dummy_heads(name, loaded_weight) + + for param_name, weight_name, shard_id in stacked_params_mapping: + if weight_name not in name: + continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue + name = name.replace(weight_name, param_name) + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader(param, loaded_weight, shard_id) + break + else: + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] + weight_loader = param.weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, + ) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, loaded_weight) + + loaded_params.add(name) + unloaded_params = params_dict.keys() - loaded_params + if unloaded_params: + raise RuntimeError( + f"Some weights are not initialized from checkpoints: {unloaded_params}" + ) + return loaded_params + + +EntryClass = [InternS1ForConditionalGeneration] diff --git a/python/sglang/srt/models/internvl.py b/python/sglang/srt/models/internvl.py index 056797cbf..db093dd08 100644 --- a/python/sglang/srt/models/internvl.py +++ b/python/sglang/srt/models/internvl.py @@ -1,16 +1,3 @@ -# Copyright 2023-2024 SGLang Team -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ==========================582==================================================== from typing import Iterable, List, Optional, Set, Tuple, Union import torch @@ -23,7 +10,9 @@ from transformers import PretrainedConfig, PreTrainedModel from transformers.activations import ACT2FN from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling +from sglang.srt.distributed import parallel_state from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention +from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class from sglang.srt.layers.quantization.base_config import QuantizationConfig from sglang.srt.managers.mm_utils import ( MultiModalityDataPaddingPatternTokenPairs, @@ -39,6 +28,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader from sglang.srt.models.deepseek_janus_pro import DropPath from sglang.srt.models.internlm2 import InternLM2ForCausalLM from sglang.srt.models.qwen2 import Qwen2ForCausalLM +from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM from sglang.utils import logger @@ -53,7 +43,6 @@ class InternAttention(nn.Module): self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads - self.scale = self.head_dim**-0.5 self.attn = VisionAttention( @@ -64,18 +53,16 @@ class InternAttention(nn.Module): use_qkv_parallel=True, quant_config=quant_config, dropout=getattr(config, "dropout", 0.0), - proj_bias=getattr(config, "qkv_bias", True), + qkv_bias=getattr(config, "qkv_bias", False) + or getattr(config, "attention_bias", False), + num_dummy_heads=getattr(config, "num_dummy_heads", 0), + qk_normalization=getattr(config, "qk_normalization", False) + or getattr(config, "use_qk_norm", False), flatten_batch=False, ) self.proj_drop = nn.Dropout(config.dropout) - self.qk_normalization = config.qk_normalization - - if self.qk_normalization: - self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) - self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps) - def forward( self, hidden_states: torch.Tensor, @@ -91,8 +78,16 @@ class InternVisionEmbeddings(nn.Module): super().__init__() self.config = config self.embed_dim = config.hidden_size - self.image_size = config.image_size - self.patch_size = config.patch_size + self.image_size = ( + config.image_size + if isinstance(config.image_size, int) + else config.image_size[0] + ) + self.patch_size = ( + config.patch_size + if isinstance(config.patch_size, int) + else config.patch_size[0] + ) self.class_embedding = nn.Parameter( torch.randn(1, 1, self.embed_dim), @@ -199,7 +194,7 @@ class InternVisionEncoderLayer(nn.Module): self.embed_dim = config.hidden_size self.intermediate_size = config.intermediate_size self.norm_type = config.norm_type - self.attn = InternAttention(config) + self.attn = InternAttention(config=config, quant_config=quant_config) self.mlp = InternMLP(config) self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps) @@ -417,7 +412,7 @@ class InternVLChatModel(nn.Module): super().__init__() self.config = config self.quant_config = quant_config - + self._update_vision_config() image_size = config.force_image_size or config.vision_config.image_size patch_size = config.vision_config.patch_size self.patch_size = patch_size @@ -446,6 +441,10 @@ class InternVLChatModel(nn.Module): self.language_model = InternLM2ForCausalLM( config=config.llm_config, quant_config=quant_config ) + elif config.llm_config.architectures[0] == "Qwen3MoeForCausalLM": + self.language_model = Qwen3MoeForCausalLM( + config=config.llm_config, quant_config=quant_config + ) else: raise NotImplementedError( f"{config.llm_config.architectures[0]} is not implemented." @@ -463,6 +462,21 @@ class InternVLChatModel(nn.Module): nn.Linear(llm_hidden_size, llm_hidden_size), ) + def _update_vision_config(self): + """update vision config to support tp""" + world_size = parallel_state.get_tensor_model_parallel_world_size() + num_heads = self.config.vision_config.num_attention_heads + head_dim = self.config.vision_config.hidden_size // num_heads + num_dummy_heads = 0 + + if num_heads % world_size != 0: + num_dummy_heads = ( + (num_heads + world_size) // world_size + ) * world_size - num_heads + + setattr(self.config.vision_config, "head_dim", head_dim) + setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads) + def pixel_shuffle(self, x, scale_factor=0.5): n, w, h, c = x.size() # N, W, H, C --> N, W, H * scale, C // scale @@ -545,7 +559,38 @@ class InternVLChatModel(nn.Module): return helper.pad_input_tokens(input_ids, mm_inputs) + def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor): + """pad attn qkv weights for dummy heads""" + num_dummy_heads = self.config.vision_config.num_dummy_heads + if num_dummy_heads == 0: + return loaded_weight + head_dim = self.config.vision_config.head_dim + + if "attn.qkv_proj" in name: + wq, wk, wv = loaded_weight.chunk(3, dim=0) + if name.endswith(".weight"): + dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]] + elif name.endswith(".bias"): + dummy_shape = [num_dummy_heads, head_dim] + else: + raise RuntimeError(f"Unsupported weight with name={name}") + pad_func = lambda x: torch.cat( + [x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0 + ).flatten(0, 1) + wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv) + loaded_weight = torch.cat([wq, wk, wv], dim=0) + if "attn.proj.weight" in name: + padded_weight = loaded_weight.new_zeros( + loaded_weight.shape[0], head_dim * num_dummy_heads + ) + loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1) + if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name: + padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads) + loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0) + return loaded_weight + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + expert_params_mapping = [] if "InternLM2ForCausalLM" in self.config.llm_config.architectures: stacked_params_mapping = [ # (param_name, shard_name, shard_id) @@ -561,15 +606,41 @@ class InternVLChatModel(nn.Module): ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] + elif "Qwen3MoeForCausalLM" in self.config.llm_config.architectures: + stacked_params_mapping = [ + # (param_name, shard_name, shard_id) + ("qkv_proj", "q_proj", "q"), + ("qkv_proj", "k_proj", "k"), + ("qkv_proj", "v_proj", "v"), + ("gate_up_proj", "gate_proj", 0), + ("gate_up_proj", "up_proj", 1), + ] + + expert_params_mapping = get_moe_impl_class().make_expert_params_mapping( + ckpt_gate_proj_name="gate_proj", + ckpt_down_proj_name="down_proj", + ckpt_up_proj_name="up_proj", + num_experts=self.config.num_experts, + ) + params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: if "rotary_emb.inv_freq" in name: continue + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue + # We have mlp.experts[0].gate_proj in the checkpoint. + # Since we handle the experts below in expert_params_mapping, + # we need to skip here BEFORE we update the name, otherwise + # name will be updated to mlp.experts[0].gate_up_proj, which + # will then be updated below in expert_params_mapping + # for mlp.experts[0].gate_gate_up_proj, which breaks load. + if "mlp.experts" in name: + continue name = name.replace(weight_name, param_name) # Skip loading extra bias for GPTQ models. if name.endswith(".bias") and name not in params_dict: @@ -584,30 +655,55 @@ class InternVLChatModel(nn.Module): name = name.replace(r"attn.", r"attn.attn.") name = name.replace(r"qkv.", r"qkv_proj.") - # Skip loading extra bias for GPTQ models. - if name.endswith(".bias") and name not in params_dict: - continue - param = params_dict[name] - if "wqkv" in name: - config = self.config - kv_groups = config.num_attention_heads // config.num_key_value_heads - head_dim = config.hidden_size // config.num_attention_heads - loaded_weight = loaded_weight.view( - -1, 2 + kv_groups, head_dim, loaded_weight.shape[-1] - ) - wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1) - wq = wq.reshape(-1, wq.shape[-1]) - wk = wk.reshape(-1, wk.shape[-1]) - wv = wv.reshape(-1, wv.shape[-1]) + for mapping in expert_params_mapping: + param_name, weight_name, expert_id, shard_id = mapping + if weight_name not in name: + continue + name = name.replace(weight_name, param_name) + param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, wq, "q") - weight_loader(param, wk, "k") - weight_loader(param, wv, "v") - else: - weight_loader = getattr( - param, "weight_loader", default_weight_loader + weight_loader( + param, + loaded_weight, + name, + shard_id=shard_id, + expert_id=expert_id, ) - weight_loader(param, loaded_weight) + break + else: + # Skip loading extra bias for GPTQ models. + if name.endswith(".bias") and name not in params_dict: + continue + param = params_dict[name] + if "wqkv" in name: + config = self.config + kv_groups = ( + config.num_attention_heads // config.num_key_value_heads + ) + head_dim = config.hidden_size // config.num_attention_heads + loaded_weight = loaded_weight.view( + -1, 2 + kv_groups, head_dim, loaded_weight.shape[-1] + ) + wq, wk, wv = torch.split( + loaded_weight, [kv_groups, 1, 1], dim=1 + ) + wq = wq.reshape(-1, wq.shape[-1]) + wk = wk.reshape(-1, wk.shape[-1]) + wv = wv.reshape(-1, wv.shape[-1]) + weight_loader = param.weight_loader + weight_loader(param, wq, "q") + weight_loader(param, wk, "k") + weight_loader(param, wv, "v") + else: + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + if "vision_model" in name: + loaded_weight = self._pad_vit_attn_dummy_heads( + name, loaded_weight + ) + weight_loader(param, loaded_weight) + loaded_params.add(name) unloaded_params = params_dict.keys() - loaded_params if unloaded_params: diff --git a/python/sglang/srt/models/qwen3_moe.py b/python/sglang/srt/models/qwen3_moe.py index 8eeee74fa..6b8655459 100644 --- a/python/sglang/srt/models/qwen3_moe.py +++ b/python/sglang/srt/models/qwen3_moe.py @@ -707,6 +707,9 @@ class Qwen3MoeForCausalLM(nn.Module): self.logits_processor = LogitsProcessor(config) self.capture_aux_hidden_states = False + def get_input_embeddings(self) -> nn.Embedding: + return self.model.embed_tokens + @torch.no_grad() def forward( self, diff --git a/python/sglang/srt/multimodal/processors/internvl.py b/python/sglang/srt/multimodal/processors/internvl.py index 234d57d35..6ab17b1a9 100644 --- a/python/sglang/srt/multimodal/processors/internvl.py +++ b/python/sglang/srt/multimodal/processors/internvl.py @@ -6,6 +6,7 @@ from decord import VideoReader, cpu from PIL import Image from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem +from sglang.srt.models.interns1 import InternS1ForConditionalGeneration from sglang.srt.models.internvl import InternVLChatModel from sglang.srt.multimodal.processors.base_processor import ( BaseMultimodalProcessor, @@ -14,12 +15,19 @@ from sglang.srt.multimodal.processors.base_processor import ( class InternVLImageProcessor(BaseMultimodalProcessor): - models = [InternVLChatModel] + models = [InternVLChatModel, InternS1ForConditionalGeneration] def __init__(self, hf_config, server_args, _image_processor, *args, **kwargs): super().__init__(hf_config, server_args, _image_processor, *args, **kwargs) - image_size = hf_config.force_image_size or hf_config.vision_config.image_size + image_size = ( + getattr(hf_config, "force_image_size", None) + or hf_config.vision_config.image_size + ) patch_size = hf_config.vision_config.patch_size + if isinstance(image_size, list): + image_size = image_size[0] + if isinstance(patch_size, list): + patch_size = patch_size[0] self.IMG_CONTEXT_TOKEN = "" self.IMG_START_TOKEN = "" @@ -27,8 +35,12 @@ class InternVLImageProcessor(BaseMultimodalProcessor): self.num_image_token = int( (image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2) ) + if hasattr(self._processor, "tokenizer"): + tokenizer = self._processor.tokenizer + else: + tokenizer = self._processor + self.tokenizer = tokenizer - tokenizer = self._processor self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN) self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN) self.mm_tokens = MultimodalSpecialTokens( @@ -195,7 +207,7 @@ class InternVLImageProcessor(BaseMultimodalProcessor): try: # TODO: video input raw_image = process_image_internvl(image) - pixel_value = [raw_image.to(torch.bfloat16).cuda()] + pixel_value = [raw_image.to(torch.bfloat16)] pixel_values += pixel_value num_patches = raw_image.shape[0] num_patches_list += [num_patches] @@ -214,8 +226,9 @@ class InternVLImageProcessor(BaseMultimodalProcessor): ) input_text = input_text.replace("", image_tokens, 1) - tokenizer = self._processor - input_ids = tokenizer(input_text, return_tensors="pt")["input_ids"].flatten() + input_ids = self.tokenizer(input_text, return_tensors="pt")[ + "input_ids" + ].flatten() image_offsets = self.get_mm_items_offset( input_ids=input_ids, mm_token_id=self.mm_tokens.image_token_id,