[Feature] Prefill assistant response - add continue_final_message parameter (#4226)

Co-authored-by: Chayenne <zhaochen20@outlook.com>
This commit is contained in:
Adarsh Shirawalmath
2025-04-21 06:07:18 +05:30
committed by GitHub
parent 5156d5a413
commit 8b39274e34
6 changed files with 82 additions and 23 deletions

View File

@@ -950,9 +950,16 @@ def v1_chat_generate_request(
openai_compatible_messages.append(
{"role": message.role, "content": content["text"]}
)
if openai_compatible_messages[-1]["role"] == "assistant":
assistant_prefix = openai_compatible_messages[-1]["content"]
openai_compatible_messages = openai_compatible_messages[:-1]
if (
openai_compatible_messages
and openai_compatible_messages[-1]["role"] == "assistant"
):
if request.continue_final_message:
# Remove the final assistant message so its content can be continued.
assistant_prefix = openai_compatible_messages[-1]["content"]
openai_compatible_messages = openai_compatible_messages[:-1]
else:
assistant_prefix = None
else:
assistant_prefix = None
@@ -991,7 +998,33 @@ def v1_chat_generate_request(
modalities = []
else:
conv = generate_chat_conv(request, chat_template_name)
prompt = conv.get_prompt()
# If we should continue the final assistant message, adjust the conversation.
if (
request.continue_final_message
and request.messages
and request.messages[-1].role == "assistant"
):
# Remove the auto-added blank assistant turn, if present.
if conv.messages and conv.messages[-1][1] is None:
conv.messages.pop()
# Rebuild the prompt from the conversation.
prompt = conv.get_prompt()
# Strip any trailing stop tokens or separators that indicate end-of-assistant.
if isinstance(conv.stop_str, list):
for stop_token in conv.stop_str:
if prompt.endswith(stop_token):
prompt = prompt[: -len(stop_token)]
elif isinstance(conv.stop_str, str) and prompt.endswith(
conv.stop_str
):
prompt = prompt[: -len(conv.stop_str)]
if conv.sep and prompt.endswith(conv.sep):
prompt = prompt[: -len(conv.sep)]
if getattr(conv, "sep2", None) and prompt.endswith(conv.sep2):
prompt = prompt[: -len(conv.sep2)]
else:
prompt = conv.get_prompt()
image_data = conv.image_data
audio_data = conv.audio_data
modalities = conv.modalities
@@ -1003,6 +1036,7 @@ def v1_chat_generate_request(
else:
stop.extend(request.stop)
prompt_ids = tokenizer_manager.tokenizer.encode(prompt)
else:
# Use the raw prompt and stop strings if the messages is already a string.
prompt_ids = request.messages

View File

@@ -355,6 +355,7 @@ class ChatCompletionRequest(BaseModel):
stop_token_ids: Optional[List[int]] = None
no_stop_trim: bool = False
ignore_eos: bool = False
continue_final_message: bool = False
skip_special_tokens: bool = True
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
session_params: Optional[Dict] = None