Gemma Support (#256)
This commit is contained in:
@@ -17,15 +17,23 @@ class ChatTemplate:
|
||||
image_token: str = "<image>"
|
||||
style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
|
||||
|
||||
def get_prefix_and_suffix(self, role: str, hist_messages: List[Dict]) -> Tuple[str, str]:
|
||||
def get_prefix_and_suffix(
|
||||
self, role: str, hist_messages: List[Dict]
|
||||
) -> Tuple[str, str]:
|
||||
prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
|
||||
|
||||
|
||||
if self.style == ChatTemplateStyle.LLAMA2:
|
||||
if role == "system" and not hist_messages:
|
||||
user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
|
||||
system_prefix, system_suffix = self.role_prefix_and_suffix.get("system", ("", ""))
|
||||
system_prefix, system_suffix = self.role_prefix_and_suffix.get(
|
||||
"system", ("", "")
|
||||
)
|
||||
return (user_prefix + system_prefix, system_suffix)
|
||||
elif role == "user" and len(hist_messages) == 1 and hist_messages[0]["content"] is not None:
|
||||
elif (
|
||||
role == "user"
|
||||
and len(hist_messages) == 1
|
||||
and hist_messages[0]["content"] is not None
|
||||
):
|
||||
return ("", suffix)
|
||||
|
||||
return prefix, suffix
|
||||
@@ -171,6 +179,19 @@ register_chat_template(
|
||||
)
|
||||
)
|
||||
|
||||
register_chat_template(
|
||||
ChatTemplate(
|
||||
name="gemma-it",
|
||||
default_system_prompt=None,
|
||||
role_prefix_and_suffix={
|
||||
"system": ("", ""),
|
||||
"user": ("<start_of_turn>user\n", "<end_of_turn>\n"),
|
||||
"assistant": ("<start_of_turn>model\n", "<end_of_turn>\n"),
|
||||
},
|
||||
style=ChatTemplateStyle.PLAIN,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_vicuna(model_path: str):
|
||||
@@ -211,6 +232,13 @@ def match_chat_yi(model_path: str):
|
||||
return get_chat_template("yi")
|
||||
|
||||
|
||||
@register_chat_template_matching_function
|
||||
def match_gemma_it(model_path: str):
|
||||
model_path = model_path.lower()
|
||||
if "gemma" in model_path and "it" in model_path:
|
||||
return get_chat_template("gemma-it")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
messages = [
|
||||
{"role": "system", "content": None}, # None means default
|
||||
|
||||
Reference in New Issue
Block a user