[Feature] Function Calling (#2544)
Co-authored-by: Haoyu Wang <120358163+HaoyuWang4188@users.noreply.github.com>
This commit is contained in:
@@ -65,10 +65,13 @@ from sglang.srt.openai_api.protocol import (
|
||||
FileDeleteResponse,
|
||||
FileRequest,
|
||||
FileResponse,
|
||||
FunctionResponse,
|
||||
LogProbs,
|
||||
ToolCall,
|
||||
TopLogprob,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -879,6 +882,21 @@ def v1_chat_generate_request(
|
||||
# None skips any image processing in GenerateReqInput.
|
||||
if not isinstance(request.messages, str):
|
||||
# Apply chat template and its stop strings.
|
||||
tools = None
|
||||
if request.tools and request.tool_choice != "none":
|
||||
request.skip_special_tokens = False
|
||||
if request.stream:
|
||||
logger.warning("Streaming is not supported with tools.")
|
||||
request.stream = False
|
||||
if not isinstance(request.tool_choice, str):
|
||||
tools = [
|
||||
item.function.model_dump()
|
||||
for item in request.tools
|
||||
if item.function.name == request.tool_choice.function.name
|
||||
]
|
||||
else:
|
||||
tools = [item.function.model_dump() for item in request.tools]
|
||||
|
||||
if chat_template_name is None:
|
||||
openai_compatible_messages = []
|
||||
for message in request.messages:
|
||||
@@ -902,6 +920,7 @@ def v1_chat_generate_request(
|
||||
openai_compatible_messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
tools=tools,
|
||||
)
|
||||
if assistant_prefix:
|
||||
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
|
||||
@@ -1041,11 +1060,46 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
||||
|
||||
finish_reason = ret_item["meta_info"]["finish_reason"]
|
||||
|
||||
tool_calls = None
|
||||
text = ret_item["text"]
|
||||
|
||||
if isinstance(request, list):
|
||||
tool_choice = request[idx].tool_choice
|
||||
tools = request[idx].tools
|
||||
else:
|
||||
tool_choice = request.tool_choice
|
||||
tools = request.tools
|
||||
|
||||
if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]):
|
||||
if finish_reason == "stop":
|
||||
finish_reason = "tool_calls"
|
||||
try:
|
||||
text, call_info_list = parse_tool_response(text, tools) # noqa
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=str(call_info[0]),
|
||||
function=FunctionResponse(
|
||||
name=call_info[1], arguments=call_info[2]
|
||||
),
|
||||
)
|
||||
for call_info in call_info_list
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Exception: {e}")
|
||||
return create_error_response(
|
||||
HTTPStatus.BAD_REQUEST,
|
||||
"Failed to parse fc related info to json format!",
|
||||
)
|
||||
|
||||
if to_file:
|
||||
# to make the choice data json serializable
|
||||
choice_data = {
|
||||
"index": 0,
|
||||
"message": {"role": "assistant", "content": ret_item["text"]},
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": ret_item["text"] if tool_calls is None else None,
|
||||
"tool_calls": tool_calls,
|
||||
},
|
||||
"logprobs": choice_logprobs,
|
||||
"finish_reason": (finish_reason["type"] if finish_reason else ""),
|
||||
"matched_stop": (
|
||||
@@ -1057,7 +1111,11 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
||||
else:
|
||||
choice_data = ChatCompletionResponseChoice(
|
||||
index=idx,
|
||||
message=ChatMessage(role="assistant", content=ret_item["text"]),
|
||||
message=ChatMessage(
|
||||
role="assistant",
|
||||
content=ret_item["text"] if tool_calls is None else None,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
||||
matched_stop=(
|
||||
|
||||
@@ -257,6 +257,34 @@ class ResponseFormat(BaseModel):
|
||||
json_schema: Optional[JsonSchemaResponseFormat] = None
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""Function descriptions."""
|
||||
|
||||
description: Optional[str] = Field(default=None, examples=[None])
|
||||
name: str
|
||||
parameters: Optional[object] = None
|
||||
|
||||
|
||||
class Tool(BaseModel):
|
||||
"""Function wrapper."""
|
||||
|
||||
type: str = Field(default="function", examples=["function"])
|
||||
function: Function
|
||||
|
||||
|
||||
class ToolChoiceFuncName(BaseModel):
|
||||
"""The name of tool choice function."""
|
||||
|
||||
name: str
|
||||
|
||||
|
||||
class ToolChoice(BaseModel):
|
||||
"""The tool choice definition."""
|
||||
|
||||
function: ToolChoiceFuncName
|
||||
type: Literal["function"] = Field(default="function", examples=["function"])
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
# Ordered by official OpenAI API documentation
|
||||
# https://platform.openai.com/docs/api-reference/chat/create
|
||||
@@ -277,6 +305,10 @@ class ChatCompletionRequest(BaseModel):
|
||||
temperature: float = 0.7
|
||||
top_p: float = 1.0
|
||||
user: Optional[str] = None
|
||||
tools: Optional[List[Tool]] = Field(default=None, examples=[None])
|
||||
tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field(
|
||||
default="auto", examples=["none"]
|
||||
) # noqa
|
||||
|
||||
# Extra parameters for SRT backend only and will be ignored by OpenAI models.
|
||||
top_k: int = -1
|
||||
@@ -292,9 +324,25 @@ class ChatCompletionRequest(BaseModel):
|
||||
ebnf: Optional[str] = None
|
||||
|
||||
|
||||
class FunctionResponse(BaseModel):
|
||||
"""Function response."""
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""Tool call response."""
|
||||
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionResponse
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
|
||||
|
||||
class ChatCompletionResponseChoice(BaseModel):
|
||||
|
||||
@@ -1273,3 +1273,65 @@ def dataclass_to_string_truncated(data, max_length=2048):
|
||||
)
|
||||
else:
|
||||
return str(data)
|
||||
|
||||
|
||||
TOOLS_TAG_LIST = ["<|plugin|>", "<function=", "<tool_call>", "<|python_tag|>"]
|
||||
|
||||
|
||||
def parse_tool_response(text, tools, **kwargs):
|
||||
"""Parse model response containing tool information.
|
||||
|
||||
Args:
|
||||
text(str): model response in string format
|
||||
tools(List): tools from user request
|
||||
"""
|
||||
if "<|plugin|>" in text: # internlm2
|
||||
text, action = text.split("<|action_start|><|plugin|>")
|
||||
action = action.split("<|action_end|>".strip())[0]
|
||||
action = action[action.find("{") :]
|
||||
action = json.loads(action)
|
||||
name, parameters = action["name"], json.dumps(
|
||||
action.get("parameters", action.get("arguments", {})), ensure_ascii=False
|
||||
)
|
||||
call_info_list = [(name, parameters)]
|
||||
elif "<function=" in text: # llama3.1
|
||||
action, _ = text.split("</function>")
|
||||
parameters = action[action.find("{") :]
|
||||
name = action.split("<function=")[1].split(">{")[0]
|
||||
call_info_list = [(name, parameters)]
|
||||
elif "<tool_call>" in text and "</tool_call>" in text: # qwen2.5
|
||||
# get tool_call in text
|
||||
pattern = r"<tool_call>(.*?)</tool_call>"
|
||||
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||
call_info_list = []
|
||||
for match_result in match_result_list:
|
||||
action = json.loads(match_result)
|
||||
call_info_list.append(
|
||||
(action["name"], json.dumps(action["arguments"], ensure_ascii=False))
|
||||
)
|
||||
# get text outside of tags
|
||||
if not text.startswith("<tool_call>"):
|
||||
text = text[: text.find("<tool_call>")]
|
||||
elif not text.endswith("</tool_call>"):
|
||||
text = text[text.rfind("</tool_call>") + len("</tool_call>") :]
|
||||
else:
|
||||
text = ""
|
||||
elif "<|python_tag|>" in text: # llama3.2
|
||||
_, action = text.split("<|python_tag|>")
|
||||
action = json.loads(action)
|
||||
name, parameters = action["name"], json.dumps(
|
||||
action.get("parameters", action.get("arguments", {})), ensure_ascii=False
|
||||
)
|
||||
call_info_list = [(name, parameters)]
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected model response: {text}")
|
||||
|
||||
call_info_list = [
|
||||
(
|
||||
[tool.function.name for tool in tools].index(call_info[0]),
|
||||
call_info[0],
|
||||
call_info[1],
|
||||
)
|
||||
for call_info in call_info_list
|
||||
]
|
||||
return text, call_info_list
|
||||
|
||||
Reference in New Issue
Block a user