From 8ee9a8501a897395e9d21dbf02986b0f98b378d0 Mon Sep 17 00:00:00 2001 From: Tanjiro Date: Sun, 29 Dec 2024 11:28:52 +0530 Subject: [PATCH] [Feature] Function Calling (#2544) Co-authored-by: Haoyu Wang <120358163+HaoyuWang4188@users.noreply.github.com> --- docs/backend/function_calling.ipynb | 146 +++++++++++++++++++++++ python/sglang/srt/openai_api/adapter.py | 62 +++++++++- python/sglang/srt/openai_api/protocol.py | 48 ++++++++ python/sglang/srt/utils.py | 62 ++++++++++ test/srt/test_openai_server.py | 52 ++++++++ 5 files changed, 368 insertions(+), 2 deletions(-) create mode 100644 docs/backend/function_calling.ipynb diff --git a/docs/backend/function_calling.ipynb b/docs/backend/function_calling.ipynb new file mode 100644 index 000000000..47a2e2278 --- /dev/null +++ b/docs/backend/function_calling.ipynb @@ -0,0 +1,146 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Function Calling\n", + "\n", + "This notebook provides a quick-start guide to use function tooling using SGLang chat completions API\n", + "\n", + "## Supported Models\n", + "\n", + "Currently, we added the support for tools calling in the following models:\n", + " - Llama 3.2 models\n", + " - Llama 3.1 models\n", + " - Qwen 2.5 models\n", + " - InternLM Models" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Usage\n", + "\n", + "### Launch a server\n", + "\n", + "This code block is equivalent to executing\n", + "\n", + "`python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct \\\n", + "--port 30000 --host 0.0.0.0`\n", + "in your terminal and wait for the server to be ready. Once the server is running, you can send test requests using curl or requests. The server implements the OpenAI-compatible APIs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from sglang.utils import (\n", + " execute_shell_command,\n", + " wait_for_server,\n", + " terminate_process,\n", + " print_highlight,\n", + ")\n", + "\n", + "\n", + "server_process = execute_shell_command(\n", + " \"\"\"\n", + " python -m sglang.launch_server --model-path meta-llama/Meta-Llama-3.1-8B-Instruct --port 30000 --host 0.0.0.0\n", + "\"\"\"\n", + ")\n", + "\n", + "wait_for_server(\"http://localhost:30000\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Single Round Invocation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from openai import OpenAI\n", + "\n", + "tools = [\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_current_weather\",\n", + " \"description\": \"Get the current weather in a given location\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", + " },\n", + " \"unit\": {\"type\": \"string\", \"enum\": [\"celsius\", \"fahrenheit\"]},\n", + " },\n", + " \"required\": [\"location\"],\n", + " },\n", + " },\n", + " }\n", + "]\n", + "messages = [{\"role\": \"user\", \"content\": \"What's the weather like in Boston today?\"}]\n", + "\n", + "client = OpenAI(api_key=\"YOUR_API_KEY\", base_url=\"http://0.0.0.0:30000/v1\")\n", + "model_name = client.models.list().data[0].id\n", + "response = client.chat.completions.create(\n", + " model=model_name,\n", + " messages=messages,\n", + " temperature=0.8,\n", + " top_p=0.8,\n", + " stream=False,\n", + " tools=tools,\n", + ")\n", + "\n", + "print(response)\n", + "\n", + "\"\"\"\n", + "\n", + "ChatCompletion(id='d6f620e1767e490d85b5ce45c15151cf', choices=[Choice(finish_reason='stop', index=0, logprobs=None, message=ChatCompletionMessage(content=None, refusal=None, \n", + "role='assistant', audio=None, function_call=None, tool_calls=[ChatCompletionMessageToolCall(id='0', function=Function(arguments='{\"a\": \"3\", \"b\": \"5\"}', name='add'), type='function')]), \n", + "matched_stop=128008)], created=1735411703, model='meta-llama/Llama-3.2-1B-Instruct', object='chat.completion', service_tier=None, system_fingerprint=None, \n", + "usage=CompletionUsage(completion_tokens=23, prompt_tokens=198, total_tokens=221, completion_tokens_details=None, prompt_tokens_details=None))\n", + "\n", + "\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "terminate_process(server_process)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## How to support a new model?\n", + "\n", + "For adding support of more different models:\n", + " 1. Update the `TOOLS_TAG_LIST` in `sglang/srt/utils.py` with the tool tag used by the model.\n", + " 2. Add support in `parse_tool_response` function for converting into tool calls `sglang/srt/utils.py`\n" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/python/sglang/srt/openai_api/adapter.py b/python/sglang/srt/openai_api/adapter.py index 46094c556..cbbc741d6 100644 --- a/python/sglang/srt/openai_api/adapter.py +++ b/python/sglang/srt/openai_api/adapter.py @@ -65,10 +65,13 @@ from sglang.srt.openai_api.protocol import ( FileDeleteResponse, FileRequest, FileResponse, + FunctionResponse, LogProbs, + ToolCall, TopLogprob, UsageInfo, ) +from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response from sglang.utils import get_exception_traceback logger = logging.getLogger(__name__) @@ -879,6 +882,21 @@ def v1_chat_generate_request( # None skips any image processing in GenerateReqInput. if not isinstance(request.messages, str): # Apply chat template and its stop strings. + tools = None + if request.tools and request.tool_choice != "none": + request.skip_special_tokens = False + if request.stream: + logger.warning("Streaming is not supported with tools.") + request.stream = False + if not isinstance(request.tool_choice, str): + tools = [ + item.function.model_dump() + for item in request.tools + if item.function.name == request.tool_choice.function.name + ] + else: + tools = [item.function.model_dump() for item in request.tools] + if chat_template_name is None: openai_compatible_messages = [] for message in request.messages: @@ -902,6 +920,7 @@ def v1_chat_generate_request( openai_compatible_messages, tokenize=True, add_generation_prompt=True, + tools=tools, ) if assistant_prefix: prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix) @@ -1041,11 +1060,46 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): finish_reason = ret_item["meta_info"]["finish_reason"] + tool_calls = None + text = ret_item["text"] + + if isinstance(request, list): + tool_choice = request[idx].tool_choice + tools = request[idx].tools + else: + tool_choice = request.tool_choice + tools = request.tools + + if tool_choice != "none" and any([i in text for i in TOOLS_TAG_LIST]): + if finish_reason == "stop": + finish_reason = "tool_calls" + try: + text, call_info_list = parse_tool_response(text, tools) # noqa + tool_calls = [ + ToolCall( + id=str(call_info[0]), + function=FunctionResponse( + name=call_info[1], arguments=call_info[2] + ), + ) + for call_info in call_info_list + ] + except Exception as e: + logger.error(f"Exception: {e}") + return create_error_response( + HTTPStatus.BAD_REQUEST, + "Failed to parse fc related info to json format!", + ) + if to_file: # to make the choice data json serializable choice_data = { "index": 0, - "message": {"role": "assistant", "content": ret_item["text"]}, + "message": { + "role": "assistant", + "content": ret_item["text"] if tool_calls is None else None, + "tool_calls": tool_calls, + }, "logprobs": choice_logprobs, "finish_reason": (finish_reason["type"] if finish_reason else ""), "matched_stop": ( @@ -1057,7 +1111,11 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False): else: choice_data = ChatCompletionResponseChoice( index=idx, - message=ChatMessage(role="assistant", content=ret_item["text"]), + message=ChatMessage( + role="assistant", + content=ret_item["text"] if tool_calls is None else None, + tool_calls=tool_calls, + ), logprobs=choice_logprobs, finish_reason=(finish_reason["type"] if finish_reason else ""), matched_stop=( diff --git a/python/sglang/srt/openai_api/protocol.py b/python/sglang/srt/openai_api/protocol.py index 4f7833a23..2599cea3d 100644 --- a/python/sglang/srt/openai_api/protocol.py +++ b/python/sglang/srt/openai_api/protocol.py @@ -257,6 +257,34 @@ class ResponseFormat(BaseModel): json_schema: Optional[JsonSchemaResponseFormat] = None +class Function(BaseModel): + """Function descriptions.""" + + description: Optional[str] = Field(default=None, examples=[None]) + name: str + parameters: Optional[object] = None + + +class Tool(BaseModel): + """Function wrapper.""" + + type: str = Field(default="function", examples=["function"]) + function: Function + + +class ToolChoiceFuncName(BaseModel): + """The name of tool choice function.""" + + name: str + + +class ToolChoice(BaseModel): + """The tool choice definition.""" + + function: ToolChoiceFuncName + type: Literal["function"] = Field(default="function", examples=["function"]) + + class ChatCompletionRequest(BaseModel): # Ordered by official OpenAI API documentation # https://platform.openai.com/docs/api-reference/chat/create @@ -277,6 +305,10 @@ class ChatCompletionRequest(BaseModel): temperature: float = 0.7 top_p: float = 1.0 user: Optional[str] = None + tools: Optional[List[Tool]] = Field(default=None, examples=[None]) + tool_choice: Union[ToolChoice, Literal["auto", "required", "none"]] = Field( + default="auto", examples=["none"] + ) # noqa # Extra parameters for SRT backend only and will be ignored by OpenAI models. top_k: int = -1 @@ -292,9 +324,25 @@ class ChatCompletionRequest(BaseModel): ebnf: Optional[str] = None +class FunctionResponse(BaseModel): + """Function response.""" + + name: str + arguments: str + + +class ToolCall(BaseModel): + """Tool call response.""" + + id: str + type: Literal["function"] = "function" + function: FunctionResponse + + class ChatMessage(BaseModel): role: Optional[str] = None content: Optional[str] = None + tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None]) class ChatCompletionResponseChoice(BaseModel): diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index b2f3e5ccb..7c3efa9a2 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -1273,3 +1273,65 @@ def dataclass_to_string_truncated(data, max_length=2048): ) else: return str(data) + + +TOOLS_TAG_LIST = ["<|plugin|>", "", "<|python_tag|>"] + + +def parse_tool_response(text, tools, **kwargs): + """Parse model response containing tool information. + + Args: + text(str): model response in string format + tools(List): tools from user request + """ + if "<|plugin|>" in text: # internlm2 + text, action = text.split("<|action_start|><|plugin|>") + action = action.split("<|action_end|>".strip())[0] + action = action[action.find("{") :] + action = json.loads(action) + name, parameters = action["name"], json.dumps( + action.get("parameters", action.get("arguments", {})), ensure_ascii=False + ) + call_info_list = [(name, parameters)] + elif "") + parameters = action[action.find("{") :] + name = action.split("{")[0] + call_info_list = [(name, parameters)] + elif "" in text and "" in text: # qwen2.5 + # get tool_call in text + pattern = r"(.*?)" + match_result_list = re.findall(pattern, text, re.DOTALL) + call_info_list = [] + for match_result in match_result_list: + action = json.loads(match_result) + call_info_list.append( + (action["name"], json.dumps(action["arguments"], ensure_ascii=False)) + ) + # get text outside of tags + if not text.startswith(""): + text = text[: text.find("")] + elif not text.endswith(""): + text = text[text.rfind("") + len("") :] + else: + text = "" + elif "<|python_tag|>" in text: # llama3.2 + _, action = text.split("<|python_tag|>") + action = json.loads(action) + name, parameters = action["name"], json.dumps( + action.get("parameters", action.get("arguments", {})), ensure_ascii=False + ) + call_info_list = [(name, parameters)] + else: + raise RuntimeError(f"Unexpected model response: {text}") + + call_info_list = [ + ( + [tool.function.name for tool in tools].index(call_info[0]), + call_info[0], + call_info[1], + ) + for call_info in call_info_list + ] + return text, call_info_list diff --git a/test/srt/test_openai_server.py b/test/srt/test_openai_server.py index 47932ae41..379e57f35 100644 --- a/test/srt/test_openai_server.py +++ b/test/srt/test_openai_server.py @@ -622,6 +622,58 @@ class TestOpenAIServerEBNF(unittest.TestCase): text, pattern, f"Text '{text}' not matching the EBNF strict JSON shape" ) + def test_function_calling_format(self): + + client = openai.Client(api_key=self.api_key, base_url=self.base_url) + tools = [ + { + "type": "function", + "function": { + "name": "add", + "description": "Compute the sum of two numbers", + "parameters": { + "type": "object", + "properties": { + "a": { + "type": "int", + "description": "A number", + }, + "b": { + "type": "int", + "description": "A number", + }, + }, + "required": ["a", "b"], + }, + }, + } + ] + + messages = [{"role": "user", "content": "Compute (3+5)"}] + response = client.chat.completions.create( + model=self.model, + messages=messages, + temperature=0.8, + top_p=0.8, + stream=False, + tools=tools, + ) + + content = response.choices[0].message.content + tool_calls = response.choices[0].message.tool_calls + + assert ( + content is None + ), "When tools provided by the response, content should be None" + assert ( + isinstance(tool_calls, list) and len(tool_calls) > 0 + ), "Format not matched, tool_calls should be a list" + + function_name = tool_calls[0].function.name + assert ( + function_name == "add" + ), "Function name should be add for the above response" + if __name__ == "__main__": unittest.main()