[Feature] support auto chat template (#4949)

This commit is contained in:
woodx
2025-04-29 13:34:18 +08:00
committed by GitHub
parent 5bb0accbcf
commit 2c3ea29476
4 changed files with 112 additions and 31 deletions

View File

@@ -17,7 +17,7 @@
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses
from enum import IntEnum, auto
from typing import Dict, List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union
from sglang.srt.openai_api.protocol import ChatCompletionRequest
@@ -407,6 +407,7 @@ class Conversation:
# A global registry for all conversation templates
chat_templates: Dict[str, Conversation] = {}
matching_function_registry: List[Callable] = []
def register_conv_template(template: Conversation, override: bool = False):
@@ -419,6 +420,18 @@ def register_conv_template(template: Conversation, override: bool = False):
chat_templates[template.name] = template
def register_conv_template_matching_function(func):
matching_function_registry.append(func)
def get_conv_template_by_model_path(model_path):
for matching_func in matching_function_registry:
conv_name = matching_func(model_path)
if conv_name is not None:
return conv_name
return None
def chat_template_exists(template_name: str) -> bool:
return template_name in chat_templates
@@ -792,3 +805,86 @@ register_conv_template(
audio_token="(<audio>./</audio>)",
)
)
@register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if (
"llama" in model_path.lower()
and "3.2" in model_path.lower()
and "vision" in model_path.lower()
):
return "llama_3_vision"
@register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if "janus" in model_path.lower():
return "janus-pro"
@register_conv_template_matching_function
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return "vicuna_v1.1"
if "llava-v1.5" in model_path.lower():
return "vicuna_v1.1"
if "llava-next-video-7b" in model_path.lower():
return "vicuna_v1.1"
@register_conv_template_matching_function
def match_llama2_chat(model_path: str):
model_path = model_path.lower()
if "llama-2" in model_path and "chat" in model_path:
return "llama-2"
if (
"mistral" in model_path or "mixtral" in model_path
) and "instruct" in model_path:
return "llama-2"
if "codellama" in model_path and "instruct" in model_path:
return "llama-2"
@register_conv_template_matching_function
def match_deepseek_vl(model_path: str):
model_path = model_path.lower()
if "deepseek" in model_path and "vl2" in model_path:
return "deepseek-vl2"
@register_conv_template_matching_function
def match_chat_ml(model_path: str):
# import pdb;pdb.set_trace()
model_path = model_path.lower()
# Now the suffix for qwen2 chat model is "instruct"
if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
return "gme-qwen2-vl"
if "qwen" in model_path and "vl" in model_path:
return "qwen2-vl"
if (
"llava-v1.6-34b" in model_path
or "llava-v1.6-yi-34b" in model_path
or "llava-next-video-34b" in model_path
or "llava-onevision-qwen2" in model_path
):
return "chatml-llava"
@register_conv_template_matching_function
def match_gemma_it(model_path: str):
model_path = model_path.lower()
if "gemma" in model_path and "it" in model_path:
return "gemma-it"
if "gemma-3" in model_path and "1b" not in model_path:
# gemma-3-1b-it is completion model
return "gemma-it"
@register_conv_template_matching_function
def match_openbmb_minicpm(model_path: str):
model_path = model_path.lower()
if "minicpm-v" in model_path:
return "minicpmv"
elif "minicpm-o" in model_path:
return "minicpmo"

View File

@@ -58,7 +58,10 @@ from sglang.srt.managers.io_struct import (
)
from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api
from sglang.srt.openai_api.adapter import (
guess_chat_template_name_from_model_path,
load_chat_template_for_openai_api,
)
from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import (
@@ -584,6 +587,8 @@ def _launch_subprocesses(
load_chat_template_for_openai_api(
tokenizer_manager, server_args.chat_template, server_args.model_path
)
else:
guess_chat_template_name_from_model_path(server_args.model_path)
if server_args.completion_template:
load_completion_template_for_openai_api(server_args.completion_template)

View File

@@ -36,6 +36,7 @@ from sglang.srt.conversation import (
chat_template_exists,
generate_chat_conv,
generate_embedding_convs,
get_conv_template_by_model_path,
register_conv_template,
)
from sglang.srt.function_call_parser import FunctionCallParser
@@ -163,10 +164,14 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
else:
chat_template_name = chat_template_arg
# Check chat-template
# TODO:
# 1. Do not import any code from sglang.lang
# 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path.
def guess_chat_template_name_from_model_path(model_path):
global chat_template_name
chat_template_name = get_conv_template_by_model_path(model_path)
if chat_template_name is not None:
logger.info(
f"Infer the chat template name from the model path and obtain the result: {chat_template_name}."
)
async def v1_files_create(