Fix the chat template for llava-v1.6-34b & format code (#177)

This commit is contained in:
Lianmin Zheng
2024-02-11 05:50:13 -08:00
committed by GitHub
parent 50afed4eaa
commit c51020cf0c
23 changed files with 101 additions and 44 deletions

View File

@@ -116,6 +116,21 @@ register_chat_template(
)
register_chat_template(
ChatTemplate(
name="chatml-llava",
default_system_prompt="Answer the questions.",
role_prefix_and_suffix={
"system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
"user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
"assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
},
style=ChatTemplateStyle.PLAIN,
stop_str=("<|im_end|>",),
image_token=" <image>\n",
)
)
register_chat_template(
ChatTemplate(
name="vicuna_v1.1",
@@ -168,7 +183,7 @@ register_chat_template(
def match_vicuna(model_path: str):
if "vicuna" in model_path.lower():
return get_chat_template("vicuna_v1.1")
if "llava" in model_path.lower():
if "llava-v1.5" in model_path.lower():
return get_chat_template("vicuna_v1.1")
@@ -192,6 +207,8 @@ def match_chat_ml(model_path: str):
return get_chat_template("chatml")
if "qwen" in model_path and "chat" in model_path:
return get_chat_template("chatml")
if "llava-v1.6-34b" in model_path:
return get_chat_template("chatml-llava")
@register_chat_template_matching_function

View File

@@ -74,9 +74,9 @@ class SglSamplingParams:
)
return {
"max_tokens_to_sample": self.max_new_tokens,
"stop_sequences": self.stop
if isinstance(self.stop, (list, tuple))
else [self.stop],
"stop_sequences": (
self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
),
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,

View File

@@ -1,4 +1,5 @@
"""Tracing a program."""
import uuid
from typing import Any, Callable, Dict, List, Optional, Union