[Feature] Prefill assistant response - add continue_final_message parameter (#4226)
Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
committed by
GitHub
parent
5156d5a413
commit
8b39274e34
@@ -950,9 +950,16 @@ def v1_chat_generate_request(
|
||||
openai_compatible_messages.append(
|
||||
{"role": message.role, "content": content["text"]}
|
||||
)
|
||||
if openai_compatible_messages[-1]["role"] == "assistant":
|
||||
assistant_prefix = openai_compatible_messages[-1]["content"]
|
||||
openai_compatible_messages = openai_compatible_messages[:-1]
|
||||
if (
|
||||
openai_compatible_messages
|
||||
and openai_compatible_messages[-1]["role"] == "assistant"
|
||||
):
|
||||
if request.continue_final_message:
|
||||
# Remove the final assistant message so its content can be continued.
|
||||
assistant_prefix = openai_compatible_messages[-1]["content"]
|
||||
openai_compatible_messages = openai_compatible_messages[:-1]
|
||||
else:
|
||||
assistant_prefix = None
|
||||
else:
|
||||
assistant_prefix = None
|
||||
|
||||
@@ -991,7 +998,33 @@ def v1_chat_generate_request(
|
||||
modalities = []
|
||||
else:
|
||||
conv = generate_chat_conv(request, chat_template_name)
|
||||
prompt = conv.get_prompt()
|
||||
# If we should continue the final assistant message, adjust the conversation.
|
||||
if (
|
||||
request.continue_final_message
|
||||
and request.messages
|
||||
and request.messages[-1].role == "assistant"
|
||||
):
|
||||
# Remove the auto-added blank assistant turn, if present.
|
||||
if conv.messages and conv.messages[-1][1] is None:
|
||||
conv.messages.pop()
|
||||
# Rebuild the prompt from the conversation.
|
||||
prompt = conv.get_prompt()
|
||||
# Strip any trailing stop tokens or separators that indicate end-of-assistant.
|
||||
if isinstance(conv.stop_str, list):
|
||||
for stop_token in conv.stop_str:
|
||||
if prompt.endswith(stop_token):
|
||||
prompt = prompt[: -len(stop_token)]
|
||||
elif isinstance(conv.stop_str, str) and prompt.endswith(
|
||||
conv.stop_str
|
||||
):
|
||||
prompt = prompt[: -len(conv.stop_str)]
|
||||
if conv.sep and prompt.endswith(conv.sep):
|
||||
prompt = prompt[: -len(conv.sep)]
|
||||
if getattr(conv, "sep2", None) and prompt.endswith(conv.sep2):
|
||||
prompt = prompt[: -len(conv.sep2)]
|
||||
else:
|
||||
prompt = conv.get_prompt()
|
||||
|
||||
image_data = conv.image_data
|
||||
audio_data = conv.audio_data
|
||||
modalities = conv.modalities
|
||||
@@ -1003,6 +1036,7 @@ def v1_chat_generate_request(
|
||||
else:
|
||||
stop.extend(request.stop)
|
||||
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
|
||||
|
||||
else:
|
||||
# Use the raw prompt and stop strings if the messages is already a string.
|
||||
prompt_ids = request.messages
|
||||
|
||||
@@ -355,6 +355,7 @@ class ChatCompletionRequest(BaseModel):
|
||||
stop_token_ids: Optional[List[int]] = None
|
||||
no_stop_trim: bool = False
|
||||
ignore_eos: bool = False
|
||||
continue_final_message: bool = False
|
||||
skip_special_tokens: bool = True
|
||||
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
||||
session_params: Optional[Dict] = None
|
||||
|
||||
Reference in New Issue
Block a user