Fix and Clean up chat-template requirement for VLM (#6114)

Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
XinyuanTong
2025-05-10 09:14:09 -07:00
committed by GitHub
parent c178abdabc
commit 9d8ec2e67e
16 changed files with 104 additions and 195 deletions

View File

@@ -1,3 +1,4 @@
import re
from dataclasses import dataclass
from enum import Enum, auto
from typing import Callable, Dict, List, Tuple
@@ -71,9 +72,9 @@ def get_chat_template(name):
def get_chat_template_by_model_path(model_path):
for matching_func in matching_function_registry:
template = matching_func(model_path)
if template is not None:
return template
template_name = matching_func(model_path)
if template_name is not None:
return get_chat_template(template_name)
return get_chat_template("default")
@@ -479,134 +480,112 @@ register_chat_template(
@register_chat_template_matching_function
def match_deepseek(model_path: str):
if (
"deepseek-v3" in model_path.lower() or "deepseek-r1" in model_path.lower()
) and "base" not in model_path.lower():
return get_chat_template("deepseek-v3")
if re.search(r"deepseek-(v3|r1)", model_path, re.IGNORECASE) and not re.search(
r"base", model_path, re.IGNORECASE
):
return "deepseek-v3"
@register_chat_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if "janus" in model_path.lower():
return get_chat_template("janus-pro")
if re.search(r"janus", model_path, re.IGNORECASE):
return "janus-pro"
@register_chat_template_matching_function
def match_dbrx(model_path: str):
if "dbrx" in model_path.lower() and "instruct" in model_path.lower():
return get_chat_template("dbrx-instruct")
if re.search(r"dbrx", model_path, re.IGNORECASE) and re.search(
r"instruct", model_path, re.IGNORECASE
):
return "dbrx-instruct"
@register_chat_template_matching_function
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return get_chat_template("vicuna_v1.1")
if "llava-v1.5" in model_path.lower():
return get_chat_template("vicuna_v1.1")
if "llava-next-video-7b" in model_path.lower():
return get_chat_template("vicuna_v1.1")
if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE):
return "vicuna_v1.1"
@register_chat_template_matching_function
def match_llama2_chat(model_path: str):
model_path = model_path.lower()
if "llama-2" in model_path and "chat" in model_path:
return get_chat_template("llama-2-chat")
if (
"mistral" in model_path or "mixtral" in model_path
) and "instruct" in model_path:
return get_chat_template("llama-2-chat")
if "codellama" in model_path and "instruct" in model_path:
return get_chat_template("llama-2-chat")
if re.search(
r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct",
model_path,
re.IGNORECASE,
):
return "llama-2-chat"
@register_chat_template_matching_function
def match_llama3_instruct(model_path: str):
model_path = model_path.lower()
if "llama-3" in model_path and "instruct" in model_path:
return get_chat_template("llama-3-instruct")
if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):
return "llama-3-instruct"
@register_chat_template_matching_function
def match_chat_ml(model_path: str):
# import pdb;pdb.set_trace()
model_path = model_path.lower()
if "tinyllama" in model_path:
return get_chat_template("chatml")
# Now the suffix for qwen2 chat model is "instruct"
if "qwen" in model_path and "vl" in model_path:
return get_chat_template("qwen2-vl")
if "qwen" in model_path:
if "vl" in model_path:
return get_chat_template("qwen2-vl")
if ("chat" in model_path or "instruct" in model_path) and (
"llava" not in model_path
):
return get_chat_template("qwen")
if (
"llava-v1.6-34b" in model_path
or "llava-v1.6-yi-34b" in model_path
or "llava-next-video-34b" in model_path
or "llava-onevision-qwen2" in model_path
if re.search(r"tinyllama", model_path, re.IGNORECASE):
return "chatml"
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
return "qwen2-vl"
if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search(
r"llava", model_path, re.IGNORECASE
):
return get_chat_template("chatml-llava")
return "qwen"
if re.search(
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
model_path,
re.IGNORECASE,
):
return "chatml-llava"
@register_chat_template_matching_function
def match_chat_yi(model_path: str):
model_path = model_path.lower()
if "yi-vl" in model_path and "llava" not in model_path:
return get_chat_template("yi-vl")
elif "yi-1.5" in model_path and "chat" in model_path:
return get_chat_template("yi-1.5")
if re.search(r"yi-vl", model_path, re.IGNORECASE) and not re.search(
r"llava", model_path, re.IGNORECASE
):
return "yi-vl"
elif re.search(r"yi-1\.5.*chat", model_path, re.IGNORECASE):
return "yi-1.5"
@register_chat_template_matching_function
def match_gemma_it(model_path: str):
model_path = model_path.lower()
if "gemma" in model_path and "it" in model_path:
return get_chat_template("gemma-it")
if re.search(r"gemma.*it", model_path, re.IGNORECASE):
return "gemma-it"
@register_chat_template_matching_function
def match_openbmb_minicpm(model_path: str):
model_path = model_path.lower()
if "minicpm-v" in model_path:
return get_chat_template("minicpmv")
elif "minicpm-o" in model_path:
return get_chat_template("minicpmo")
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
return "minicpmv"
elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
return "minicpmo"
@register_chat_template_matching_function
def match_c4ai_command_r(model_path: str):
model_path = model_path.lower()
if "c4ai-command-r" in model_path:
return get_chat_template("c4ai-command-r")
if re.search(r"c4ai-command-r", model_path, re.IGNORECASE):
return "c4ai-command-r"
@register_chat_template_matching_function
def match_granite_instruct(model_path: str):
model_path = model_path.lower()
# When future versions of Granite are released, this code may
# need to be updated. For now, assume that the Granite 3.0
# template works across the board.
if "granite" in model_path and "instruct" in model_path:
return get_chat_template("granite-3-instruct")
if re.search(r"granite.*instruct", model_path, re.IGNORECASE):
return "granite-3-instruct"
@register_chat_template_matching_function
def match_gemma3_instruct(model_path: str):
model_path = model_path.lower()
if "gemma-3" in model_path and "1b" not in model_path:
# gemma-3-1b-it is completion model
return get_chat_template("gemma-it")
if re.search(r"gemma-3", model_path, re.IGNORECASE):
return "gemma-it"
@register_chat_template_matching_function
def match_internvl_chat(model_path: str):
model_path = model_path.lower()
if "internvl" in model_path:
return get_chat_template("internvl-2-5")
if re.search(r"internvl2_5", model_path, re.IGNORECASE):
return "internvl-2-5"
if __name__ == "__main__":

View File

@@ -16,6 +16,7 @@
# Adapted from
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses
import re
from enum import IntEnum, auto
from typing import Callable, Dict, List, Optional, Tuple, Union
@@ -852,91 +853,75 @@ register_conv_template(
)
@register_conv_template_matching_function
def match_internvl(model_path: str):
if re.search(r"internvl2_5", model_path, re.IGNORECASE):
return "internvl-2-5"
@register_conv_template_matching_function
def match_llama_3_vision(model_path: str):
if (
"llama" in model_path.lower()
and "3.2" in model_path.lower()
and "vision" in model_path.lower()
):
if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
return "llama_3_vision"
@register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if "janus" in model_path.lower():
if re.search(r"janus", model_path, re.IGNORECASE):
return "janus-pro"
@register_conv_template_matching_function
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return "vicuna_v1.1"
if "llava-v1.5" in model_path.lower():
return "vicuna_v1.1"
if "llava-next-video-7b" in model_path.lower():
if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE):
return "vicuna_v1.1"
@register_conv_template_matching_function
def match_llama2_chat(model_path: str):
model_path = model_path.lower()
if "llama-2" in model_path and "chat" in model_path:
return "llama-2"
if (
"mistral" in model_path or "mixtral" in model_path
) and "instruct" in model_path:
return "llama-2"
if "codellama" in model_path and "instruct" in model_path:
if re.search(
r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct",
model_path,
re.IGNORECASE,
):
return "llama-2"
@register_conv_template_matching_function
def match_deepseek_vl(model_path: str):
model_path = model_path.lower()
if "deepseek" in model_path and "vl2" in model_path:
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
return "deepseek-vl2"
@register_conv_template_matching_function
def match_chat_ml(model_path: str):
# import pdb;pdb.set_trace()
model_path = model_path.lower()
# Now the suffix for qwen2 chat model is "instruct"
if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
def match_qwen_chat_ml(model_path: str):
if re.search(r"gme.*qwen.*vl", model_path, re.IGNORECASE):
return "gme-qwen2-vl"
if "qwen" in model_path and "vl" in model_path:
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
return "qwen2-vl"
if (
"llava-v1.6-34b" in model_path
or "llava-v1.6-yi-34b" in model_path
or "llava-next-video-34b" in model_path
or "llava-onevision-qwen2" in model_path
if re.search(
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
model_path,
re.IGNORECASE,
):
return "chatml-llava"
@register_conv_template_matching_function
def match_gemma_it(model_path: str):
model_path = model_path.lower()
if "gemma" in model_path and "it" in model_path:
return "gemma-it"
if "gemma-3" in model_path and "1b" not in model_path:
# gemma-3-1b-it is completion model
def match_gemma3_instruct(model_path: str):
if re.search(r"gemma-3.*it", model_path, re.IGNORECASE):
return "gemma-it"
@register_conv_template_matching_function
def match_openbmb_minicpm(model_path: str):
model_path = model_path.lower()
if "minicpm-v" in model_path:
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
return "minicpmv"
elif "minicpm-o" in model_path:
elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
return "minicpmo"
@register_conv_template_matching_function
def match_moonshot_kimivl(model_path: str):
model_path = model_path.lower()
if "kimi" in model_path and "vl" in model_path:
if re.search(r"kimi.*vl", model_path, re.IGNORECASE):
return "kimi-vl"