Refactor ChatTemplate for Enhanced Clarity and Efficiency (#201)
This commit is contained in:
@@ -12,42 +12,35 @@ class ChatTemplateStyle(Enum):
|
||||
class ChatTemplate:
|
||||
name: str
|
||||
default_system_prompt: str
|
||||
role_prefix_and_suffix: Dict[str, Tuple[str]]
|
||||
role_prefix_and_suffix: Dict[str, Tuple[str, str]]
|
||||
stop_str: List[str] = ()
|
||||
image_token: str = "<image>"
|
||||
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
|
||||
|
||||
def get_prefix_and_suffix(self, role, hist_messages):
|
||||
if self.style == ChatTemplateStyle.PLAIN:
|
||||
return self.role_prefix_and_suffix[role]
|
||||
elif self.style == ChatTemplateStyle.LLAMA2:
|
||||
if len(hist_messages) == 0 and role == "system":
|
||||
return (
|
||||
self.role_prefix_and_suffix["user"][0]
|
||||
+ self.role_prefix_and_suffix["system"][0],
|
||||
self.role_prefix_and_suffix["system"][1],
|
||||
)
|
||||
elif (
|
||||
len(hist_messages) == 1
|
||||
and role == "user"
|
||||
and hist_messages[0]["content"] is not None
|
||||
):
|
||||
return ("", self.role_prefix_and_suffix["user"][1])
|
||||
return self.role_prefix_and_suffix[role]
|
||||
else:
|
||||
raise ValueError(f"Invalid style: {self.style}")
|
||||
def get_prefix_and_suffix(self, role: str, hist_messages: List[Dict]) -> Tuple[str, str]:
|
||||
prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
|
||||
|
||||
if self.style == ChatTemplateStyle.LLAMA2:
|
||||
if role == "system" and not hist_messages:
|
||||
user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
|
||||
system_prefix, system_suffix = self.role_prefix_and_suffix.get("system", ("", ""))
|
||||
return (user_prefix + system_prefix, system_suffix)
|
||||
elif role == "user" and len(hist_messages) == 1 and hist_messages[0]["content"] is not None:
|
||||
return ("", suffix)
|
||||
|
||||
def get_prompt(self, messages):
|
||||
return prefix, suffix
|
||||
|
||||
def get_prompt(self, messages: List[Dict]) -> str:
|
||||
prompt = ""
|
||||
for i in range(len(messages)):
|
||||
role, content = messages[i]["role"], messages[i]["content"]
|
||||
for i, message in enumerate(messages):
|
||||
role, content = message["role"], message["content"]
|
||||
if role == "system" and content is None:
|
||||
content = self.default_system_prompt
|
||||
if content is None:
|
||||
continue
|
||||
|
||||
prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
|
||||
prompt += prefix + content + suffix
|
||||
prompt += f"{prefix}{content}{suffix}"
|
||||
return prompt
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user