[Feature] add support kimi vl model (#5383)
Co-authored-by: wenju.li <wenju.li@deepctr.cn>
This commit is contained in:
@@ -28,4 +28,5 @@ python3 -m sglang.launch_server \
|
||||
| **LLaVA** (v1.5 & v1.6) | *e.g.* `liuhaotian/llava-v1.5-13b` | `vicuna_v1.1` | Open vision-chat models that add an image encoder to LLaMA/Vicuna (e.g. LLaMA2 13B) for following multimodal instruction prompts. |
|
||||
| **LLaVA-NeXT** (8B, 72B) | `lmms-lab/llava-next-72b` | `chatml-llava` | Improved LLaVA models (with an 8B Llama3 version and a 72B version) offering enhanced visual instruction-following and accuracy on multimodal benchmarks. |
|
||||
| **LLaVA-OneVision** | `lmms-lab/llava-onevision-qwen2-7b-ov` | `chatml-llava` | Enhanced LLaVA variant integrating Qwen as the backbone; supports multiple images (and even video frames) as inputs via an OpenAI Vision API-compatible format. |
|
||||
| **Gemma 3 (Multimodal)** | `google/gemma-3-4b-it` | `gemma-it` | Gemma 3’s larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context. |
|
||||
| **Gemma 3 (Multimodal)** | `google/gemma-3-4b-it` | `gemma-it` | Gemma 3’s larger models (4B, 12B, 27B) accept images (each image encoded as 256 tokens) alongside text in a combined 128K-token context. |
|
||||
| **Kimi-VL** (A3B) | `moonshotai/Kimi-VL-A3B-Instruct` | `kimi-vl` | Kimi-VL is a multimodal model that can understand and generate text from images. |
|
||||
|
||||
@@ -42,6 +42,7 @@ runtime_common = [
|
||||
"uvicorn",
|
||||
"uvloop",
|
||||
"xgrammar==0.1.17",
|
||||
"blobfile==3.0.0"
|
||||
]
|
||||
|
||||
srt = [
|
||||
|
||||
@@ -3,6 +3,8 @@ from sglang.srt.configs.dbrx import DbrxConfig
|
||||
from sglang.srt.configs.deepseekvl2 import DeepseekVL2Config
|
||||
from sglang.srt.configs.exaone import ExaoneConfig
|
||||
from sglang.srt.configs.janus_pro import MultiModalityConfig
|
||||
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
||||
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
||||
|
||||
__all__ = [
|
||||
"ExaoneConfig",
|
||||
@@ -10,4 +12,6 @@ __all__ = [
|
||||
"DbrxConfig",
|
||||
"DeepseekVL2Config",
|
||||
"MultiModalityConfig",
|
||||
"KimiVLConfig",
|
||||
"MoonViTConfig",
|
||||
]
|
||||
|
||||
38
python/sglang/srt/configs/kimi_vl.py
Normal file
38
python/sglang/srt/configs/kimi_vl.py
Normal file
@@ -0,0 +1,38 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
|
||||
from typing import Optional, Union
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
|
||||
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
||||
|
||||
|
||||
class KimiVLConfig(PretrainedConfig):
|
||||
model_type = "kimi_vl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config: Optional[Union[dict, MoonViTConfig]] = None,
|
||||
text_config: Optional[Union[dict, DeepseekV2Config]] = None,
|
||||
ignore_index: int = -100,
|
||||
media_placeholder_token_id: int = 163605,
|
||||
pad_token_id: int = 0,
|
||||
**kwargs
|
||||
):
|
||||
if vision_config is None:
|
||||
vision_config = MoonViTConfig()
|
||||
elif isinstance(vision_config, dict):
|
||||
vision_config = MoonViTConfig(**vision_config)
|
||||
self.vision_config = vision_config
|
||||
|
||||
if text_config is None:
|
||||
text_config = DeepseekV2Config()
|
||||
elif isinstance(text_config, dict):
|
||||
text_config = DeepseekV2Config(**text_config)
|
||||
self.text_config = text_config
|
||||
|
||||
self.ignore_index = ignore_index
|
||||
self.media_placeholder_token_id = media_placeholder_token_id
|
||||
|
||||
super().__init__(pad_token_id=pad_token_id, **kwargs)
|
||||
32
python/sglang/srt/configs/kimi_vl_moonvit.py
Normal file
32
python/sglang/srt/configs/kimi_vl_moonvit.py
Normal file
@@ -0,0 +1,32 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/configuration_kimi_vl.py
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
|
||||
|
||||
class MoonViTConfig(PretrainedConfig):
|
||||
model_type = "moonvit"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
patch_size: int = 14,
|
||||
init_pos_emb_height: int = 64,
|
||||
init_pos_emb_width: int = 64,
|
||||
num_attention_heads: int = 16,
|
||||
num_hidden_layers: int = 27,
|
||||
hidden_size: int = 1152,
|
||||
intermediate_size: int = 4304,
|
||||
merge_kernel_size: tuple[int, int] = (2, 2),
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.patch_size = patch_size
|
||||
# Positional embedding config
|
||||
self.init_pos_emb_height = init_pos_emb_height
|
||||
self.init_pos_emb_width = init_pos_emb_width
|
||||
# Transformer config
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
# Patch merger config
|
||||
self.merge_kernel_size = merge_kernel_size
|
||||
@@ -176,6 +176,13 @@ class ModelConfig:
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
|
||||
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
|
||||
elif "KimiVLForConditionalGeneration" in self.hf_config.architectures:
|
||||
self.head_dim = 256
|
||||
self.attention_arch = AttentionArch.MLA
|
||||
self.kv_lora_rank = self.hf_text_config.kv_lora_rank
|
||||
self.qk_rope_head_dim = self.hf_text_config.qk_rope_head_dim
|
||||
self.v_head_dim = self.hf_text_config.v_head_dim
|
||||
self.qk_nope_head_dim = self.hf_text_config.qk_nope_head_dim
|
||||
else:
|
||||
self.attention_arch = AttentionArch.MHA
|
||||
|
||||
@@ -530,6 +537,7 @@ multimodal_model_archs = [
|
||||
"Qwen2VLForConditionalGeneration",
|
||||
"Qwen2_5_VLForConditionalGeneration",
|
||||
"CLIPModel",
|
||||
"KimiVLForConditionalGeneration",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -806,6 +806,24 @@ register_conv_template(
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/chat_template.jinja
|
||||
register_conv_template(
|
||||
Conversation(
|
||||
name="kimi-vl",
|
||||
system_message="You are a helpful assistant",
|
||||
system_template="<|im_system|>system<|im_middle|>{system_message}",
|
||||
roles=(
|
||||
"<|im_user|>user<|im_middle|>",
|
||||
"<|im_assistant|>assistant<|im_middle|>",
|
||||
),
|
||||
messages=[],
|
||||
sep="<|im_end|>",
|
||||
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
||||
stop_str="<|im_end|>",
|
||||
image_token="<|media_start|>image<|media_content|><|media_pad|><|media_end|>",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_deepseek_janus_pro(model_path: str):
|
||||
@@ -888,3 +906,10 @@ def match_openbmb_minicpm(model_path: str):
|
||||
return "minicpmv"
|
||||
elif "minicpm-o" in model_path:
|
||||
return "minicpmo"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_moonshot_kimivl(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "kimi" in model_path and "vl" in model_path:
|
||||
return "kimi-vl"
|
||||
|
||||
@@ -35,6 +35,7 @@ from sglang.srt.configs import (
|
||||
DbrxConfig,
|
||||
DeepseekVL2Config,
|
||||
ExaoneConfig,
|
||||
KimiVLConfig,
|
||||
MultiModalityConfig,
|
||||
)
|
||||
from sglang.srt.connector import create_remote_connector
|
||||
@@ -46,6 +47,7 @@ _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
|
||||
ExaoneConfig.model_type: ExaoneConfig,
|
||||
DeepseekVL2Config.model_type: DeepseekVL2Config,
|
||||
MultiModalityConfig.model_type: MultiModalityConfig,
|
||||
KimiVLConfig.model_type: KimiVLConfig,
|
||||
}
|
||||
|
||||
for name, cls in _CONFIG_REGISTRY.items():
|
||||
|
||||
73
python/sglang/srt/managers/multimodal_processors/kimi_vl.py
Normal file
73
python/sglang/srt/managers/multimodal_processors/kimi_vl.py
Normal file
@@ -0,0 +1,73 @@
|
||||
import asyncio
|
||||
import math
|
||||
from typing import List, Union
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
BaseMultimodalProcessor as SGLangBaseProcessor,
|
||||
)
|
||||
from sglang.srt.managers.multimodal_processors.base_processor import (
|
||||
MultimodalSpecialTokens,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import Modality, MultimodalDataItem
|
||||
from sglang.srt.models.kimi_vl import KimiVLForConditionalGeneration
|
||||
|
||||
|
||||
# Compatible with KimiVLForConditionalGeneration
|
||||
class KimiVLImageProcessor(SGLangBaseProcessor):
|
||||
models = [KimiVLForConditionalGeneration]
|
||||
|
||||
def __init__(self, hf_config, server_args, _processor):
|
||||
super().__init__(hf_config, server_args, _processor)
|
||||
self.IMAGE_TOKEN = "<|media_pad|>"
|
||||
self.im_token_id = _processor.tokenizer.convert_tokens_to_ids(self.IMAGE_TOKEN)
|
||||
|
||||
self.im_start = "<|media_start|>"
|
||||
self.im_start_id = _processor.tokenizer.convert_tokens_to_ids(self.im_start)
|
||||
|
||||
self.im_end = "<|media_end|>"
|
||||
self.im_end_id = _processor.tokenizer.convert_tokens_to_ids(self.im_end)
|
||||
|
||||
self.im_content = "<|media_content|>"
|
||||
self.im_content_id = _processor.tokenizer.convert_tokens_to_ids(self.im_content)
|
||||
|
||||
async def process_mm_data_async(
|
||||
self,
|
||||
image_data: List[Union[str, bytes]],
|
||||
input_text,
|
||||
request_obj,
|
||||
max_req_input_len,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
if not image_data:
|
||||
return None
|
||||
if isinstance(image_data, str):
|
||||
image_data = [image_data]
|
||||
|
||||
base_output = self.load_mm_data(
|
||||
prompt=input_text,
|
||||
image_data=image_data,
|
||||
multimodal_tokens=MultimodalSpecialTokens(image_token=self.IMAGE_TOKEN),
|
||||
max_req_input_len=max_req_input_len,
|
||||
)
|
||||
ret = self.process_mm_data(
|
||||
input_text=base_output.input_text,
|
||||
images=base_output.images,
|
||||
)
|
||||
return {
|
||||
"input_ids": ret["input_ids"].flatten().tolist(),
|
||||
"mm_items": [
|
||||
MultimodalDataItem(
|
||||
pixel_values=ret["pixel_values"],
|
||||
image_grid_thws=ret["image_grid_hws"],
|
||||
modality=Modality.IMAGE,
|
||||
)
|
||||
],
|
||||
"im_token_id": self.im_token_id,
|
||||
"im_start_id": self.im_start_id,
|
||||
"im_end_id": self.im_end_id,
|
||||
"im_content_id": self.im_content_id,
|
||||
}
|
||||
@@ -752,7 +752,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
||||
q_nope_out = q_nope_out.transpose(0, 1)
|
||||
|
||||
k_nope = latent_cache[..., : self.kv_lora_rank]
|
||||
k_nope = self.kv_a_layernorm(k_nope).unsqueeze(1)
|
||||
k_nope = self.kv_a_layernorm(k_nope.contiguous()).unsqueeze(1)
|
||||
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
||||
|
||||
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
||||
@@ -1391,6 +1391,9 @@ class DeepseekV2Model(nn.Module):
|
||||
|
||||
self.dp_size = get_attention_dp_size()
|
||||
|
||||
def get_input_embeddings(self) -> torch.Tensor:
|
||||
return self.embed_tokens
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
||||
308
python/sglang/srt/models/kimi_vl.py
Normal file
308
python/sglang/srt/models/kimi_vl.py
Normal file
@@ -0,0 +1,308 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# ruff: noqa: E501
|
||||
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
|
||||
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
|
||||
#
|
||||
# Licensing Information:
|
||||
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
|
||||
# - Other parts of the code are licensed under the MIT License.
|
||||
#
|
||||
# Apache License, Version 2.0:
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License:
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import copy
|
||||
import logging
|
||||
import math
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Iterable, List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from transformers.activations import GELUActivation
|
||||
|
||||
from sglang.srt.configs import KimiVLConfig
|
||||
from sglang.srt.configs.deepseekvl2 import DeepseekV2Config
|
||||
from sglang.srt.configs.kimi_vl import KimiVLConfig
|
||||
from sglang.srt.configs.kimi_vl_moonvit import MoonViTConfig
|
||||
from sglang.srt.distributed import (
|
||||
get_tensor_model_parallel_rank,
|
||||
get_tensor_model_parallel_world_size,
|
||||
)
|
||||
from sglang.srt.layers.activation import QuickGELU
|
||||
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
||||
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
||||
from sglang.srt.managers.mm_utils import (
|
||||
MultiModalityDataPaddingPatternMultimodalTokens,
|
||||
general_mm_embed_routine,
|
||||
)
|
||||
from sglang.srt.managers.schedule_batch import MultimodalDataItem, MultimodalInputs
|
||||
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
||||
from sglang.srt.model_loader.weight_utils import (
|
||||
default_weight_loader,
|
||||
maybe_remap_kv_scale_name,
|
||||
)
|
||||
from sglang.srt.models.deepseek_v2 import DeepseekV2ForCausalLM
|
||||
from sglang.srt.models.kimi_vl_moonvit import MoonVitPretrainedModel
|
||||
from sglang.srt.utils import add_prefix
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# For dummy input only
|
||||
@dataclass
|
||||
class MaxImageTokenMeta:
|
||||
width: int = 1024
|
||||
height: int = 1024
|
||||
|
||||
|
||||
class KimiVLMultiModalProjector(nn.Module):
|
||||
|
||||
def __init__(self, config: KimiVLConfig):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = (
|
||||
config.vision_config.hidden_size
|
||||
* config.vision_config.merge_kernel_size[0]
|
||||
* config.vision_config.merge_kernel_size[1]
|
||||
)
|
||||
|
||||
self.pre_norm = torch.nn.LayerNorm(config.vision_config.hidden_size, eps=1e-5)
|
||||
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
|
||||
self.act = GELUActivation()
|
||||
self.act = QuickGELU()
|
||||
self.linear_2 = nn.Linear(
|
||||
self.hidden_size, config.text_config.hidden_size, bias=True
|
||||
)
|
||||
|
||||
def forward(self, image_features: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.pre_norm(image_features).view(-1, self.hidden_size)
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class KimiVLForConditionalGeneration(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
config: KimiVLConfig,
|
||||
quant_config: Optional[QuantizationConfig] = None,
|
||||
prefix: str = "",
|
||||
**kwargs, # fix init_tts argument error
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
assert isinstance(config.vision_config, MoonViTConfig)
|
||||
|
||||
self.vision_tower = MoonVitPretrainedModel(config.vision_config)
|
||||
|
||||
self.multi_modal_projector = KimiVLMultiModalProjector(config=config)
|
||||
self.quant_config = quant_config
|
||||
text_config = copy.deepcopy(config.text_config)
|
||||
text_config.architectures = ["DeepseekV2ForCausalLM"]
|
||||
self.language_model = DeepseekV2ForCausalLM(
|
||||
config=text_config,
|
||||
quant_config=quant_config,
|
||||
prefix=add_prefix("language_model", prefix),
|
||||
)
|
||||
|
||||
def get_image_feature(self, items: List[MultimodalDataItem]) -> torch.Tensor:
|
||||
pixel_values = (
|
||||
torch.cat([item.pixel_values for item in items], dim=0)
|
||||
.type(self.vision_tower.dtype)
|
||||
.to(self.vision_tower.device)
|
||||
)
|
||||
image_grid_thws = torch.concat(
|
||||
[item.image_grid_thws for item in items], dim=0
|
||||
).to(self.vision_tower.device)
|
||||
image_features = self.vision_tower(pixel_values, image_grid_thws)
|
||||
assert isinstance(image_features, list)
|
||||
# lengths = [x.shape[0] for x in image_features]
|
||||
res = self.multi_modal_projector(torch.cat(image_features)) # .split(lengths)
|
||||
return res
|
||||
|
||||
def pad_input_ids(self, input_ids: List[int], mm_inputs: MultimodalInputs):
|
||||
# Get all special token IDs
|
||||
pattern = MultiModalityDataPaddingPatternMultimodalTokens(mm_inputs.im_token_id)
|
||||
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
forward_batch: ForwardBatch,
|
||||
get_embedding: bool = False,
|
||||
):
|
||||
hidden_states = general_mm_embed_routine(
|
||||
input_ids=input_ids,
|
||||
forward_batch=forward_batch,
|
||||
language_model=self.language_model,
|
||||
image_data_embedding_func=self.get_image_feature,
|
||||
positions=positions,
|
||||
)
|
||||
|
||||
return hidden_states
|
||||
|
||||
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
||||
config = self.config.text_config
|
||||
_KEYS_TO_MODIFY_MAPPING = {
|
||||
# "language_model.lm_head": "lm_head",
|
||||
# "language_model.model": "language_model",
|
||||
}
|
||||
# only doing this for language model part for now.
|
||||
stacked_params_mapping = [
|
||||
# (param_name, shard_name, shard_id)
|
||||
(".gate_up_proj", ".gate_proj", 0),
|
||||
(".gate_up_proj", ".up_proj", 1),
|
||||
]
|
||||
if not config.use_mla:
|
||||
stacked_params_mapping += [
|
||||
(".qkv_proj", ".q_proj", "q"),
|
||||
(".qkv_proj", ".k_proj", "k"),
|
||||
(".qkv_proj", ".v_proj", "v"),
|
||||
]
|
||||
if getattr(config, "n_routed_experts", None):
|
||||
# Params for weights, fp8 weight scales, fp8 activation scales
|
||||
# (param_name, weight_name, expert_id, shard_id)
|
||||
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
||||
ckpt_gate_proj_name="gate_proj",
|
||||
ckpt_down_proj_name="down_proj",
|
||||
ckpt_up_proj_name="up_proj",
|
||||
num_experts=config.n_routed_experts,
|
||||
)
|
||||
else:
|
||||
expert_params_mapping = []
|
||||
|
||||
params_dict = dict(self.named_parameters())
|
||||
for args in weights:
|
||||
name, loaded_weight = args[:2]
|
||||
kwargs = args[2] if len(args) > 2 else {}
|
||||
if "rotary_emb.inv_freq" in name:
|
||||
continue
|
||||
|
||||
spec_layer = get_spec_layer_idx_from_weight_name(config, name)
|
||||
if spec_layer is not None:
|
||||
continue # skip spec decode layers for main model
|
||||
|
||||
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
||||
# Models trained using ColossalAI may include these tensors in
|
||||
# the checkpoint. Skip them.
|
||||
continue
|
||||
for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
|
||||
if key_to_modify in name:
|
||||
name = name.replace(key_to_modify, new_key)
|
||||
use_default_weight_loading = False
|
||||
if "vision" in name:
|
||||
if self.vision_tower is not None:
|
||||
# We only do sharding for language model and
|
||||
# not vision model for now.
|
||||
use_default_weight_loading = True
|
||||
else:
|
||||
for param_name, weight_name, shard_id in stacked_params_mapping:
|
||||
if weight_name not in name:
|
||||
continue
|
||||
# We have mlp.experts[0].gate_proj in the checkpoint.
|
||||
# Since we handle the experts below in expert_params_mapping,
|
||||
# we need to skip here BEFORE we update the name, otherwise
|
||||
# name will be updated to mlp.experts[0].gate_up_proj, which
|
||||
# will then be updated below in expert_params_mapping
|
||||
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
||||
if ("mlp.experts." in name) and name not in params_dict:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(param, loaded_weight, shard_id, **kwargs)
|
||||
break
|
||||
else:
|
||||
for idx, (
|
||||
param_name,
|
||||
weight_name,
|
||||
expert_id,
|
||||
shard_id,
|
||||
) in enumerate(expert_params_mapping):
|
||||
if weight_name not in name:
|
||||
continue
|
||||
name = name.replace(weight_name, param_name)
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = param.weight_loader
|
||||
weight_loader(
|
||||
param,
|
||||
loaded_weight,
|
||||
name,
|
||||
expert_id=expert_id,
|
||||
shard_id=shard_id,
|
||||
**kwargs,
|
||||
)
|
||||
break
|
||||
else:
|
||||
use_default_weight_loading = True
|
||||
if use_default_weight_loading:
|
||||
# Skip loading extra bias for GPTQ models.
|
||||
if name.endswith(".bias") and name not in params_dict:
|
||||
continue
|
||||
# Remapping the name of FP8 kv-scale.
|
||||
name = maybe_remap_kv_scale_name(name, params_dict)
|
||||
if name is None:
|
||||
continue
|
||||
|
||||
# if is_pp_missing_parameter(name, self):
|
||||
# continue
|
||||
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight, **kwargs)
|
||||
self.language_model.post_load_weights()
|
||||
|
||||
|
||||
def get_spec_layer_idx_from_weight_name(
|
||||
config: DeepseekV2Config, weight_name: str
|
||||
) -> Optional[int]:
|
||||
if hasattr(config, "num_nextn_predict_layers") and (
|
||||
config.num_nextn_predict_layers > 0
|
||||
):
|
||||
layer_idx = config.num_hidden_layers
|
||||
for i in range(config.num_nextn_predict_layers):
|
||||
if weight_name.startswith(f"model.layers.{layer_idx+i}."):
|
||||
return layer_idx + i
|
||||
return None
|
||||
|
||||
|
||||
EntryClass = [KimiVLForConditionalGeneration]
|
||||
639
python/sglang/srt/models/kimi_vl_moonvit.py
Normal file
639
python/sglang/srt/models/kimi_vl_moonvit.py
Normal file
@@ -0,0 +1,639 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
# ruff: noqa: E501
|
||||
# Adapted from https://huggingface.co/moonshotai/Kimi-VL-A3B-Instruct/blob/main/modeling_kimi_vl.py
|
||||
# This file is meant to be used in kimi_vl.py only
|
||||
# Copyright 2025 The Moonshot AI Team, DeepSeek-AI, and HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# The code is based on llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py), but modified for KimiVL.
|
||||
#
|
||||
# Licensing Information:
|
||||
# - Code derived from llava (llava/modeling_llava.py) and DeepSeek-V3 (DeepSeek-V3/modeling_deepseek.py) is licensed under the Apache License, Version 2.0.
|
||||
# - Other parts of the code are licensed under the MIT License.
|
||||
#
|
||||
# Apache License, Version 2.0:
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
# MIT License:
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
import math
|
||||
from copy import deepcopy
|
||||
from functools import cached_property
|
||||
from typing import List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from transformers.activations import ACT2FN, PytorchGELUTanh
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func
|
||||
except ImportError:
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
from sglang.srt.configs import MoonViTConfig
|
||||
|
||||
|
||||
def multihead_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
q_cu_seqlens: Optional[torch.Tensor] = None,
|
||||
k_cu_seqlens: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""Multi-head attention using flash attention 2.
|
||||
This function is used to handle the case where the query, key, and value are packed.
|
||||
Args:
|
||||
q, k, v: tensor of shape (tot_seqlens, num_heads, head_dim).
|
||||
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
|
||||
The first element should be 0 and the last element should be q.shape[0].
|
||||
k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
|
||||
The first element should be 0 and the last element should be k.shape[0].
|
||||
|
||||
Returns:
|
||||
output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
|
||||
where dim = num_heads * head_dim
|
||||
"""
|
||||
if flash_attn_varlen_func is None:
|
||||
raise ImportError(
|
||||
"flash_attn is not installed, this function needs flash_attn_varlen_func from flash_attn"
|
||||
)
|
||||
# Unified format legal check
|
||||
assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
|
||||
assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
|
||||
assert (
|
||||
k_cu_seqlens[-1] == k.shape[0] == v.shape[0]
|
||||
), "k_cu_seqlens must sum to k.shape[0]"
|
||||
assert q.dtype in [
|
||||
torch.bfloat16,
|
||||
torch.float16,
|
||||
], f"unsupported dtype {q.dtype} for multihead attn"
|
||||
|
||||
max_seqlen_q = (q_cu_seqlens[1:] - q_cu_seqlens[:-1]).max().item()
|
||||
max_seqlen_k = (k_cu_seqlens[1:] - k_cu_seqlens[:-1]).max().item()
|
||||
attn_out = flash_attn_varlen_func(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
q_cu_seqlens,
|
||||
k_cu_seqlens,
|
||||
max_seqlen_q,
|
||||
max_seqlen_k,
|
||||
causal=False,
|
||||
)
|
||||
attn_out = attn_out.flatten(start_dim=-2)
|
||||
|
||||
return attn_out
|
||||
|
||||
|
||||
def sdpa_attention(
|
||||
q: torch.Tensor,
|
||||
k: torch.Tensor,
|
||||
v: torch.Tensor,
|
||||
q_cu_seqlens: Optional[torch.Tensor] = None,
|
||||
k_cu_seqlens: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Multi-head attention using torch scaled dot product attention.
|
||||
This function is used to handle the case where the query, key, and value are packed.
|
||||
Args:
|
||||
q, k, v: tensor of shape (tot_seqlens, num_heads, head_dim).
|
||||
q_cu_seqlens (torch.Tensor): cumulative sequence lengths of q.
|
||||
The first element should be 0 and the last element should be q.shape[0].
|
||||
k_cu_seqlens (torch.Tensor): cumulative sequence lengths of k.
|
||||
The first element should be 0 and the last element should be k.shape[0].
|
||||
|
||||
Returns:
|
||||
output: shape (batch_size, seqlen, dim) or (tot_seqlens, dim) if packing,
|
||||
where dim = num_heads * head_dim
|
||||
"""
|
||||
# Unified format legal check
|
||||
assert q.dim() == k.dim() == v.dim() == 3, "q, k, v must have 3 dims"
|
||||
assert q_cu_seqlens[-1] == q.shape[0], "q_cu_seqlens must sum to q.shape[0]"
|
||||
seq_length = q.shape[0]
|
||||
attention_mask = torch.zeros(
|
||||
[1, seq_length, seq_length], device=q.device, dtype=torch.bool
|
||||
)
|
||||
for i in range(1, len(q_cu_seqlens)):
|
||||
attention_mask[
|
||||
...,
|
||||
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
|
||||
q_cu_seqlens[i - 1] : q_cu_seqlens[i],
|
||||
] = True
|
||||
q = q.transpose(0, 1)
|
||||
k = k.transpose(0, 1)
|
||||
v = v.transpose(0, 1)
|
||||
attn_output = F.scaled_dot_product_attention(q, k, v, attention_mask, dropout_p=0.0)
|
||||
attn_output = attn_output.transpose(0, 1)
|
||||
attn_output = attn_output.reshape(seq_length, -1)
|
||||
return attn_output
|
||||
|
||||
|
||||
VL_VISION_ATTENTION_FUNCTIONS = {
|
||||
"flash_attention_2": multihead_attention,
|
||||
"sdpa": sdpa_attention,
|
||||
}
|
||||
|
||||
|
||||
def _apply_rope_input_validation(x, freqs_cis):
|
||||
assert x.ndim == freqs_cis.ndim + 1, (x.shape, freqs_cis.shape)
|
||||
assert x.shape[:-2] == freqs_cis.shape[:-1], (x.shape, freqs_cis.shape)
|
||||
assert x.shape[-1] == 2 * freqs_cis.shape[-1], (x.shape, freqs_cis.shape)
|
||||
assert freqs_cis.dtype == torch.complex64, freqs_cis.dtype
|
||||
|
||||
|
||||
def apply_rope(
|
||||
xq: torch.Tensor, xk: torch.Tensor, freqs_cis: torch.Tensor
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Args: (The leading dimensions of all inputs should be the same)
|
||||
xq: query, tensor of shape (..., num_heads, head_dim)
|
||||
xk: key, tensor of shape (..., num_heads, head_dim)
|
||||
freqs_cis: tensor of shape (..., head_dim/2), dtype=torch.complex64. It contains the precomputed cis(freqs) for each position in the 2D grid.
|
||||
Returns:
|
||||
xq_out, xk_out: tensors of shape (..., num_heads, head_dim)
|
||||
"""
|
||||
_apply_rope_input_validation(xq, freqs_cis)
|
||||
_apply_rope_input_validation(xk, freqs_cis)
|
||||
|
||||
freqs_cis = freqs_cis.unsqueeze(-2) # ..., 1, head_dim/2
|
||||
# ..., num_heads, head_dim/2
|
||||
xq_ = torch.view_as_complex(xq.float().view(*xq.shape[:-1], -1, 2))
|
||||
xk_ = torch.view_as_complex(xk.float().view(*xq.shape[:-1], -1, 2))
|
||||
xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
|
||||
xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(-2) # ..., num_heads, head_dim
|
||||
return xq_out.type_as(xq), xk_out.type_as(xk)
|
||||
|
||||
|
||||
class Learnable2DInterpPosEmb(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self, height: int, width: int, dim: int, interpolation_mode: str = "bicubic"
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.height = height
|
||||
self.width = width
|
||||
self.interpolation_mode = interpolation_mode
|
||||
self.weight = nn.Parameter(torch.empty(height, width, dim))
|
||||
self.reset_parameters()
|
||||
|
||||
def reset_parameters(self):
|
||||
nn.init.normal_(self.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor, grid_hws: torch.Tensor) -> torch.Tensor:
|
||||
pos_embs = []
|
||||
for shape in grid_hws.tolist():
|
||||
if shape == self.weight.shape[:-1]:
|
||||
pos_embs.append(self.weight.flatten(end_dim=1))
|
||||
else:
|
||||
pos_embs.append(
|
||||
F.interpolate(
|
||||
self.weight.permute((2, 0, 1)).unsqueeze(0),
|
||||
size=shape,
|
||||
mode=self.interpolation_mode,
|
||||
)
|
||||
.squeeze(0)
|
||||
.permute((1, 2, 0))
|
||||
.flatten(end_dim=1)
|
||||
)
|
||||
out = x + torch.cat(pos_embs)
|
||||
return out
|
||||
|
||||
|
||||
class MoonVisionPatchEmbed(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
out_dim: int,
|
||||
in_dim: int = 3,
|
||||
patch_size: Union[int, Tuple[int, int]] = (14, 14),
|
||||
pos_emb_height: int = 14,
|
||||
pos_emb_width: int = 14,
|
||||
):
|
||||
super().__init__()
|
||||
assert isinstance(
|
||||
patch_size, (int, Sequence)
|
||||
), f"Invalid patch_size type: {type(patch_size)}"
|
||||
if isinstance(patch_size, int):
|
||||
patch_size = (patch_size, patch_size)
|
||||
assert (
|
||||
len(patch_size) == 2
|
||||
), f"Expected patch_size to be a tuple of 2, got {patch_size}"
|
||||
self.patch_size = patch_size
|
||||
|
||||
self.proj = nn.Conv2d(
|
||||
in_dim, out_dim, kernel_size=patch_size, stride=patch_size
|
||||
)
|
||||
|
||||
self.pos_emb = Learnable2DInterpPosEmb(
|
||||
height=pos_emb_height, width=pos_emb_width, dim=out_dim
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, grid_hw: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
x (L, Channels): input tensor
|
||||
grid_hw (N, 2): grid height and width
|
||||
|
||||
Returns:
|
||||
(L, Cout) tensor
|
||||
"""
|
||||
x = self.proj(x).view(x.size(0), -1)
|
||||
# apply positional embedding
|
||||
x = self.pos_emb(x, grid_hw)
|
||||
return x
|
||||
|
||||
|
||||
class Rope2DPosEmb(nn.Module):
|
||||
"""2D rotary position embedding with multi-resolution support.
|
||||
|
||||
This class is intended to be used in the following way:
|
||||
1. Before training, create an instance of Rope2DPosEmb. This instance will hold the precomputed cis.
|
||||
2. Before each forward pass, call `get_freqs_cis_by_*` to get the `freqs_cis` tensor for this iteration.
|
||||
3. During the forward pass, pass the `freqs_cis` tensor to each attention layer, and call `apply` just before each attention operation.
|
||||
The rope is shared across all attention layers and all heads.
|
||||
|
||||
Refs:
|
||||
- RoFormer: https://arxiv.org/abs/2104.09864
|
||||
- VisionLLaMA: https://arxiv.org/abs/2403.00522
|
||||
- https://github.com/Meituan-AutoML/VisionLLaMA/blob/main/dit/models.py
|
||||
|
||||
Args:
|
||||
dim (int): usually the multi-head attention dimension, should be divisible by 4 (TODO: relax this constraint if needed)
|
||||
max_height (int): the maximum height of the 2D grid
|
||||
max_width (int): the maximum width of the 2D grid
|
||||
theta_base (float): the base of the theta
|
||||
device (str): the device to store the precomputed cis
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, dim: int, max_height: int, max_width: int, theta_base=10000, device="cuda"
|
||||
):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
assert self.dim % 4 == 0, "dim must be divisible by 4"
|
||||
self.max_height = max_height
|
||||
self.max_width = max_width
|
||||
self.theta_base = theta_base
|
||||
self.device = device
|
||||
|
||||
def extra_repr(self):
|
||||
return f"dim={self.dim}, max_height={self.max_height}, max_width={self.max_width}, theta_base={self.theta_base}"
|
||||
|
||||
@cached_property
|
||||
def precomputed_freqs_cis(self) -> torch.Tensor:
|
||||
"""Calculate the cis(freqs) for each position in the 2D grid.
|
||||
|
||||
Return: complex tensor of shape (max_height, max_width, dim//2) and value:
|
||||
height axis: ret[h, w, 2*i] = cis(h * theta_base**(-4*i/dim))
|
||||
weight axis: ret[h, w, 2*i+1] = cis(w * theta_base**(-4*i/dim)) with (i in [0, dim//4))
|
||||
note: `cis` is a mathematical notation defined by cis x = cos x + i sin x,
|
||||
"""
|
||||
N = self.max_height * self.max_width
|
||||
flat_pos = torch.arange(0, N).float().to(self.device)
|
||||
x_pos = flat_pos % self.max_width
|
||||
y_pos = flat_pos // self.max_width
|
||||
dim_range = (
|
||||
torch.arange(0, self.dim, 4)[: (self.dim // 4)].float().to(self.device)
|
||||
) # C/4
|
||||
freqs = 1.0 / (self.theta_base ** (dim_range / self.dim))
|
||||
x_freqs = torch.outer(x_pos, freqs).float() # N, C/4
|
||||
y_freqs = torch.outer(y_pos, freqs).float() # N, C/4
|
||||
x_cis = torch.polar(torch.ones_like(x_freqs), x_freqs) # N, C/4
|
||||
y_cis = torch.polar(torch.ones_like(y_freqs), y_freqs) # N, C/4
|
||||
# N, C/4, 2
|
||||
freqs_cis = torch.cat(
|
||||
[x_cis.unsqueeze(dim=-1), y_cis.unsqueeze(dim=-1)], dim=-1
|
||||
)
|
||||
# max_height, max_width, C/2
|
||||
freqs_cis = freqs_cis.reshape(self.max_height, self.max_width, -1)
|
||||
return freqs_cis
|
||||
|
||||
def get_freqs_cis_by_seqlens(self, grid_hws: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
grid_hws (torch.Tensor): containing list of (height, width) or (t, height, width) tuples.
|
||||
Returns:
|
||||
freqs_cis: tensor of shape (sum(t * height * width), dim//2)
|
||||
"""
|
||||
shapes = grid_hws.tolist()
|
||||
assert all(
|
||||
1 <= h <= self.max_height and 1 <= w <= self.max_width for h, w in shapes
|
||||
), (
|
||||
shapes,
|
||||
self.max_height,
|
||||
self.max_width,
|
||||
)
|
||||
freqs_cis = torch.cat(
|
||||
[
|
||||
self.precomputed_freqs_cis[:h, :w].reshape(-1, self.dim // 2)
|
||||
for h, w in shapes
|
||||
],
|
||||
dim=0,
|
||||
)
|
||||
return freqs_cis
|
||||
|
||||
def get_freqs_cis_by_idx(
|
||||
self, pos_idx: torch.Tensor, pos_idx_mask: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
pos_idx: tensor of shape (..., 2), It contains the (h, w) position indices of each 2D token.
|
||||
pos_idx_mask: a mask of shape (...), the leading dimensions should be the same as pos_idx.
|
||||
Rope will only be applied to the tokens with True mask. `freqs_cis` for the tokens with False mask with be ones.
|
||||
Return:
|
||||
freqs_cis: tensor of shape (..., dim//2)
|
||||
"""
|
||||
assert (
|
||||
pos_idx.shape[:-1] == pos_idx_mask.shape
|
||||
and pos_idx.shape[-1] == 2
|
||||
and pos_idx.ndim == pos_idx_mask.ndim + 1
|
||||
), (pos_idx.shape, pos_idx_mask.shape)
|
||||
assert pos_idx_mask.dtype == torch.bool, pos_idx_mask.dtype
|
||||
|
||||
shp = pos_idx_mask.shape + (self.dim // 2,) # ..., head_dim/2
|
||||
freqs_cis = torch.ones(
|
||||
shp, dtype=torch.complex64, device=self.device
|
||||
) # ..., head_dim/2
|
||||
freqs_cis[pos_idx_mask] = self.precomputed_freqs_cis[
|
||||
pos_idx[..., 0][pos_idx_mask], pos_idx[..., 1][pos_idx_mask]
|
||||
]
|
||||
return freqs_cis
|
||||
|
||||
|
||||
class MLP2(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
dims: [in_dim, hidden_dim, out_dim]
|
||||
bias: whether to use bias in linear layer.
|
||||
"""
|
||||
|
||||
def __init__(self, dims: list[int], activation, bias=True):
|
||||
super().__init__()
|
||||
assert len(dims) == 3
|
||||
self.fc0 = nn.Linear(dims[0], dims[1], bias=bias)
|
||||
self.fc1 = nn.Linear(dims[1], dims[2], bias=bias)
|
||||
self.activation = activation
|
||||
for m in [self.fc0, self.fc1]:
|
||||
nn.init.trunc_normal_(m.weight, std=math.sqrt(2 / m.in_features))
|
||||
if m.bias is not None:
|
||||
nn.init.zeros_(m.bias)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.fc0(x)
|
||||
x = self.activation(x)
|
||||
return self.fc1(x)
|
||||
|
||||
|
||||
class MoonVitEncoderLayer(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
hidden_dim: int,
|
||||
mlp_dim: int,
|
||||
*,
|
||||
attn_implementation: str = "flash_attention_2", # use fa2 in sglang by default
|
||||
activation=F.gelu,
|
||||
attn_bias: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
self.hidden_dim = hidden_dim
|
||||
self.hidden_size_per_attention_head = self.hidden_dim // self.num_heads
|
||||
self.attn_implementation = attn_implementation
|
||||
|
||||
self.norm0 = nn.LayerNorm(hidden_dim)
|
||||
self.norm1 = nn.LayerNorm(hidden_dim)
|
||||
self.mlp = MLP2([hidden_dim, mlp_dim, hidden_dim], activation)
|
||||
self.wqkv = nn.Linear(hidden_dim, hidden_dim * 3, bias=attn_bias)
|
||||
self.wo = nn.Linear(hidden_dim, hidden_dim, bias=attn_bias)
|
||||
|
||||
def attention_qkvpacked(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rope_freqs_cis: Optional[torch.Tensor] = None,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
x (torch.Tensor): (batch_size, seqlen, hidden_dim)
|
||||
cu_seqlens (torch.Tensor):
|
||||
"""
|
||||
xqkv = self.wqkv(x)
|
||||
|
||||
qkv_shape = xqkv.size()[:-1] + (
|
||||
3,
|
||||
self.num_heads,
|
||||
self.hidden_size_per_attention_head,
|
||||
)
|
||||
# xqkv: (batch_size, seqlen, 3, nheads, headdim)
|
||||
xqkv = xqkv.view(*qkv_shape)
|
||||
xq, xk, xv = torch.unbind(xqkv, dim=-3)
|
||||
|
||||
xq, xk = apply_rope(xq, xk, rope_freqs_cis)
|
||||
|
||||
attn_func = VL_VISION_ATTENTION_FUNCTIONS[self.attn_implementation]
|
||||
attn_out = attn_func(
|
||||
xq, xk, xv, q_cu_seqlens=cu_seqlens, k_cu_seqlens=cu_seqlens
|
||||
)
|
||||
|
||||
attn_out = self.wo(attn_out)
|
||||
return attn_out
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
cu_seqlens: torch.Tensor,
|
||||
rope_freqs_cis: Union[torch.Tensor, None] = None,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
hidden_states: non-packed (B, N, D) or packed (L, D). if non-packed, seqlens should be None, if packed, seqlens should be set
|
||||
|
||||
Returns:
|
||||
output: same shape of input, non-packed (B, N, D) for non-packed input, (L, D) for packed input
|
||||
"""
|
||||
residual = hidden_states
|
||||
hidden_states = self.norm0(hidden_states)
|
||||
attn_out = self.attention_qkvpacked(
|
||||
hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
|
||||
)
|
||||
hidden_states = residual + attn_out
|
||||
|
||||
residual = hidden_states
|
||||
hidden_states = self.mlp(self.norm1(hidden_states))
|
||||
hidden_states = residual + hidden_states
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MoonVitEncoder(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_dim: int,
|
||||
num_layers: int,
|
||||
block_cfg: dict,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
|
||||
self.rope_2d = Rope2DPosEmb(
|
||||
block_cfg["hidden_dim"] // block_cfg["num_heads"], 512, 512
|
||||
)
|
||||
self.blocks = nn.ModuleList(
|
||||
[MoonVitEncoderLayer(**block_cfg) for _ in range(num_layers)]
|
||||
)
|
||||
self.final_layernorm = nn.LayerNorm(hidden_dim)
|
||||
|
||||
def forward(
|
||||
self, hidden_states: torch.Tensor, grid_hw: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
rope_freqs_cis = self.rope_2d.get_freqs_cis_by_seqlens(grid_hws=grid_hw)
|
||||
|
||||
lengths = torch.cat(
|
||||
(
|
||||
torch.zeros(1, device=hidden_states.device, dtype=grid_hw.dtype),
|
||||
grid_hw[:, 0] * grid_hw[:, 1],
|
||||
)
|
||||
)
|
||||
cu_seqlens = lengths.cumsum(dim=0, dtype=torch.int32)
|
||||
|
||||
for _, block in enumerate(self.blocks):
|
||||
hidden_states = block(
|
||||
hidden_states, cu_seqlens, rope_freqs_cis=rope_freqs_cis
|
||||
)
|
||||
|
||||
hidden_states = self.final_layernorm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
def patch_merger(
|
||||
x: torch.Tensor,
|
||||
grid_hw: torch.Tensor,
|
||||
merge_kernel_size: list[int, int] = (2, 2),
|
||||
) -> List[torch.Tensor]:
|
||||
d_model = x.size(-1)
|
||||
|
||||
outputs = []
|
||||
pre_sum = 0
|
||||
for x_shape in grid_hw.tolist():
|
||||
height, width = x_shape[0], x_shape[1]
|
||||
# Get the current sequence
|
||||
seq = x[pre_sum : pre_sum + height * width]
|
||||
# Reshape along self.merge_kernel_size and concat to the last dimension
|
||||
kernel_height, kernel_width = merge_kernel_size
|
||||
new_height, new_width = height // kernel_height, width // kernel_width
|
||||
reshaped_seq = seq.view(
|
||||
new_height, kernel_height, new_width, kernel_width, d_model
|
||||
)
|
||||
reshaped_seq = reshaped_seq.permute(0, 2, 1, 3, 4).contiguous()
|
||||
padded_seq = reshaped_seq.view(
|
||||
new_height * new_width, kernel_height * kernel_width, -1
|
||||
)
|
||||
outputs.append(padded_seq)
|
||||
pre_sum += height * width
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class MoonVitVLProjector(nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
merge_kernel_size: list[int, int],
|
||||
hidden_act: str = "gelu",
|
||||
ln_eps: float = 1e-5,
|
||||
out_dim: int = 4096,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_size = in_channels * merge_kernel_size[0] * merge_kernel_size[1]
|
||||
|
||||
self.pre_norm = nn.nn.LayerNorm(in_channels, eps=ln_eps)
|
||||
self.linear_1 = nn.Linear(self.hidden_size, self.hidden_size, bias=True)
|
||||
self.act = ACT2FN[hidden_act]
|
||||
self.linear_2 = nn.Linear(self.hidden_size, out_dim, bias=True)
|
||||
|
||||
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
||||
hidden_states = self.pre_norm(hidden_states).view(-1, self.hidden_size)
|
||||
hidden_states = self.linear_1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
|
||||
class MoonVitPretrainedModel(PreTrainedModel):
|
||||
config_class = MoonViTConfig
|
||||
model_type = "moonvit"
|
||||
_no_split_modules = ["PackingTransformer"]
|
||||
_supports_flash_attn_2 = True
|
||||
_supports_sdpa = True
|
||||
|
||||
def __init__(self, config: MoonViTConfig, *inputs, **kwargs):
|
||||
super().__init__(config, *inputs, **kwargs)
|
||||
config = deepcopy(config)
|
||||
self.merge_kernel_size = config.merge_kernel_size
|
||||
self.patch_size = config.patch_size
|
||||
self.patch_embed = MoonVisionPatchEmbed(
|
||||
out_dim=config.hidden_size,
|
||||
patch_size=config.patch_size,
|
||||
pos_emb_height=config.init_pos_emb_height,
|
||||
pos_emb_width=config.init_pos_emb_width,
|
||||
)
|
||||
|
||||
self.encoder = MoonVitEncoder(
|
||||
hidden_dim=config.hidden_size,
|
||||
num_layers=config.num_hidden_layers,
|
||||
block_cfg={
|
||||
"num_heads": config.num_attention_heads,
|
||||
"hidden_dim": config.hidden_size,
|
||||
"mlp_dim": config.intermediate_size,
|
||||
"activation": PytorchGELUTanh(),
|
||||
"attn_bias": True,
|
||||
"attn_implementation": config._attn_implementation,
|
||||
},
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, pixel_values: torch.Tensor, grid_hw: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
pixel_values (torch.Tensor): The input pixel values.
|
||||
grid_hw (torch.Tensor): The grid height and width.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The output tokens.
|
||||
"""
|
||||
hidden_states = self.patch_embed(pixel_values, grid_hw)
|
||||
hidden_states = self.encoder(hidden_states, grid_hw)
|
||||
hidden_states = patch_merger(
|
||||
hidden_states, grid_hw, merge_kernel_size=self.merge_kernel_size
|
||||
)
|
||||
return hidden_states
|
||||
@@ -81,10 +81,20 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
text = response.choices[0].message.content
|
||||
assert isinstance(text, str)
|
||||
# `driver` is for gemma-3-it
|
||||
assert "man" in text or "person" or "driver" in text, text
|
||||
assert "cab" in text or "taxi" in text or "SUV" in text, text
|
||||
assert (
|
||||
"man" in text or "person" or "driver" in text
|
||||
), f"text: {text}, should contain man, person or driver"
|
||||
assert (
|
||||
"cab" in text
|
||||
or "taxi" in text
|
||||
or "SUV" in text
|
||||
or "vehicle" in text
|
||||
or "car" in text
|
||||
), f"text: {text}, should contain cab, taxi, SUV, vehicle or car"
|
||||
# MiniCPMO fails to recognize `iron`, but `hanging`
|
||||
assert "iron" in text or "hang" in text, text
|
||||
assert (
|
||||
"iron" in text or "hang" in text or "cloth" in text or "holding" in text
|
||||
), f"text: {text}, should contain iron, hang, cloth or holding"
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.usage.prompt_tokens > 0
|
||||
@@ -132,7 +142,9 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
assert response.choices[0].message.role == "assistant"
|
||||
text = response.choices[0].message.content
|
||||
assert isinstance(text, str)
|
||||
assert "man" in text or "cab" in text, text
|
||||
assert (
|
||||
"man" in text or "cab" in text
|
||||
), f"text: {text}, should contain man or cab"
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.usage.prompt_tokens > 0
|
||||
@@ -175,8 +187,12 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
print("-" * 30)
|
||||
print(f"Multi images response:\n{text}")
|
||||
print("-" * 30)
|
||||
assert "man" in text or "cab" in text or "SUV" in text or "taxi" in text, text
|
||||
assert "logo" in text or '"S"' in text or "SG" in text, text
|
||||
assert (
|
||||
"man" in text or "cab" in text or "SUV" in text or "taxi" in text
|
||||
), f"text: {text}, should contain man, cab, SUV or taxi"
|
||||
assert (
|
||||
"logo" in text or '"S"' in text or "SG" in text
|
||||
), f"text: {text}, should contain logo, S or SG"
|
||||
assert response.id
|
||||
assert response.created
|
||||
assert response.usage.prompt_tokens > 0
|
||||
@@ -305,9 +321,9 @@ class TestOpenAIVisionServer(CustomTestCase):
|
||||
client = openai.Client(api_key=self.api_key, base_url=self.base_url)
|
||||
|
||||
regex = (
|
||||
r"""\{\n"""
|
||||
+ r""" "color": "[\w]+",\n"""
|
||||
+ r""" "number_of_cars": [\d]+\n"""
|
||||
r"""\{"""
|
||||
+ r""""color":"[\w]+","""
|
||||
+ r""""number_of_cars":[\d]+"""
|
||||
+ r"""\}"""
|
||||
)
|
||||
|
||||
@@ -732,5 +748,33 @@ class TestGemma3itServer(TestOpenAIVisionServer):
|
||||
pass
|
||||
|
||||
|
||||
class TestKimiVLServer(TestOpenAIVisionServer):
|
||||
@classmethod
|
||||
def setUpClass(cls):
|
||||
cls.model = "moonshotai/Kimi-VL-A3B-Instruct"
|
||||
cls.base_url = DEFAULT_URL_FOR_TEST
|
||||
cls.api_key = "sk-123456"
|
||||
cls.process = popen_launch_server(
|
||||
cls.model,
|
||||
cls.base_url,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
other_args=[
|
||||
"--trust-remote-code",
|
||||
"--chat-template",
|
||||
"kimi-vl",
|
||||
"--context-length",
|
||||
"4096",
|
||||
"--tensor-parallel-size",
|
||||
"2",
|
||||
"--dtype",
|
||||
"bfloat16",
|
||||
],
|
||||
)
|
||||
cls.base_url += "/v1"
|
||||
|
||||
def test_video_chat_completion(self):
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
Reference in New Issue
Block a user