[Feature] Function Calling (#2544)

Co-authored-by: Haoyu Wang <120358163+HaoyuWang4188@users.noreply.github.com>
This commit is contained in:
Tanjiro
2024-12-29 11:28:52 +05:30
committed by GitHub
parent fd28640dc5
commit 8ee9a8501a
5 changed files with 368 additions and 2 deletions

View File

@@ -65,10 +65,13 @@ from sglang.srt.openai_api.protocol import (
FileDeleteResponse,
FileRequest,
FileResponse,
FunctionResponse,
LogProbs,
ToolCall,
TopLogprob,
UsageInfo,
)
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
from sglang.utils import get_exception_traceback
logger = logging.getLogger(__name__)
@@ -879,6 +882,21 @@ def v1_chat_generate_request(
# None skips any image processing in GenerateReqInput.
if not isinstance(request.messages, str):
# Apply chat template and its stop strings.
tools = None
if request.tools and request.tool_choice != "none":
request.skip_special_tokens = False
if request.stream:
logger.warning("Streaming is not supported with tools.")
request.stream = False
if not isinstance(request.tool_choice, str):
tools = [
item.function.model_dump()
for item in request.tools
if item.function.name == request.tool_choice.function.name
]
else:
tools = [item.function.model_dump() for item in request.tools]
if chat_template_name is None:
openai_compatible_messages = []
for message in request.messages:
@@ -902,6 +920,7 @@ def v1_chat_generate_request(
openai_compatible_messages,
tokenize=True,
add_generation_prompt=True,
tools=tools,
)
if assistant_prefix:
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
@@ -1041,11 +1060,46 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
finish_reason = ret_item["meta_info"]["finish_reason"]
tool_calls = None
text = ret_item["text"]
if isinstance(request, list):
tool_choice = request[idx].tool_choice
tools = request[idx].tools
else:
tool_choice = request.tool_choice
tools = request.tools
if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
if finish_reason == "stop":
finish_reason = "tool_calls"
try:
text, call_info_list = parse_tool_response(text, tools) # noqa
tool_calls = [
ToolCall(
id=str(call_info[0]),
function=FunctionResponse(
name=call_info[1], arguments=call_info[2]
),
)
for call_info in call_info_list
]
except Exception as e:
logger.error(f"Exception: {e}")
return create_error_response(
HTTPStatus.BAD_REQUEST,
"Failed to parse fc related info to json format!",
)
if to_file:
# to make the choice data json serializable
choice_data = {
"index": 0,
"message": {"role": "assistant", "content": ret_item["text"]},
"message": {
"role": "assistant",
"content": ret_item["text"] if tool_calls is None else None,
"tool_calls": tool_calls,
},
"logprobs": choice_logprobs,
"finish_reason": (finish_reason["type"] if finish_reason else ""),
"matched_stop": (
@@ -1057,7 +1111,11 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
else:
choice_data = ChatCompletionResponseChoice(
index=idx,
message=ChatMessage(role="assistant", content=ret_item["text"]),
message=ChatMessage(
role="assistant",
content=ret_item["text"] if tool_calls is None else None,
tool_calls=tool_calls,
),
logprobs=choice_logprobs,
finish_reason=(finish_reason["type"] if finish_reason else ""),
matched_stop=(

View File

@@ -257,6 +257,34 @@ class ResponseFormat(BaseModel):
json_schema: Optional[JsonSchemaResponseFormat] = None
class Function(BaseModel):
"""Function descriptions."""
description: Optional[str] = Field(default=None, examples=[None])
name: str
parameters: Optional[object] = None
class Tool(BaseModel):
"""Function wrapper."""
type: str = Field(default="function", examples=["function"])
function: Function
class ToolChoiceFuncName(BaseModel):
"""The name of tool choice function."""
name: str
class ToolChoice(BaseModel):
"""The tool choice definition."""
function: ToolChoiceFuncName
type: Literal["function"] = Field(default="function", examples=["function"])
class ChatCompletionRequest(BaseModel):
# Ordered by official OpenAI API documentation
# https://platform.openai.com/docs/api-reference/chat/create
@@ -277,6 +305,10 @@ class ChatCompletionRequest(BaseModel):
temperature: float = 0.7
top_p: float = 1.0
user: Optional[str] = None
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
default="auto", examples=["none"]
) # noqa
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
top_k: int = -1
@@ -292,9 +324,25 @@ class ChatCompletionRequest(BaseModel):
ebnf: Optional[str] = None
class FunctionResponse(BaseModel):
"""Function response."""
name: str
arguments: str
class ToolCall(BaseModel):
"""Tool call response."""
id: str
type: Literal["function"] = "function"
function: FunctionResponse
class ChatMessage(BaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
class ChatCompletionResponseChoice(BaseModel):