Gemma Support (#256)

This commit is contained in:
Liangsheng Yin
2024-03-11 12:14:27 +08:00
committed by GitHub
parent 64fe311593
commit 89885b31ef
10 changed files with 428 additions and 55 deletions

View File

@@ -17,15 +17,23 @@ class ChatTemplate:
image_token: str = "<image>"
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
def get_prefix_and_suffix(self, role: str, hist_messages: List[Dict]) -> Tuple[str, str]:
def get_prefix_and_suffix(
self, role: str, hist_messages: List[Dict]
) -> Tuple[str, str]:
prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
if self.style == ChatTemplateStyle.LLAMA2:
if role == "system" and not hist_messages:
user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
system_prefix, system_suffix = self.role_prefix_and_suffix.get("system", ("", ""))
system_prefix, system_suffix = self.role_prefix_and_suffix.get(
"system", ("", "")
)
return (user_prefix + system_prefix, system_suffix)
elif role == "user" and len(hist_messages) == 1 and hist_messages[0]["content"] is not None:
elif (
role == "user"
and len(hist_messages) == 1
and hist_messages[0]["content"] is not None
):
return ("", suffix)
return prefix, suffix
@@ -171,6 +179,19 @@ register_chat_template(
)
)
register_chat_template(
ChatTemplate(
name="gemma-it",
default_system_prompt=None,
role_prefix_and_suffix={
"system": ("", ""),
"user": ("<start_of_turn>user\n", "<end_of_turn>\n"),
"assistant": ("<start_of_turn>model\n", "<end_of_turn>\n"),
},
style=ChatTemplateStyle.PLAIN,
)
)
@register_chat_template_matching_function
def match_vicuna(model_path: str):
@@ -211,6 +232,13 @@ def match_chat_yi(model_path: str):
return get_chat_template("yi")
@register_chat_template_matching_function
def match_gemma_it(model_path: str):
model_path = model_path.lower()
if "gemma" in model_path and "it" in model_path:
return get_chat_template("gemma-it")
if __name__ == "__main__":
messages = [
{"role": "system", "content": None}, # None means default