Add image_token in conversation.py (#1632)
Co-authored-by: yizhang2077 <1109276519@qq.com>
This commit is contained in:
@@ -70,6 +70,9 @@ class Conversation:
|
|||||||
sep2: str = None
|
sep2: str = None
|
||||||
# Stop criteria (the default one is EOS token)
|
# Stop criteria (the default one is EOS token)
|
||||||
stop_str: Union[str, List[str]] = None
|
stop_str: Union[str, List[str]] = None
|
||||||
|
# The string that represents an image token in the prompt
|
||||||
|
image_token: str = "<image>"
|
||||||
|
|
||||||
image_data: Optional[List[str]] = None
|
image_data: Optional[List[str]] = None
|
||||||
modalities: Optional[List[str]] = None
|
modalities: Optional[List[str]] = None
|
||||||
|
|
||||||
@@ -334,6 +337,7 @@ class Conversation:
|
|||||||
sep=self.sep,
|
sep=self.sep,
|
||||||
sep2=self.sep2,
|
sep2=self.sep2,
|
||||||
stop_str=self.stop_str,
|
stop_str=self.stop_str,
|
||||||
|
image_token=self.image_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
def dict(self):
|
def dict(self):
|
||||||
@@ -381,6 +385,7 @@ def generate_chat_conv(
|
|||||||
stop_str=conv.stop_str,
|
stop_str=conv.stop_str,
|
||||||
image_data=[],
|
image_data=[],
|
||||||
modalities=[],
|
modalities=[],
|
||||||
|
image_token=conv.image_token,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(request.messages, str):
|
if isinstance(request.messages, str):
|
||||||
@@ -412,9 +417,13 @@ def generate_chat_conv(
|
|||||||
num_image_url += 1
|
num_image_url += 1
|
||||||
conv.modalities.append(content.modalities)
|
conv.modalities.append(content.modalities)
|
||||||
if num_image_url > 1:
|
if num_image_url > 1:
|
||||||
image_token = "<image>"
|
image_token = conv.image_token
|
||||||
else:
|
else:
|
||||||
image_token = "<image>\n"
|
image_token = (
|
||||||
|
conv.image_token + "\n"
|
||||||
|
if conv.name != "qwen2-vl"
|
||||||
|
else conv.image_token
|
||||||
|
)
|
||||||
for content in message.content:
|
for content in message.content:
|
||||||
if content.type == "text":
|
if content.type == "text":
|
||||||
if num_image_url > 16:
|
if num_image_url > 16:
|
||||||
|
|||||||
@@ -117,7 +117,9 @@ def create_streaming_error_response(
|
|||||||
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
|
def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg):
|
||||||
global chat_template_name
|
global chat_template_name
|
||||||
|
|
||||||
logger.info(f"Use chat template: {chat_template_arg}")
|
logger.info(
|
||||||
|
f"Use chat template for the OpenAI-compatible API server: {chat_template_arg}"
|
||||||
|
)
|
||||||
if not chat_template_exists(chat_template_arg):
|
if not chat_template_exists(chat_template_arg):
|
||||||
if not os.path.exists(chat_template_arg):
|
if not os.path.exists(chat_template_arg):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
|
|||||||
Reference in New Issue
Block a user