199 lines
8.0 KiB
Python
199 lines
8.0 KiB
Python
|
|
import json
|
||
|
|
import re
|
||
|
|
from typing import Dict, List, Sequence, Union
|
||
|
|
import partial_json_parser
|
||
|
|
from partial_json_parser.core.options import Allow
|
||
|
|
|
||
|
|
from vllm.entrypoints.openai.protocol import (
|
||
|
|
ChatCompletionRequest, DeltaMessage, DeltaToolCall,
|
||
|
|
DeltaFunctionCall, ExtractedToolCallInformation, ToolCall, FunctionCall
|
||
|
|
)
|
||
|
|
from vllm.entrypoints.openai.tool_parsers.abstract_tool_parser import ToolParser, ToolParserManager
|
||
|
|
from vllm.utils import random_uuid
|
||
|
|
from vllm.logger import init_logger
|
||
|
|
from transformers import PreTrainedTokenizerBase
|
||
|
|
from vllm.entrypoints.openai.tool_parsers.utils import (find_common_prefix,
|
||
|
|
is_complete_json,
|
||
|
|
partial_json_loads)
|
||
|
|
|
||
|
|
logger = init_logger(__name__)
|
||
|
|
|
||
|
|
@ToolParserManager.register_module("xlam")
|
||
|
|
class xLAMToolParser(ToolParser):
|
||
|
|
def __init__(self, tokenizer: PreTrainedTokenizerBase):
|
||
|
|
super().__init__(tokenizer)
|
||
|
|
# State for streaming mode
|
||
|
|
self.prev_tool_calls: List[Dict] = []
|
||
|
|
self.current_tools_sent: List[bool] = []
|
||
|
|
self.streamed_args: List[str] = []
|
||
|
|
# Remove regex since we're parsing direct JSON
|
||
|
|
|
||
|
|
def extract_tool_calls(
|
||
|
|
self,
|
||
|
|
model_output: str,
|
||
|
|
request: ChatCompletionRequest
|
||
|
|
) -> ExtractedToolCallInformation:
|
||
|
|
try:
|
||
|
|
# Modified: Direct JSON parsing without looking for ```
|
||
|
|
if not model_output.strip().startswith('['):
|
||
|
|
return ExtractedToolCallInformation(
|
||
|
|
tools_called=False,
|
||
|
|
tool_calls=[],
|
||
|
|
content=model_output
|
||
|
|
)
|
||
|
|
|
||
|
|
tool_calls_data = json.loads(model_output)
|
||
|
|
tool_calls: List[ToolCall] = []
|
||
|
|
|
||
|
|
for idx, call in enumerate(tool_calls_data):
|
||
|
|
tool_call = ToolCall(
|
||
|
|
id=f"call_{idx}_{random_uuid()}",
|
||
|
|
type="function",
|
||
|
|
function=FunctionCall(
|
||
|
|
name=call["name"],
|
||
|
|
arguments=json.dumps(call["arguments"])
|
||
|
|
)
|
||
|
|
)
|
||
|
|
tool_calls.append(tool_call)
|
||
|
|
|
||
|
|
return ExtractedToolCallInformation(
|
||
|
|
tools_called=True,
|
||
|
|
tool_calls=tool_calls,
|
||
|
|
content=None
|
||
|
|
)
|
||
|
|
|
||
|
|
except Exception:
|
||
|
|
logger.exception("Error extracting tool calls")
|
||
|
|
return ExtractedToolCallInformation(
|
||
|
|
tools_called=False,
|
||
|
|
tool_calls=[],
|
||
|
|
content=model_output
|
||
|
|
)
|
||
|
|
|
||
|
|
def extract_tool_calls_streaming(
|
||
|
|
self,
|
||
|
|
previous_text: str,
|
||
|
|
current_text: str,
|
||
|
|
delta_text: str,
|
||
|
|
previous_token_ids: Sequence[int],
|
||
|
|
current_token_ids: Sequence[int],
|
||
|
|
delta_token_ids: Sequence[int],
|
||
|
|
request: ChatCompletionRequest,
|
||
|
|
) -> Union[DeltaMessage, None]:
|
||
|
|
if not current_text.strip().startswith('['):
|
||
|
|
return DeltaMessage(content=delta_text)
|
||
|
|
|
||
|
|
flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
|
||
|
|
|
||
|
|
try:
|
||
|
|
tool_call_arr = []
|
||
|
|
is_complete = []
|
||
|
|
try:
|
||
|
|
# Parse the JSON array
|
||
|
|
start_idx = 0
|
||
|
|
while start_idx < len(current_text):
|
||
|
|
obj, end_idx = partial_json_loads(current_text[start_idx:], flags)
|
||
|
|
is_complete.append(
|
||
|
|
is_complete_json(current_text[start_idx:start_idx + end_idx])
|
||
|
|
)
|
||
|
|
start_idx += end_idx
|
||
|
|
tool_call_arr.append(obj)
|
||
|
|
except partial_json_parser.core.exceptions.MalformedJSON:
|
||
|
|
logger.debug('not enough tokens to parse into JSON yet')
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Get current tool call based on state
|
||
|
|
current_tool_call: Dict = tool_call_arr[self.current_tool_id] \
|
||
|
|
if len(tool_call_arr) > 0 else {}
|
||
|
|
|
||
|
|
# Case 1: No tools parsed yet
|
||
|
|
if len(tool_call_arr) == 0:
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Case 2: Starting a new tool in array
|
||
|
|
elif (len(tool_call_arr) > 0
|
||
|
|
and len(tool_call_arr) > self.current_tool_id + 1):
|
||
|
|
|
||
|
|
# Handle any remaining arguments from previous tool
|
||
|
|
if self.current_tool_id >= 0:
|
||
|
|
cur_arguments = current_tool_call.get("arguments")
|
||
|
|
if cur_arguments:
|
||
|
|
cur_args_json = json.dumps(cur_arguments)
|
||
|
|
sent = len(self.streamed_args[self.current_tool_id])
|
||
|
|
argument_diff = cur_args_json[sent:]
|
||
|
|
|
||
|
|
if argument_diff:
|
||
|
|
delta = DeltaMessage(tool_calls=[
|
||
|
|
DeltaToolCall(
|
||
|
|
index=self.current_tool_id,
|
||
|
|
function=DeltaFunctionCall(
|
||
|
|
arguments=argument_diff
|
||
|
|
).model_dump(exclude_none=True)
|
||
|
|
)
|
||
|
|
])
|
||
|
|
self.streamed_args[self.current_tool_id] += argument_diff
|
||
|
|
return delta
|
||
|
|
|
||
|
|
# Setup new tool
|
||
|
|
self.current_tool_id = len(tool_call_arr) - 1
|
||
|
|
self.current_tools_sent.append(False)
|
||
|
|
self.streamed_args.append("")
|
||
|
|
logger.debug("starting new tool %d", self.current_tool_id)
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Case 3: Send tool name if not sent yet
|
||
|
|
elif not self.current_tools_sent[self.current_tool_id]:
|
||
|
|
function_name = current_tool_call.get("name")
|
||
|
|
if function_name:
|
||
|
|
delta = DeltaMessage(tool_calls=[
|
||
|
|
DeltaToolCall(
|
||
|
|
index=self.current_tool_id,
|
||
|
|
type="function",
|
||
|
|
id=f"call_{self.current_tool_id}_{random_uuid()}",
|
||
|
|
function=DeltaFunctionCall(
|
||
|
|
name=function_name
|
||
|
|
).model_dump(exclude_none=True)
|
||
|
|
)
|
||
|
|
])
|
||
|
|
self.current_tools_sent[self.current_tool_id] = True
|
||
|
|
return delta
|
||
|
|
return None
|
||
|
|
|
||
|
|
# Case 4: Stream arguments
|
||
|
|
else:
|
||
|
|
cur_arguments = current_tool_call.get("arguments")
|
||
|
|
if cur_arguments:
|
||
|
|
sent = len(self.streamed_args[self.current_tool_id])
|
||
|
|
cur_args_json = json.dumps(cur_arguments)
|
||
|
|
prev_arguments = self.prev_tool_calls[self.current_tool_id].get("arguments")
|
||
|
|
|
||
|
|
argument_diff = None
|
||
|
|
if is_complete[self.current_tool_id]:
|
||
|
|
argument_diff = cur_args_json[sent:]
|
||
|
|
elif prev_arguments:
|
||
|
|
prev_args_json = json.dumps(prev_arguments)
|
||
|
|
if cur_args_json != prev_args_json:
|
||
|
|
prefix = find_common_prefix(prev_args_json, cur_args_json)
|
||
|
|
argument_diff = prefix[sent:]
|
||
|
|
|
||
|
|
if argument_diff is not None:
|
||
|
|
delta = DeltaMessage(tool_calls=[
|
||
|
|
DeltaToolCall(
|
||
|
|
index=self.current_tool_id,
|
||
|
|
function=DeltaFunctionCall(
|
||
|
|
arguments=argument_diff
|
||
|
|
).model_dump(exclude_none=True)
|
||
|
|
)
|
||
|
|
])
|
||
|
|
self.streamed_args[self.current_tool_id] += argument_diff
|
||
|
|
return delta
|
||
|
|
|
||
|
|
self.prev_tool_calls = tool_call_arr
|
||
|
|
return None
|
||
|
|
|
||
|
|
except Exception:
|
||
|
|
logger.exception("Error in streaming tool calls")
|
||
|
|
logger.debug("Skipping chunk due to streaming error")
|
||
|
|
return None
|
||
|
|
|