Support the internvl3.5 family models in sglang (#9705)
This commit is contained in:
@@ -6,11 +6,13 @@ from typing import Any, Dict, List, Optional, Tuple, Union
|
|||||||
import sentencepiece as spm
|
import sentencepiece as spm
|
||||||
from transformers import (
|
from transformers import (
|
||||||
TOKENIZER_MAPPING,
|
TOKENIZER_MAPPING,
|
||||||
|
GptOssConfig,
|
||||||
LlamaConfig,
|
LlamaConfig,
|
||||||
PretrainedConfig,
|
PretrainedConfig,
|
||||||
PreTrainedTokenizer,
|
PreTrainedTokenizer,
|
||||||
Qwen2Config,
|
Qwen2Config,
|
||||||
Qwen3Config,
|
Qwen3Config,
|
||||||
|
Qwen3MoeConfig,
|
||||||
)
|
)
|
||||||
|
|
||||||
from sglang.utils import logger
|
from sglang.utils import logger
|
||||||
@@ -316,7 +318,11 @@ class InternVLChatConfig(PretrainedConfig):
|
|||||||
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
|
elif llm_config.get("architectures")[0] == "Qwen2ForCausalLM":
|
||||||
self.llm_config = Qwen2Config(**llm_config)
|
self.llm_config = Qwen2Config(**llm_config)
|
||||||
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
|
elif llm_config.get("architectures")[0] == "Qwen3MoeForCausalLM":
|
||||||
|
self.llm_config = Qwen3MoeConfig(**llm_config)
|
||||||
|
elif llm_config.get("architectures")[0] == "Qwen3ForCausalLM":
|
||||||
self.llm_config = Qwen3Config(**llm_config)
|
self.llm_config = Qwen3Config(**llm_config)
|
||||||
|
elif llm_config.get("architectures")[0] == "GptOssForCausalLM":
|
||||||
|
self.llm_config = GptOssConfig(**llm_config)
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Unsupported architecture: {}".format(
|
"Unsupported architecture: {}".format(
|
||||||
|
|||||||
@@ -26,8 +26,10 @@ from sglang.srt.managers.schedule_batch import (
|
|||||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||||
from sglang.srt.models.deepseek_janus_pro import DropPath
|
from sglang.srt.models.deepseek_janus_pro import DropPath
|
||||||
|
from sglang.srt.models.gpt_oss import GptOssForCausalLM
|
||||||
from sglang.srt.models.internlm2 import InternLM2ForCausalLM
|
from sglang.srt.models.internlm2 import InternLM2ForCausalLM
|
||||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||||
|
from sglang.srt.models.qwen3 import Qwen3ForCausalLM
|
||||||
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
|
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
|
||||||
from sglang.utils import logger
|
from sglang.utils import logger
|
||||||
|
|
||||||
@@ -445,6 +447,14 @@ class InternVLChatModel(nn.Module):
|
|||||||
self.language_model = Qwen3MoeForCausalLM(
|
self.language_model = Qwen3MoeForCausalLM(
|
||||||
config=config.llm_config, quant_config=quant_config
|
config=config.llm_config, quant_config=quant_config
|
||||||
)
|
)
|
||||||
|
elif config.llm_config.architectures[0] == "GptOssForCausalLM":
|
||||||
|
self.language_model = GptOssForCausalLM(
|
||||||
|
config=config.llm_config, quant_config=quant_config
|
||||||
|
)
|
||||||
|
elif config.llm_config.architectures[0] == "Qwen3ForCausalLM":
|
||||||
|
self.language_model = Qwen3ForCausalLM(
|
||||||
|
config=config.llm_config, quant_config=quant_config
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"{config.llm_config.architectures[0]} is not implemented."
|
f"{config.llm_config.architectures[0]} is not implemented."
|
||||||
@@ -577,6 +587,15 @@ class InternVLChatModel(nn.Module):
|
|||||||
ckpt_up_proj_name="up_proj",
|
ckpt_up_proj_name="up_proj",
|
||||||
num_experts=self.config.num_experts,
|
num_experts=self.config.num_experts,
|
||||||
)
|
)
|
||||||
|
elif "Qwen3ForCausalLM" in self.config.llm_config.architectures:
|
||||||
|
stacked_params_mapping = [
|
||||||
|
# (param_name, shard_name, shard_id)
|
||||||
|
("qkv_proj", "q_proj", "q"),
|
||||||
|
("qkv_proj", "k_proj", "k"),
|
||||||
|
("qkv_proj", "v_proj", "v"),
|
||||||
|
("gate_up_proj", "gate_proj", 0),
|
||||||
|
("gate_up_proj", "up_proj", 1),
|
||||||
|
]
|
||||||
|
|
||||||
params_dict = dict(self.named_parameters())
|
params_dict = dict(self.named_parameters())
|
||||||
loaded_params: Set[str] = set()
|
loaded_params: Set[str] = set()
|
||||||
@@ -661,6 +680,15 @@ class InternVLChatModel(nn.Module):
|
|||||||
|
|
||||||
loaded_params.add(name)
|
loaded_params.add(name)
|
||||||
unloaded_params = params_dict.keys() - loaded_params
|
unloaded_params = params_dict.keys() - loaded_params
|
||||||
|
# Skip params that are created by quantization wrappers and are not expected in the ckpt
|
||||||
|
_quant_only_fragments = (
|
||||||
|
"weight_scale", # per-matrix FP8 scales (e.g., w2_weight_scale, w13_weight_scale)
|
||||||
|
)
|
||||||
|
unloaded_params = {
|
||||||
|
n
|
||||||
|
for n in unloaded_params
|
||||||
|
if not any(frag in n for frag in _quant_only_fragments)
|
||||||
|
}
|
||||||
if unloaded_params:
|
if unloaded_params:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
f"Some weights are not initialized from checkpoints: {unloaded_params}"
|
f"Some weights are not initialized from checkpoints: {unloaded_params}"
|
||||||
|
|||||||
Reference in New Issue
Block a user