[Feature] add support kimi vl model (#5383)

Co-authored-by: wenju.li <wenju.li@deepctr.cn>
This commit is contained in:
liwenju0
2025-04-30 12:31:19 +08:00
committed by GitHub
parent 403b855a22
commit 8fefdd32c7
13 changed files with 1189 additions and 11 deletions

View File

@@ -3,6 +3,8 @@ from sglang.srt.configs.dbrx import DbrxConfig
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
from sglang.srt.configs.exaone import ExaoneConfig
from sglang.srt.configs.janus_pro import MultiModalityConfig
from sglang.srt.configs.kimi_vl import KimiVLConfig
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
__all__ = [
"ExaoneConfig",
@@ -10,4 +12,6 @@ __all__ = [
"DbrxConfig",
"DeepseekVL2Config",
"MultiModalityConfig",
"KimiVLConfig",
"MoonViTConfig",
]

View File

@@ -0,0 +1,38 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
from typing import Optional, Union
from transformers.configuration_utils import PretrainedConfig
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
class KimiVLConfig(PretrainedConfig):
model_type = "kimi_vl"
def __init__(
self,
vision_config: Optional[Union[dict, MoonViTConfig]] = None,
text_config: Optional[Union[dict, DeepseekV2Config]] = None,
ignore_index: int = -100,
media_placeholder_token_id: int = 163605,
pad_token_id: int = 0,
**kwargs
):
if vision_config is None:
vision_config = MoonViTConfig()
elif isinstance(vision_config, dict):
vision_config = MoonViTConfig(**vision_config)
self.vision_config = vision_config
if text_config is None:
text_config = DeepseekV2Config()
elif isinstance(text_config, dict):
text_config = DeepseekV2Config(**text_config)
self.text_config = text_config
self.ignore_index = ignore_index
self.media_placeholder_token_id = media_placeholder_token_id
super().__init__(pad_token_id=pad_token_id, **kwargs)

View File

@@ -0,0 +1,32 @@
# SPDX-License-Identifier: Apache-2.0
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
from transformers.configuration_utils import PretrainedConfig
class MoonViTConfig(PretrainedConfig):
model_type = "moonvit"
def __init__(
self,
patch_size: int = 14,
init_pos_emb_height: int = 64,
init_pos_emb_width: int = 64,
num_attention_heads: int = 16,
num_hidden_layers: int = 27,
hidden_size: int = 1152,
intermediate_size: int = 4304,
merge_kernel_size: tuple[int, int] = (2, 2),
**kwargs,
):
super().__init__(**kwargs)
self.patch_size = patch_size
# Positional embedding config
self.init_pos_emb_height = init_pos_emb_height
self.init_pos_emb_width = init_pos_emb_width
# Transformer config
self.num_hidden_layers = num_hidden_layers
self.num_attention_heads = num_attention_heads
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
# Patch merger config
self.merge_kernel_size = merge_kernel_size

View File

@@ -176,6 +176,13 @@ class ModelConfig:
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
self.head_dim = 256
self.attention_arch = AttentionArch.MLA
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
self.v_head_dim = self.hf_text_config.v_head_dim
self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
else:
self.attention_arch = AttentionArch.MHA
@@ -530,6 +537,7 @@ multimodal_model_archs = [
"Qwen2VLForConditionalGeneration",
"Qwen2_5_VLForConditionalGeneration",
"CLIPModel",
"KimiVLForConditionalGeneration",
]