Support InternVL3 (#5350)
Co-authored-by: Mick <mickjagger19@icloud.com> Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -19,6 +19,7 @@ import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
import transformers
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
@@ -26,6 +27,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
@@ -38,6 +40,7 @@ from sglang.srt.configs import (
|
||||
KimiVLConfig,
|
||||
MultiModalityConfig,
|
||||
)
|
||||
from sglang.srt.configs.internvl import InternVLChatConfig
|
||||
from sglang.srt.connector import create_remote_connector
|
||||
from sglang.srt.utils import is_remote_url
|
||||
|
||||
@@ -48,6 +51,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
DeepseekVL2Config.model_type: DeepseekVL2Config,
|
||||
MultiModalityConfig.model_type: MultiModalityConfig,
|
||||
KimiVLConfig.model_type: KimiVLConfig,
|
||||
InternVLChatConfig.model_type: InternVLChatConfig,
|
||||
}
|
||||
|
||||
for name, cls in _CONFIG_REGISTRY.items():
|
||||
@@ -90,6 +94,12 @@ def get_config(
|
||||
config = config_class.from_pretrained(model, revision=revision)
|
||||
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
|
||||
setattr(config, "_name_or_path", model)
|
||||
|
||||
if isinstance(model, str) and config.model_type == "internvl_chat":
|
||||
for key, val in config.llm_config.__dict__.items():
|
||||
if not hasattr(config, key):
|
||||
setattr(config, key, val)
|
||||
|
||||
if model_override_args:
|
||||
config.update(model_override_args)
|
||||
|
||||
@@ -211,6 +221,13 @@ def get_tokenizer(
|
||||
return tokenizer
|
||||
|
||||
|
||||
# Some models doesn't have an available processor, e.g.: InternVL
|
||||
def get_tokenizer_from_processor(processor):
|
||||
if isinstance(processor, PreTrainedTokenizerBase):
|
||||
return processor
|
||||
return processor.tokenizer
|
||||
|
||||
|
||||
def get_processor(
|
||||
tokenizer_name: str,
|
||||
*args,
|
||||
@@ -246,7 +263,9 @@ def get_processor(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attach_additional_stop_token_ids(processor.tokenizer)
|
||||
tokenizer = get_tokenizer_from_processor(processor)
|
||||
|
||||
attach_additional_stop_token_ids(tokenizer)
|
||||
return processor
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user