Yi-VL Model (#112)
This commit is contained in:
68
examples/quick_start/srt_example_yi_vl.py
Normal file
68
examples/quick_start/srt_example_yi_vl.py
Normal file
@@ -0,0 +1,68 @@
|
||||
"""
|
||||
Usage: python3 srt_example_yi_vl.py
|
||||
"""
|
||||
import sglang as sgl
|
||||
|
||||
|
||||
@sgl.function
|
||||
def image_qa(s, image_path, question):
|
||||
s += sgl.user(sgl.image(image_path) + question)
|
||||
s += sgl.assistant(sgl.gen("answer"))
|
||||
|
||||
|
||||
def single():
|
||||
state = image_qa.run(
|
||||
image_path="images/cat.jpeg",
|
||||
question="What is this?",
|
||||
max_new_tokens=64,
|
||||
stop="###")
|
||||
print(state["answer"], "\n")
|
||||
|
||||
|
||||
def stream():
|
||||
state = image_qa.run(
|
||||
image_path="images/cat.jpeg",
|
||||
question="What is this?",
|
||||
max_new_tokens=64,
|
||||
stream=True,
|
||||
stop="###")
|
||||
|
||||
for out in state.text_iter("answer"):
|
||||
print(out, end="", flush=True)
|
||||
print()
|
||||
|
||||
|
||||
def batch():
|
||||
states = image_qa.run_batch(
|
||||
[
|
||||
{"image_path": "images/cat.jpeg", "question":"What is this?"},
|
||||
{"image_path": "images/dog.jpeg", "question":"What is this?"},
|
||||
],
|
||||
max_new_tokens=64,
|
||||
stop="###"
|
||||
)
|
||||
for s in states:
|
||||
print(s["answer"], "\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
runtime = sgl.Runtime(model_path="BabyChou/Yi-VL-6B",
|
||||
tokenizer_path="BabyChou/Yi-VL-6B")
|
||||
sgl.set_default_backend(runtime)
|
||||
# Or you can use API models
|
||||
# sgl.set_default_backend(sgl.OpenAI("gpt-4-vision-preview"))
|
||||
# sgl.set_default_backend(sgl.VertexAI("gemini-pro-vision"))
|
||||
|
||||
# Run a single request
|
||||
print("\n========== single ==========\n")
|
||||
single()
|
||||
|
||||
# Stream output
|
||||
print("\n========== stream ==========\n")
|
||||
stream()
|
||||
|
||||
# Run a batch of requests
|
||||
print("\n========== batch ==========\n")
|
||||
batch()
|
||||
|
||||
runtime.shutdown()
|
||||
@@ -146,6 +146,23 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="yi",
|
||||
default_system_prompt=(
|
||||
"This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
|
||||
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
|
||||
),
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", "\n\n"),
|
||||
"user": ("### Human:", "\n"),
|
||||
"assistant": ("### Assistant:", "\n"),
|
||||
},
|
||||
image_token=" <image_placeholder>\n",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_vicuna(model_path: str):
|
||||
@@ -176,6 +193,12 @@ def match_chat_ml(model_path: str):
|
||||
if "qwen" in model_path and "chat" in model_path:
|
||||
return get_chat_template("chatml")
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_yi(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "yi" in model_path:
|
||||
return get_chat_template("yi")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
messages = [
|
||||
|
||||
101
python/sglang/srt/models/yivl.py
Normal file
101
python/sglang/srt/models/yivl.py
Normal file
@@ -0,0 +1,101 @@
|
||||
"""Inference-only Yi-VL model."""
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import CLIPVisionModel, LlavaConfig
|
||||
from vllm.model_executor.weight_utils import (
|
||||
default_weight_loader,
|
||||
hf_model_weights_iterator,
|
||||
)
|
||||
|
||||
from sglang.srt.models.llava import LlavaLlamaForCausalLM, clip_vision_embed_forward, monkey_path_clip_vision_embed_forward
|
||||
|
||||
|
||||
class YiVLForCausalLM(LlavaLlamaForCausalLM):
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.config = kwargs["config"]
|
||||
super().__init__(self.config)
|
||||
|
||||
self.multi_modal_projector = YiVLMultiModalProjector(self.config)
|
||||
self.vision_tower_subfolder = self.config.mm_vision_tower.replace("./", "") # Everything after "./"
|
||||
|
||||
def load_weights(
|
||||
self,
|
||||
model_name_or_path: str,
|
||||
cache_dir: Optional[str] = None,
|
||||
load_format: str = "auto",
|
||||
revision: Optional[str] = None,
|
||||
):
|
||||
# We have to use the subfolder of the main model directory (e.g. 01-ai/Yi-VL-6B)
|
||||
self.vision_tower = CLIPVisionModel.from_pretrained(
|
||||
model_name_or_path, torch_dtype=torch.float16, subfolder=self.vision_tower_subfolder
|
||||
).cuda()
|
||||
|
||||
self.vision_tower.eval()
|
||||
|
||||
self.vision_feature_layer = self.config.mm_vision_select_layer
|
||||
self.vision_feature_select_strategy = self.config.mm_vision_select_feature
|
||||
self.image_size = self.vision_tower.config.image_size
|
||||
self.patch_size = self.vision_tower.config.patch_size
|
||||
|
||||
self.mm_patch_merge_type = getattr(self.config, "mm_patch_merge_type", "flat")
|
||||
self.image_aspect_ratio = getattr(self.config, "image_aspect_ratio", "square")
|
||||
self.image_grid_pinpoints = getattr(self.config, "image_grid_pinpoints", None)
|
||||
|
||||
self.image_feature_len = int((self.image_size / self.patch_size) ** 2)
|
||||
if self.vision_feature_select_strategy == "patch":
|
||||
pass
|
||||
elif self.vision_feature_select_strategy == "cls_patch":
|
||||
self.image_feature_len += 1
|
||||
else:
|
||||
raise ValueError(f"Unexpected select feature: {self.select_feature}")
|
||||
|
||||
# load mm_projector
|
||||
# TODO: support TP?
|
||||
projector_weights = {
|
||||
"model.mm_projector.0": "multi_modal_projector.linear_1",
|
||||
"model.mm_projector.1": "multi_modal_projector.ln_1",
|
||||
"model.mm_projector.3": "multi_modal_projector.linear_2",
|
||||
"model.mm_projector.4": "multi_modal_projector.ln_2",
|
||||
"model.vision_tower.vision_tower": "vision_tower", # Update the vision tower weights if we find them in the checkpoint (it may be finetuned).
|
||||
}
|
||||
params_dict = dict(self.named_parameters())
|
||||
for name, loaded_weight in hf_model_weights_iterator(
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
):
|
||||
if "projector" in name or "vision_tower" in name:
|
||||
for weight_name, param_name in projector_weights.items():
|
||||
if weight_name in name:
|
||||
name = name.replace(weight_name, param_name)
|
||||
param = params_dict[name]
|
||||
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
||||
weight_loader(param, loaded_weight)
|
||||
|
||||
# load language model
|
||||
self.language_model.load_weights(
|
||||
model_name_or_path, cache_dir, load_format, revision
|
||||
)
|
||||
|
||||
monkey_path_clip_vision_embed_forward()
|
||||
|
||||
class YiVLMultiModalProjector(nn.Module):
|
||||
def __init__(self, config: LlavaConfig):
|
||||
super().__init__()
|
||||
|
||||
self.linear_1 = nn.Linear(config.vision_config.hidden_size, config.text_config.hidden_size)
|
||||
self.ln_1 = nn.LayerNorm(config.text_config.hidden_size)
|
||||
self.act = nn.GELU()
|
||||
self.linear_2 = nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size)
|
||||
self.ln_2 = nn.LayerNorm(config.text_config.hidden_size)
|
||||
|
||||
def forward(self, image_features):
|
||||
hidden_states = self.linear_1(image_features)
|
||||
hidden_state = self.ln_1(hidden_states)
|
||||
hidden_states = self.act(hidden_states)
|
||||
hidden_states = self.linear_2(hidden_states)
|
||||
hidden_states = self.ln_2(hidden_states)
|
||||
return hidden_states
|
||||
|
||||
EntryClass = YiVLForCausalLM
|
||||
@@ -233,11 +233,12 @@ def wrap_kernel_launcher(kernel):
|
||||
|
||||
def is_multimodal_model(model):
|
||||
if isinstance(model, str):
|
||||
return "llava" in model
|
||||
return "llava" or "yi-vl" in model
|
||||
from sglang.srt.model_config import ModelConfig
|
||||
|
||||
if isinstance(model, ModelConfig):
|
||||
return "llava" in model.path.lower()
|
||||
model_path = model.path.lower()
|
||||
return "llava" in model_path or "yi-vl" in model_path
|
||||
raise Exception("unrecognized type")
|
||||
|
||||
|
||||
|
||||
38
scripts/convert_yi_vl.py
Normal file
38
scripts/convert_yi_vl.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Convert Yi-VL config into a format useable with SGLang
|
||||
|
||||
Usage: python3 scripts/convert_yi_vl.py --model-path <path-to-model>
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
|
||||
from transformers import AutoConfig, AutoTokenizer
|
||||
|
||||
def add_image_token(model_path: str):
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
||||
tokenizer.add_tokens(
|
||||
["<image_placeholder>"],
|
||||
special_tokens=True
|
||||
)
|
||||
|
||||
print(tokenizer)
|
||||
tokenizer.save_pretrained(model_path)
|
||||
|
||||
def edit_model_config(model_path):
|
||||
config = AutoConfig.from_pretrained(model_path)
|
||||
|
||||
setattr(config, "architectures", ["YiVLForCausalLM"])
|
||||
setattr(config, "image_token_index", 64002)
|
||||
|
||||
print(config)
|
||||
config.save_pretrained(model_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model-path", type=str)
|
||||
args = parser.parse_args()
|
||||
|
||||
add_image_token(args.model_path)
|
||||
edit_model_config(args.model_path)
|
||||
13
scripts/convert_yi_vl.sh
Normal file
13
scripts/convert_yi_vl.sh
Normal file
@@ -0,0 +1,13 @@
|
||||
# For 34B Model
|
||||
mkdir ~/model_weights
|
||||
cd ~/model_weights
|
||||
git clone https://huggingface.co/01-ai/Yi-VL-34B
|
||||
cp ~/model_weights/Yi-VL-34B/vit/clip-vit-H-14-laion2B-s32B-b79K-yi-vl-34B-448/preprocessor_config.json ~/model_weights/Yi-VL-34B
|
||||
python3 convert_yi_vl.py --model-path ~/model_weights/Yi-VL-34B
|
||||
|
||||
# For 6B Model
|
||||
mkdir ~/model_weights
|
||||
cd ~/model_weights
|
||||
git clone https://huggingface.co/01-ai/Yi-VL-6B
|
||||
cp ~/model_weights/Yi-VL-6B/vit/clip-vit-H-14-laion2B-s32B-b79K-yi-vl-6B-448/preprocessor_config.json ~/model_weights/Yi-VL-6B
|
||||
python3 convert_yi_vl.py --model-path ~/model_weights/Yi-VL-6B
|
||||
Reference in New Issue
Block a user