Files
sglang/scripts/deprecated/convert_yi_vl.py

39 lines
937 B
Python
Raw Normal View History

2024-02-01 08:33:22 -08:00
"""
Convert Yi-VL config into a format useable with SGLang
Usage: python3 scripts/convert_yi_vl.py --model-path <path-to-model>
"""
import argparse
import json
import os
from transformers import AutoConfig, AutoTokenizer
2024-07-18 04:55:39 +10:00
2024-02-01 08:33:22 -08:00
def add_image_token(model_path: str):
tokenizer = AutoTokenizer.from_pretrained(model_path)
2024-07-18 04:55:39 +10:00
tokenizer.add_tokens(["<image_placeholder>"], special_tokens=True)
2024-02-01 08:33:22 -08:00
print(tokenizer)
tokenizer.save_pretrained(model_path)
2024-07-18 04:55:39 +10:00
2024-02-01 08:33:22 -08:00
def edit_model_config(model_path):
config = AutoConfig.from_pretrained(model_path)
setattr(config, "architectures", ["YiVLForCausalLM"])
setattr(config, "image_token_index", 64002)
print(config)
config.save_pretrained(model_path)
2024-07-18 04:55:39 +10:00
2024-02-01 08:33:22 -08:00
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--model-path", type=str)
args = parser.parse_args()
add_image_token(args.model_path)
2024-07-18 04:55:39 +10:00
edit_model_config(args.model_path)