Files
sglang/python/sglang/lang/chat_template.py
2024-02-24 16:34:22 +08:00

233 lines
7.3 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from dataclasses import dataclass, field
from enum import Enum, auto
from typing import Callable, Dict, List, Optional, Tuple
class ChatTemplateStyle(Enum):
PLAIN = auto()
LLAMA2 = auto()
@dataclass
class ChatTemplate:
name: str
default_system_prompt: str
role_prefix_and_suffix: Dict[str, Tuple[str]]
stop_str: List[str] = ()
image_token: str = "<image>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix(self, role, hist_messages):
if self.style == ChatTemplateStyle.PLAIN:
return self.role_prefix_and_suffix[role]
elif self.style == ChatTemplateStyle.LLAMA2:
if len(hist_messages) == 0 and role == "system":
return (
self.role_prefix_and_suffix["user"][0]
+ self.role_prefix_and_suffix["system"][0],
self.role_prefix_and_suffix["system"][1],
)
elif (
len(hist_messages) == 1
and role == "user"
and hist_messages[0]["content"] is not None
):
return ("", self.role_prefix_and_suffix["user"][1])
return self.role_prefix_and_suffix[role]
else:
raise ValueError(f"Invalid style: {self.style}")
def get_prompt(self, messages):
prompt = ""
for i in range(len(messages)):
role, content = messages[i]["role"], messages[i]["content"]
if role == "system" and content is None:
content = self.default_system_prompt
if content is None:
continue
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
prompt += prefix + content + suffix
return prompt
chat_template_registry: Dict[str, ChatTemplate] = {}
matching_function_registry: List[Callable] = []
def register_chat_template(template):
chat_template_registry[template.name] = template
def register_chat_template_matching_function(func):
matching_function_registry.append(func)
def get_chat_template(name):
return chat_template_registry[name]
def get_chat_template_by_model_path(model_path):
for matching_func in matching_function_registry:
template = matching_func(model_path)
if template is not None:
return template
return get_chat_template("default")
register_chat_template(
ChatTemplate(
name="default",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("SYSTEM:", "\n"),
"user": ("USER:", "\n"),
"assistant": ("ASSISTANT:", "\n"),
},
)
)
register_chat_template(
ChatTemplate(
name="claude",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("\n\nHuman: ", ""),
"assistant": ("\n\nAssistant:", ""),
},
)
)
register_chat_template(
ChatTemplate(
name="chatml",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
)
)
register_chat_template(
ChatTemplate(
name="chatml-llava",
default_system_prompt="Answer the questions.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "<|im_end|>\n"),
"user": ("<|im_start|>user\n", "<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token=" <image>\n",
)
)
register_chat_template(
ChatTemplate(
name="vicuna_v1.1",
default_system_prompt=(
"A chat between a curious user and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the user's questions."
),
role_prefix_and_suffix={
"system": ("", " "),
"user": ("USER:", " "),
"assistant": ("ASSISTANT:", "</s>"),
},
image_token=" <image>\n",
)
)
register_chat_template(
ChatTemplate(
name="llama-2-chat",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
"user": ("[INST] ", " [/INST]"),
"assistant": ("", " </s><s>"),
},
style=ChatTemplateStyle.LLAMA2,
)
)
# Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
register_chat_template(
ChatTemplate(
name="yi",
default_system_prompt=(
"This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
),
role_prefix_and_suffix={
"system": ("", "\n\n"),
"user": ("### Human:", "\n"),
"assistant": ("### Assistant:", "\n"),
},
image_token=" <image_placeholder>\n",
)
)
@register_chat_template_matching_function
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return get_chat_template("vicuna_v1.1")
if "llava-v1.5" in model_path.lower():
return get_chat_template("vicuna_v1.1")
@register_chat_template_matching_function
def match_llama2_chat(model_path: str):
model_path = model_path.lower()
if "llama-2" in model_path and "chat" in model_path:
return get_chat_template("llama-2-chat")
if (
"mistral" in model_path or "mixtral" in model_path
) and "instruct" in model_path:
return get_chat_template("llama-2-chat")
if "codellama" in model_path and "instruct" in model_path:
return get_chat_template("llama-2-chat")
@register_chat_template_matching_function
def match_chat_ml(model_path: str):
model_path = model_path.lower()
if "tinyllama" in model_path:
return get_chat_template("chatml")
if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml")
if "llava-v1.6-34b" in model_path:
return get_chat_template("chatml-llava")
@register_chat_template_matching_function
def match_chat_yi(model_path: str):
model_path = model_path.lower()
if "yi" in model_path:
return get_chat_template("yi")
if __name__ == "__main__":
messages = [
{"role": "system", "content": None}, # None means default
# {"role": "system", "content": "You are a helpful, respectful and honest assistant."},
{"role": "user", "content": "Hello!"},
{"role": "assistant", "content": "Hi!"},
{"role": "user", "content": "What can you do?"},
{"role": "assistant", "content": "I can chat with you."},
]
template = get_chat_template("llama-2-chat")
print(template.get_prompt(messages))