Feature/function calling update (#2700)

Co-authored-by: Mingyuan Ma <mamingyuan2001@berkeley.edu>
Co-authored-by: Chayenne <zhaochen20@outlook.com>
Co-authored-by: shuaills <shishuaiuoe@gmail.com>
This commit is contained in:
YAMY
2025-01-26 09:57:51 -08:00
committed by GitHub
parent f265d15b96
commit b045841bae
10 changed files with 1377 additions and 219 deletions

View File

@@ -20,7 +20,7 @@ import os
import time
import uuid
from http import HTTPStatus
from typing import Dict, List
from typing import Dict, List, Optional
from fastapi import HTTPException, Request, UploadFile
from fastapi.responses import ORJSONResponse, StreamingResponse
@@ -40,6 +40,7 @@ from sglang.srt.conversation import (
generate_chat_conv,
register_conv_template,
)
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
from sglang.srt.openai_api.protocol import (
BatchRequest,
@@ -71,7 +72,6 @@ from sglang.srt.openai_api.protocol import (
TopLogprob,
UsageInfo,
)
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
@@ -309,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
ret,
to_file=True,
cache_report=tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
)
else:
responses = v1_generate_response(
@@ -877,9 +878,6 @@ def v1_chat_generate_request(
tools = None
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
if request.stream:
logger.warning("Streaming is not supported with tools.")
request.stream = False
if not isinstance(request.tool_choice, str):
tools = [
item.function.model_dump()
@@ -908,12 +906,26 @@ def v1_chat_generate_request(
openai_compatible_messages = openai_compatible_messages[:-1]
else:
assistant_prefix = None
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
tools=tools,
)
try:
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
tools=tools,
)
except:
# This except branch will be triggered when the chosen model
# has a different tools input format that is not compatiable
# with openAI's apply_chat_template tool_call format, like Mistral.
tools = [t if "function" in t else {"function": t} for t in tools]
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
tools=tools,
)
if assistant_prefix:
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
stop = request.stop
@@ -1005,7 +1017,9 @@ def v1_chat_generate_request(
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
def v1_chat_generate_response(
request, ret, to_file=False, cache_report=False, tool_call_parser=None
):
choices = []
for idx, ret_item in enumerate(ret):
@@ -1066,12 +1080,13 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
if finish_reason == "stop":
finish_reason = "tool_calls"
try:
text, call_info_list = parse_tool_response(text, tools) # noqa
parser = FunctionCallParser(tools, tool_call_parser)
full_normal_text, call_info_list = parser.parse_non_stream(text)
tool_calls = [
ToolCall(
id=str(call_info[0]),
id=str(call_info.tool_index),
function=FunctionResponse(
name=call_info[1], arguments=call_info[2]
name=call_info.name, arguments=call_info.parameters
),
)
for call_info in call_info_list
@@ -1172,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
if adapted_request.stream:
parser_dict = {}
async def generate_stream_resp():
is_firsts = {}
@@ -1184,6 +1200,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
adapted_request, raw_request
):
index = content.get("index", 0)
text = content["text"]
is_first = is_firsts.get(index, True)
stream_buffer = stream_buffers.get(index, "")
@@ -1263,29 +1280,111 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
text = content["text"]
delta = text[len(stream_buffer) :]
stream_buffer = stream_buffer + delta
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta),
finish_reason=(finish_reason["type"] if finish_reason else ""),
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
new_stream_buffer = stream_buffer + delta
is_firsts[index] = is_first
stream_buffers[index] = stream_buffer
n_prev_tokens[index] = n_prev_token
if request.tool_choice != "none" and request.tools:
if index not in parser_dict:
parser_dict[index] = FunctionCallParser(
tools=request.tools,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
)
parser = parser_dict[index]
yield f"data: {chunk.model_dump_json()}\n\n"
# parse_increment => returns (normal_text, calls)
normal_text, calls = parser.parse_stream_chunk(delta)
# 1) if there's normal_text, output it as normal content
if normal_text:
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=normal_text),
finish_reason=(
finish_reason["type"] if finish_reason else ""
),
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
# 2) if we found calls, we output them as separate chunk(s)
for call_item in calls:
# transform call_item -> FunctionResponse + ToolCall
if (
content["meta_info"]["finish_reason"]
and content["meta_info"]["finish_reason"]["type"]
== "stop"
):
latest_delta_len = 0
if isinstance(call_item.parameters, str):
latest_delta_len = len(call_item.parameters)
expected_call = json.dumps(
parser.multi_format_parser.detectors[0]
.prev_tool_call_arr[index]
.get("arguments", {}),
ensure_ascii=False,
)
actual_call = parser.multi_format_parser.detectors[
0
].streamed_args_for_tool[index]
if latest_delta_len > 0:
actual_call = actual_call[:-latest_delta_len]
remaining_call = expected_call.replace(
actual_call, "", 1
)
call_item.parameters = remaining_call
tool_call = ToolCall(
id=str(call_item.tool_index),
function=FunctionResponse(
name=call_item.name,
arguments=call_item.parameters,
),
)
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(
role="assistant", tool_calls=[tool_call]
),
finish_reason="tool_call",
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
else:
# No tool calls => just treat this as normal text
choice_data = ChatCompletionResponseStreamChoice(
index=index,
delta=DeltaMessage(content=delta),
finish_reason=(
finish_reason["type"] if finish_reason else ""
),
matched_stop=(
finish_reason["matched"]
if finish_reason and "matched" in finish_reason
else None
),
logprobs=choice_logprobs,
)
chunk = ChatCompletionStreamResponse(
id=content["meta_info"]["id"],
choices=[choice_data],
model=request.model,
)
yield f"data: {chunk.model_dump_json()}\n\n"
stream_buffers[index] = new_stream_buffer
is_firsts[index] = is_first
if request.stream_options and request.stream_options.include_usage:
total_prompt_tokens = sum(
tokens
@@ -1333,7 +1432,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
ret = [ret]
response = v1_chat_generate_response(
request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report
request,
ret,
cache_report=tokenizer_manager.server_args.enable_cache_report,
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
)
return response

View File

@@ -262,7 +262,7 @@ class Function(BaseModel):
"""Function descriptions."""
description: Optional[str] = Field(default=None, examples=[None])
name: str
name: Optional[str] = None
parameters: Optional[object] = None
@@ -276,7 +276,7 @@ class Tool(BaseModel):
class ToolChoiceFuncName(BaseModel):
"""The name of tool choice function."""
name: str
name: Optional[str] = None
class ToolChoice(BaseModel):
@@ -329,8 +329,8 @@ class ChatCompletionRequest(BaseModel):
class FunctionResponse(BaseModel):
"""Function response."""
name: str
arguments: str
name: Optional[str] = None
arguments: Optional[str] = None
class ToolCall(BaseModel):
@@ -367,6 +367,7 @@ class ChatCompletionResponse(BaseModel):
class DeltaMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
class ChatCompletionResponseStreamChoice(BaseModel):