[Feature] support auto chat template (#4949)
This commit is contained in:
@@ -17,7 +17,7 @@
|
||||
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
|
||||
import dataclasses
|
||||
from enum import IntEnum, auto
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from sglang.srt.openai_api.protocol import ChatCompletionRequest
|
||||
|
||||
@@ -407,6 +407,7 @@ class Conversation:
|
||||
|
||||
# A global registry for all conversation templates
|
||||
chat_templates: Dict[str, Conversation] = {}
|
||||
matching_function_registry: List[Callable] = []
|
||||
|
||||
|
||||
def register_conv_template(template: Conversation, override: bool = False):
|
||||
@@ -419,6 +420,18 @@ def register_conv_template(template: Conversation, override: bool = False):
|
||||
chat_templates[template.name] = template
|
||||
|
||||
|
||||
def register_conv_template_matching_function(func):
|
||||
matching_function_registry.append(func)
|
||||
|
||||
|
||||
def get_conv_template_by_model_path(model_path):
|
||||
for matching_func in matching_function_registry:
|
||||
conv_name = matching_func(model_path)
|
||||
if conv_name is not None:
|
||||
return conv_name
|
||||
return None
|
||||
|
||||
|
||||
def chat_template_exists(template_name: str) -> bool:
|
||||
return template_name in chat_templates
|
||||
|
||||
@@ -792,3 +805,86 @@ register_conv_template(
|
||||
audio_token="(<audio>./</audio>)",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_deepseek_janus_pro(model_path: str):
|
||||
if (
|
||||
"llama" in model_path.lower()
|
||||
and "3.2" in model_path.lower()
|
||||
and "vision" in model_path.lower()
|
||||
):
|
||||
return "llama_3_vision"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_deepseek_janus_pro(model_path: str):
|
||||
if "janus" in model_path.lower():
|
||||
return "janus-pro"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_vicuna(model_path: str):
|
||||
if "vicuna" in model_path.lower():
|
||||
return "vicuna_v1.1"
|
||||
if "llava-v1.5" in model_path.lower():
|
||||
return "vicuna_v1.1"
|
||||
if "llava-next-video-7b" in model_path.lower():
|
||||
return "vicuna_v1.1"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_llama2_chat(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "llama-2" in model_path and "chat" in model_path:
|
||||
return "llama-2"
|
||||
if (
|
||||
"mistral" in model_path or "mixtral" in model_path
|
||||
) and "instruct" in model_path:
|
||||
return "llama-2"
|
||||
if "codellama" in model_path and "instruct" in model_path:
|
||||
return "llama-2"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_deepseek_vl(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "deepseek" in model_path and "vl2" in model_path:
|
||||
return "deepseek-vl2"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_chat_ml(model_path: str):
|
||||
# import pdb;pdb.set_trace()
|
||||
model_path = model_path.lower()
|
||||
# Now the suffix for qwen2 chat model is "instruct"
|
||||
if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
|
||||
return "gme-qwen2-vl"
|
||||
if "qwen" in model_path and "vl" in model_path:
|
||||
return "qwen2-vl"
|
||||
if (
|
||||
"llava-v1.6-34b" in model_path
|
||||
or "llava-v1.6-yi-34b" in model_path
|
||||
or "llava-next-video-34b" in model_path
|
||||
or "llava-onevision-qwen2" in model_path
|
||||
):
|
||||
return "chatml-llava"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_gemma_it(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "gemma" in model_path and "it" in model_path:
|
||||
return "gemma-it"
|
||||
if "gemma-3" in model_path and "1b" not in model_path:
|
||||
# gemma-3-1b-it is completion model
|
||||
return "gemma-it"
|
||||
|
||||
|
||||
@register_conv_template_matching_function
|
||||
def match_openbmb_minicpm(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "minicpm-v" in model_path:
|
||||
return "minicpmv"
|
||||
elif "minicpm-o" in model_path:
|
||||
return "minicpmo"
|
||||
|
||||
@@ -58,7 +58,10 @@ from sglang.srt.managers.io_struct import (
|
||||
)
|
||||
from sglang.srt.managers.scheduler import run_scheduler_process
|
||||
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
||||
from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api
|
||||
from sglang.srt.openai_api.adapter import (
|
||||
guess_chat_template_name_from_model_path,
|
||||
load_chat_template_for_openai_api,
|
||||
)
|
||||
from sglang.srt.server_args import PortArgs, ServerArgs
|
||||
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
||||
from sglang.srt.utils import (
|
||||
@@ -584,6 +587,8 @@ def _launch_subprocesses(
|
||||
load_chat_template_for_openai_api(
|
||||
tokenizer_manager, server_args.chat_template, server_args.model_path
|
||||
)
|
||||
else:
|
||||
guess_chat_template_name_from_model_path(server_args.model_path)
|
||||
|
||||
if server_args.completion_template:
|
||||
load_completion_template_for_openai_api(server_args.completion_template)
|
||||
|
||||
@@ -36,6 +36,7 @@ from sglang.srt.conversation import (
|
||||
chat_template_exists,
|
||||
generate_chat_conv,
|
||||
generate_embedding_convs,
|
||||
get_conv_template_by_model_path,
|
||||
register_conv_template,
|
||||
)
|
||||
from sglang.srt.function_call_parser import FunctionCallParser
|
||||
@@ -163,10 +164,14 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
|
||||
else:
|
||||
chat_template_name = chat_template_arg
|
||||
|
||||
# Check chat-template
|
||||
# TODO:
|
||||
# 1. Do not import any code from sglang.lang
|
||||
# 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path.
|
||||
|
||||
def guess_chat_template_name_from_model_path(model_path):
|
||||
global chat_template_name
|
||||
chat_template_name = get_conv_template_by_model_path(model_path)
|
||||
if chat_template_name is not None:
|
||||
logger.info(
|
||||
f"Infer the chat template name from the model path and obtain the result: {chat_template_name}."
|
||||
)
|
||||
|
||||
|
||||
async def v1_files_create(
|
||||
|
||||
Reference in New Issue
Block a user