Feature/function calling update (#2700)
Co-authored-by: Mingyuan Ma <mamingyuan2001@berkeley.edu> Co-authored-by: Chayenne <zhaochen20@outlook.com> Co-authored-by: shuaills <shishuaiuoe@gmail.com>
This commit is contained in:
@@ -39,10 +39,12 @@ from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
||||
|
||||
from sglang.srt.entrypoints.engine import _launch_subprocesses
|
||||
from sglang.srt.function_call_parser import FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import (
|
||||
CloseSessionReqInput,
|
||||
ConfigureLoggingReq,
|
||||
EmbeddingReqInput,
|
||||
FunctionCallReqInput,
|
||||
GenerateReqInput,
|
||||
GetWeightsByNameReqInput,
|
||||
InitWeightsUpdateGroupReqInput,
|
||||
@@ -369,6 +371,28 @@ async def configure_logging(obj: ConfigureLoggingReq, request: Request):
|
||||
return Response(status_code=200)
|
||||
|
||||
|
||||
@app.post("/function_call")
|
||||
async def function_call_request(obj: FunctionCallReqInput, request: Request):
|
||||
"""
|
||||
A native API endpoint to parse function calls from a text.
|
||||
"""
|
||||
# 1) Initialize the parser based on the request body
|
||||
parser = FunctionCallParser(tools=obj.tools, tool_call_parser=obj.tool_call_parser)
|
||||
|
||||
# 2) Call the non-stream parsing method (non-stream)
|
||||
normal_text, calls = parser.parse_non_stream(obj.text)
|
||||
|
||||
# 3) Organize the response content
|
||||
response_data = {
|
||||
"normal_text": normal_text,
|
||||
"calls": [
|
||||
call.model_dump() for call in calls
|
||||
], # Convert pydantic objects to dictionaries
|
||||
}
|
||||
|
||||
return ORJSONResponse(content=response_data, status_code=200)
|
||||
|
||||
|
||||
##### OpenAI-compatible API endpoints #####
|
||||
|
||||
|
||||
|
||||
494
python/sglang/srt/function_call_parser.py
Normal file
494
python/sglang/srt/function_call_parser.py
Normal file
@@ -0,0 +1,494 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from json import JSONDecodeError, JSONDecoder
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import partial_json_parser
|
||||
from partial_json_parser.core.options import Allow
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
TOOLS_TAG_LIST = [
|
||||
"<|plugin|>",
|
||||
"<function=",
|
||||
"<tool_call>",
|
||||
"<|python_tag|>",
|
||||
"[TOOL_CALLS]",
|
||||
]
|
||||
|
||||
|
||||
class Function(BaseModel):
|
||||
"""Function Tool Template."""
|
||||
|
||||
description: Optional[str] = Field(default=None, examples=[None])
|
||||
name: Optional[str] = None
|
||||
parameters: Optional[object] = None
|
||||
|
||||
|
||||
class ToolCallItem(BaseModel):
|
||||
"""Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts."""
|
||||
|
||||
tool_index: int
|
||||
name: Optional[str] = None
|
||||
parameters: str # JSON string
|
||||
|
||||
|
||||
def _find_common_prefix(s1: str, s2: str) -> str:
|
||||
prefix = ""
|
||||
min_length = min(len(s1), len(s2))
|
||||
for i in range(0, min_length):
|
||||
if s1[i] == s2[i]:
|
||||
prefix += s1[i]
|
||||
else:
|
||||
break
|
||||
return prefix
|
||||
|
||||
|
||||
def _partial_json_loads(input_str: str, flags: Allow) -> Tuple[Any, int]:
|
||||
try:
|
||||
return (partial_json_parser.loads(input_str, flags), len(input_str))
|
||||
except JSONDecodeError as e:
|
||||
if "Extra data" in e.msg:
|
||||
dec = JSONDecoder()
|
||||
return dec.raw_decode(input_str)
|
||||
raise
|
||||
|
||||
|
||||
def _is_complete_json(input_str: str) -> bool:
|
||||
try:
|
||||
json.loads(input_str)
|
||||
return True
|
||||
except JSONDecodeError:
|
||||
return False
|
||||
|
||||
|
||||
class StreamingParseResult:
|
||||
"""Result of streaming incremental parsing."""
|
||||
|
||||
def __init__(
|
||||
self, normal_text: str = "", calls: Optional[List[ToolCallItem]] = None
|
||||
):
|
||||
self.normal_text = normal_text
|
||||
self.calls = calls or []
|
||||
|
||||
|
||||
class BaseFormatDetector:
|
||||
"""Base class providing two sets of interfaces: one-time and streaming incremental."""
|
||||
|
||||
def __init__(self):
|
||||
# initialize properties used for state when parsing tool calls in
|
||||
self._buffer = ""
|
||||
# streaming mode
|
||||
self.prev_tool_call_arr: List[Dict] = []
|
||||
self.current_tool_id: int = -1
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool: List[str] = (
|
||||
[]
|
||||
) # map what has been streamed for each tool so far to a list
|
||||
self.bot_token = ""
|
||||
self.eot_token = ""
|
||||
|
||||
def parse_base_json(self, action: Dict, tools: List[Function]):
|
||||
name, parameters = action["name"], json.dumps(
|
||||
action.get("parameters", action.get("arguments", {})),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
tool_index = [tool.function.name for tool in tools].index(name)
|
||||
tool_call_item = ToolCallItem(
|
||||
tool_index=tool_index, name=name, parameters=parameters
|
||||
)
|
||||
calls = [tool_call_item]
|
||||
return calls
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||
"""
|
||||
Parses the text in one go. Returns success=True if the format matches, otherwise False.
|
||||
Note that leftover_text here represents "content that this parser will not consume further".
|
||||
"""
|
||||
action = json.loads(text)
|
||||
return self.parse_base_json(action, tools)
|
||||
|
||||
def parse_streaming_increment(
|
||||
self, new_text: str, tools: List[Function]
|
||||
) -> StreamingParseResult:
|
||||
"""
|
||||
Streaming incremental parsing, referencing the logic of Llama32Detector.
|
||||
We partially parse JSON within <tool_call>...</tool_call>, and handle
|
||||
incremental argument output.
|
||||
"""
|
||||
# Append new text to buffer
|
||||
self._buffer += new_text
|
||||
current_text = self._buffer
|
||||
if not (self.bot_token in current_text or current_text.startswith("{")):
|
||||
self._buffer = ""
|
||||
if self.eot_token in new_text:
|
||||
new_text = new_text.replace(self.eot_token, "")
|
||||
return StreamingParseResult(normal_text=new_text)
|
||||
|
||||
# bit mask flags for partial JSON parsing. If the name hasn't been
|
||||
# sent yet, don't allow sending
|
||||
# an incomplete string since OpenAI only ever (as far as I have
|
||||
# seen) allows sending the entire tool/ function name at once.
|
||||
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||||
try:
|
||||
tool_call_arr = []
|
||||
is_complete = []
|
||||
try:
|
||||
# depending on the prompt format the Llama model may or may not
|
||||
# prefix the output with the <|python_tag|> token
|
||||
start_idx = (
|
||||
len(self.bot_token)
|
||||
if current_text.startswith(self.bot_token)
|
||||
else 0
|
||||
)
|
||||
while start_idx < len(current_text):
|
||||
(obj, end_idx) = _partial_json_loads(
|
||||
current_text[start_idx:], flags
|
||||
)
|
||||
is_complete.append(
|
||||
_is_complete_json(current_text[start_idx : start_idx + end_idx])
|
||||
)
|
||||
start_idx += end_idx + len("; ")
|
||||
# depending on the prompt Llama can use
|
||||
# either arguments or parameters
|
||||
if "parameters" in obj:
|
||||
assert (
|
||||
"arguments" not in obj
|
||||
), "model generated both parameters and arguments"
|
||||
obj["arguments"] = obj["parameters"]
|
||||
tool_call_arr.append(obj)
|
||||
|
||||
except partial_json_parser.core.exceptions.MalformedJSON:
|
||||
# not enough tokens to parse into JSON yet
|
||||
return StreamingParseResult()
|
||||
|
||||
# select as the current tool call the one we're on the state at
|
||||
current_tool_call: Dict = (
|
||||
tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
|
||||
)
|
||||
|
||||
# case -- if no tokens have been streamed for the tool, e.g.
|
||||
# only the array brackets, stream nothing
|
||||
if len(tool_call_arr) == 0:
|
||||
return StreamingParseResult()
|
||||
|
||||
# case: we are starting a new tool in the array
|
||||
# -> array has > 0 length AND length has moved past cursor
|
||||
elif (
|
||||
len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
|
||||
):
|
||||
|
||||
# if we're moving on to a new call, first make sure we
|
||||
# haven't missed anything in the previous one that was
|
||||
# auto-generated due to JSON completions, but wasn't
|
||||
# streamed to the client yet.
|
||||
if self.current_tool_id >= 0:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
if cur_arguments:
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||
argument_diff = cur_args_json[sent:]
|
||||
|
||||
res = StreamingParseResult(
|
||||
normal_text=None,
|
||||
calls=[
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
name="",
|
||||
parameters=argument_diff,
|
||||
)
|
||||
],
|
||||
)
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id
|
||||
] += argument_diff
|
||||
else:
|
||||
res = StreamingParseResult()
|
||||
else:
|
||||
res = StreamingParseResult()
|
||||
# re-set stuff pertaining to progress in the current tool
|
||||
self.current_tool_id = len(tool_call_arr) - 1
|
||||
self.current_tool_name_sent = False
|
||||
self.streamed_args_for_tool.append("")
|
||||
print("starting on new tool %d", self.current_tool_id)
|
||||
return res
|
||||
|
||||
# if the current tool name hasn't been sent, send if available
|
||||
# - otherwise send nothing
|
||||
elif not self.current_tool_name_sent:
|
||||
function_name = current_tool_call.get("name")
|
||||
if function_name:
|
||||
res = StreamingParseResult(
|
||||
normal_text=None,
|
||||
calls=[
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
name=function_name,
|
||||
parameters="",
|
||||
)
|
||||
],
|
||||
)
|
||||
self.current_tool_name_sent = True
|
||||
else:
|
||||
res = StreamingParseResult()
|
||||
|
||||
# now we know we're on the same tool call and we're streaming
|
||||
# arguments
|
||||
else:
|
||||
cur_arguments = current_tool_call.get("arguments")
|
||||
res = StreamingParseResult()
|
||||
|
||||
if cur_arguments:
|
||||
sent = len(self.streamed_args_for_tool[self.current_tool_id])
|
||||
cur_args_json = json.dumps(cur_arguments)
|
||||
prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
|
||||
"arguments"
|
||||
)
|
||||
|
||||
argument_diff = None
|
||||
if is_complete[self.current_tool_id]:
|
||||
argument_diff = cur_args_json[sent:]
|
||||
self._buffer = ""
|
||||
self.prev_tool_call_arr[self.current_tool_id].clear()
|
||||
self.current_tool_name_sent: bool = False
|
||||
self.streamed_args_for_tool[self.current_tool_id] = ""
|
||||
|
||||
elif prev_arguments:
|
||||
prev_args_json = json.dumps(prev_arguments)
|
||||
if cur_args_json != prev_args_json:
|
||||
|
||||
prefix = _find_common_prefix(prev_args_json, cur_args_json)
|
||||
argument_diff = prefix[sent:]
|
||||
|
||||
if argument_diff is not None:
|
||||
res = StreamingParseResult(
|
||||
calls=[
|
||||
ToolCallItem(
|
||||
tool_index=self.current_tool_id,
|
||||
name="",
|
||||
parameters=argument_diff,
|
||||
)
|
||||
],
|
||||
)
|
||||
if not is_complete[self.current_tool_id]:
|
||||
self.streamed_args_for_tool[
|
||||
self.current_tool_id
|
||||
] += argument_diff
|
||||
|
||||
self.prev_tool_call_arr = tool_call_arr
|
||||
return res
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# Skipping chunk as a result of tool streaming extraction error
|
||||
return StreamingParseResult()
|
||||
|
||||
|
||||
class Qwen25Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Qwen 2.5 models.
|
||||
Assumes function call format:
|
||||
<tool_call>{"name":"xxx", "arguments":{...}}</tool_call>
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the detector with necessary state variables.
|
||||
"""
|
||||
super().__init__()
|
||||
self.bot_token = "<tool_call>"
|
||||
self.eot_token = "</tool_call>"
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
:param text: The complete text to parse.
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
if "<tool_call>" not in text:
|
||||
return []
|
||||
pattern = r"<tool_call>(.*?)</tool_call>"
|
||||
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||
calls = []
|
||||
for match_result in match_result_list:
|
||||
match_result = json.loads(match_result)
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
return calls
|
||||
|
||||
|
||||
class MistralDetector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Mistral models.
|
||||
Assumes function call format:
|
||||
<|action_start|><|plugin|>{"name":"xxx", "arguments":{...}}<|action_end|>
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the detector with necessary state variables.
|
||||
"""
|
||||
super().__init__()
|
||||
self.bot_token = "[TOOL_CALLS] ["
|
||||
self.tool_call_regex = re.compile(r"\[{.*}\]", re.DOTALL)
|
||||
|
||||
def _clean_text(self, text: str) -> str:
|
||||
"""
|
||||
clean text to only leave ''[TOOL_CALLS] [{"name": xxx, "arguments": {xxx}}]'
|
||||
for example,
|
||||
text = '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]\n\nToday\'s weather in Boston is :{function call result} (in Fahrenheit)\n\nIf you prefer Celsius, please let me know.'
|
||||
return '[TOOL_CALLS] [{"name": "get_current_weather", "arguments": {"location": "Boston, MA", "unit": "fahrenheit"}}]'
|
||||
The key pattern is [TOOL_CALLS] [...]
|
||||
"""
|
||||
find_results = re.findall(r"\[TOOL_CALLS\] \[.*?\]", text, re.DOTALL)
|
||||
if len(find_results) > 0:
|
||||
return find_results[0]
|
||||
else:
|
||||
return ""
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
:param text: The complete text to parse.
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
text = self._clean_text(text)
|
||||
tool_content = text.replace("[TOOL_CALLS]", "").strip()
|
||||
raw_tool_calls = self.tool_call_regex.findall(tool_content)
|
||||
calls = []
|
||||
if len(raw_tool_calls) > 0:
|
||||
raw_tool_call = raw_tool_calls[0]
|
||||
function_call_arr = json.loads(raw_tool_call)
|
||||
for match_result in function_call_arr:
|
||||
calls.extend(self.parse_base_json(match_result, tools))
|
||||
return calls
|
||||
|
||||
|
||||
class Llama32Detector(BaseFormatDetector):
|
||||
"""
|
||||
Detector for Llama 3.2 models.
|
||||
Assumes function call format:
|
||||
<|python_tag|>{"name":"xxx", "arguments":{...}}
|
||||
Does not require a closing tag "</python_tag|>",
|
||||
relies on json.loads(...) success to determine if JSON is complete.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""
|
||||
Initializes the detector with necessary state variables.
|
||||
"""
|
||||
super().__init__()
|
||||
self.bot_token = "<|python_tag|>"
|
||||
|
||||
def detect_and_parse(self, text: str, tools: List[Function]) -> List[ToolCallItem]:
|
||||
"""
|
||||
One-time parsing: Detects and parses tool calls in the provided text.
|
||||
|
||||
:param text: The complete text to parse.
|
||||
:param tools: List of available tools.
|
||||
:return: ParseResult indicating success or failure, consumed text, leftover text, and parsed calls.
|
||||
"""
|
||||
|
||||
if "<|python_tag|>" not in text:
|
||||
return []
|
||||
_, action = text.split("<|python_tag|>")
|
||||
action = json.loads(action)
|
||||
return self.parse_base_json(action, tools)
|
||||
|
||||
|
||||
class MultiFormatParser:
|
||||
def __init__(self, detectors: List[BaseFormatDetector]):
|
||||
"""
|
||||
:param detectors: A series of available Detector instances passed in
|
||||
"""
|
||||
self.detectors = detectors
|
||||
|
||||
def parse_once(self, text: str, tools: List[Function]):
|
||||
"""
|
||||
One-time parsing: Loop through detectors until there are no new matches or text is exhausted
|
||||
Return: (final_text, all_calls)
|
||||
- final_text: The remaining text after parsing that was not consumed by any Detector (can be treated as normal text)
|
||||
- all_calls: All calls parsed by the Detectors
|
||||
"""
|
||||
final_calls = []
|
||||
final_normal_text = text
|
||||
for detector in self.detectors:
|
||||
tool_call_list = detector.detect_and_parse(text, tools)
|
||||
if len(tool_call_list) > 0: # parsed successfully
|
||||
final_calls = tool_call_list
|
||||
break
|
||||
|
||||
# leftover_text is the normal text not consumed by any Detector
|
||||
return final_normal_text, final_calls
|
||||
|
||||
def parse_streaming_increment(self, new_text: str, tools: List[Function]):
|
||||
"""
|
||||
Streaming incremental parsing: Feed new_text to each detector's parse_streaming_increment
|
||||
and merge their produced normal_text/calls to return.
|
||||
(The logic here can be "priority-based" or "parallel parsing" based on your needs)
|
||||
"""
|
||||
final_normal_text = ""
|
||||
final_calls = []
|
||||
|
||||
for detector in self.detectors:
|
||||
sp_result = detector.parse_streaming_increment(new_text, tools)
|
||||
# Merge normal_text and calls
|
||||
# If one sp_result contains result call, this should be a successful parse
|
||||
# If one sp_result only contains normal_text, this can either be a successful
|
||||
# parse or it is not using the desired parsing tool.
|
||||
if sp_result.normal_text:
|
||||
final_normal_text = sp_result.normal_text
|
||||
if sp_result.calls:
|
||||
final_calls.extend(sp_result.calls)
|
||||
final_normal_text = sp_result.normal_text
|
||||
break
|
||||
|
||||
return final_normal_text, final_calls
|
||||
|
||||
|
||||
class FunctionCallParser:
|
||||
"""
|
||||
In streaming scenarios, each time new_text is received, it calls multi_format_parser.parse_streaming_increment
|
||||
and returns the resulting normal_text and calls to the upper layer (or SSE).
|
||||
"""
|
||||
|
||||
ToolCallParserEnum: Dict[str, BaseFormatDetector] = {
|
||||
"llama3": Llama32Detector,
|
||||
"qwen25": Qwen25Detector,
|
||||
"mistral": MistralDetector,
|
||||
}
|
||||
|
||||
def __init__(self, tools: List[Function], tool_call_parser: str = None):
|
||||
detectors = []
|
||||
if tool_call_parser:
|
||||
detector_class = self.ToolCallParserEnum.get(tool_call_parser)
|
||||
if detector_class:
|
||||
detectors.append(detector_class())
|
||||
else:
|
||||
raise ValueError(f"Unsupported tool_call_parser: {tool_call_parser}")
|
||||
else:
|
||||
raise ValueError("Tool Call Parser Not Given!")
|
||||
|
||||
self.multi_format_parser = MultiFormatParser(detectors)
|
||||
self.tools = tools
|
||||
|
||||
def parse_non_stream(self, full_text: str):
|
||||
"""
|
||||
Non-streaming call: one-time parsing
|
||||
"""
|
||||
full_normal_text, calls = self.multi_format_parser.parse_once(
|
||||
full_text, self.tools
|
||||
)
|
||||
return full_normal_text, calls
|
||||
|
||||
def parse_stream_chunk(self, chunk_text: str):
|
||||
"""
|
||||
Streaming call: incremental parsing
|
||||
"""
|
||||
normal_text, calls = self.multi_format_parser.parse_streaming_increment(
|
||||
chunk_text, self.tools
|
||||
)
|
||||
return normal_text, calls
|
||||
@@ -17,7 +17,7 @@ processes (TokenizerManager, DetokenizerManager, Controller).
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
@@ -540,3 +540,27 @@ class CloseSessionReqInput:
|
||||
class OpenSessionReqOutput:
|
||||
session_id: Optional[str]
|
||||
success: bool
|
||||
|
||||
|
||||
@dataclass
|
||||
class Function:
|
||||
description: Optional[str] = None
|
||||
name: Optional[str] = None
|
||||
parameters: Optional[object] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class Tool:
|
||||
function: Function
|
||||
type: Optional[str] = "function"
|
||||
|
||||
|
||||
@dataclass
|
||||
class FunctionCallReqInput:
|
||||
text: str # The text to parse.
|
||||
tools: List[Tool] = field(
|
||||
default_factory=list
|
||||
) # A list of available function tools (name, parameters, etc.).
|
||||
tool_call_parser: Optional[str] = (
|
||||
None # Specify the parser type, e.g. 'llama3', 'qwen25', or 'mistral'. If not specified, tries all.
|
||||
)
|
||||
|
||||
@@ -20,7 +20,7 @@ import os
|
||||
import time
|
||||
import uuid
|
||||
from http import HTTPStatus
|
||||
from typing import Dict, List
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from fastapi import HTTPException, Request, UploadFile
|
||||
from fastapi.responses import ORJSONResponse, StreamingResponse
|
||||
@@ -40,6 +40,7 @@ from sglang.srt.conversation import (
|
||||
generate_chat_conv,
|
||||
register_conv_template,
|
||||
)
|
||||
from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
|
||||
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
||||
from sglang.srt.openai_api.protocol import (
|
||||
BatchRequest,
|
||||
@@ -71,7 +72,6 @@ from sglang.srt.openai_api.protocol import (
|
||||
TopLogprob,
|
||||
UsageInfo,
|
||||
)
|
||||
from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
|
||||
from sglang.utils import get_exception_traceback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -309,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
|
||||
ret,
|
||||
to_file=True,
|
||||
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
||||
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
||||
)
|
||||
else:
|
||||
responses = v1_generate_response(
|
||||
@@ -877,9 +878,6 @@ def v1_chat_generate_request(
|
||||
tools = None
|
||||
if request.tools and request.tool_choice != "none":
|
||||
request.skip_special_tokens = False
|
||||
if request.stream:
|
||||
logger.warning("Streaming is not supported with tools.")
|
||||
request.stream = False
|
||||
if not isinstance(request.tool_choice, str):
|
||||
tools = [
|
||||
item.function.model_dump()
|
||||
@@ -908,12 +906,26 @@ def v1_chat_generate_request(
|
||||
openai_compatible_messages = openai_compatible_messages[:-1]
|
||||
else:
|
||||
assistant_prefix = None
|
||||
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
openai_compatible_messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
try:
|
||||
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
openai_compatible_messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
tools=tools,
|
||||
)
|
||||
except:
|
||||
# This except branch will be triggered when the chosen model
|
||||
# has a different tools input format that is not compatiable
|
||||
# with openAI's apply_chat_template tool_call format, like Mistral.
|
||||
tools = [t if "function" in t else {"function": t} for t in tools]
|
||||
prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
|
||||
openai_compatible_messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
if assistant_prefix:
|
||||
prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
|
||||
stop = request.stop
|
||||
@@ -1005,7 +1017,9 @@ def v1_chat_generate_request(
|
||||
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
||||
|
||||
|
||||
def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
||||
def v1_chat_generate_response(
|
||||
request, ret, to_file=False, cache_report=False, tool_call_parser=None
|
||||
):
|
||||
choices = []
|
||||
|
||||
for idx, ret_item in enumerate(ret):
|
||||
@@ -1066,12 +1080,13 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
||||
if finish_reason == "stop":
|
||||
finish_reason = "tool_calls"
|
||||
try:
|
||||
text, call_info_list = parse_tool_response(text, tools) # noqa
|
||||
parser = FunctionCallParser(tools, tool_call_parser)
|
||||
full_normal_text, call_info_list = parser.parse_non_stream(text)
|
||||
tool_calls = [
|
||||
ToolCall(
|
||||
id=str(call_info[0]),
|
||||
id=str(call_info.tool_index),
|
||||
function=FunctionResponse(
|
||||
name=call_info[1], arguments=call_info[2]
|
||||
name=call_info.name, arguments=call_info.parameters
|
||||
),
|
||||
)
|
||||
for call_info in call_info_list
|
||||
@@ -1172,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
||||
|
||||
if adapted_request.stream:
|
||||
parser_dict = {}
|
||||
|
||||
async def generate_stream_resp():
|
||||
is_firsts = {}
|
||||
@@ -1184,6 +1200,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
adapted_request, raw_request
|
||||
):
|
||||
index = content.get("index", 0)
|
||||
text = content["text"]
|
||||
|
||||
is_first = is_firsts.get(index, True)
|
||||
stream_buffer = stream_buffers.get(index, "")
|
||||
@@ -1263,29 +1280,111 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
|
||||
text = content["text"]
|
||||
delta = text[len(stream_buffer) :]
|
||||
stream_buffer = stream_buffer + delta
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=delta),
|
||||
finish_reason=(finish_reason["type"] if finish_reason else ""),
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
new_stream_buffer = stream_buffer + delta
|
||||
|
||||
is_firsts[index] = is_first
|
||||
stream_buffers[index] = stream_buffer
|
||||
n_prev_tokens[index] = n_prev_token
|
||||
if request.tool_choice != "none" and request.tools:
|
||||
if index not in parser_dict:
|
||||
parser_dict[index] = FunctionCallParser(
|
||||
tools=request.tools,
|
||||
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
||||
)
|
||||
parser = parser_dict[index]
|
||||
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
# parse_increment => returns (normal_text, calls)
|
||||
normal_text, calls = parser.parse_stream_chunk(delta)
|
||||
|
||||
# 1) if there's normal_text, output it as normal content
|
||||
if normal_text:
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=normal_text),
|
||||
finish_reason=(
|
||||
finish_reason["type"] if finish_reason else ""
|
||||
),
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# 2) if we found calls, we output them as separate chunk(s)
|
||||
for call_item in calls:
|
||||
# transform call_item -> FunctionResponse + ToolCall
|
||||
|
||||
if (
|
||||
content["meta_info"]["finish_reason"]
|
||||
and content["meta_info"]["finish_reason"]["type"]
|
||||
== "stop"
|
||||
):
|
||||
latest_delta_len = 0
|
||||
if isinstance(call_item.parameters, str):
|
||||
latest_delta_len = len(call_item.parameters)
|
||||
|
||||
expected_call = json.dumps(
|
||||
parser.multi_format_parser.detectors[0]
|
||||
.prev_tool_call_arr[index]
|
||||
.get("arguments", {}),
|
||||
ensure_ascii=False,
|
||||
)
|
||||
actual_call = parser.multi_format_parser.detectors[
|
||||
0
|
||||
].streamed_args_for_tool[index]
|
||||
if latest_delta_len > 0:
|
||||
actual_call = actual_call[:-latest_delta_len]
|
||||
remaining_call = expected_call.replace(
|
||||
actual_call, "", 1
|
||||
)
|
||||
call_item.parameters = remaining_call
|
||||
|
||||
tool_call = ToolCall(
|
||||
id=str(call_item.tool_index),
|
||||
function=FunctionResponse(
|
||||
name=call_item.name,
|
||||
arguments=call_item.parameters,
|
||||
),
|
||||
)
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(
|
||||
role="assistant", tool_calls=[tool_call]
|
||||
),
|
||||
finish_reason="tool_call",
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
|
||||
else:
|
||||
# No tool calls => just treat this as normal text
|
||||
choice_data = ChatCompletionResponseStreamChoice(
|
||||
index=index,
|
||||
delta=DeltaMessage(content=delta),
|
||||
finish_reason=(
|
||||
finish_reason["type"] if finish_reason else ""
|
||||
),
|
||||
matched_stop=(
|
||||
finish_reason["matched"]
|
||||
if finish_reason and "matched" in finish_reason
|
||||
else None
|
||||
),
|
||||
logprobs=choice_logprobs,
|
||||
)
|
||||
chunk = ChatCompletionStreamResponse(
|
||||
id=content["meta_info"]["id"],
|
||||
choices=[choice_data],
|
||||
model=request.model,
|
||||
)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
stream_buffers[index] = new_stream_buffer
|
||||
is_firsts[index] = is_first
|
||||
if request.stream_options and request.stream_options.include_usage:
|
||||
total_prompt_tokens = sum(
|
||||
tokens
|
||||
@@ -1333,7 +1432,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
||||
ret = [ret]
|
||||
|
||||
response = v1_chat_generate_response(
|
||||
request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report
|
||||
request,
|
||||
ret,
|
||||
cache_report=tokenizer_manager.server_args.enable_cache_report,
|
||||
tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
@@ -262,7 +262,7 @@ class Function(BaseModel):
|
||||
"""Function descriptions."""
|
||||
|
||||
description: Optional[str] = Field(default=None, examples=[None])
|
||||
name: str
|
||||
name: Optional[str] = None
|
||||
parameters: Optional[object] = None
|
||||
|
||||
|
||||
@@ -276,7 +276,7 @@ class Tool(BaseModel):
|
||||
class ToolChoiceFuncName(BaseModel):
|
||||
"""The name of tool choice function."""
|
||||
|
||||
name: str
|
||||
name: Optional[str] = None
|
||||
|
||||
|
||||
class ToolChoice(BaseModel):
|
||||
@@ -329,8 +329,8 @@ class ChatCompletionRequest(BaseModel):
|
||||
class FunctionResponse(BaseModel):
|
||||
"""Function response."""
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
name: Optional[str] = None
|
||||
arguments: Optional[str] = None
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
@@ -367,6 +367,7 @@ class ChatCompletionResponse(BaseModel):
|
||||
class DeltaMessage(BaseModel):
|
||||
role: Optional[str] = None
|
||||
content: Optional[str] = None
|
||||
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
||||
|
||||
|
||||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||||
|
||||
@@ -161,6 +161,7 @@ class ServerArgs:
|
||||
|
||||
# Custom logit processor
|
||||
enable_custom_logit_processor: bool = False
|
||||
tool_call_parser: str = None
|
||||
|
||||
def __post_init__(self):
|
||||
# Set missing default values
|
||||
@@ -877,6 +878,14 @@ class ServerArgs:
|
||||
action="store_true",
|
||||
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
||||
)
|
||||
# Function Calling
|
||||
parser.add_argument(
|
||||
"--tool-call-parser",
|
||||
type=str,
|
||||
choices=["qwen25", "mistral", "llama3"],
|
||||
default=ServerArgs.tool_call_parser,
|
||||
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_cli_args(cls, args: argparse.Namespace):
|
||||
|
||||
@@ -1243,68 +1243,6 @@ def dataclass_to_string_truncated(data, max_length=2048):
|
||||
return str(data)
|
||||
|
||||
|
||||
TOOLS_TAG_LIST = ["<|plugin|>", "<function=", "<tool_call>", "<|python_tag|>"]
|
||||
|
||||
|
||||
def parse_tool_response(text, tools, **kwargs):
|
||||
"""Parse model response containing tool information.
|
||||
|
||||
Args:
|
||||
text(str): model response in string format
|
||||
tools(List): tools from user request
|
||||
"""
|
||||
if "<|plugin|>" in text: # internlm2
|
||||
text, action = text.split("<|action_start|><|plugin|>")
|
||||
action = action.split("<|action_end|>".strip())[0]
|
||||
action = action[action.find("{") :]
|
||||
action = json.loads(action)
|
||||
name, parameters = action["name"], json.dumps(
|
||||
action.get("parameters", action.get("arguments", {})), ensure_ascii=False
|
||||
)
|
||||
call_info_list = [(name, parameters)]
|
||||
elif "<function=" in text: # llama3.1
|
||||
action, _ = text.split("</function>")
|
||||
parameters = action[action.find("{") :]
|
||||
name = action.split("<function=")[1].split(">{")[0]
|
||||
call_info_list = [(name, parameters)]
|
||||
elif "<tool_call>" in text and "</tool_call>" in text: # qwen2.5
|
||||
# get tool_call in text
|
||||
pattern = r"<tool_call>(.*?)</tool_call>"
|
||||
match_result_list = re.findall(pattern, text, re.DOTALL)
|
||||
call_info_list = []
|
||||
for match_result in match_result_list:
|
||||
action = json.loads(match_result)
|
||||
call_info_list.append(
|
||||
(action["name"], json.dumps(action["arguments"], ensure_ascii=False))
|
||||
)
|
||||
# get text outside of tags
|
||||
if not text.startswith("<tool_call>"):
|
||||
text = text[: text.find("<tool_call>")]
|
||||
elif not text.endswith("</tool_call>"):
|
||||
text = text[text.rfind("</tool_call>") + len("</tool_call>") :]
|
||||
else:
|
||||
text = ""
|
||||
elif "<|python_tag|>" in text: # llama3.2
|
||||
_, action = text.split("<|python_tag|>")
|
||||
action = json.loads(action)
|
||||
name, parameters = action["name"], json.dumps(
|
||||
action.get("parameters", action.get("arguments", {})), ensure_ascii=False
|
||||
)
|
||||
call_info_list = [(name, parameters)]
|
||||
else:
|
||||
raise RuntimeError(f"Unexpected model response: {text}")
|
||||
|
||||
call_info_list = [
|
||||
(
|
||||
[tool.function.name for tool in tools].index(call_info[0]),
|
||||
call_info[0],
|
||||
call_info[1],
|
||||
)
|
||||
for call_info in call_info_list
|
||||
]
|
||||
return text, call_info_list
|
||||
|
||||
|
||||
def permute_weight(x: torch.Tensor) -> torch.Tensor:
|
||||
b_ = x.shape[0]
|
||||
n_ = x.shape[1]
|
||||
|
||||
Reference in New Issue
Block a user