diff --git a/python/sglang/lang/chat_template.py b/python/sglang/lang/chat_template.py index 612033543..6f7375299 100644 --- a/python/sglang/lang/chat_template.py +++ b/python/sglang/lang/chat_template.py @@ -12,42 +12,35 @@ class ChatTemplateStyle(Enum): class ChatTemplate: name: str default_system_prompt: str - role_prefix_and_suffix: Dict[str, Tuple[str]] + role_prefix_and_suffix: Dict[str, Tuple[str, str]] stop_str: List[str] = () image_token: str = "" style: ChatTemplateStyle = ChatTemplateStyle.PLAIN - def get_prefix_and_suffix(self, role, hist_messages): - if self.style == ChatTemplateStyle.PLAIN: - return self.role_prefix_and_suffix[role] - elif self.style == ChatTemplateStyle.LLAMA2: - if len(hist_messages) == 0 and role == "system": - return ( - self.role_prefix_and_suffix["user"][0] - + self.role_prefix_and_suffix["system"][0], - self.role_prefix_and_suffix["system"][1], - ) - elif ( - len(hist_messages) == 1 - and role == "user" - and hist_messages[0]["content"] is not None - ): - return ("", self.role_prefix_and_suffix["user"][1]) - return self.role_prefix_and_suffix[role] - else: - raise ValueError(f"Invalid style: {self.style}") + def get_prefix_and_suffix(self, role: str, hist_messages: List[Dict]) -> Tuple[str, str]: + prefix, suffix = self.role_prefix_and_suffix.get(role, ("", "")) + + if self.style == ChatTemplateStyle.LLAMA2: + if role == "system" and not hist_messages: + user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", "")) + system_prefix, system_suffix = self.role_prefix_and_suffix.get("system", ("", "")) + return (user_prefix + system_prefix, system_suffix) + elif role == "user" and len(hist_messages) == 1 and hist_messages[0]["content"] is not None: + return ("", suffix) - def get_prompt(self, messages): + return prefix, suffix + + def get_prompt(self, messages: List[Dict]) -> str: prompt = "" - for i in range(len(messages)): - role, content = messages[i]["role"], messages[i]["content"] + for i, message in enumerate(messages): + role, content = message["role"], message["content"] if role == "system" and content is None: content = self.default_system_prompt if content is None: continue prefix, suffix = self.get_prefix_and_suffix(role, messages[:i]) - prompt += prefix + content + suffix + prompt += f"{prefix}{content}{suffix}" return prompt