feat: (chat-template matching) enhance multimodal model detection with config.json (#9597)
This commit is contained in:
@@ -26,6 +26,8 @@ Key components:
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
import dataclasses
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from enum import IntEnum, auto
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
@@ -959,16 +961,42 @@ register_conv_template(
|
||||
)
|
||||
|
||||
|
||||
MODEL_TYPE_TO_TEMPLATE = {
|
||||
"internvl_chat": "internvl-2-5",
|
||||
"deepseek_vl_v2": "deepseek-vl2",
|
||||
"multi_modality": "janus-pro",
|
||||
"phi4mm": "phi-4-mm",
|
||||
"minicpmv": "minicpmv",
|
||||
"minicpmo": "minicpmo",
|
||||
}
|
||||
|
||||
|
||||
def get_model_type(model_path: str) -> Optional[str]:
|
||||
config_path = os.path.join(model_path, "config.json")
|
||||
if not os.path.exists(config_path):
|
||||
return None
|
||||
try:
|
||||
with open(config_path, "r", encoding="utf-8") as f:
|
||||
config = json.load(f)
|
||||
return config.get("model_type")
|
||||
except (IOError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_internvl(model_path: str):
|
||||
if re.search(r"internvl", model_path, re.IGNORECASE):
|
||||
return "internvl-2-5"
|
||||
model_type = get_model_type(model_path)
|
||||
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_deepseek_janus_pro(model_path: str):
|
||||
if re.search(r"janus", model_path, re.IGNORECASE):
|
||||
return "janus-pro"
|
||||
model_type = get_model_type(model_path)
|
||||
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
@@ -981,6 +1009,8 @@ def match_vicuna(model_path: str):
|
||||
def match_deepseek_vl(model_path: str):
|
||||
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
|
||||
return "deepseek-vl2"
|
||||
model_type = get_model_type(model_path)
|
||||
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
@@ -994,14 +1024,17 @@ def match_qwen_chat_ml(model_path: str):
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_openbmb_minicpm(model_path: str):
|
||||
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
|
||||
return "minicpmv"
|
||||
elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
|
||||
return "minicpmo"
|
||||
def match_minicpm(model_path: str):
|
||||
match = re.search(r"minicpm-(v|o)", model_path, re.IGNORECASE)
|
||||
if match:
|
||||
return f"minicpm{match.group(1).lower()}"
|
||||
model_type = get_model_type(model_path)
|
||||
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_phi_4_mm(model_path: str):
|
||||
if "phi-4-multimodal" in model_path.lower():
|
||||
return "phi-4-mm"
|
||||
model_type = get_model_type(model_path)
|
||||
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
|
||||
|
||||
Reference in New Issue
Block a user