feat: (chat-template matching) enhance multimodal model detection with config.json (#9597)

This commit is contained in:
Kevin Tuan
2025-08-27 08:55:40 +08:00
committed by GitHub
parent c04c17edfa
commit b21fdd5373

View File

@@ -26,6 +26,8 @@ Key components:
# Adapted from # Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses import dataclasses
import json
import os
import re import re
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Callable, Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
@@ -959,16 +961,42 @@ register_conv_template(
) )
MODEL_TYPE_TO_TEMPLATE = {
"internvl_chat": "internvl-2-5",
"deepseek_vl_v2": "deepseek-vl2",
"multi_modality": "janus-pro",
"phi4mm": "phi-4-mm",
"minicpmv": "minicpmv",
"minicpmo": "minicpmo",
}
def get_model_type(model_path: str) -> Optional[str]:
config_path = os.path.join(model_path, "config.json")
if not os.path.exists(config_path):
return None
try:
with open(config_path, "r", encoding="utf-8") as f:
config = json.load(f)
return config.get("model_type")
except (IOError, json.JSONDecodeError):
return None
@register_conv_template_matching_function @register_conv_template_matching_function
def match_internvl(model_path: str): def match_internvl(model_path: str):
if re.search(r"internvl", model_path, re.IGNORECASE): if re.search(r"internvl", model_path, re.IGNORECASE):
return "internvl-2-5" return "internvl-2-5"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
@register_conv_template_matching_function @register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str): def match_deepseek_janus_pro(model_path: str):
if re.search(r"janus", model_path, re.IGNORECASE): if re.search(r"janus", model_path, re.IGNORECASE):
return "janus-pro" return "janus-pro"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
@register_conv_template_matching_function @register_conv_template_matching_function
@@ -981,6 +1009,8 @@ def match_vicuna(model_path: str):
def match_deepseek_vl(model_path: str): def match_deepseek_vl(model_path: str):
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE): if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
return "deepseek-vl2" return "deepseek-vl2"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
@register_conv_template_matching_function @register_conv_template_matching_function
@@ -994,14 +1024,17 @@ def match_qwen_chat_ml(model_path: str):
@register_conv_template_matching_function @register_conv_template_matching_function
def match_openbmb_minicpm(model_path: str): def match_minicpm(model_path: str):
if re.search(r"minicpm-v", model_path, re.IGNORECASE): match = re.search(r"minicpm-(v|o)", model_path, re.IGNORECASE)
return "minicpmv" if match:
elif re.search(r"minicpm-o", model_path, re.IGNORECASE): return f"minicpm{match.group(1).lower()}"
return "minicpmo" model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)
@register_conv_template_matching_function @register_conv_template_matching_function
def match_phi_4_mm(model_path: str): def match_phi_4_mm(model_path: str):
if "phi-4-multimodal" in model_path.lower(): if "phi-4-multimodal" in model_path.lower():
return "phi-4-mm" return "phi-4-mm"
model_type = get_model_type(model_path)
return MODEL_TYPE_TO_TEMPLATE.get(model_type)