from dataclasses import dataclass, field from enum import Enum, auto from typing import Callable, Dict, List, Optional, Tuple class ChatTemplateStyle(Enum): PLAIN = auto() LLAMA2 = auto() @dataclass class ChatTemplate: name: str default_system_prompt: str role_prefix_and_suffix: Dict[str, Tuple[str]] stop_str: List[str] = () image_token: str = "" style: ChatTemplateStyle = ChatTemplateStyle.PLAIN def get_prefix_and_suffix(self, role, hist_messages): if self.style == ChatTemplateStyle.PLAIN: return self.role_prefix_and_suffix[role] elif self.style == ChatTemplateStyle.LLAMA2: if len(hist_messages) == 0 and role == "system": return ( self.role_prefix_and_suffix["user"][0] + self.role_prefix_and_suffix["system"][0], self.role_prefix_and_suffix["system"][1], ) elif ( len(hist_messages) == 1 and role == "user" and hist_messages[0]["content"] is not None ): return ("", self.role_prefix_and_suffix["user"][1]) return self.role_prefix_and_suffix[role] else: raise ValueError(f"Invalid style: {self.style}") def get_prompt(self, messages): prompt = "" for i in range(len(messages)): role, content = messages[i]["role"], messages[i]["content"] if role == "system" and content is None: content = self.default_system_prompt if content is None: continue prefix, suffix = self.get_prefix_and_suffix(role, messages[:i]) prompt += prefix + content + suffix return prompt chat_template_registry: Dict[str, ChatTemplate] = {} matching_function_registry: List[Callable] = [] def register_chat_template(template): chat_template_registry[template.name] = template def register_chat_template_matching_function(func): matching_function_registry.append(func) def get_chat_template(name): return chat_template_registry[name] def get_chat_template_by_model_path(model_path): for matching_func in matching_function_registry: template = matching_func(model_path) if template is not None: return template return get_chat_template("default") register_chat_template( ChatTemplate( name="default", default_system_prompt=None, role_prefix_and_suffix={ "system": ("SYSTEM:", "\n"), "user": ("USER:", "\n"), "assistant": ("ASSISTANT:", "\n"), }, ) ) register_chat_template( ChatTemplate( name="claude", default_system_prompt=None, role_prefix_and_suffix={ "system": ("", ""), "user": ("\n\nHuman: ", ""), "assistant": ("\n\nAssistant:", ""), }, ) ) register_chat_template( ChatTemplate( name="chatml", default_system_prompt=None, role_prefix_and_suffix={ "system": ("<|im_start|>system\n", "<|im_end|>\n"), "user": ("<|im_start|>user\n", "<|im_end|>\n"), "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"), }, style=ChatTemplateStyle.PLAIN, stop_str=("<|im_end|>",), ) ) register_chat_template( ChatTemplate( name="chatml-llava", default_system_prompt="Answer the questions.", role_prefix_and_suffix={ "system": ("<|im_start|>system\n", "<|im_end|>\n"), "user": ("<|im_start|>user\n", "<|im_end|>\n"), "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"), }, style=ChatTemplateStyle.PLAIN, stop_str=("<|im_end|>",), image_token=" \n", ) ) register_chat_template( ChatTemplate( name="vicuna_v1.1", default_system_prompt=( "A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions." ), role_prefix_and_suffix={ "system": ("", " "), "user": ("USER:", " "), "assistant": ("ASSISTANT:", ""), }, image_token=" \n", ) ) register_chat_template( ChatTemplate( name="llama-2-chat", default_system_prompt=None, role_prefix_and_suffix={ "system": ("<>\n", "\n<>\n\n"), "user": ("[INST] ", " [/INST]"), "assistant": ("", " "), }, style=ChatTemplateStyle.LLAMA2, ) ) # Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava register_chat_template( ChatTemplate( name="yi", default_system_prompt=( "This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers." "这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。" ), role_prefix_and_suffix={ "system": ("", "\n\n"), "user": ("### Human:", "\n"), "assistant": ("### Assistant:", "\n"), }, image_token=" \n", ) ) @register_chat_template_matching_function def match_vicuna(model_path: str): if "vicuna" in model_path.lower(): return get_chat_template("vicuna_v1.1") if "llava-v1.5" in model_path.lower(): return get_chat_template("vicuna_v1.1") @register_chat_template_matching_function def match_llama2_chat(model_path: str): model_path = model_path.lower() if "llama-2" in model_path and "chat" in model_path: return get_chat_template("llama-2-chat") if ( "mistral" in model_path or "mixtral" in model_path ) and "instruct" in model_path: return get_chat_template("llama-2-chat") if "codellama" in model_path and "instruct" in model_path: return get_chat_template("llama-2-chat") @register_chat_template_matching_function def match_chat_ml(model_path: str): model_path = model_path.lower() if "tinyllama" in model_path: return get_chat_template("chatml") if "qwen" in model_path and "chat" in model_path: return get_chat_template("chatml") if "llava-v1.6-34b" in model_path: return get_chat_template("chatml-llava") @register_chat_template_matching_function def match_chat_yi(model_path: str): model_path = model_path.lower() if "yi" in model_path: return get_chat_template("yi") if __name__ == "__main__": messages = [ {"role": "system", "content": None}, # None means default # {"role": "system", "content": "You are a helpful, respectful and honest assistant."}, {"role": "user", "content": "Hello!"}, {"role": "assistant", "content": "Hi!"}, {"role": "user", "content": "What can you do?"}, {"role": "assistant", "content": "I can chat with you."}, ] template = get_chat_template("llama-2-chat") print(template.get_prompt(messages))