Support InternVL3 (#5350)

Co-authored-by: Mick <mickjagger19@icloud.com>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
xm:D
2025-05-02 13:38:59 +08:00
committed by GitHub
parent 73dcf2b326
commit 3409aaab32
12 changed files with 1728 additions and 9 deletions

View File

@@ -19,6 +19,7 @@ import warnings
from pathlib import Path
from typing import Dict, Optional, Type, Union
import transformers
from huggingface_hub import snapshot_download
from transformers import (
AutoConfig,
@@ -26,6 +27,7 @@ from transformers import (
AutoTokenizer,
PretrainedConfig,
PreTrainedTokenizer,
PreTrainedTokenizerBase,
PreTrainedTokenizerFast,
)
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
@@ -38,6 +40,7 @@ from sglang.srt.configs import (
KimiVLConfig,
MultiModalityConfig,
)
from sglang.srt.configs.internvl import InternVLChatConfig
from sglang.srt.connector import create_remote_connector
from sglang.srt.utils import is_remote_url
@@ -48,6 +51,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
DeepseekVL2Config.model_type: DeepseekVL2Config,
MultiModalityConfig.model_type: MultiModalityConfig,
KimiVLConfig.model_type: KimiVLConfig,
InternVLChatConfig.model_type: InternVLChatConfig,
}
for name, cls in _CONFIG_REGISTRY.items():
@@ -90,6 +94,12 @@ def get_config(
config = config_class.from_pretrained(model, revision=revision)
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
setattr(config, "_name_or_path", model)
if isinstance(model, str) and config.model_type == "internvl_chat":
for key, val in config.llm_config.__dict__.items():
if not hasattr(config, key):
setattr(config, key, val)
if model_override_args:
config.update(model_override_args)
@@ -211,6 +221,13 @@ def get_tokenizer(
return tokenizer
# Some models doesn't have an available processor, e.g.: InternVL
def get_tokenizer_from_processor(processor):
if isinstance(processor, PreTrainedTokenizerBase):
return processor
return processor.tokenizer
def get_processor(
tokenizer_name: str,
*args,
@@ -246,7 +263,9 @@ def get_processor(
**kwargs,
)
attach_additional_stop_token_ids(processor.tokenizer)
tokenizer = get_tokenizer_from_processor(processor)
attach_additional_stop_token_ids(tokenizer)
return processor