Support InternVL3 (#5350)
Co-authored-by: Mick <mickjagger19@icloud.com> Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
@@ -270,6 +270,29 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="janus",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": (
|
||||
"",
|
||||
"",
|
||||
),
|
||||
"user": (
|
||||
"<|User|>",
|
||||
"",
|
||||
),
|
||||
"assistant": (
|
||||
"<|Assistant|>",
|
||||
"<|end▁of▁sentence|>",
|
||||
),
|
||||
},
|
||||
stop_str=("<|end▁of▁sentence|>",),
|
||||
image_token="<image_placeholder>\n",
|
||||
)
|
||||
)
|
||||
|
||||
# The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
@@ -395,6 +418,20 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="internvl-2-5",
|
||||
default_system_prompt="你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。",
|
||||
role_prefix_and_suffix={
|
||||
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
|
||||
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
|
||||
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
|
||||
},
|
||||
stop_str=["<|im_end|>", "<|action_end|>"],
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="granite-3-instruct",
|
||||
@@ -565,6 +602,13 @@ def match_gemma3_instruct(model_path: str):
|
||||
return get_chat_template("gemma-it")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_internvl_chat(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "internvl" in model_path:
|
||||
return get_chat_template("internvl-2-5")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
messages = [
|
||||
{"role": "system", "content": None}, # None means default
|
||||
|
||||
696
python/sglang/srt/configs/internvl.py
Normal file
696
python/sglang/srt/configs/internvl.py
Normal file
@@ -0,0 +1,696 @@
|
||||
import copy
|
||||
import os
|
||||
from shutil import copyfile
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import sentencepiece as spm
|
||||
from transformers import (
|
||||
TOKENIZER_MAPPING,
|
||||
LlamaConfig,
|
||||
Phi3Config,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerFast,
|
||||
Qwen2Config,
|
||||
)
|
||||
|
||||
from sglang.utils import logger
|
||||
|
||||
# Copied from: https://github.com/OpenGVLab/InternVL/blob/34a81000402bf8f716bab8c9b57aff1f6b436bd0/internvl_chat/internvl/model/internvl_chat/configuration_internvl_chat.py#L21
|
||||
|
||||
|
||||
VOCAB_FILES_NAMES = {"vocab_file": "./tokenizer.model"}
|
||||
|
||||
PRETRAINED_VOCAB_FILES_MAP = {}
|
||||
|
||||
|
||||
# Modified from transformers.model.llama.configuration_llama.LlamaConfig
|
||||
class InternLM2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`InternLM2Model`]. It is used to instantiate
|
||||
an InternLM2 model according to the specified arguments, defining the model architecture. Instantiating a
|
||||
configuration with the defaults will yield a similar configuration to that of the InternLM2-7B.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 32000):
|
||||
Vocabulary size of the InternLM2 model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`InternLM2Model`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 11008):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1 the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details checkout [this
|
||||
paper](https://arxiv.org/pdf/2305.13245.pdf). If it is not specified, will default to
|
||||
`num_attention_heads`.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 2048):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-12):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings(`bool`, *optional*, defaults to `False`):
|
||||
Whether to tie weight embeddings
|
||||
Example:
|
||||
|
||||
"""
|
||||
|
||||
model_type = "internlm2"
|
||||
_auto_class = "AutoConfig"
|
||||
|
||||
def __init__( # pylint: disable=W0102
|
||||
self,
|
||||
vocab_size=103168,
|
||||
hidden_size=4096,
|
||||
intermediate_size=11008,
|
||||
num_hidden_layers=32,
|
||||
num_attention_heads=32,
|
||||
num_key_value_heads=None,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=2048,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=1e-6,
|
||||
use_cache=True,
|
||||
pad_token_id=0,
|
||||
bos_token_id=1,
|
||||
eos_token_id=2,
|
||||
tie_word_embeddings=False,
|
||||
bias=True,
|
||||
rope_theta=10000,
|
||||
rope_scaling=None,
|
||||
attn_implementation="eager",
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.bias = bias
|
||||
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self._rope_scaling_validation()
|
||||
|
||||
self.attn_implementation = attn_implementation
|
||||
if self.attn_implementation is None:
|
||||
self.attn_implementation = "eager"
|
||||
super().__init__(
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _rope_scaling_validation(self):
|
||||
"""
|
||||
Validate the `rope_scaling` configuration.
|
||||
"""
|
||||
if self.rope_scaling is None:
|
||||
return
|
||||
|
||||
if not isinstance(self.rope_scaling, dict) or len(self.rope_scaling) != 2:
|
||||
raise ValueError(
|
||||
"`rope_scaling` must be a dictionary with with two fields, `type` and `factor`, "
|
||||
f"got {self.rope_scaling}"
|
||||
)
|
||||
rope_scaling_type = self.rope_scaling.get("type", None)
|
||||
rope_scaling_factor = self.rope_scaling.get("factor", None)
|
||||
if rope_scaling_type is None or rope_scaling_type not in ["linear", "dynamic"]:
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s type field must be one of ['linear', 'dynamic'], got {rope_scaling_type}"
|
||||
)
|
||||
if (
|
||||
rope_scaling_factor is None
|
||||
or not isinstance(rope_scaling_factor, float)
|
||||
or rope_scaling_factor < 1.0
|
||||
):
|
||||
raise ValueError(
|
||||
f"`rope_scaling`'s factor field must be a float >= 1, got {rope_scaling_factor}"
|
||||
)
|
||||
|
||||
|
||||
class InternVisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`InternVisionModel`]. It is used to
|
||||
instantiate a vision encoder according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
num_channels (`int`, *optional*, defaults to 3):
|
||||
Number of color channels in the input images (e.g., 3 for RGB).
|
||||
patch_size (`int`, *optional*, defaults to 14):
|
||||
The size (resolution) of each patch.
|
||||
image_size (`int`, *optional*, defaults to 224):
|
||||
The size (resolution) of each image.
|
||||
qkv_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to add a bias to the queries and values in the self-attention layers.
|
||||
hidden_size (`int`, *optional*, defaults to 3200):
|
||||
Dimensionality of the encoder layers and the pooler layer.
|
||||
num_attention_heads (`int`, *optional*, defaults to 25):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
intermediate_size (`int`, *optional*, defaults to 12800):
|
||||
Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
|
||||
qk_normalization (`bool`, *optional*, defaults to `True`):
|
||||
Whether to normalize the queries and keys in the self-attention layers.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 48):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
use_flash_attn (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use flash attention mechanism.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"selu"` and `"gelu_new"` ``"gelu"` are supported.
|
||||
layer_norm_eps (`float`, *optional*, defaults to 1e-6):
|
||||
The epsilon used by the layer normalization layers.
|
||||
dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
drop_path_rate (`float`, *optional*, defaults to 0.0):
|
||||
Dropout rate for stochastic depth.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
initializer_factor (`float`, *optional*, defaults to 0.1):
|
||||
A factor for layer scale.
|
||||
"""
|
||||
|
||||
model_type = "intern_vit_6b"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_channels=3,
|
||||
patch_size=14,
|
||||
image_size=224,
|
||||
qkv_bias=False,
|
||||
hidden_size=3200,
|
||||
num_attention_heads=25,
|
||||
intermediate_size=12800,
|
||||
qk_normalization=True,
|
||||
num_hidden_layers=48,
|
||||
use_flash_attn=True,
|
||||
hidden_act="gelu",
|
||||
layer_norm_eps=1e-6,
|
||||
dropout=0.0,
|
||||
drop_path_rate=0.0,
|
||||
attention_dropout=0.0,
|
||||
initializer_range=0.02,
|
||||
initializer_factor=0.1,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.dropout = dropout
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_channels = num_channels
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
self.initializer_range = initializer_range
|
||||
self.initializer_factor = initializer_factor
|
||||
self.attention_dropout = attention_dropout
|
||||
self.layer_norm_eps = layer_norm_eps
|
||||
self.hidden_act = hidden_act
|
||||
self.qkv_bias = qkv_bias
|
||||
self.qk_normalization = qk_normalization
|
||||
self.use_flash_attn = use_flash_attn
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs
|
||||
) -> "PretrainedConfig":
|
||||
config_dict, kwargs = cls.get_config_dict(
|
||||
pretrained_model_name_or_path, **kwargs
|
||||
)
|
||||
|
||||
if "vision_config" in config_dict:
|
||||
config_dict = config_dict["vision_config"]
|
||||
|
||||
if (
|
||||
"model_type" in config_dict
|
||||
and hasattr(cls, "model_type")
|
||||
and config_dict["model_type"] != cls.model_type
|
||||
):
|
||||
logger.warning(
|
||||
f"You are using a model of type {config_dict['model_type']} to instantiate a model of type "
|
||||
f"{cls.model_type}. This is not supported for all configurations of models and can yield errors."
|
||||
)
|
||||
|
||||
return cls.from_dict(config_dict, **kwargs)
|
||||
|
||||
|
||||
class InternVLChatConfig(PretrainedConfig):
|
||||
model_type = "internvl_chat"
|
||||
is_composition = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
llm_config=None,
|
||||
use_backbone_lora=0,
|
||||
use_llm_lora=0,
|
||||
pad2square=False,
|
||||
select_layer=-1,
|
||||
force_image_size=None,
|
||||
downsample_ratio=0.5,
|
||||
template=None,
|
||||
dynamic_image_size=False,
|
||||
use_thumbnail=False,
|
||||
ps_version="v1",
|
||||
min_dynamic_patch=1,
|
||||
max_dynamic_patch=6,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if vision_config is None:
|
||||
vision_config = {"architectures": ["InternVisionModel"]}
|
||||
logger.info(
|
||||
"vision_config is None. Initializing the InternVisionConfig with default values."
|
||||
)
|
||||
|
||||
if llm_config is None:
|
||||
# TODO: There might still be a bug in transformers version 4.44 and above.
|
||||
llm_config = {"architectures": [""]}
|
||||
logger.info(
|
||||
"llm_config is None. Initializing the LlamaConfig config with default values (`LlamaConfig`)."
|
||||
)
|
||||
self.vision_config = InternVisionConfig(**vision_config)
|
||||
if llm_config["architectures"][0] == "LlamaForCausalLM":
|
||||
self.llm_config = LlamaConfig(**llm_config)
|
||||
elif llm_config["architectures"][0] == "InternLM2ForCausalLM":
|
||||
self.llm_config = InternLM2Config(**llm_config)
|
||||
elif llm_config["architectures"][0] == "Phi3ForCausalLM":
|
||||
self.llm_config = Phi3Config(**llm_config)
|
||||
elif llm_config["architectures"][0] == "Qwen2ForCausalLM":
|
||||
self.llm_config = Qwen2Config(**llm_config)
|
||||
else:
|
||||
raise ValueError(
|
||||
"Unsupported architecture: {}".format(llm_config["architectures"][0])
|
||||
)
|
||||
self.use_backbone_lora = use_backbone_lora
|
||||
self.use_llm_lora = use_llm_lora
|
||||
self.pad2square = pad2square
|
||||
self.select_layer = select_layer
|
||||
self.force_image_size = force_image_size
|
||||
self.downsample_ratio = downsample_ratio
|
||||
self.template = template
|
||||
self.dynamic_image_size = dynamic_image_size
|
||||
self.use_thumbnail = use_thumbnail
|
||||
self.ps_version = ps_version # pixel shuffle version
|
||||
self.min_dynamic_patch = min_dynamic_patch
|
||||
self.max_dynamic_patch = max_dynamic_patch
|
||||
|
||||
self.hidden_size = self.llm_config.hidden_size
|
||||
# By default, we use tie_word_embeddings=False for models of all sizes.
|
||||
self.tie_word_embeddings = False
|
||||
self.llm_config.tie_word_embeddings = self.tie_word_embeddings
|
||||
|
||||
def to_dict(self):
|
||||
"""
|
||||
Serializes this instance to a Python dictionary. Override the default [`~PretrainedConfig.to_dict`].
|
||||
|
||||
Returns:
|
||||
`Dict[str, any]`: Dictionary of all the attributes that make up this configuration instance,
|
||||
"""
|
||||
output = copy.deepcopy(self.__dict__)
|
||||
output["vision_config"] = self.vision_config.to_dict()
|
||||
output["llm_config"] = self.llm_config.to_dict()
|
||||
output["model_type"] = self.__class__.model_type
|
||||
output["use_backbone_lora"] = self.use_backbone_lora
|
||||
output["use_llm_lora"] = self.use_llm_lora
|
||||
output["select_layer"] = self.select_layer
|
||||
output["force_image_size"] = self.force_image_size
|
||||
output["downsample_ratio"] = self.downsample_ratio
|
||||
output["template"] = self.template
|
||||
output["dynamic_image_size"] = self.dynamic_image_size
|
||||
output["use_thumbnail"] = self.use_thumbnail
|
||||
output["ps_version"] = self.ps_version
|
||||
output["min_dynamic_patch"] = self.min_dynamic_patch
|
||||
output["max_dynamic_patch"] = self.max_dynamic_patch
|
||||
|
||||
return output
|
||||
|
||||
|
||||
# # Modified from transformers.model.llama.tokenization_llama_fast.LlamaTokenizerFast -> InternLM2TokenizerFast
|
||||
# class InternLM2TokenizerFast(PreTrainedTokenizerFast):
|
||||
# vocab_files_names = VOCAB_FILES_NAMES
|
||||
# slow_tokenizer_class = InternLM2Tokenizer
|
||||
# padding_side = 'left'
|
||||
# model_input_names = ['input_ids', 'attention_mask']
|
||||
# _auto_class = 'AutoTokenizer'
|
||||
#
|
||||
# def __init__(
|
||||
# self,
|
||||
# vocab_file,
|
||||
# unk_token='<unk>',
|
||||
# bos_token='<s>',
|
||||
# eos_token='</s>',
|
||||
# pad_token='</s>',
|
||||
# sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
# add_bos_token=True,
|
||||
# add_eos_token=False,
|
||||
# decode_with_prefix_space=False,
|
||||
# clean_up_tokenization_spaces=False,
|
||||
# **kwargs,
|
||||
# ):
|
||||
# super().__init__(
|
||||
# vocab_file=vocab_file,
|
||||
# unk_token=unk_token,
|
||||
# bos_token=bos_token,
|
||||
# eos_token=eos_token,
|
||||
# pad_token=pad_token,
|
||||
# sp_model_kwargs=sp_model_kwargs,
|
||||
# add_bos_token=add_bos_token,
|
||||
# add_eos_token=add_eos_token,
|
||||
# decode_with_prefix_space=decode_with_prefix_space,
|
||||
# clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
# **kwargs,
|
||||
# )
|
||||
# self._add_bos_token = add_bos_token
|
||||
# self._add_eos_token = add_eos_token
|
||||
# self.update_post_processor()
|
||||
# self.vocab_file = vocab_file
|
||||
#
|
||||
# @property
|
||||
# def can_save_slow_tokenizer(self) -> bool:
|
||||
# return os.path.isfile(self.vocab_file) if self.vocab_file else False
|
||||
#
|
||||
# def update_post_processor(self):
|
||||
# """
|
||||
# Updates the underlying post processor with the current `bos_token` and `eos_token`.
|
||||
# """
|
||||
# bos = self.bos_token
|
||||
# bos_token_id = self.bos_token_id
|
||||
# if bos is None and self.add_bos_token:
|
||||
# raise ValueError('add_bos_token = True but bos_token = None')
|
||||
#
|
||||
# eos = self.eos_token
|
||||
# eos_token_id = self.eos_token_id
|
||||
# if eos is None and self.add_eos_token:
|
||||
# raise ValueError('add_eos_token = True but eos_token = None')
|
||||
#
|
||||
# single = f"{(bos + ':0 ') if self.add_bos_token else ''}$A:0{(' ' + eos + ':0') if self.add_eos_token else ''}"
|
||||
# pair = f"{single}{(' ' + bos + ':1') if self.add_bos_token else ''} $B:1{(' ' + eos + ':1') if self.add_eos_token else ''}"
|
||||
#
|
||||
# special_tokens = []
|
||||
# if self.add_bos_token:
|
||||
# special_tokens.append((bos, bos_token_id))
|
||||
# if self.add_eos_token:
|
||||
# special_tokens.append((eos, eos_token_id))
|
||||
# self._tokenizer.post_processor = processors.TemplateProcessing(
|
||||
# single=single, pair=pair, special_tokens=special_tokens
|
||||
# )
|
||||
#
|
||||
# @property
|
||||
# def add_eos_token(self):
|
||||
# return self._add_eos_token
|
||||
#
|
||||
# @property
|
||||
# def add_bos_token(self):
|
||||
# return self._add_bos_token
|
||||
#
|
||||
# @add_eos_token.setter
|
||||
# def add_eos_token(self, value):
|
||||
# self._add_eos_token = value
|
||||
# self.update_post_processor()
|
||||
#
|
||||
# @add_bos_token.setter
|
||||
# def add_bos_token(self, value):
|
||||
# self._add_bos_token = value
|
||||
# self.update_post_processor()
|
||||
#
|
||||
# def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> Tuple[str]:
|
||||
# if not self.can_save_slow_tokenizer:
|
||||
# raise ValueError(
|
||||
# 'Your fast tokenizer does not have the necessary information to save the vocabulary for a slow '
|
||||
# 'tokenizer.'
|
||||
# )
|
||||
#
|
||||
# if not os.path.isdir(save_directory):
|
||||
# logger.error(f'Vocabulary path ({save_directory}) should be a directory')
|
||||
# return
|
||||
# out_vocab_file = os.path.join(
|
||||
# save_directory, (filename_prefix + '-' if filename_prefix else '') + VOCAB_FILES_NAMES['vocab_file']
|
||||
# )
|
||||
#
|
||||
# if os.path.abspath(self.vocab_file) != os.path.abspath(out_vocab_file):
|
||||
# copyfile(self.vocab_file, out_vocab_file)
|
||||
#
|
||||
# return (out_vocab_file,)
|
||||
|
||||
|
||||
# Modified from transformers.model.llama.tokenization_llama.LlamaTokenizer
|
||||
class InternLM2Tokenizer(PreTrainedTokenizer):
|
||||
"""
|
||||
Construct a InternLM2 tokenizer. Based on byte-level Byte-Pair-Encoding.
|
||||
|
||||
Args:
|
||||
vocab_file (`str`):
|
||||
Path to the vocabulary file.
|
||||
"""
|
||||
|
||||
vocab_files_names = VOCAB_FILES_NAMES
|
||||
pretrained_vocab_files_map = PRETRAINED_VOCAB_FILES_MAP
|
||||
model_input_names = ["input_ids", "attention_mask"]
|
||||
_auto_class = "AutoTokenizer"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_file,
|
||||
unk_token="<unk>",
|
||||
bos_token="<s>",
|
||||
eos_token="</s>",
|
||||
pad_token="</s>",
|
||||
sp_model_kwargs: Optional[Dict[str, Any]] = None,
|
||||
add_bos_token=True,
|
||||
add_eos_token=False,
|
||||
decode_with_prefix_space=False,
|
||||
clean_up_tokenization_spaces=False,
|
||||
**kwargs,
|
||||
):
|
||||
print("register succeed")
|
||||
self.sp_model_kwargs = {} if sp_model_kwargs is None else sp_model_kwargs
|
||||
self.vocab_file = vocab_file
|
||||
self.add_bos_token = add_bos_token
|
||||
self.add_eos_token = add_eos_token
|
||||
self.decode_with_prefix_space = decode_with_prefix_space
|
||||
self.sp_model = spm.SentencePieceProcessor(**self.sp_model_kwargs)
|
||||
self.sp_model.Load(vocab_file)
|
||||
self._no_prefix_space_tokens = None
|
||||
super().__init__(
|
||||
bos_token=bos_token,
|
||||
eos_token=eos_token,
|
||||
unk_token=unk_token,
|
||||
pad_token=pad_token,
|
||||
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@property
|
||||
def no_prefix_space_tokens(self):
|
||||
if self._no_prefix_space_tokens is None:
|
||||
vocab = self.convert_ids_to_tokens(list(range(self.vocab_size)))
|
||||
self._no_prefix_space_tokens = {
|
||||
i for i, tok in enumerate(vocab) if not tok.startswith("▁")
|
||||
}
|
||||
return self._no_prefix_space_tokens
|
||||
|
||||
@property
|
||||
def vocab_size(self):
|
||||
"""Returns vocab size"""
|
||||
return self.sp_model.get_piece_size()
|
||||
|
||||
@property
|
||||
def bos_token_id(self) -> Optional[int]:
|
||||
return self.sp_model.bos_id()
|
||||
|
||||
@property
|
||||
def eos_token_id(self) -> Optional[int]:
|
||||
return self.sp_model.eos_id()
|
||||
|
||||
def get_vocab(self):
|
||||
"""Returns vocab as a dict"""
|
||||
vocab = {self.convert_ids_to_tokens(i): i for i in range(self.vocab_size)}
|
||||
vocab.update(self.added_tokens_encoder)
|
||||
return vocab
|
||||
|
||||
def _tokenize(self, text):
|
||||
"""Returns a tokenized string."""
|
||||
return self.sp_model.encode(text, out_type=str)
|
||||
|
||||
def _convert_token_to_id(self, token):
|
||||
"""Converts a token (str) in an id using the vocab."""
|
||||
return self.sp_model.piece_to_id(token)
|
||||
|
||||
def _convert_id_to_token(self, index):
|
||||
"""Converts an index (integer) in a token (str) using the vocab."""
|
||||
token = self.sp_model.IdToPiece(index)
|
||||
return token
|
||||
|
||||
def _maybe_add_prefix_space(self, tokens, decoded):
|
||||
if tokens and tokens[0] not in self.no_prefix_space_tokens:
|
||||
return " " + decoded
|
||||
else:
|
||||
return decoded
|
||||
|
||||
def convert_tokens_to_string(self, tokens):
|
||||
"""Converts a sequence of tokens (string) in a single string."""
|
||||
current_sub_tokens = []
|
||||
out_string = ""
|
||||
prev_is_special = False
|
||||
for token in tokens:
|
||||
# make sure that special tokens are not decoded using sentencepiece model
|
||||
if token in self.all_special_tokens:
|
||||
if not prev_is_special:
|
||||
out_string += " "
|
||||
out_string += self.sp_model.decode(current_sub_tokens) + token
|
||||
prev_is_special = True
|
||||
current_sub_tokens = []
|
||||
else:
|
||||
current_sub_tokens.append(token)
|
||||
prev_is_special = False
|
||||
out_string += self.sp_model.decode(current_sub_tokens)
|
||||
out_string = self.clean_up_tokenization(out_string)
|
||||
out_string = self._maybe_add_prefix_space(tokens=tokens, decoded=out_string)
|
||||
return out_string[1:]
|
||||
|
||||
def save_vocabulary(
|
||||
self, save_directory, filename_prefix: Optional[str] = None
|
||||
) -> Tuple[str]:
|
||||
"""
|
||||
Save the vocabulary and special tokens file to a directory.
|
||||
|
||||
Args:
|
||||
save_directory (`str`):
|
||||
The directory in which to save the vocabulary.
|
||||
|
||||
Returns:
|
||||
`Tuple(str)`: Paths to the files saved.
|
||||
"""
|
||||
if not os.path.isdir(save_directory):
|
||||
logger.error(f"Vocabulary path ({save_directory}) should be a directory")
|
||||
return
|
||||
out_vocab_file = os.path.join(
|
||||
save_directory,
|
||||
(filename_prefix + "-" if filename_prefix else "")
|
||||
+ VOCAB_FILES_NAMES["vocab_file"],
|
||||
)
|
||||
|
||||
if os.path.abspath(self.vocab_file) != os.path.abspath(
|
||||
out_vocab_file
|
||||
) and os.path.isfile(self.vocab_file):
|
||||
copyfile(self.vocab_file, out_vocab_file)
|
||||
elif not os.path.isfile(self.vocab_file):
|
||||
with open(out_vocab_file, "wb") as fi:
|
||||
content_spiece_model = self.sp_model.serialized_model_proto()
|
||||
fi.write(content_spiece_model)
|
||||
|
||||
return (out_vocab_file,)
|
||||
|
||||
def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
|
||||
if self.add_bos_token:
|
||||
bos_token_ids = [self.bos_token_id]
|
||||
else:
|
||||
bos_token_ids = []
|
||||
|
||||
output = bos_token_ids + token_ids_0
|
||||
|
||||
if token_ids_1 is not None:
|
||||
output = output + token_ids_1
|
||||
|
||||
if self.add_eos_token:
|
||||
output = output + [self.eos_token_id]
|
||||
|
||||
return output
|
||||
|
||||
def get_special_tokens_mask(
|
||||
self,
|
||||
token_ids_0: List[int],
|
||||
token_ids_1: Optional[List[int]] = None,
|
||||
already_has_special_tokens: bool = False,
|
||||
) -> List[int]:
|
||||
"""
|
||||
Retrieve sequence ids from a token list that has no special tokens added. This method is called when adding
|
||||
special tokens using the tokenizer `prepare_for_model` method.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
already_has_special_tokens (`bool`, *optional*, defaults to `False`):
|
||||
Whether or not the token list is already formatted with special tokens for the model.
|
||||
|
||||
Returns:
|
||||
`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
|
||||
"""
|
||||
if already_has_special_tokens:
|
||||
return super().get_special_tokens_mask(
|
||||
token_ids_0=token_ids_0,
|
||||
token_ids_1=token_ids_1,
|
||||
already_has_special_tokens=True,
|
||||
)
|
||||
|
||||
if token_ids_1 is None:
|
||||
return [1] + ([0] * len(token_ids_0)) + [1]
|
||||
return [1] + ([0] * len(token_ids_0)) + [1, 1] + ([0] * len(token_ids_1)) + [1]
|
||||
|
||||
def create_token_type_ids_from_sequences(
|
||||
self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
|
||||
) -> List[int]:
|
||||
"""
|
||||
Create a mask from the two sequences passed to be used in a sequence-pair classification task. T5 does not make
|
||||
use of token type ids, therefore a list of zeros is returned.
|
||||
|
||||
Args:
|
||||
token_ids_0 (`List[int]`):
|
||||
List of IDs.
|
||||
token_ids_1 (`List[int]`, *optional*):
|
||||
Optional second list of IDs for sequence pairs.
|
||||
|
||||
Returns:
|
||||
`List[int]`: List of zeros.
|
||||
"""
|
||||
eos = [self.eos_token_id]
|
||||
|
||||
if token_ids_1 is None:
|
||||
return len(token_ids_0 + eos) * [0]
|
||||
return len(token_ids_0 + eos + token_ids_1 + eos) * [0]
|
||||
|
||||
|
||||
TOKENIZER_MAPPING.register(
|
||||
InternVLChatConfig, (InternLM2Tokenizer, None), exist_ok=True
|
||||
)
|
||||
@@ -538,6 +538,7 @@ multimodal_model_archs = [
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"CLIPModel",
|
||||
"KimiVLForConditionalGeneration",
|
||||
"InternVLChatModel",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ class SeparatorStyle(IntEnum):
|
||||
DeepSeekVL2 = auto()
|
||||
QWEN2_VL_EMBED = auto()
|
||||
GEMMA3 = auto()
|
||||
MPT = auto()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@@ -327,6 +328,16 @@ class Conversation:
|
||||
ret += role
|
||||
return ret
|
||||
|
||||
elif self.sep_style == SeparatorStyle.MPT:
|
||||
ret = system_prompt + self.sep
|
||||
for role, message in self.messages:
|
||||
if message:
|
||||
if type(message) is tuple:
|
||||
message, _, _ = message
|
||||
ret += role + message + self.sep
|
||||
else:
|
||||
ret += role
|
||||
return ret
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.sep_style}")
|
||||
|
||||
@@ -570,8 +581,11 @@ def generate_chat_conv(
|
||||
real_content += "\n" # for video
|
||||
real_content += content.text
|
||||
elif content.type == "image_url":
|
||||
# NOTE: Only works for llava
|
||||
real_content += image_token
|
||||
# NOTE: works for llava and intervl2_5
|
||||
if conv.name == "internvl-2-5":
|
||||
real_content = image_token + real_content
|
||||
else:
|
||||
real_content += image_token
|
||||
conv.append_image(content.image_url.url)
|
||||
elif content.type == "audio_url":
|
||||
real_content += audio_token
|
||||
@@ -703,6 +717,19 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="internvl-2-5",
|
||||
system_template="<|im_start|>system\n{system_message}",
|
||||
system_message="你是书生·万象,英文名是InternVL,是由上海人工智能实验室、清华大学及多家合作单位联合开发的多模态大语言模型。",
|
||||
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
||||
sep_style=SeparatorStyle.MPT,
|
||||
sep="<|im_end|>\n",
|
||||
stop_str=["<|im_end|>", "<|action_end|>"],
|
||||
image_token="<image>",
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
|
||||
@@ -19,6 +19,7 @@ import warnings
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Type, Union
|
||||
|
||||
import transformers
|
||||
from huggingface_hub import snapshot_download
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
@@ -26,6 +27,7 @@ from transformers import (
|
||||
AutoTokenizer,
|
||||
PretrainedConfig,
|
||||
PreTrainedTokenizer,
|
||||
PreTrainedTokenizerBase,
|
||||
PreTrainedTokenizerFast,
|
||||
)
|
||||
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
||||
@@ -38,6 +40,7 @@ from sglang.srt.configs import (
|
||||
KimiVLConfig,
|
||||
MultiModalityConfig,
|
||||
)
|
||||
from sglang.srt.configs.internvl import InternVLChatConfig
|
||||
from sglang.srt.connector import create_remote_connector
|
||||
from sglang.srt.utils import is_remote_url
|
||||
|
||||
@@ -48,6 +51,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
DeepseekVL2Config.model_type: DeepseekVL2Config,
|
||||
MultiModalityConfig.model_type: MultiModalityConfig,
|
||||
KimiVLConfig.model_type: KimiVLConfig,
|
||||
InternVLChatConfig.model_type: InternVLChatConfig,
|
||||
}
|
||||
|
||||
for name, cls in _CONFIG_REGISTRY.items():
|
||||
@@ -90,6 +94,12 @@ def get_config(
|
||||
config = config_class.from_pretrained(model, revision=revision)
|
||||
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
|
||||
setattr(config, "_name_or_path", model)
|
||||
|
||||
if isinstance(model, str) and config.model_type == "internvl_chat":
|
||||
for key, val in config.llm_config.__dict__.items():
|
||||
if not hasattr(config, key):
|
||||
setattr(config, key, val)
|
||||
|
||||
if model_override_args:
|
||||
config.update(model_override_args)
|
||||
|
||||
@@ -211,6 +221,13 @@ def get_tokenizer(
|
||||
return tokenizer
|
||||
|
||||
|
||||
# Some models doesn't have an available processor, e.g.: InternVL
|
||||
def get_tokenizer_from_processor(processor):
|
||||
if isinstance(processor, PreTrainedTokenizerBase):
|
||||
return processor
|
||||
return processor.tokenizer
|
||||
|
||||
|
||||
def get_processor(
|
||||
tokenizer_name: str,
|
||||
*args,
|
||||
@@ -246,7 +263,9 @@ def get_processor(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
attach_additional_stop_token_ids(processor.tokenizer)
|
||||
tokenizer = get_tokenizer_from_processor(processor)
|
||||
|
||||
attach_additional_stop_token_ids(tokenizer)
|
||||
return processor
|
||||
|
||||
|
||||
|
||||
232
python/sglang/srt/managers/multimodal_processors/internvl.py
Normal file
232
python/sglang/srt/managers/multimodal_processors/internvl.py
Normal file
@@ -0,0 +1,232 @@
|
||||
# Adapted from https://huggingface.co/OpenGVLab/InternVL2-4B/blob/main/modeling_intern_vit.py
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from decord import VideoReader, cpu
|
||||
from numpy.distutils.cpuinfo import cpu
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor,
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.internvl import InternVLChatModel
|
||||
|
||||
|
||||
class InternVLImageProcessor(BaseMultimodalProcessor):
|
||||
models = [InternVLChatModel]
|
||||
|
||||
def __init__(self, hf_config, server_args, _image_processor):
|
||||
super().__init__(hf_config, server_args, _image_processor)
|
||||
image_size = hf_config.force_image_size or hf_config.vision_config.image_size
|
||||
patch_size = hf_config.vision_config.patch_size
|
||||
|
||||
self.IMG_CONTEXT_TOKEN = "<IMG_CONTEXT>"
|
||||
self.IMG_START_TOKEN = "<img>"
|
||||
self.IMG_END_TOKEN = "</img>"
|
||||
self.IMG_TOKEN = "<image>"
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size) ** 2 * (hf_config.downsample_ratio**2)
|
||||
)
|
||||
|
||||
tokenizer = self._processor
|
||||
self.img_start_token_id = tokenizer.convert_tokens_to_ids(self.IMG_START_TOKEN)
|
||||
self.img_end_token_id = tokenizer.convert_tokens_to_ids(self.IMG_END_TOKEN)
|
||||
self.img_context_token_id = tokenizer.convert_tokens_to_ids(
|
||||
self.IMG_CONTEXT_TOKEN
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def build_transform(input_size):
|
||||
IMAGENET_MEAN = (0.485, 0.456, 0.406)
|
||||
IMAGENET_STD = (0.229, 0.224, 0.225)
|
||||
|
||||
def resize_image(img, size):
|
||||
return img.resize((size, size), Image.Resampling.BICUBIC)
|
||||
|
||||
def to_tensor(img):
|
||||
# Convert PIL Image to numpy array
|
||||
img_array = np.array(img).astype(np.float32) / 255.0
|
||||
# Convert HWC to CHW format
|
||||
img_array = img_array.transpose(2, 0, 1)
|
||||
return torch.from_numpy(img_array)
|
||||
|
||||
def normalize(tensor, mean, std):
|
||||
mean = torch.tensor(mean).view(-1, 1, 1)
|
||||
std = torch.tensor(std).view(-1, 1, 1)
|
||||
return (tensor - mean) / std
|
||||
|
||||
def transform(img):
|
||||
img = img.convert("RGB") if img.mode != "RGB" else img
|
||||
img = resize_image(img, input_size)
|
||||
tensor = to_tensor(img)
|
||||
tensor = normalize(tensor, IMAGENET_MEAN, IMAGENET_STD)
|
||||
return tensor
|
||||
|
||||
return transform
|
||||
|
||||
@staticmethod
|
||||
def dynamic_preprocess(
|
||||
image, min_num=1, max_num=12, image_size=448, use_thumbnail=False
|
||||
):
|
||||
|
||||
def find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, width, height, image_size
|
||||
):
|
||||
best_ratio_diff = float("inf")
|
||||
best_ratio = (1, 1)
|
||||
area = width * height
|
||||
for ratio in target_ratios:
|
||||
target_aspect_ratio = ratio[0] / ratio[1]
|
||||
ratio_diff = abs(aspect_ratio - target_aspect_ratio)
|
||||
if ratio_diff < best_ratio_diff:
|
||||
best_ratio_diff = ratio_diff
|
||||
best_ratio = ratio
|
||||
elif ratio_diff == best_ratio_diff:
|
||||
if area > 0.5 * image_size * image_size * ratio[0] * ratio[1]:
|
||||
best_ratio = ratio
|
||||
return best_ratio
|
||||
|
||||
orig_width, orig_height = image.size
|
||||
aspect_ratio = orig_width / orig_height
|
||||
|
||||
# calculate the existing image aspect ratio
|
||||
target_ratios = set(
|
||||
(i, j)
|
||||
for n in range(min_num, max_num + 1)
|
||||
for i in range(1, n + 1)
|
||||
for j in range(1, n + 1)
|
||||
if i * j <= max_num and i * j >= min_num
|
||||
)
|
||||
target_ratios = sorted(target_ratios, key=lambda x: x[0] * x[1])
|
||||
|
||||
# find the closest aspect ratio to the target
|
||||
target_aspect_ratio = find_closest_aspect_ratio(
|
||||
aspect_ratio, target_ratios, orig_width, orig_height, image_size
|
||||
)
|
||||
|
||||
# calculate the target width and height
|
||||
target_width = image_size * target_aspect_ratio[0]
|
||||
target_height = image_size * target_aspect_ratio[1]
|
||||
blocks = target_aspect_ratio[0] * target_aspect_ratio[1]
|
||||
|
||||
# resize the image
|
||||
resized_img = image.resize((target_width, target_height))
|
||||
processed_images = []
|
||||
for i in range(blocks):
|
||||
box = (
|
||||
(i % (target_width // image_size)) * image_size,
|
||||
(i // (target_width // image_size)) * image_size,
|
||||
((i % (target_width // image_size)) + 1) * image_size,
|
||||
((i // (target_width // image_size)) + 1) * image_size,
|
||||
)
|
||||
# split the image
|
||||
split_img = resized_img.crop(box)
|
||||
processed_images.append(split_img)
|
||||
assert len(processed_images) == blocks
|
||||
if use_thumbnail and len(processed_images) != 1:
|
||||
thumbnail_img = image.resize((image_size, image_size))
|
||||
processed_images.append(thumbnail_img)
|
||||
return processed_images
|
||||
|
||||
@staticmethod
|
||||
def get_index(bound, fps, max_frame, first_idx=0, num_segments=32):
|
||||
if bound:
|
||||
start, end = bound[0], bound[1]
|
||||
else:
|
||||
start, end = -100000, 100000
|
||||
start_idx = max(first_idx, round(start * fps))
|
||||
end_idx = min(round(end * fps), max_frame)
|
||||
seg_size = float(end_idx - start_idx) / num_segments
|
||||
frame_indices = np.array(
|
||||
[
|
||||
int(start_idx + (seg_size / 2) + np.round(seg_size * idx))
|
||||
for idx in range(num_segments)
|
||||
]
|
||||
)
|
||||
return frame_indices
|
||||
|
||||
@staticmethod
|
||||
def load_video(video_path, bound=None, input_size=448, max_num=1, num_segments=32):
|
||||
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
|
||||
max_frame = len(vr) - 1
|
||||
fps = float(vr.get_avg_fps())
|
||||
|
||||
pixel_values_list, num_patches_list = [], []
|
||||
transform = InternVLImageProcessor.build_transform(input_size=input_size)
|
||||
frame_indices = InternVLImageProcessor.get_index(
|
||||
bound, fps, max_frame, first_idx=0, num_segments=num_segments
|
||||
)
|
||||
for frame_index in frame_indices:
|
||||
img = Image.fromarray(vr[frame_index].asnumpy()).convert("RGB")
|
||||
img = InternVLImageProcessor.dynamic_preprocess(
|
||||
img, image_size=input_size, use_thumbnail=True, max_num=max_num
|
||||
)
|
||||
pixel_values = [transform(tile) for tile in img]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
num_patches_list.append(pixel_values.shape[0])
|
||||
pixel_values_list.append(pixel_values)
|
||||
pixel_values = torch.cat(pixel_values_list)
|
||||
return pixel_values, num_patches_list
|
||||
|
||||
async def process_mm_data_async(
|
||||
self, image_data, input_text, request_obj, max_req_input_len, **kwargs
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMG_TOKEN),
|
||||
max_req_input_len=max_req_input_len,
|
||||
discard_alpha_channel=True,
|
||||
)
|
||||
|
||||
def process_image_internvl(image, input_size=448, max_num=12):
|
||||
transform = InternVLImageProcessor.build_transform(input_size=input_size)
|
||||
images = InternVLImageProcessor.dynamic_preprocess(
|
||||
image, image_size=input_size, use_thumbnail=True, max_num=max_num
|
||||
)
|
||||
pixel_values = [transform(image) for image in images]
|
||||
pixel_values = torch.stack(pixel_values)
|
||||
return pixel_values
|
||||
|
||||
num_patches_list = []
|
||||
pixel_values = []
|
||||
# Process each input with allocated frames
|
||||
for image_index, (image) in enumerate(base_output.images):
|
||||
try:
|
||||
# TODO: video input
|
||||
raw_image = process_image_internvl(image)
|
||||
pixel_value = [raw_image.to(torch.bfloat16).cuda()]
|
||||
pixel_values += pixel_value
|
||||
num_patches = raw_image.shape[0]
|
||||
num_patches_list += [num_patches]
|
||||
|
||||
except FileNotFoundError as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
pixel_values = torch.cat(pixel_values, dim=0)
|
||||
items = [MultimodalDataItem(pixel_values=pixel_values, modality=Modality.IMAGE)]
|
||||
|
||||
for idx, num_patches in enumerate(num_patches_list):
|
||||
image_tokens = (
|
||||
self.IMG_START_TOKEN
|
||||
+ self.IMG_CONTEXT_TOKEN * self.num_image_token * num_patches
|
||||
+ self.IMG_END_TOKEN
|
||||
)
|
||||
input_text = input_text.replace("<image>", image_tokens, 1)
|
||||
|
||||
tokenizer = self._processor
|
||||
return {
|
||||
"input_ids": tokenizer(input_text, return_tensors="pt")["input_ids"]
|
||||
.flatten()
|
||||
.tolist(),
|
||||
"mm_items": items,
|
||||
"im_start_id": self.img_start_token_id,
|
||||
"im_end_id": self.img_end_token_id,
|
||||
"im_token_id": self.img_context_token_id,
|
||||
}
|
||||
@@ -52,7 +52,11 @@ from sglang.srt.disaggregation.utils import (
|
||||
TransferBackend,
|
||||
)
|
||||
from sglang.srt.distributed import get_pp_group, get_world_group
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
get_processor,
|
||||
get_tokenizer,
|
||||
get_tokenizer_from_processor,
|
||||
)
|
||||
from sglang.srt.layers.dp_attention import compute_dp_attention_world_info
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
||||
@@ -475,7 +479,7 @@ class Scheduler(
|
||||
revision=server_args.revision,
|
||||
use_fast=not server_args.disable_fast_image_processor,
|
||||
)
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
|
||||
@@ -54,7 +54,11 @@ from sglang.srt.disaggregation.utils import (
|
||||
TransferBackend,
|
||||
get_kv_class,
|
||||
)
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
get_processor,
|
||||
get_tokenizer,
|
||||
get_tokenizer_from_processor,
|
||||
)
|
||||
from sglang.srt.managers.io_struct import (
|
||||
AbortReq,
|
||||
BatchEmbeddingOut,
|
||||
@@ -199,7 +203,7 @@ class TokenizerManager:
|
||||
self.tokenizer = self.processor = None
|
||||
else:
|
||||
self.processor = _processor
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
else:
|
||||
self.mm_processor = get_dummy_processor()
|
||||
|
||||
@@ -21,7 +21,11 @@ import torch
|
||||
|
||||
from sglang.srt.configs.model_config import ModelConfig
|
||||
from sglang.srt.distributed import get_pp_group, get_tp_group, get_world_group
|
||||
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
||||
from sglang.srt.hf_transformers_utils import (
|
||||
get_processor,
|
||||
get_tokenizer,
|
||||
get_tokenizer_from_processor,
|
||||
)
|
||||
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
||||
from sglang.srt.managers.io_struct import (
|
||||
GetWeightsByNameReqInput,
|
||||
@@ -102,7 +106,7 @@ class TpModelWorker:
|
||||
trust_remote_code=server_args.trust_remote_code,
|
||||
revision=server_args.revision,
|
||||
)
|
||||
self.tokenizer = self.processor.tokenizer
|
||||
self.tokenizer = get_tokenizer_from_processor(self.processor)
|
||||
else:
|
||||
self.tokenizer = get_tokenizer(
|
||||
server_args.tokenizer_path,
|
||||
|
||||
@@ -290,6 +290,9 @@ class InternLM2ForCausalLM(nn.Module):
|
||||
)
|
||||
self.logits_processor = LogitsProcessor(config)
|
||||
|
||||
def get_input_embeddings(self) -> nn.Embedding:
|
||||
return self.model.tok_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
|
||||
670
python/sglang/srt/models/internvl.py
Normal file
670
python/sglang/srt/models/internvl.py
Normal file
@@ -0,0 +1,670 @@
|
||||
# Copyright 2023-2024 SGLang Team
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==========================582====================================================
|
||||
|
||||
from typing import Iterable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/7f62077af5159c625fe3ad1c812e6c1a2b93ba3b/vllm/model_executor/models/internlm2.py
|
||||
# Adapted from https://raw.githubusercontent.com/hehesangsj/sglang/refs/heads/internvl/python/sglang/srt/models/internvl.py
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from sgl_kernel.flash_attn import flash_attn_varlen_func
|
||||
from torch import nn
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
from transformers.activations import ACT2FN
|
||||
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
||||
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternTokenPairs,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
||||
from sglang.srt.models.deepseek_janus_pro import DropPath
|
||||
from sglang.srt.models.internlm2 import InternLM2ForCausalLM
|
||||
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
||||
from sglang.utils import logger
|
||||
|
||||
|
||||
class FlashAttention(nn.Module):
|
||||
"""Implement the scaled dot product attention with softmax.
|
||||
Arguments
|
||||
---------
|
||||
softmax_scale: The temperature to use for the softmax attention.
|
||||
(default: 1/sqrt(d_keys) where d_keys is computed at
|
||||
runtime)
|
||||
attention_dropout: The dropout rate to apply to the attention
|
||||
(default: 0.0)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, softmax_scale=None, attention_dropout=0.0, device=None, dtype=None
|
||||
):
|
||||
super().__init__()
|
||||
self.softmax_scale = softmax_scale
|
||||
self.dropout_p = attention_dropout
|
||||
|
||||
def forward(
|
||||
self,
|
||||
qkv,
|
||||
causal=False,
|
||||
max_s=None,
|
||||
):
|
||||
"""Implements the multihead softmax attention.
|
||||
Arguments
|
||||
---------
|
||||
qkv: The tensor containing the query, key, and value. (B, S, 3, H, D) if key_padding_mask is None
|
||||
if unpadded: (nnz, 3, h, d)
|
||||
"""
|
||||
assert qkv.dtype in [torch.float16, torch.bfloat16]
|
||||
assert qkv.is_cuda
|
||||
|
||||
batch_size, seqlen, _, nheads, d = qkv.shape
|
||||
if batch_size == 0 or seqlen == 0:
|
||||
output_shape = (batch_size, seqlen, nheads, d)
|
||||
return (
|
||||
torch.zeros(output_shape, dtype=qkv.dtype, device=qkv.device),
|
||||
None,
|
||||
)
|
||||
|
||||
qkv_reshaped = rearrange(qkv, "b s three h d -> (b s) three h d", three=3)
|
||||
q, k, v = qkv_reshaped.unbind(1)
|
||||
|
||||
max_s = seqlen
|
||||
cu_seqlens = torch.arange(
|
||||
0,
|
||||
(batch_size + 1) * seqlen,
|
||||
step=seqlen,
|
||||
dtype=torch.int32,
|
||||
device=qkv.device,
|
||||
)
|
||||
output_reshaped = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
max_s,
|
||||
max_s,
|
||||
softmax_scale=self.softmax_scale,
|
||||
causal=causal,
|
||||
)
|
||||
output = rearrange(output_reshaped, "(b s) h d -> b s h d", b=batch_size)
|
||||
return output, None
|
||||
|
||||
|
||||
class InternAttention(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.num_heads = config.num_attention_heads
|
||||
self.head_dim = self.embed_dim // self.num_heads
|
||||
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.qkv = nn.Linear(self.embed_dim, 3 * self.embed_dim, bias=config.qkv_bias)
|
||||
self.proj_drop = nn.Dropout(config.dropout)
|
||||
|
||||
self.qk_normalization = config.qk_normalization
|
||||
|
||||
if self.qk_normalization:
|
||||
self.q_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.k_norm = InternRMSNorm(self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.inner_attn = FlashAttention(softmax_scale=self.scale)
|
||||
|
||||
self.proj = nn.Linear(self.embed_dim, self.embed_dim)
|
||||
|
||||
def _flash_attn(
|
||||
self,
|
||||
x,
|
||||
):
|
||||
qkv = self.qkv(x)
|
||||
qkv = rearrange(
|
||||
qkv, "b s (three h d) -> b s three h d", three=3, h=self.num_heads
|
||||
)
|
||||
|
||||
if self.qk_normalization:
|
||||
q, k, v = qkv.unbind(2)
|
||||
q = self.q_norm(q.flatten(-2, -1)).view(q.shape)
|
||||
k = self.k_norm(k.flatten(-2, -1)).view(k.shape)
|
||||
qkv = torch.stack([q, k, v], dim=2)
|
||||
|
||||
context, _ = self.inner_attn(
|
||||
qkv,
|
||||
)
|
||||
outs = self.proj(rearrange(context, "b s h d -> b s (h d)"))
|
||||
outs = self.proj_drop(outs)
|
||||
return outs
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
x = self._flash_attn(hidden_states)
|
||||
return x
|
||||
|
||||
|
||||
class InternVisionEmbeddings(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.embed_dim = config.hidden_size
|
||||
self.image_size = config.image_size
|
||||
self.patch_size = config.patch_size
|
||||
|
||||
self.class_embedding = nn.Parameter(
|
||||
torch.randn(1, 1, self.embed_dim),
|
||||
)
|
||||
|
||||
self.patch_embedding = nn.Conv2d(
|
||||
in_channels=3,
|
||||
out_channels=self.embed_dim,
|
||||
kernel_size=self.patch_size,
|
||||
stride=self.patch_size,
|
||||
)
|
||||
|
||||
self.num_patches = (self.image_size // self.patch_size) ** 2
|
||||
self.num_positions = self.num_patches + 1
|
||||
|
||||
self.position_embedding = nn.Parameter(
|
||||
torch.randn(1, self.num_positions, self.embed_dim)
|
||||
)
|
||||
|
||||
def _get_pos_embed(self, pos_embed, H, W):
|
||||
target_dtype = pos_embed.dtype
|
||||
pos_embed = (
|
||||
pos_embed.float()
|
||||
.reshape(
|
||||
1,
|
||||
self.image_size // self.patch_size,
|
||||
self.image_size // self.patch_size,
|
||||
-1,
|
||||
)
|
||||
.permute(0, 3, 1, 2)
|
||||
)
|
||||
pos_embed = (
|
||||
F.interpolate(pos_embed, size=(H, W), mode="bicubic", align_corners=False)
|
||||
.reshape(1, -1, H * W)
|
||||
.permute(0, 2, 1)
|
||||
.to(target_dtype)
|
||||
)
|
||||
return pos_embed
|
||||
|
||||
def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
||||
target_dtype = self.patch_embedding.weight.dtype
|
||||
patch_embeds = self.patch_embedding(
|
||||
pixel_values
|
||||
) # shape = [*, channel, width, height]
|
||||
batch_size, _, height, width = patch_embeds.shape
|
||||
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
||||
class_embeds = self.class_embedding.expand(batch_size, 1, -1).to(target_dtype)
|
||||
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
||||
position_embedding = torch.cat(
|
||||
[
|
||||
self.position_embedding[:, :1, :],
|
||||
self._get_pos_embed(self.position_embedding[:, 1:, :], height, width),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
embeddings = embeddings + position_embedding.to(target_dtype)
|
||||
return embeddings
|
||||
|
||||
|
||||
class InternRMSNorm(nn.Module):
|
||||
def __init__(self, hidden_size, eps=1e-6):
|
||||
super().__init__()
|
||||
self.weight = nn.Parameter(torch.ones(hidden_size))
|
||||
self.variance_epsilon = eps
|
||||
|
||||
def forward(self, hidden_states):
|
||||
input_dtype = hidden_states.dtype
|
||||
hidden_states = hidden_states.to(torch.float32)
|
||||
variance = hidden_states.pow(2).mean(-1, keepdim=True)
|
||||
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
||||
return self.weight * hidden_states.to(input_dtype)
|
||||
|
||||
|
||||
class InternMLP(nn.Module):
|
||||
def __init__(self, config: PretrainedConfig):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.act = ACT2FN[config.hidden_act]
|
||||
self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
|
||||
self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.fc1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.fc2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
NORM2FN = {
|
||||
"rms_norm": InternRMSNorm,
|
||||
"layer_norm": nn.LayerNorm,
|
||||
}
|
||||
|
||||
|
||||
class InternVisionEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
drop_path_rate: float,
|
||||
quant_config: QuantizationConfig = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.embed_dim = config.hidden_size
|
||||
self.intermediate_size = config.intermediate_size
|
||||
self.norm_type = config.norm_type
|
||||
self.attn = InternAttention(config)
|
||||
self.mlp = InternMLP(config)
|
||||
self.norm1 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
||||
self.norm2 = NORM2FN[self.norm_type](self.embed_dim, eps=config.layer_norm_eps)
|
||||
|
||||
self.ls1 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
||||
self.ls2 = nn.Parameter(config.initializer_factor * torch.ones(self.embed_dim))
|
||||
self.drop_path1 = (
|
||||
DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
||||
)
|
||||
self.drop_path2 = (
|
||||
DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity()
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
) -> Tuple[
|
||||
torch.FloatTensor,
|
||||
Optional[torch.FloatTensor],
|
||||
Optional[Tuple[torch.FloatTensor]],
|
||||
]:
|
||||
"""
|
||||
Args:
|
||||
hidden_states (`Tuple[torch.FloatTensor, Optional[torch.FloatTensor]]`): input to the layer of shape `(batch, seq_len, embed_dim)`
|
||||
"""
|
||||
hidden_states = hidden_states + self.drop_path1(
|
||||
self.attn(self.norm1(hidden_states).to(hidden_states.dtype)) * self.ls1
|
||||
)
|
||||
|
||||
hidden_states = hidden_states + self.drop_path2(
|
||||
self.mlp(self.norm2(hidden_states).to(hidden_states.dtype)) * self.ls2
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class InternVisionEncoder(nn.Module):
|
||||
"""
|
||||
Transformer encoder consisting of `config.num_hidden_layers` self attention layers. Each layer is a
|
||||
[`InternEncoderLayer`].
|
||||
|
||||
Args:
|
||||
config (`InternConfig`):
|
||||
The corresponding vision configuration for the `InternEncoder`.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
# stochastic depth decay rule
|
||||
dpr = [
|
||||
x.item()
|
||||
for x in torch.linspace(0, config.drop_path_rate, config.num_hidden_layers)
|
||||
]
|
||||
self.layers = nn.ModuleList(
|
||||
[
|
||||
InternVisionEncoderLayer(config, dpr[idx], quant_config)
|
||||
for idx in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
inputs_embeds,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
) -> Union[Tuple, BaseModelOutput]:
|
||||
r"""
|
||||
Args:
|
||||
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
|
||||
Embedded representation of the inputs. Should be float, not int tokens.
|
||||
output_hidden_states (`bool`, *optional*):
|
||||
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
|
||||
for more detail.
|
||||
return_dict (`bool`, *optional*):
|
||||
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
|
||||
"""
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
encoder_states = () if output_hidden_states else None
|
||||
hidden_states = inputs_embeds
|
||||
|
||||
for idx, encoder_layer in enumerate(self.layers):
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
layer_outputs = encoder_layer(
|
||||
hidden_states,
|
||||
)
|
||||
hidden_states = layer_outputs
|
||||
|
||||
if output_hidden_states:
|
||||
encoder_states = encoder_states + (hidden_states,)
|
||||
|
||||
if not return_dict:
|
||||
return tuple(v for v in [hidden_states, encoder_states] if v is not None)
|
||||
return BaseModelOutput(
|
||||
last_hidden_state=hidden_states, hidden_states=encoder_states
|
||||
)
|
||||
|
||||
|
||||
class InternVisionModel(PreTrainedModel):
|
||||
main_input_name = "pixel_values"
|
||||
_supports_flash_attn_2 = True
|
||||
config_class = PretrainedConfig
|
||||
_no_split_modules = ["InternVisionEncoderLayer"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
self.embeddings = InternVisionEmbeddings(
|
||||
config,
|
||||
)
|
||||
self.encoder = InternVisionEncoder(config, quant_config)
|
||||
|
||||
def resize_pos_embeddings(self, old_size, new_size, patch_size):
|
||||
pos_emb = self.embeddings.position_embedding
|
||||
_, num_positions, embed_dim = pos_emb.shape
|
||||
cls_emb = pos_emb[:, :1, :]
|
||||
pos_emb = (
|
||||
pos_emb[:, 1:, :]
|
||||
.reshape(1, old_size // patch_size, old_size // patch_size, -1)
|
||||
.permute(0, 3, 1, 2)
|
||||
)
|
||||
pos_emb = F.interpolate(
|
||||
pos_emb.float(),
|
||||
size=new_size // patch_size,
|
||||
mode="bicubic",
|
||||
align_corners=False,
|
||||
)
|
||||
pos_emb = pos_emb.to(cls_emb.dtype).reshape(1, embed_dim, -1).permute(0, 2, 1)
|
||||
pos_emb = torch.cat([cls_emb, pos_emb], dim=1)
|
||||
self.embeddings.position_embedding = nn.Parameter(pos_emb)
|
||||
self.embeddings.image_size = new_size
|
||||
logger.info(
|
||||
"Resized position embeddings from {} to {}".format(old_size, new_size)
|
||||
)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embeddings
|
||||
|
||||
def forward(
|
||||
self,
|
||||
pixel_values: Optional[torch.FloatTensor] = None,
|
||||
output_hidden_states: Optional[bool] = None,
|
||||
return_dict: Optional[bool] = None,
|
||||
pixel_embeds: Optional[torch.FloatTensor] = None,
|
||||
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
||||
pixel_values = pixel_values.to(device=self.device, dtype=self.dtype)
|
||||
output_hidden_states = (
|
||||
output_hidden_states
|
||||
if output_hidden_states is not None
|
||||
else self.config.output_hidden_states
|
||||
)
|
||||
return_dict = (
|
||||
return_dict if return_dict is not None else self.config.use_return_dict
|
||||
)
|
||||
|
||||
if pixel_values is None and pixel_embeds is None:
|
||||
raise ValueError("You have to specify pixel_values or pixel_embeds")
|
||||
|
||||
if pixel_embeds is not None:
|
||||
hidden_states = pixel_embeds
|
||||
else:
|
||||
if len(pixel_values.shape) == 4:
|
||||
hidden_states = self.embeddings(pixel_values)
|
||||
else:
|
||||
raise ValueError(f"wrong pixel_values size: {pixel_values.shape}")
|
||||
encoder_outputs = self.encoder(
|
||||
inputs_embeds=hidden_states,
|
||||
output_hidden_states=output_hidden_states,
|
||||
return_dict=return_dict,
|
||||
)
|
||||
last_hidden_state = encoder_outputs.last_hidden_state
|
||||
pooled_output = last_hidden_state[:, 0, :]
|
||||
|
||||
if not return_dict:
|
||||
return (last_hidden_state, pooled_output) + encoder_outputs[1:]
|
||||
|
||||
return BaseModelOutputWithPooling(
|
||||
last_hidden_state=last_hidden_state,
|
||||
pooler_output=pooled_output,
|
||||
hidden_states=encoder_outputs.hidden_states,
|
||||
attentions=encoder_outputs.attentions,
|
||||
)
|
||||
|
||||
|
||||
class InternVLChatModel(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: PretrainedConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
use_flash_attn=True,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.quant_config = quant_config
|
||||
|
||||
image_size = config.force_image_size or config.vision_config.image_size
|
||||
patch_size = config.vision_config.patch_size
|
||||
self.patch_size = patch_size
|
||||
self.select_layer = config.select_layer
|
||||
self.template = config.template
|
||||
self.num_image_token = int(
|
||||
(image_size // patch_size) ** 2 * (config.downsample_ratio**2)
|
||||
)
|
||||
self.downsample_ratio = config.downsample_ratio
|
||||
self.ps_version = config.ps_version
|
||||
|
||||
config.vision_config.use_flash_attn = True if use_flash_attn else False
|
||||
config.llm_config._attn_implementation = (
|
||||
"flash_attention_2" if use_flash_attn else "eager"
|
||||
)
|
||||
|
||||
logger.info(f"num_image_token: {self.num_image_token}")
|
||||
logger.info(f"ps_version: {self.ps_version}")
|
||||
|
||||
self.vision_model = InternVisionModel(config.vision_config)
|
||||
if config.llm_config.architectures[0] == "Qwen2ForCausalLM":
|
||||
self.language_model = Qwen2ForCausalLM(
|
||||
config=config.llm_config, quant_config=quant_config
|
||||
)
|
||||
elif config.llm_config.architectures[0] == "InternLM2ForCausalLM":
|
||||
self.language_model = InternLM2ForCausalLM(
|
||||
config=config.llm_config, quant_config=quant_config
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"{config.llm_config.architectures[0]} is not implemented."
|
||||
)
|
||||
|
||||
vit_hidden_size = config.vision_config.hidden_size
|
||||
llm_hidden_size = config.llm_config.hidden_size
|
||||
|
||||
self.mlp1 = nn.Sequential(
|
||||
nn.LayerNorm(vit_hidden_size * int(1 / self.downsample_ratio) ** 2),
|
||||
nn.Linear(
|
||||
vit_hidden_size * int(1 / self.downsample_ratio) ** 2, llm_hidden_size
|
||||
),
|
||||
nn.GELU(),
|
||||
nn.Linear(llm_hidden_size, llm_hidden_size),
|
||||
)
|
||||
|
||||
def pixel_shuffle(self, x, scale_factor=0.5):
|
||||
n, w, h, c = x.size()
|
||||
# N, W, H, C --> N, W, H * scale, C // scale
|
||||
x = x.view(n, w, int(h * scale_factor), int(c / scale_factor))
|
||||
# N, W, H * scale, C // scale --> N, H * scale, W, C // scale
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
# N, H * scale, W, C // scale --> N, H * scale, W * scale, C // (scale ** 2)
|
||||
x = x.view(
|
||||
n,
|
||||
int(h * scale_factor),
|
||||
int(w * scale_factor),
|
||||
int(c / (scale_factor * scale_factor)),
|
||||
)
|
||||
if self.ps_version == "v1":
|
||||
logger.warn(
|
||||
"In ps_version 'v1', the height and width have not been swapped back, "
|
||||
"which results in a transposed image."
|
||||
)
|
||||
else:
|
||||
x = x.permute(0, 2, 1, 3).contiguous()
|
||||
return x
|
||||
|
||||
def extract_feature(self, pixel_values):
|
||||
if self.select_layer == -1:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=False, return_dict=True
|
||||
).last_hidden_state
|
||||
else:
|
||||
vit_embeds = self.vision_model(
|
||||
pixel_values=pixel_values, output_hidden_states=True, return_dict=True
|
||||
).hidden_states[self.select_layer]
|
||||
vit_embeds = vit_embeds[:, 1:, :]
|
||||
|
||||
h = w = int(vit_embeds.shape[1] ** 0.5)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], h, w, -1)
|
||||
vit_embeds = self.pixel_shuffle(vit_embeds, scale_factor=self.downsample_ratio)
|
||||
vit_embeds = vit_embeds.reshape(vit_embeds.shape[0], -1, vit_embeds.shape[-1])
|
||||
vit_embeds = self.mlp1(vit_embeds)
|
||||
return vit_embeds
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]):
|
||||
"""
|
||||
Projects the last hidden state from the vision model into language model space.
|
||||
|
||||
Returns:
|
||||
image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
|
||||
"""
|
||||
pixel_values = torch.cat([item.pixel_values for item in items])
|
||||
image_features = self.extract_feature(pixel_values)
|
||||
return image_features
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
input_embeds: torch.Tensor = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
hs = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
return hs
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
im_start_id: int = mm_inputs.im_start_id
|
||||
im_end_id: int = mm_inputs.im_end_id
|
||||
|
||||
media_token_pairs = [(im_start_id, im_end_id)]
|
||||
helper = MultiModalityDataPaddingPatternTokenPairs(media_token_pairs)
|
||||
|
||||
return helper.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("gate_up_proj", "w1", 0),
|
||||
("gate_up_proj", "w3", 1),
|
||||
]
|
||||
elif "Qwen2ForCausalLM" in self.config.llm_config.architectures:
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
("qkv_proj", "q_proj", "q"),
|
||||
("qkv_proj", "k_proj", "k"),
|
||||
("qkv_proj", "v_proj", "v"),
|
||||
("gate_up_proj", "gate_proj", 0),
|
||||
("gate_up_proj", "up_proj", 1),
|
||||
]
|
||||
params_dict = dict(self.named_parameters())
|
||||
|
||||
for name, loaded_weight in weights:
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id)
|
||||
break
|
||||
else:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
param = params_dict[name]
|
||||
if "wqkv" in name:
|
||||
config = self.config
|
||||
kv_groups = config.num_attention_heads // config.num_key_value_heads
|
||||
head_dim = config.hidden_size // config.num_attention_heads
|
||||
loaded_weight = loaded_weight.view(
|
||||
-1, 2 + kv_groups, head_dim, loaded_weight.shape[-1]
|
||||
)
|
||||
wq, wk, wv = torch.split(loaded_weight, [kv_groups, 1, 1], dim=1)
|
||||
wq = wq.reshape(-1, wq.shape[-1])
|
||||
wk = wk.reshape(-1, wk.shape[-1])
|
||||
wv = wv.reshape(-1, wv.shape[-1])
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, wq, "q")
|
||||
weight_loader(param, wk, "k")
|
||||
weight_loader(param, wv, "v")
|
||||
else:
|
||||
weight_loader = getattr(
|
||||
param, "weight_loader", default_weight_loader
|
||||
)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
|
||||
EntryClass = InternVLChatModel
|
||||
Reference in New Issue
Block a user