""" Copyright 2023-2024 SGLang Team Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. """ """Conversation chat templates.""" # Adapted from # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py import dataclasses from enum import IntEnum, auto from typing import Dict, List, Optional, Tuple, Union from sglang.srt.openai_api.protocol import ChatCompletionRequest class SeparatorStyle(IntEnum): """Separator styles.""" ADD_COLON_SINGLE = auto() ADD_COLON_TWO = auto() ADD_COLON_SPACE_SINGLE = auto() NO_COLON_SINGLE = auto() NO_COLON_TWO = auto() ADD_NEW_LINE_SINGLE = auto() LLAMA2 = auto() LLAMA3 = auto() CHATGLM = auto() CHATML = auto() CHATINTERN = auto() DOLLY = auto() RWKV = auto() PHOENIX = auto() ROBIN = auto() FALCON_CHAT = auto() CHATGLM3 = auto() DEEPSEEK_CHAT = auto() METAMATH = auto() @dataclasses.dataclass class Conversation: """A class that manages prompt templates and keeps all conversation history.""" # The name of this template name: str # The template of the system prompt system_template: str = "{system_message}" # The system message system_message: str = "" # The names of two roles roles: Tuple[str] = ("USER", "ASSISTANT") # All messages. Each item is (role, message). messages: List[List[str]] = () # The number of few shot examples offset: int = 0 # The separator style and configurations sep_style: SeparatorStyle = SeparatorStyle.ADD_COLON_SINGLE sep: str = "\n" sep2: str = None # Stop criteria (the default one is EOS token) stop_str: Union[str, List[str]] = None image_data: Optional[List[str]] = None modalities: Optional[List[str]] = None def get_prompt(self) -> str: """Get the prompt for generation.""" system_prompt = self.system_template.format(system_message=self.system_message) if self.sep_style == SeparatorStyle.ADD_COLON_SINGLE: ret = system_prompt + self.sep for role, message in self.messages: if message: ret += role + ": " + message + self.sep else: ret += role + ":" return ret elif self.sep_style == SeparatorStyle.ADD_COLON_TWO: seps = [self.sep, self.sep2] ret = system_prompt + seps[0] for i, (role, message) in enumerate(self.messages): if message: ret += role + ": " + message + seps[i % 2] else: ret += role + ":" return ret elif self.sep_style == SeparatorStyle.ADD_COLON_SPACE_SINGLE: ret = system_prompt + self.sep for role, message in self.messages: if message: ret += role + ": " + message + self.sep else: ret += role + ": " # must be end with a space return ret elif self.sep_style == SeparatorStyle.ADD_NEW_LINE_SINGLE: ret = "" if system_prompt == "" else system_prompt + self.sep for role, message in self.messages: if message: ret += role + "\n" + message + self.sep else: ret += role + "\n" return ret elif self.sep_style == SeparatorStyle.NO_COLON_SINGLE: ret = system_prompt for role, message in self.messages: if message: ret += role + message + self.sep else: ret += role return ret elif self.sep_style == SeparatorStyle.NO_COLON_TWO: seps = [self.sep, self.sep2] ret = system_prompt for i, (role, message) in enumerate(self.messages): if message: ret += role + message + seps[i % 2] else: ret += role return ret elif self.sep_style == SeparatorStyle.RWKV: ret = system_prompt for i, (role, message) in enumerate(self.messages): if message: ret += ( role + ": " + message.replace("\r\n", "\n").replace("\n\n", "\n") ) ret += "\n\n" else: ret += role + ":" return ret elif self.sep_style == SeparatorStyle.LLAMA3: ret = "<|begin_of_text|>" if self.system_message: ret += system_prompt else: ret += "" for i, (role, message) in enumerate(self.messages): if message: ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" ret += f"{message.strip()}<|eot_id|>" else: ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n" # print(ret) return ret elif self.sep_style == SeparatorStyle.LLAMA2: seps = [self.sep, self.sep2] if self.system_message: ret = system_prompt else: ret = "[INST] " for i, (role, message) in enumerate(self.messages): tag = self.roles[i % 2] if message: if i == 0: ret += message + " " else: ret += tag + " " + message + seps[i % 2] else: ret += tag return ret elif self.sep_style == SeparatorStyle.CHATGLM: # source: https://huggingface.co/THUDM/chatglm-6b/blob/1d240ba371910e9282298d4592532d7f0f3e9f3e/modeling_chatglm.py#L1302-L1308 # source2: https://huggingface.co/THUDM/chatglm2-6b/blob/e186c891cf64310ac66ef10a87e6635fa6c2a579/modeling_chatglm.py#L926 round_add_n = 1 if self.name == "chatglm2" else 0 if system_prompt: ret = system_prompt + self.sep else: ret = "" for i, (role, message) in enumerate(self.messages): if i % 2 == 0: ret += f"[Round {i//2 + round_add_n}]{self.sep}" if message: ret += f"{role}:{message}{self.sep}" else: ret += f"{role}:" return ret elif self.sep_style == SeparatorStyle.CHATML: ret = "" if system_prompt == "" else system_prompt + self.sep + "\n" for role, message in self.messages: if message: ret += role + "\n" + message + self.sep + "\n" else: ret += role + "\n" return ret elif self.sep_style == SeparatorStyle.CHATGLM3: ret = "" if self.system_message: ret += system_prompt for role, message in self.messages: if message: ret += role + "\n" + message else: ret += role return ret elif self.sep_style == SeparatorStyle.CHATINTERN: # source: https://huggingface.co/internlm/internlm-chat-7b-8k/blob/bd546fa984b4b0b86958f56bf37f94aa75ab8831/modeling_internlm.py#L771 seps = [self.sep, self.sep2] ret = system_prompt for i, (role, message) in enumerate(self.messages): if i % 2 == 0: ret += "" if message: ret += role + ":" + message + seps[i % 2] + "\n" else: ret += role + ":" return ret elif self.sep_style == SeparatorStyle.DOLLY: seps = [self.sep, self.sep2] ret = system_prompt for i, (role, message) in enumerate(self.messages): if message: ret += role + ":\n" + message + seps[i % 2] if i % 2 == 1: ret += "\n\n" else: ret += role + ":\n" return ret elif self.sep_style == SeparatorStyle.PHOENIX: ret = system_prompt for role, message in self.messages: if message: ret += role + ": " + "" + message + "" else: ret += role + ": " + "" return ret elif self.sep_style == SeparatorStyle.ROBIN: ret = system_prompt + self.sep for role, message in self.messages: if message: ret += role + ":\n" + message + self.sep else: ret += role + ":\n" return ret elif self.sep_style == SeparatorStyle.FALCON_CHAT: ret = "" if self.system_message: ret += system_prompt + self.sep for role, message in self.messages: if message: ret += role + ": " + message + self.sep else: ret += role + ":" return ret elif self.sep_style == SeparatorStyle.METAMATH: ret = "" if system_prompt == "" else system_prompt + self.sep for i, (role, message) in enumerate(self.messages): # For MetaMath, sep2 is used to prefix the message. starting_sep = ":\n" if i % 2 == 0 else ": " + self.sep2 ending_sep = self.sep if i % 2 == 0 else "" if message: ret += role + starting_sep + message + ending_sep else: ret += role + starting_sep return ret elif self.sep_style == SeparatorStyle.DEEPSEEK_CHAT: seps = [self.sep, self.sep2] ret = system_prompt for i, (role, message) in enumerate(self.messages): if message: ret += role + ": " + message + seps[i % 2] else: ret += role + ":" return ret else: raise ValueError(f"Invalid style: {self.sep_style}") def set_system_message(self, system_message: str): """Set the system message.""" self.system_message = system_message def append_message(self, role: str, message: str): """Append a new message.""" self.messages.append([role, message]) def append_image(self, image: str): """Append a new message.""" self.image_data.append(image) def update_last_message(self, message: str): """Update the last output. The last message is typically set to be None when constructing the prompt, so we need to update it in-place after getting the response from a model. """ self.messages[-1][1] = message def to_gradio_chatbot(self): """Convert the conversation to gradio chatbot format.""" ret = [] for i, (role, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: ret.append([msg, None]) else: ret[-1][-1] = msg return ret def to_openai_api_messages(self): """Convert the conversation to OpenAI chat completion format.""" if self.system_message == "": ret = [] else: ret = [{"role": "system", "content": self.system_message}] for i, (_, msg) in enumerate(self.messages[self.offset :]): if i % 2 == 0: ret.append({"role": "user", "content": msg}) else: if msg is not None: ret.append({"role": "assistant", "content": msg}) return ret def copy(self): return Conversation( name=self.name, system_template=self.system_template, system_message=self.system_message, roles=self.roles, messages=[[x, y] for x, y in self.messages], offset=self.offset, sep_style=self.sep_style, sep=self.sep, sep2=self.sep2, stop_str=self.stop_str, ) def dict(self): return { "template_name": self.name, "system_message": self.system_message, "roles": self.roles, "messages": self.messages, "offset": self.offset, } # A global registry for all conversation templates chat_templates: Dict[str, Conversation] = {} def register_conv_template(template: Conversation, override: bool = False): """Register a new conversation template.""" if not override: assert ( template.name not in chat_templates ), f"{template.name} has been registered." chat_templates[template.name] = template def chat_template_exists(template_name: str) -> bool: return template_name in chat_templates def generate_chat_conv( request: ChatCompletionRequest, template_name: str ) -> Conversation: conv = chat_templates[template_name].copy() conv = Conversation( name=conv.name, system_template=conv.system_template, system_message=conv.system_message, roles=conv.roles, messages=list(conv.messages), # prevent in-place modification offset=conv.offset, sep_style=SeparatorStyle(conv.sep_style), sep=conv.sep, sep2=conv.sep2, stop_str=conv.stop_str, image_data=[], modalities=[], ) if isinstance(request.messages, str): raise ValueError("The messages should be a list of dict.") for message in request.messages: msg_role = message.role if msg_role == "system": if isinstance(message.content, str): conv.system_message = message.content elif isinstance(message.content, list): if ( len(message.content) != 1 or getattr(message.content[0], "type", None) != "text" ): raise ValueError("The system message should be a single text.") else: conv.system_message = getattr(message.content[0], "text", "") elif msg_role == "user": # Handle the various types of Chat Request content types here. role = conv.roles[0] if isinstance(message.content, str): conv.append_message(conv.roles[0], message.content) else: real_content = "" # calculate number of image_url num_image_url = 0 for content in message.content: if content.type == "image_url": num_image_url += 1 conv.modalities.append(content.modalities) if num_image_url > 1: image_token = "" else: image_token = "\n" for content in message.content: if content.type == "text": if num_image_url > 16: real_content += "\n" # for video real_content += content.text elif content.type == "image_url": # NOTE: Only works for llava real_content += image_token conv.append_image(content.image_url.url) conv.append_message(conv.roles[0], real_content) elif msg_role == "assistant": parsed_content = "" if isinstance(message.content, str): parsed_content = message.content elif isinstance(message.content, list): if ( len(message.content) != 1 or getattr(message.content[0], "type", None) != "text" ): raise ValueError( "The assistant's response should be a single text." ) else: parsed_content = getattr(message.content[0], "text", "") conv.append_message(conv.roles[1], parsed_content) else: raise ValueError(f"Unknown role: {msg_role}") # Add a blank message for the assistant. conv.append_message(conv.roles[1], None) return conv # llama2 template # reference: https://huggingface.co/blog/codellama#conversational-instructions # reference: https://github.com/facebookresearch/llama/blob/1a240688810f8036049e8da36b073f63d2ac552c/llama/generation.py#L212 register_conv_template( Conversation( name="llama-2", system_template="[INST] <>\n{system_message}\n<>\n\n", roles=("[INST]", "[/INST]"), sep_style=SeparatorStyle.LLAMA2, sep=" ", sep2=" ", stop_str=["[INST]", "[/INST]", "<>", "<>"], ) ) register_conv_template( Conversation( name="chatml", system_template="<|im_start|>system\n{system_message}", system_message="You are a helpful assistant.", roles=("<|im_start|>user", "<|im_start|>assistant"), sep_style=SeparatorStyle.CHATML, sep="<|im_end|>", stop_str=["<|endoftext|>", "<|im_end|>"], ) ) register_conv_template( Conversation( name="chatml-llava", system_template="<|im_start|>system\n{system_message}", system_message="You are a helpful assistant.", roles=("<|im_start|>user", "<|im_start|>assistant"), sep_style=SeparatorStyle.CHATML, sep="<|im_end|>", stop_str=["<|endoftext|>", "<|im_end|>"], ) ) register_conv_template( Conversation( name="vicuna_v1.1", system_message="A chat between a curious user and an artificial intelligence assistant. " "The assistant gives helpful, detailed, and polite answers to the user's questions.", roles=("USER", "ASSISTANT"), sep_style=SeparatorStyle.ADD_COLON_TWO, sep=" ", sep2="", ) ) register_conv_template( Conversation( name="llava_llama_3", system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.", system_template="<|start_header_id|>system<|end_header_id|>\n\n{system_message}<|eot_id|>", roles=("user", "assistant"), sep_style=SeparatorStyle.LLAMA3, sep="", stop_str=["<|end_of_text|>", "<|eot_id|>"], ) ) # Reference: https://github.com/InternLM/lmdeploy/blob/387bf54b4f124e72aab30ae9755f562e435d3d01/lmdeploy/model.py#L425-L442 register_conv_template( Conversation( name="internlm2-chat", system_template="<|im_start|>system\n{system_message}", roles=("<|im_start|>user", "<|im_start|>assistant"), sep="\n", stop_str=["<|im_end|>", "<|action_end|>"], ) )