[Feature] support auto chat template (#4949)

This commit is contained in:
woodx
2025-04-29 13:34:18 +08:00
committed by GitHub
parent 5bb0accbcf
commit 2c3ea29476
4 changed files with 112 additions and 31 deletions

View File

@@ -17,7 +17,7 @@
# https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
import dataclasses import dataclasses
from enum import IntEnum, auto from enum import IntEnum, auto
from typing import Dict, List, Optional, Tuple, Union from typing import Callable, Dict, List, Optional, Tuple, Union
from sglang.srt.openai_api.protocol import ChatCompletionRequest from sglang.srt.openai_api.protocol import ChatCompletionRequest
@@ -407,6 +407,7 @@ class Conversation:
# A global registry for all conversation templates # A global registry for all conversation templates
chat_templates: Dict[str, Conversation] = {} chat_templates: Dict[str, Conversation] = {}
matching_function_registry: List[Callable] = []
def register_conv_template(template: Conversation, override: bool = False): def register_conv_template(template: Conversation, override: bool = False):
@@ -419,6 +420,18 @@ def register_conv_template(template: Conversation, override: bool = False):
chat_templates[template.name] = template chat_templates[template.name] = template
def register_conv_template_matching_function(func):
matching_function_registry.append(func)
def get_conv_template_by_model_path(model_path):
for matching_func in matching_function_registry:
conv_name = matching_func(model_path)
if conv_name is not None:
return conv_name
return None
def chat_template_exists(template_name: str) -> bool: def chat_template_exists(template_name: str) -> bool:
return template_name in chat_templates return template_name in chat_templates
@@ -792,3 +805,86 @@ register_conv_template(
audio_token="(<audio>./</audio>)", audio_token="(<audio>./</audio>)",
) )
) )
@register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if (
"llama" in model_path.lower()
and "3.2" in model_path.lower()
and "vision" in model_path.lower()
):
return "llama_3_vision"
@register_conv_template_matching_function
def match_deepseek_janus_pro(model_path: str):
if "janus" in model_path.lower():
return "janus-pro"
@register_conv_template_matching_function
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return "vicuna_v1.1"
if "llava-v1.5" in model_path.lower():
return "vicuna_v1.1"
if "llava-next-video-7b" in model_path.lower():
return "vicuna_v1.1"
@register_conv_template_matching_function
def match_llama2_chat(model_path: str):
model_path = model_path.lower()
if "llama-2" in model_path and "chat" in model_path:
return "llama-2"
if (
"mistral" in model_path or "mixtral" in model_path
) and "instruct" in model_path:
return "llama-2"
if "codellama" in model_path and "instruct" in model_path:
return "llama-2"
@register_conv_template_matching_function
def match_deepseek_vl(model_path: str):
model_path = model_path.lower()
if "deepseek" in model_path and "vl2" in model_path:
return "deepseek-vl2"
@register_conv_template_matching_function
def match_chat_ml(model_path: str):
# import pdb;pdb.set_trace()
model_path = model_path.lower()
# Now the suffix for qwen2 chat model is "instruct"
if "gme" in model_path and "qwen" in model_path and "vl" in model_path:
return "gme-qwen2-vl"
if "qwen" in model_path and "vl" in model_path:
return "qwen2-vl"
if (
"llava-v1.6-34b" in model_path
or "llava-v1.6-yi-34b" in model_path
or "llava-next-video-34b" in model_path
or "llava-onevision-qwen2" in model_path
):
return "chatml-llava"
@register_conv_template_matching_function
def match_gemma_it(model_path: str):
model_path = model_path.lower()
if "gemma" in model_path and "it" in model_path:
return "gemma-it"
if "gemma-3" in model_path and "1b" not in model_path:
# gemma-3-1b-it is completion model
return "gemma-it"
@register_conv_template_matching_function
def match_openbmb_minicpm(model_path: str):
model_path = model_path.lower()
if "minicpm-v" in model_path:
return "minicpmv"
elif "minicpm-o" in model_path:
return "minicpmo"

View File

@@ -58,7 +58,10 @@ from sglang.srt.managers.io_struct import (
) )
from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.scheduler import run_scheduler_process
from sglang.srt.managers.tokenizer_manager import TokenizerManager from sglang.srt.managers.tokenizer_manager import TokenizerManager
from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api from sglang.srt.openai_api.adapter import (
guess_chat_template_name_from_model_path,
load_chat_template_for_openai_api,
)
from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.server_args import PortArgs, ServerArgs
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
from sglang.srt.utils import ( from sglang.srt.utils import (
@@ -584,6 +587,8 @@ def _launch_subprocesses(
load_chat_template_for_openai_api( load_chat_template_for_openai_api(
tokenizer_manager, server_args.chat_template, server_args.model_path tokenizer_manager, server_args.chat_template, server_args.model_path
) )
else:
guess_chat_template_name_from_model_path(server_args.model_path)
if server_args.completion_template: if server_args.completion_template:
load_completion_template_for_openai_api(server_args.completion_template) load_completion_template_for_openai_api(server_args.completion_template)

View File

@@ -36,6 +36,7 @@ from sglang.srt.conversation import (
chat_template_exists, chat_template_exists,
generate_chat_conv, generate_chat_conv,
generate_embedding_convs, generate_embedding_convs,
get_conv_template_by_model_path,
register_conv_template, register_conv_template,
) )
from sglang.srt.function_call_parser import FunctionCallParser from sglang.srt.function_call_parser import FunctionCallParser
@@ -163,10 +164,14 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
else: else:
chat_template_name = chat_template_arg chat_template_name = chat_template_arg
# Check chat-template
# TODO: def guess_chat_template_name_from_model_path(model_path):
# 1. Do not import any code from sglang.lang global chat_template_name
# 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path. chat_template_name = get_conv_template_by_model_path(model_path)
if chat_template_name is not None:
logger.info(
f"Infer the chat template name from the model path and obtain the result: {chat_template_name}."
)
async def v1_files_create( async def v1_files_create(

View File

@@ -47,11 +47,6 @@ class TestOpenAIVisionServer(CustomTestCase):
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key, api_key=cls.api_key,
other_args=[
"--chat-template",
"chatml-llava",
# "--log-requests",
],
) )
cls.base_url += "/v1" cls.base_url += "/v1"
@@ -475,8 +470,6 @@ class TestQwen2VLServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key, api_key=cls.api_key,
other_args=[ other_args=[
"--chat-template",
"qwen2-vl",
"--mem-fraction-static", "--mem-fraction-static",
"0.4", "0.4",
], ],
@@ -496,8 +489,6 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key, api_key=cls.api_key,
other_args=[ other_args=[
"--chat-template",
"qwen2-vl",
"--mem-fraction-static", "--mem-fraction-static",
"0.4", "0.4",
], ],
@@ -517,8 +508,6 @@ class TestVLMContextLengthIssue(CustomTestCase):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key, api_key=cls.api_key,
other_args=[ other_args=[
"--chat-template",
"qwen2-vl",
"--context-length", "--context-length",
"300", "300",
"--mem-fraction-static=0.80", "--mem-fraction-static=0.80",
@@ -573,10 +562,6 @@ class TestMllamaServer(TestOpenAIVisionServer):
cls.base_url, cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key=cls.api_key, api_key=cls.api_key,
other_args=[
"--chat-template",
"llama_3_vision",
],
) )
cls.base_url += "/v1" cls.base_url += "/v1"
@@ -596,8 +581,6 @@ class TestMinicpmvServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ other_args=[
"--trust-remote-code", "--trust-remote-code",
"--chat-template",
"minicpmv",
"--mem-fraction-static", "--mem-fraction-static",
"0.4", "0.4",
], ],
@@ -617,8 +600,6 @@ class TestMinicpmoServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ other_args=[
"--trust-remote-code", "--trust-remote-code",
"--chat-template",
"minicpmo",
"--mem-fraction-static", "--mem-fraction-static",
"0.7", "0.7",
], ],
@@ -642,8 +623,6 @@ class TestDeepseekVL2Server(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ other_args=[
"--trust-remote-code", "--trust-remote-code",
"--chat-template",
"deepseek-vl2",
"--context-length", "--context-length",
"4096", "4096",
], ],
@@ -690,8 +669,6 @@ class TestJanusProServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ other_args=[
"--trust-remote-code", "--trust-remote-code",
"--chat-template",
"janus-pro",
"--mem-fraction-static", "--mem-fraction-static",
"0.4", "0.4",
], ],
@@ -744,8 +721,6 @@ class TestGemma3itServer(TestOpenAIVisionServer):
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[ other_args=[
"--trust-remote-code", "--trust-remote-code",
"--chat-template",
"gemma-it",
"--mem-fraction-static", "--mem-fraction-static",
"0.75", "0.75",
"--enable-multimodal", "--enable-multimodal",