model: support intern-s1 (#8350)

Signed-off-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: zxy <zhou0493@e.ntu.edu.sg>
Co-authored-by: Xinyuan Tong <xinyuantong.cs@gmail.com>
Co-authored-by: Mick <mickjagger19@icloud.com>
Co-authored-by: Xinyuan Tong <115166877+JustinTong0323@users.noreply.github.com>
This commit is contained in:
RunningLeon
2025-07-27 04:48:51 +08:00
committed by GitHub
parent da0c026084
commit b7094a5ef1
10 changed files with 616 additions and 63 deletions

View File

@@ -1,16 +1,3 @@
# Copyright 2023-2024 SGLang Team
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==========================582====================================================
from typing import Iterable, List, Optional, Set, Tuple, Union
import torch
@@ -23,7 +10,9 @@ from transformers import PretrainedConfig, PreTrainedModel
from transformers.activations import ACT2FN
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
from sglang.srt.distributed import parallel_state
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.managers.mm_utils import (
MultiModalityDataPaddingPatternTokenPairs,
@@ -39,6 +28,7 @@ from sglang.srt.model_loader.weight_utils import default_weight_loader
from sglang.srt.models.deepseek_janus_pro import DropPath
from sglang.srt.models.internlm2 import InternLM2ForCausalLM
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
from sglang.utils import logger
@@ -53,7 +43,6 @@ class InternAttention(nn.Module):
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.scale = self.head_dim**-0.5
self.attn = VisionAttention(
@@ -64,18 +53,16 @@ class InternAttention(nn.Module):
use_qkv_parallel=True,
quant_config=quant_config,
dropout=getattr(config, "dropout", 0.0),
proj_bias=getattr(config, "qkv_bias", True),
qkv_bias=getattr(config, "qkv_bias", False)
or getattr(config, "attention_bias", False),
num_dummy_heads=getattr(config, "num_dummy_heads", 0),
qk_normalization=getattr(config, "qk_normalization", False)
or getattr(config, "use_qk_norm", False),
flatten_batch=False,
)
self.proj_drop = nn.Dropout(config.dropout)
self.qk_normalization = config.qk_normalization
if self.qk_normalization:
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
def forward(
self,
hidden_states: torch.Tensor,
@@ -91,8 +78,16 @@ class InternVisionEmbeddings(nn.Module):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.patch_size = config.patch_size
self.image_size = (
config.image_size
if isinstance(config.image_size, int)
else config.image_size[0]
)
self.patch_size = (
config.patch_size
if isinstance(config.patch_size, int)
else config.patch_size[0]
)
self.class_embedding = nn.Parameter(
torch.randn(1, 1, self.embed_dim),
@@ -199,7 +194,7 @@ class InternVisionEncoderLayer(nn.Module):
self.embed_dim = config.hidden_size
self.intermediate_size = config.intermediate_size
self.norm_type = config.norm_type
self.attn = InternAttention(config)
self.attn = InternAttention(config=config, quant_config=quant_config)
self.mlp = InternMLP(config)
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
@@ -417,7 +412,7 @@ class InternVLChatModel(nn.Module):
super().__init__()
self.config = config
self.quant_config = quant_config
self._update_vision_config()
image_size = config.force_image_size or config.vision_config.image_size
patch_size = config.vision_config.patch_size
self.patch_size = patch_size
@@ -446,6 +441,10 @@ class InternVLChatModel(nn.Module):
self.language_model = InternLM2ForCausalLM(
config=config.llm_config, quant_config=quant_config
)
elif config.llm_config.architectures[0] == "Qwen3MoeForCausalLM":
self.language_model = Qwen3MoeForCausalLM(
config=config.llm_config, quant_config=quant_config
)
else:
raise NotImplementedError(
f"{config.llm_config.architectures[0]} is not implemented."
@@ -463,6 +462,21 @@ class InternVLChatModel(nn.Module):
nn.Linear(llm_hidden_size, llm_hidden_size),
)
def _update_vision_config(self):
"""update vision config to support tp"""
world_size = parallel_state.get_tensor_model_parallel_world_size()
num_heads = self.config.vision_config.num_attention_heads
head_dim = self.config.vision_config.hidden_size // num_heads
num_dummy_heads = 0
if num_heads % world_size != 0:
num_dummy_heads = (
(num_heads + world_size) // world_size
) * world_size - num_heads
setattr(self.config.vision_config, "head_dim", head_dim)
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
def pixel_shuffle(self, x, scale_factor=0.5):
n, w, h, c = x.size()
# N, W, H, C --> N, W, H * scale, C // scale
@@ -545,7 +559,38 @@ class InternVLChatModel(nn.Module):
return helper.pad_input_tokens(input_ids, mm_inputs)
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
"""pad attn qkv weights for dummy heads"""
num_dummy_heads = self.config.vision_config.num_dummy_heads
if num_dummy_heads == 0:
return loaded_weight
head_dim = self.config.vision_config.head_dim
if "attn.qkv_proj" in name:
wq, wk, wv = loaded_weight.chunk(3, dim=0)
if name.endswith(".weight"):
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
elif name.endswith(".bias"):
dummy_shape = [num_dummy_heads, head_dim]
else:
raise RuntimeError(f"Unsupported weight with name={name}")
pad_func = lambda x: torch.cat(
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
).flatten(0, 1)
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
loaded_weight = torch.cat([wq, wk, wv], dim=0)
if "attn.proj.weight" in name:
padded_weight = loaded_weight.new_zeros(
loaded_weight.shape[0], head_dim * num_dummy_heads
)
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
return loaded_weight
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
expert_params_mapping = []
if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
@@ -561,15 +606,41 @@ class InternVLChatModel(nn.Module):
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
elif "Qwen3MoeForCausalLM" in self.config.llm_config.architectures:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("qkv_proj", "q_proj", "q"),
("qkv_proj", "k_proj", "k"),
("qkv_proj", "v_proj", "v"),
("gate_up_proj", "gate_proj", 0),
("gate_up_proj", "up_proj", 1),
]
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
ckpt_gate_proj_name="gate_proj",
ckpt_down_proj_name="down_proj",
ckpt_up_proj_name="up_proj",
num_experts=self.config.num_experts,
)
params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
if "rotary_emb.inv_freq" in name:
continue
for param_name, weight_name, shard_id in stacked_params_mapping:
if weight_name not in name:
continue
# We have mlp.experts[0].gate_proj in the checkpoint.
# Since we handle the experts below in expert_params_mapping,
# we need to skip here BEFORE we update the name, otherwise
# name will be updated to mlp.experts[0].gate_up_proj, which
# will then be updated below in expert_params_mapping
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
if "mlp.experts" in name:
continue
name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
@@ -584,30 +655,55 @@ class InternVLChatModel(nn.Module):
name = name.replace(r"attn.", r"attn.attn.")
name = name.replace(r"qkv.", r"qkv_proj.")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
if "wqkv" in name:
config = self.config
kv_groups = config.num_attention_heads // config.num_key_value_heads
head_dim = config.hidden_size // config.num_attention_heads
loaded_weight = loaded_weight.view(
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
)
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1)
wq = wq.reshape(-1, wq.shape[-1])
wk = wk.reshape(-1, wk.shape[-1])
wv = wv.reshape(-1, wv.shape[-1])
for mapping in expert_params_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, wq, "q")
weight_loader(param, wk, "k")
weight_loader(param, wv, "v")
else:
weight_loader = getattr(
param, "weight_loader", default_weight_loader
weight_loader(
param,
loaded_weight,
name,
shard_id=shard_id,
expert_id=expert_id,
)
weight_loader(param, loaded_weight)
break
else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
if "wqkv" in name:
config = self.config
kv_groups = (
config.num_attention_heads // config.num_key_value_heads
)
head_dim = config.hidden_size // config.num_attention_heads
loaded_weight = loaded_weight.view(
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
)
wq, wk, wv = torch.split(
loaded_weight, [kv_groups, 1, 1], dim=1
)
wq = wq.reshape(-1, wq.shape[-1])
wk = wk.reshape(-1, wk.shape[-1])
wv = wv.reshape(-1, wv.shape[-1])
weight_loader = param.weight_loader
weight_loader(param, wq, "q")
weight_loader(param, wk, "k")
weight_loader(param, wv, "v")
else:
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
if "vision_model" in name:
loaded_weight = self._pad_vit_attn_dummy_heads(
name, loaded_weight
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
unloaded_params = params_dict.keys() - loaded_params
if unloaded_params: