Support the internvl3.5 family models in sglang (#9705)
This commit is contained in:
@@ -6,11 +6,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
import sentencepiece as spm
|
||||
from transformers import (
|
||||
TOKENIZER_MAPPING,
|
||||
GptOssConfig,
|
||||
LlamaConfig,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
Qwen2Config,
|
||||
Qwen3Config,
|
||||
Qwen3MoeConfig,
|
||||
)
|
||||
|
||||
from sglang.utils import logger
|
||||
@@ -316,7 +318,11 @@ class InternVLChatConfig(PretrainedConfig):
|
||||
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
|
||||
self.llm_config = Qwen2Config(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
|
||||
self.llm_config = Qwen3MoeConfig(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "Qwen3ForCausalLM":
|
||||
self.llm_config = Qwen3Config(**llm_config)
|
||||
elif llm_config.get("architectures")[0] == "GptOssForCausalLM":
|
||||
self.llm_config = GptOssConfig(**llm_config)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported architecture: {}".format(
|
||||
|
||||
@@ -26,8 +26,10 @@ from sglang.srt.managers.schedule_batch import (
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_janus_pro import DropPath
|
||||
from sglang.srt.models.gpt_oss import GptOssForCausalLM
|
||||
from sglang.srt.models.internlm2 import InternLM2ForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.srt.models.qwen3 import Qwen3ForCausalLM
|
||||
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
|
||||
from sglang.utils import logger
|
||||
|
||||
@@ -445,6 +447,14 @@ class InternVLChatModel(nn.Module):
|
||||
self.language_model = Qwen3MoeForCausalLM(
|
||||
config=config.llm_config, quant_config=quant_config
|
||||
)
|
||||
elif config.llm_config.architectures[0] == "GptOssForCausalLM":
|
||||
self.language_model = GptOssForCausalLM(
|
||||
config=config.llm_config, quant_config=quant_config
|
||||
)
|
||||
elif config.llm_config.architectures[0] == "Qwen3ForCausalLM":
|
||||
self.language_model = Qwen3ForCausalLM(
|
||||
config=config.llm_config, quant_config=quant_config
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{config.llm_config.architectures[0]} is not implemented."
|
||||
@@ -577,6 +587,15 @@ class InternVLChatModel(nn.Module):
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=self.config.num_experts,
|
||||
)
|
||||
elif "Qwen3ForCausalLM" in self.config.llm_config.architectures:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
loaded_params: Set[str] = set()
|
||||
@@ -661,6 +680,15 @@ class InternVLChatModel(nn.Module):
|
||||
|
||||
loaded_params.add(name)
|
||||
unloaded_params = params_dict.keys() - loaded_params
|
||||
# Skip params that are created by quantization wrappers and are not expected in the ckpt
|
||||
_quant_only_fragments = (
|
||||
"weight_scale", # per-matrix FP8 scales (e.g., w2_weight_scale, w13_weight_scale)
|
||||
)
|
||||
unloaded_params = {
|
||||
n
|
||||
for n in unloaded_params
|
||||
if not any(frag in n for frag in _quant_only_fragments)
|
||||
}
|
||||
if unloaded_params:
|
||||
raise RuntimeError(
|
||||
f"Some weights are not initialized from checkpoints: {unloaded_params}"
|
||||
|
||||
Reference in New Issue
Block a user