Fix and Clean up chat-template requirement for VLM (#6114)
Signed-off-by: Xinyuan Tong <justinning0323@outlook.com>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import re
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum, auto
|
||||
from typing import Callable, Dict, List, Tuple
|
||||
@@ -71,9 +72,9 @@ def get_chat_template(name):
|
||||
|
||||
def get_chat_template_by_model_path(model_path):
|
||||
for matching_func in matching_function_registry:
|
||||
template = matching_func(model_path)
|
||||
if template is not None:
|
||||
return template
|
||||
template_name = matching_func(model_path)
|
||||
if template_name is not None:
|
||||
return get_chat_template(template_name)
|
||||
return get_chat_template("default")
|
||||
|
||||
|
||||
@@ -479,134 +480,112 @@ register_chat_template(
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_deepseek(model_path: str):
|
||||
if (
|
||||
"deepseek-v3" in model_path.lower() or "deepseek-r1" in model_path.lower()
|
||||
) and "base" not in model_path.lower():
|
||||
return get_chat_template("deepseek-v3")
|
||||
if re.search(r"deepseek-(v3|r1)", model_path, re.IGNORECASE) and not re.search(
|
||||
r"base", model_path, re.IGNORECASE
|
||||
):
|
||||
return "deepseek-v3"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_deepseek_janus_pro(model_path: str):
|
||||
if "janus" in model_path.lower():
|
||||
return get_chat_template("janus-pro")
|
||||
if re.search(r"janus", model_path, re.IGNORECASE):
|
||||
return "janus-pro"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_dbrx(model_path: str):
|
||||
if "dbrx" in model_path.lower() and "instruct" in model_path.lower():
|
||||
return get_chat_template("dbrx-instruct")
|
||||
if re.search(r"dbrx", model_path, re.IGNORECASE) and re.search(
|
||||
r"instruct", model_path, re.IGNORECASE
|
||||
):
|
||||
return "dbrx-instruct"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_vicuna(model_path: str):
|
||||
if "vicuna" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
if "llava-v1.5" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
if "llava-next-video-7b" in model_path.lower():
|
||||
return get_chat_template("vicuna_v1.1")
|
||||
if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE):
|
||||
return "vicuna_v1.1"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_llama2_chat(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "llama-2" in model_path and "chat" in model_path:
|
||||
return get_chat_template("llama-2-chat")
|
||||
if (
|
||||
"mistral" in model_path or "mixtral" in model_path
|
||||
) and "instruct" in model_path:
|
||||
return get_chat_template("llama-2-chat")
|
||||
if "codellama" in model_path and "instruct" in model_path:
|
||||
return get_chat_template("llama-2-chat")
|
||||
if re.search(
|
||||
r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct",
|
||||
model_path,
|
||||
re.IGNORECASE,
|
||||
):
|
||||
return "llama-2-chat"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_llama3_instruct(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "llama-3" in model_path and "instruct" in model_path:
|
||||
return get_chat_template("llama-3-instruct")
|
||||
if re.search(r"llama-3.*instruct", model_path, re.IGNORECASE):
|
||||
return "llama-3-instruct"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_ml(model_path: str):
|
||||
# import pdb;pdb.set_trace()
|
||||
model_path = model_path.lower()
|
||||
if "tinyllama" in model_path:
|
||||
return get_chat_template("chatml")
|
||||
# Now the suffix for qwen2 chat model is "instruct"
|
||||
if "qwen" in model_path and "vl" in model_path:
|
||||
return get_chat_template("qwen2-vl")
|
||||
if "qwen" in model_path:
|
||||
if "vl" in model_path:
|
||||
return get_chat_template("qwen2-vl")
|
||||
if ("chat" in model_path or "instruct" in model_path) and (
|
||||
"llava" not in model_path
|
||||
):
|
||||
return get_chat_template("qwen")
|
||||
if (
|
||||
"llava-v1.6-34b" in model_path
|
||||
or "llava-v1.6-yi-34b" in model_path
|
||||
or "llava-next-video-34b" in model_path
|
||||
or "llava-onevision-qwen2" in model_path
|
||||
if re.search(r"tinyllama", model_path, re.IGNORECASE):
|
||||
return "chatml"
|
||||
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
|
||||
return "qwen2-vl"
|
||||
if re.search(r"qwen.*(chat|instruct)", model_path, re.IGNORECASE) and not re.search(
|
||||
r"llava", model_path, re.IGNORECASE
|
||||
):
|
||||
return get_chat_template("chatml-llava")
|
||||
return "qwen"
|
||||
if re.search(
|
||||
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
|
||||
model_path,
|
||||
re.IGNORECASE,
|
||||
):
|
||||
return "chatml-llava"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_chat_yi(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "yi-vl" in model_path and "llava" not in model_path:
|
||||
return get_chat_template("yi-vl")
|
||||
elif "yi-1.5" in model_path and "chat" in model_path:
|
||||
return get_chat_template("yi-1.5")
|
||||
if re.search(r"yi-vl", model_path, re.IGNORECASE) and not re.search(
|
||||
r"llava", model_path, re.IGNORECASE
|
||||
):
|
||||
return "yi-vl"
|
||||
elif re.search(r"yi-1\.5.*chat", model_path, re.IGNORECASE):
|
||||
return "yi-1.5"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_gemma_it(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "gemma" in model_path and "it" in model_path:
|
||||
return get_chat_template("gemma-it")
|
||||
if re.search(r"gemma.*it", model_path, re.IGNORECASE):
|
||||
return "gemma-it"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_openbmb_minicpm(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "minicpm-v" in model_path:
|
||||
return get_chat_template("minicpmv")
|
||||
elif "minicpm-o" in model_path:
|
||||
return get_chat_template("minicpmo")
|
||||
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
|
||||
return "minicpmv"
|
||||
elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
|
||||
return "minicpmo"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_c4ai_command_r(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "c4ai-command-r" in model_path:
|
||||
return get_chat_template("c4ai-command-r")
|
||||
if re.search(r"c4ai-command-r", model_path, re.IGNORECASE):
|
||||
return "c4ai-command-r"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_granite_instruct(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
# When future versions of Granite are released, this code may
|
||||
# need to be updated. For now, assume that the Granite 3.0
|
||||
# template works across the board.
|
||||
if "granite" in model_path and "instruct" in model_path:
|
||||
return get_chat_template("granite-3-instruct")
|
||||
if re.search(r"granite.*instruct", model_path, re.IGNORECASE):
|
||||
return "granite-3-instruct"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_gemma3_instruct(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "gemma-3" in model_path and "1b" not in model_path:
|
||||
# gemma-3-1b-it is completion model
|
||||
return get_chat_template("gemma-it")
|
||||
if re.search(r"gemma-3", model_path, re.IGNORECASE):
|
||||
return "gemma-it"
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_internvl_chat(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "internvl" in model_path:
|
||||
return get_chat_template("internvl-2-5")
|
||||
if re.search(r"internvl2_5", model_path, re.IGNORECASE):
|
||||
return "internvl-2-5"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
# Adapted from
|
||||
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
import dataclasses
|
||||
import re
|
||||
from enum import IntEnum, auto
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -852,91 +853,75 @@ register_conv_template(
|
||||
)
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_internvl(model_path: str):
|
||||
if re.search(r"internvl2_5", model_path, re.IGNORECASE):
|
||||
return "internvl-2-5"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_llama_3_vision(model_path: str):
|
||||
if (
|
||||
"llama" in model_path.lower()
|
||||
and "3.2" in model_path.lower()
|
||||
and "vision" in model_path.lower()
|
||||
):
|
||||
if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
|
||||
return "llama_3_vision"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_deepseek_janus_pro(model_path: str):
|
||||
if "janus" in model_path.lower():
|
||||
if re.search(r"janus", model_path, re.IGNORECASE):
|
||||
return "janus-pro"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_vicuna(model_path: str):
|
||||
if "vicuna" in model_path.lower():
|
||||
return "vicuna_v1.1"
|
||||
if "llava-v1.5" in model_path.lower():
|
||||
return "vicuna_v1.1"
|
||||
if "llava-next-video-7b" in model_path.lower():
|
||||
if re.search(r"vicuna|llava-v1\.5|llava-next-video-7b", model_path, re.IGNORECASE):
|
||||
return "vicuna_v1.1"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_llama2_chat(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "llama-2" in model_path and "chat" in model_path:
|
||||
return "llama-2"
|
||||
if (
|
||||
"mistral" in model_path or "mixtral" in model_path
|
||||
) and "instruct" in model_path:
|
||||
return "llama-2"
|
||||
if "codellama" in model_path and "instruct" in model_path:
|
||||
if re.search(
|
||||
r"llama-2.*chat|(mistral|mixtral).*instruct|codellama.*instruct",
|
||||
model_path,
|
||||
re.IGNORECASE,
|
||||
):
|
||||
return "llama-2"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_deepseek_vl(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "deepseek" in model_path and "vl2" in model_path:
|
||||
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
|
||||
return "deepseek-vl2"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_chat_ml(model_path: str):
|
||||
# import pdb;pdb.set_trace()
|
||||
model_path = model_path.lower()
|
||||
# Now the suffix for qwen2 chat model is "instruct"
|
||||
if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
|
||||
def match_qwen_chat_ml(model_path: str):
|
||||
if re.search(r"gme.*qwen.*vl", model_path, re.IGNORECASE):
|
||||
return "gme-qwen2-vl"
|
||||
if "qwen" in model_path and "vl" in model_path:
|
||||
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
|
||||
return "qwen2-vl"
|
||||
if (
|
||||
"llava-v1.6-34b" in model_path
|
||||
or "llava-v1.6-yi-34b" in model_path
|
||||
or "llava-next-video-34b" in model_path
|
||||
or "llava-onevision-qwen2" in model_path
|
||||
if re.search(
|
||||
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
|
||||
model_path,
|
||||
re.IGNORECASE,
|
||||
):
|
||||
return "chatml-llava"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_gemma_it(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "gemma" in model_path and "it" in model_path:
|
||||
return "gemma-it"
|
||||
if "gemma-3" in model_path and "1b" not in model_path:
|
||||
# gemma-3-1b-it is completion model
|
||||
def match_gemma3_instruct(model_path: str):
|
||||
if re.search(r"gemma-3.*it", model_path, re.IGNORECASE):
|
||||
return "gemma-it"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_openbmb_minicpm(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "minicpm-v" in model_path:
|
||||
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
|
||||
return "minicpmv"
|
||||
elif "minicpm-o" in model_path:
|
||||
elif re.search(r"minicpm-o", model_path, re.IGNORECASE):
|
||||
return "minicpmo"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_moonshot_kimivl(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "kimi" in model_path and "vl" in model_path:
|
||||
if re.search(r"kimi.*vl", model_path, re.IGNORECASE):
|
||||
return "kimi-vl"
|
||||
|
||||
Reference in New Issue
Block a user