Refactor ChatTemplate for Enhanced Clarity and Efficiency (#201)
This commit is contained in:
@@ -12,42 +12,35 @@ class ChatTemplateStyle(Enum):
|
|||||||
class ChatTemplate:
|
class ChatTemplate:
|
||||||
name: str
|
name: str
|
||||||
default_system_prompt: str
|
default_system_prompt: str
|
||||||
role_prefix_and_suffix: Dict[str, Tuple[str]]
|
role_prefix_and_suffix: Dict[str, Tuple[str, str]]
|
||||||
stop_str: List[str] = ()
|
stop_str: List[str] = ()
|
||||||
image_token: str = "<image>"
|
image_token: str = "<image>"
|
||||||
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
|
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
|
||||||
|
|
||||||
def get_prefix_and_suffix(self, role, hist_messages):
|
def get_prefix_and_suffix(self, role: str, hist_messages: List[Dict]) -> Tuple[str, str]:
|
||||||
if self.style == ChatTemplateStyle.PLAIN:
|
prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
|
||||||
return self.role_prefix_and_suffix[role]
|
|
||||||
elif self.style == ChatTemplateStyle.LLAMA2:
|
|
||||||
if len(hist_messages) == 0 and role == "system":
|
|
||||||
return (
|
|
||||||
self.role_prefix_and_suffix["user"][0]
|
|
||||||
+ self.role_prefix_and_suffix["system"][0],
|
|
||||||
self.role_prefix_and_suffix["system"][1],
|
|
||||||
)
|
|
||||||
elif (
|
|
||||||
len(hist_messages) == 1
|
|
||||||
and role == "user"
|
|
||||||
and hist_messages[0]["content"] is not None
|
|
||||||
):
|
|
||||||
return ("", self.role_prefix_and_suffix["user"][1])
|
|
||||||
return self.role_prefix_and_suffix[role]
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Invalid style: {self.style}")
|
|
||||||
|
|
||||||
def get_prompt(self, messages):
|
if self.style == ChatTemplateStyle.LLAMA2:
|
||||||
|
if role == "system" and not hist_messages:
|
||||||
|
user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
|
||||||
|
system_prefix, system_suffix = self.role_prefix_and_suffix.get("system", ("", ""))
|
||||||
|
return (user_prefix + system_prefix, system_suffix)
|
||||||
|
elif role == "user" and len(hist_messages) == 1 and hist_messages[0]["content"] is not None:
|
||||||
|
return ("", suffix)
|
||||||
|
|
||||||
|
return prefix, suffix
|
||||||
|
|
||||||
|
def get_prompt(self, messages: List[Dict]) -> str:
|
||||||
prompt = ""
|
prompt = ""
|
||||||
for i in range(len(messages)):
|
for i, message in enumerate(messages):
|
||||||
role, content = messages[i]["role"], messages[i]["content"]
|
role, content = message["role"], message["content"]
|
||||||
if role == "system" and content is None:
|
if role == "system" and content is None:
|
||||||
content = self.default_system_prompt
|
content = self.default_system_prompt
|
||||||
if content is None:
|
if content is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
|
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
|
||||||
prompt += prefix + content + suffix
|
prompt += f"{prefix}{content}{suffix}"
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user