From b21fdd537350c8e652f263fca9439c7fa323ac7e Mon Sep 17 00:00:00 2001 From: Kevin Tuan <46362395+KEVINTUAN12@users.noreply.github.com> Date: Wed, 27 Aug 2025 08:55:40 +0800 Subject: [PATCH] feat: (chat-template matching) enhance multimodal model detection with config.json (#9597) --- python/sglang/srt/conversation.py | 43 +++++++++++++++++++++++++++---- 1 file changed, 38 insertions(+), 5 deletions(-) diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index dde9632b8..8a2fe4e7f 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -26,6 +26,8 @@ Key components: # Adapted from # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py import dataclasses +import json +import os import re from enum import IntEnum, auto from typing import Callable, Dict, List, Optional, Tuple, Union @@ -959,16 +961,42 @@ register_conv_template( ) +MODEL_TYPE_TO_TEMPLATE = { + "internvl_chat": "internvl-2-5", + "deepseek_vl_v2": "deepseek-vl2", + "multi_modality": "janus-pro", + "phi4mm": "phi-4-mm", + "minicpmv": "minicpmv", + "minicpmo": "minicpmo", +} + + +def get_model_type(model_path: str) -> Optional[str]: + config_path = os.path.join(model_path, "config.json") + if not os.path.exists(config_path): + return None + try: + with open(config_path, "r", encoding="utf-8") as f: + config = json.load(f) + return config.get("model_type") + except (IOError, json.JSONDecodeError): + return None + + @register_conv_template_matching_function def match_internvl(model_path: str): if re.search(r"internvl", model_path, re.IGNORECASE): return "internvl-2-5" + model_type = get_model_type(model_path) + return MODEL_TYPE_TO_TEMPLATE.get(model_type) @register_conv_template_matching_function def match_deepseek_janus_pro(model_path: str): if re.search(r"janus", model_path, re.IGNORECASE): return "janus-pro" + model_type = get_model_type(model_path) + return MODEL_TYPE_TO_TEMPLATE.get(model_type) @register_conv_template_matching_function @@ -981,6 +1009,8 @@ def match_vicuna(model_path: str): def match_deepseek_vl(model_path: str): if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE): return "deepseek-vl2" + model_type = get_model_type(model_path) + return MODEL_TYPE_TO_TEMPLATE.get(model_type) @register_conv_template_matching_function @@ -994,14 +1024,17 @@ def match_qwen_chat_ml(model_path: str): @register_conv_template_matching_function -def match_openbmb_minicpm(model_path: str): - if re.search(r"minicpm-v", model_path, re.IGNORECASE): - return "minicpmv" - elif re.search(r"minicpm-o", model_path, re.IGNORECASE): - return "minicpmo" +def match_minicpm(model_path: str): + match = re.search(r"minicpm-(v|o)", model_path, re.IGNORECASE) + if match: + return f"minicpm{match.group(1).lower()}" + model_type = get_model_type(model_path) + return MODEL_TYPE_TO_TEMPLATE.get(model_type) @register_conv_template_matching_function def match_phi_4_mm(model_path: str): if "phi-4-multimodal" in model_path.lower(): return "phi-4-mm" + model_type = get_model_type(model_path) + return MODEL_TYPE_TO_TEMPLATE.get(model_type)