diff --git a/python/sglang/srt/conversation.py b/python/sglang/srt/conversation.py index 3fd45b467..33792101b 100644 --- a/python/sglang/srt/conversation.py +++ b/python/sglang/srt/conversation.py @@ -17,7 +17,7 @@ # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py import dataclasses from enum import IntEnum, auto -from typing import Dict, List, Optional, Tuple, Union +from typing import Callable, Dict, List, Optional, Tuple, Union from sglang.srt.openai_api.protocol import ChatCompletionRequest @@ -407,6 +407,7 @@ class Conversation: # A global registry for all conversation templates chat_templates: Dict[str, Conversation] = {} +matching_function_registry: List[Callable] = [] def register_conv_template(template: Conversation, override: bool = False): @@ -419,6 +420,18 @@ def register_conv_template(template: Conversation, override: bool = False): chat_templates[template.name] = template +def register_conv_template_matching_function(func): + matching_function_registry.append(func) + + +def get_conv_template_by_model_path(model_path): + for matching_func in matching_function_registry: + conv_name = matching_func(model_path) + if conv_name is not None: + return conv_name + return None + + def chat_template_exists(template_name: str) -> bool: return template_name in chat_templates @@ -792,3 +805,86 @@ register_conv_template( audio_token="()", ) ) + + +@register_conv_template_matching_function +def match_deepseek_janus_pro(model_path: str): + if ( + "llama" in model_path.lower() + and "3.2" in model_path.lower() + and "vision" in model_path.lower() + ): + return "llama_3_vision" + + +@register_conv_template_matching_function +def match_deepseek_janus_pro(model_path: str): + if "janus" in model_path.lower(): + return "janus-pro" + + +@register_conv_template_matching_function +def match_vicuna(model_path: str): + if "vicuna" in model_path.lower(): + return "vicuna_v1.1" + if "llava-v1.5" in model_path.lower(): + return "vicuna_v1.1" + if "llava-next-video-7b" in model_path.lower(): + return "vicuna_v1.1" + + +@register_conv_template_matching_function +def match_llama2_chat(model_path: str): + model_path = model_path.lower() + if "llama-2" in model_path and "chat" in model_path: + return "llama-2" + if ( + "mistral" in model_path or "mixtral" in model_path + ) and "instruct" in model_path: + return "llama-2" + if "codellama" in model_path and "instruct" in model_path: + return "llama-2" + + +@register_conv_template_matching_function +def match_deepseek_vl(model_path: str): + model_path = model_path.lower() + if "deepseek" in model_path and "vl2" in model_path: + return "deepseek-vl2" + + +@register_conv_template_matching_function +def match_chat_ml(model_path: str): + # import pdb;pdb.set_trace() + model_path = model_path.lower() + # Now the suffix for qwen2 chat model is "instruct" + if "gme" in model_path and "qwen" in model_path and "vl" in model_path: + return "gme-qwen2-vl" + if "qwen" in model_path and "vl" in model_path: + return "qwen2-vl" + if ( + "llava-v1.6-34b" in model_path + or "llava-v1.6-yi-34b" in model_path + or "llava-next-video-34b" in model_path + or "llava-onevision-qwen2" in model_path + ): + return "chatml-llava" + + +@register_conv_template_matching_function +def match_gemma_it(model_path: str): + model_path = model_path.lower() + if "gemma" in model_path and "it" in model_path: + return "gemma-it" + if "gemma-3" in model_path and "1b" not in model_path: + # gemma-3-1b-it is completion model + return "gemma-it" + + +@register_conv_template_matching_function +def match_openbmb_minicpm(model_path: str): + model_path = model_path.lower() + if "minicpm-v" in model_path: + return "minicpmv" + elif "minicpm-o" in model_path: + return "minicpmo" diff --git a/python/sglang/srt/entrypoints/engine.py b/python/sglang/srt/entrypoints/engine.py index b03e3a494..3e8222f13 100644 --- a/python/sglang/srt/entrypoints/engine.py +++ b/python/sglang/srt/entrypoints/engine.py @@ -58,7 +58,10 @@ from sglang.srt.managers.io_struct import ( ) from sglang.srt.managers.scheduler import run_scheduler_process from sglang.srt.managers.tokenizer_manager import TokenizerManager -from sglang.srt.openai_api.adapter import load_chat_template_for_openai_api +from sglang.srt.openai_api.adapter import ( + guess_chat_template_name_from_model_path, + load_chat_template_for_openai_api, +) from sglang.srt.server_args import PortArgs, ServerArgs from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter from sglang.srt.utils import ( @@ -584,6 +587,8 @@ def _launch_subprocesses( load_chat_template_for_openai_api( tokenizer_manager, server_args.chat_template, server_args.model_path ) + else: + guess_chat_template_name_from_model_path(server_args.model_path) if server_args.completion_template: load_completion_template_for_openai_api(server_args.completion_template) diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 911969859..619f1d404 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -36,6 +36,7 @@ from sglang.srt.conversation import ( chat_template_exists, generate_chat_conv, generate_embedding_convs, + get_conv_template_by_model_path, register_conv_template, ) from sglang.srt.function_call_parser import FunctionCallParser @@ -163,10 +164,14 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode else: chat_template_name = chat_template_arg - # Check chat-template - # TODO: - # 1. Do not import any code from sglang.lang - # 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path. + +def guess_chat_template_name_from_model_path(model_path): + global chat_template_name + chat_template_name = get_conv_template_by_model_path(model_path) + if chat_template_name is not None: + logger.info( + f"Infer the chat template name from the model path and obtain the result: {chat_template_name}." + ) async def v1_files_create( diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 8f5eed220..6cef5e5e5 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -47,11 +47,6 @@ class TestOpenAIVisionServer(CustomTestCase): cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, - other_args=[ - "--chat-template", - "chatml-llava", - # "--log-requests", - ], ) cls.base_url += "/v1" @@ -475,8 +470,6 @@ class TestQwen2VLServer(TestOpenAIVisionServer): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, other_args=[ - "--chat-template", - "qwen2-vl", "--mem-fraction-static", "0.4", ], @@ -496,8 +489,6 @@ class TestQwen2_5_VLServer(TestOpenAIVisionServer): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, other_args=[ - "--chat-template", - "qwen2-vl", "--mem-fraction-static", "0.4", ], @@ -517,8 +508,6 @@ class TestVLMContextLengthIssue(CustomTestCase): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, other_args=[ - "--chat-template", - "qwen2-vl", "--context-length", "300", "--mem-fraction-static=0.80", @@ -573,10 +562,6 @@ class TestMllamaServer(TestOpenAIVisionServer): cls.base_url, timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, api_key=cls.api_key, - other_args=[ - "--chat-template", - "llama_3_vision", - ], ) cls.base_url += "/v1" @@ -596,8 +581,6 @@ class TestMinicpmvServer(TestOpenAIVisionServer): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", - "--chat-template", - "minicpmv", "--mem-fraction-static", "0.4", ], @@ -617,8 +600,6 @@ class TestMinicpmoServer(TestOpenAIVisionServer): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", - "--chat-template", - "minicpmo", "--mem-fraction-static", "0.7", ], @@ -642,8 +623,6 @@ class TestDeepseekVL2Server(TestOpenAIVisionServer): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", - "--chat-template", - "deepseek-vl2", "--context-length", "4096", ], @@ -690,8 +669,6 @@ class TestJanusProServer(TestOpenAIVisionServer): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", - "--chat-template", - "janus-pro", "--mem-fraction-static", "0.4", ], @@ -744,8 +721,6 @@ class TestGemma3itServer(TestOpenAIVisionServer): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, other_args=[ "--trust-remote-code", - "--chat-template", - "gemma-it", "--mem-fraction-static", "0.75", "--enable-multimodal",