2025-01-26 09:57:51 -08:00
import json
2025-02-07 04:52:01 +00:00
import logging
2025-01-26 09:57:51 -08:00
import re
from abc import ABC , abstractmethod
from json import JSONDecodeError , JSONDecoder
from typing import Any , Dict , List , Optional , Tuple
import partial_json_parser
from partial_json_parser . core . options import Allow
from pydantic import BaseModel , Field
2025-02-07 04:52:01 +00:00
logger = logging . getLogger ( __name__ )
2025-01-26 09:57:51 -08:00
TOOLS_TAG_LIST = [
" <|plugin|> " ,
" <function= " ,
" <tool_call> " ,
" <|python_tag|> " ,
" [TOOL_CALLS] " ,
]
class Function ( BaseModel ) :
""" Function Tool Template. """
description : Optional [ str ] = Field ( default = None , examples = [ None ] )
name : Optional [ str ] = None
parameters : Optional [ object ] = None
class ToolCallItem ( BaseModel ) :
""" Simple encapsulation of the parsed ToolCall result for easier usage in streaming contexts. """
tool_index : int
name : Optional [ str ] = None
parameters : str # JSON string
def _find_common_prefix ( s1 : str , s2 : str ) - > str :
prefix = " "
min_length = min ( len ( s1 ) , len ( s2 ) )
for i in range ( 0 , min_length ) :
if s1 [ i ] == s2 [ i ] :
prefix + = s1 [ i ]
else :
break
return prefix
def _partial_json_loads ( input_str : str , flags : Allow ) - > Tuple [ Any , int ] :
try :
return ( partial_json_parser . loads ( input_str , flags ) , len ( input_str ) )
except JSONDecodeError as e :
if " Extra data " in e . msg :
dec = JSONDecoder ( )
return dec . raw_decode ( input_str )
raise
def _is_complete_json ( input_str : str ) - > bool :
try :
json . loads ( input_str )
return True
except JSONDecodeError :
return False
class StreamingParseResult :
""" Result of streaming incremental parsing. """
def __init__ (
self , normal_text : str = " " , calls : Optional [ List [ ToolCallItem ] ] = None
) :
self . normal_text = normal_text
self . calls = calls or [ ]
class BaseFormatDetector :
""" Base class providing two sets of interfaces: one-time and streaming incremental. """
def __init__ ( self ) :
# initialize properties used for state when parsing tool calls in
self . _buffer = " "
# streaming mode
self . prev_tool_call_arr : List [ Dict ] = [ ]
self . current_tool_id : int = - 1
self . current_tool_name_sent : bool = False
self . streamed_args_for_tool : List [ str ] = (
[ ]
) # map what has been streamed for each tool so far to a list
self . bot_token = " "
self . eot_token = " "
2025-02-07 04:52:01 +00:00
def parse_base_json ( self , action : Any , tools : List [ Function ] ) - > List [ ToolCallItem ] :
tool_indices = {
tool . function . name : i for i , tool in enumerate ( tools ) if tool . function . name
}
if not isinstance ( action , list ) :
name = action . get ( " name " )
if not name or name not in tool_indices :
logger . warning ( f " Model attempted to call undefined function: { name } " )
return [ ]
return [
ToolCallItem (
tool_index = tool_indices [ name ] ,
name = name ,
parameters = json . dumps (
action . get ( " parameters " ) or action . get ( " arguments " , { } ) ,
ensure_ascii = False ,
) ,
)
]
results = [ ]
for act in action :
name = act . get ( " name " )
if name and name in tool_indices :
results . append (
ToolCallItem (
tool_index = tool_indices [ name ] ,
name = name ,
parameters = json . dumps (
act . get ( " parameters " ) or act . get ( " arguments " , { } ) ,
ensure_ascii = False ,
) ,
)
)
return results
2025-01-26 09:57:51 -08:00
def detect_and_parse ( self , text : str , tools : List [ Function ] ) - > List [ ToolCallItem ] :
"""
Parses the text in one go . Returns success = True if the format matches , otherwise False .
Note that leftover_text here represents " content that this parser will not consume further " .
"""
action = json . loads ( text )
return self . parse_base_json ( action , tools )
def parse_streaming_increment (
self , new_text : str , tools : List [ Function ]
) - > StreamingParseResult :
"""
2025-02-07 04:52:01 +00:00
Streaming incremental parsing with tool validation .
2025-01-26 09:57:51 -08:00
"""
# Append new text to buffer
self . _buffer + = new_text
current_text = self . _buffer
if not ( self . bot_token in current_text or current_text . startswith ( " { " ) ) :
self . _buffer = " "
if self . eot_token in new_text :
new_text = new_text . replace ( self . eot_token , " " )
return StreamingParseResult ( normal_text = new_text )
2025-02-07 04:52:01 +00:00
# Build tool indices if not already built
if not hasattr ( self , " _tool_indices " ) :
self . _tool_indices = {
tool . function . name : i
for i , tool in enumerate ( tools )
if tool . function and tool . function . name
}
2025-01-26 09:57:51 -08:00
flags = Allow . ALL if self . current_tool_name_sent else Allow . ALL & ~ Allow . STR
try :
tool_call_arr = [ ]
is_complete = [ ]
try :
start_idx = (
len ( self . bot_token )
if current_text . startswith ( self . bot_token )
else 0
)
while start_idx < len ( current_text ) :
( obj , end_idx ) = _partial_json_loads (
current_text [ start_idx : ] , flags
)
is_complete . append (
_is_complete_json ( current_text [ start_idx : start_idx + end_idx ] )
)
start_idx + = end_idx + len ( " ; " )
2025-02-07 04:52:01 +00:00
# Validate tool name if present
if " name " in obj and obj [ " name " ] not in self . _tool_indices :
# Invalid tool name - reset state
self . _buffer = " "
self . current_tool_id = - 1
self . current_tool_name_sent = False
if self . streamed_args_for_tool :
self . streamed_args_for_tool . pop ( )
return StreamingParseResult ( )
# Handle parameters/arguments consistency
2025-01-26 09:57:51 -08:00
if " parameters " in obj :
assert (
" arguments " not in obj
) , " model generated both parameters and arguments "
obj [ " arguments " ] = obj [ " parameters " ]
tool_call_arr . append ( obj )
except partial_json_parser . core . exceptions . MalformedJSON :
return StreamingParseResult ( )
if len ( tool_call_arr ) == 0 :
return StreamingParseResult ( )
2025-02-07 04:52:01 +00:00
current_tool_call : Dict = (
tool_call_arr [ self . current_tool_id ] if len ( tool_call_arr ) > 0 else { }
)
2025-01-26 09:57:51 -08:00
2025-02-07 04:52:01 +00:00
# Handle new tool in array
if len ( tool_call_arr ) > 0 and len ( tool_call_arr ) > self . current_tool_id + 1 :
2025-01-26 09:57:51 -08:00
if self . current_tool_id > = 0 :
cur_arguments = current_tool_call . get ( " arguments " )
if cur_arguments :
cur_args_json = json . dumps ( cur_arguments )
sent = len ( self . streamed_args_for_tool [ self . current_tool_id ] )
argument_diff = cur_args_json [ sent : ]
res = StreamingParseResult (
calls = [
ToolCallItem (
tool_index = self . current_tool_id ,
name = " " ,
parameters = argument_diff ,
)
] ,
)
self . streamed_args_for_tool [
self . current_tool_id
] + = argument_diff
else :
res = StreamingParseResult ( )
else :
res = StreamingParseResult ( )
2025-02-07 04:52:01 +00:00
2025-01-26 09:57:51 -08:00
self . current_tool_id = len ( tool_call_arr ) - 1
self . current_tool_name_sent = False
self . streamed_args_for_tool . append ( " " )
return res
2025-02-07 04:52:01 +00:00
# Handle tool name
2025-01-26 09:57:51 -08:00
elif not self . current_tool_name_sent :
function_name = current_tool_call . get ( " name " )
2025-02-07 04:52:01 +00:00
if function_name and function_name in self . _tool_indices :
2025-01-26 09:57:51 -08:00
res = StreamingParseResult (
calls = [
ToolCallItem (
2025-02-07 04:52:01 +00:00
tool_index = self . _tool_indices [ function_name ] ,
2025-01-26 09:57:51 -08:00
name = function_name ,
parameters = " " ,
)
] ,
)
self . current_tool_name_sent = True
else :
res = StreamingParseResult ( )
2025-02-07 04:52:01 +00:00
# Handle streaming arguments
2025-01-26 09:57:51 -08:00
else :
cur_arguments = current_tool_call . get ( " arguments " )
res = StreamingParseResult ( )
if cur_arguments :
sent = len ( self . streamed_args_for_tool [ self . current_tool_id ] )
cur_args_json = json . dumps ( cur_arguments )
prev_arguments = self . prev_tool_call_arr [ self . current_tool_id ] . get (
" arguments "
)
argument_diff = None
if is_complete [ self . current_tool_id ] :
argument_diff = cur_args_json [ sent : ]
self . _buffer = " "
self . prev_tool_call_arr [ self . current_tool_id ] . clear ( )
2025-02-07 04:52:01 +00:00
self . current_tool_name_sent = False
2025-01-26 09:57:51 -08:00
self . streamed_args_for_tool [ self . current_tool_id ] = " "
elif prev_arguments :
prev_args_json = json . dumps ( prev_arguments )
if cur_args_json != prev_args_json :
prefix = _find_common_prefix ( prev_args_json , cur_args_json )
argument_diff = prefix [ sent : ]
if argument_diff is not None :
res = StreamingParseResult (
calls = [
ToolCallItem (
tool_index = self . current_tool_id ,
name = " " ,
parameters = argument_diff ,
)
] ,
)
if not is_complete [ self . current_tool_id ] :
self . streamed_args_for_tool [
self . current_tool_id
] + = argument_diff
self . prev_tool_call_arr = tool_call_arr
return res
except Exception as e :
2025-02-07 04:52:01 +00:00
logger . error ( f " Error in parse_streaming_increment: { e } " )
2025-01-26 09:57:51 -08:00
return StreamingParseResult ( )
class Qwen25Detector ( BaseFormatDetector ) :
"""
Detector for Qwen 2.5 models .
Assumes function call format :
< tool_call > { " name " : " xxx " , " arguments " : { . . . } } < / tool_call >
"""
def __init__ ( self ) :
"""
Initializes the detector with necessary state variables .
"""
super ( ) . __init__ ( )
self . bot_token = " <tool_call> "
self . eot_token = " </tool_call> "
def detect_and_parse ( self , text : str , tools : List [ Function ] ) - > List [ ToolCallItem ] :
"""
One - time parsing : Detects and parses tool calls in the provided text .
: param text : The complete text to parse .
: param tools : List of available tools .
: return : ParseResult indicating success or failure , consumed text , leftover text , and parsed calls .
"""
if " <tool_call> " not in text :
return [ ]
pattern = r " <tool_call>(.*?)</tool_call> "
match_result_list = re . findall ( pattern , text , re . DOTALL )
calls = [ ]
for match_result in match_result_list :
match_result = json . loads ( match_result )
calls . extend ( self . parse_base_json ( match_result , tools ) )
return calls
class MistralDetector ( BaseFormatDetector ) :
"""
Detector for Mistral models .
Assumes function call format :
< | action_start | > < | plugin | > { " name " : " xxx " , " arguments " : { . . . } } < | action_end | >
"""
def __init__ ( self ) :
"""
Initializes the detector with necessary state variables .
"""
super ( ) . __init__ ( )
self . bot_token = " [TOOL_CALLS] [ "
self . tool_call_regex = re . compile ( r " \ [ { .*} \ ] " , re . DOTALL )
def _clean_text ( self , text : str ) - > str :
"""
clean text to only leave ' ' [ TOOL_CALLS ] [ { " name " : xxx , " arguments " : { xxx } } ] '
for example ,
text = ' [TOOL_CALLS] [ { " name " : " get_current_weather " , " arguments " : { " location " : " Boston, MA " , " unit " : " fahrenheit " }}] \n \n Today \' s weather in Boston is : { function call result} (in Fahrenheit) \n \n If you prefer Celsius, please let me know. '
return ' [TOOL_CALLS] [ { " name " : " get_current_weather " , " arguments " : { " location " : " Boston, MA " , " unit " : " fahrenheit " }}] '
The key pattern is [ TOOL_CALLS ] [ . . . ]
"""
find_results = re . findall ( r " \ [TOOL_CALLS \ ] \ [.*? \ ] " , text , re . DOTALL )
if len ( find_results ) > 0 :
return find_results [ 0 ]
else :
return " "
def detect_and_parse ( self , text : str , tools : List [ Function ] ) - > List [ ToolCallItem ] :
"""
One - time parsing : Detects and parses tool calls in the provided text .
: param text : The complete text to parse .
: param tools : List of available tools .
: return : ParseResult indicating success or failure , consumed text , leftover text , and parsed calls .
"""
text = self . _clean_text ( text )
tool_content = text . replace ( " [TOOL_CALLS] " , " " ) . strip ( )
raw_tool_calls = self . tool_call_regex . findall ( tool_content )
calls = [ ]
if len ( raw_tool_calls ) > 0 :
raw_tool_call = raw_tool_calls [ 0 ]
function_call_arr = json . loads ( raw_tool_call )
for match_result in function_call_arr :
calls . extend ( self . parse_base_json ( match_result , tools ) )
return calls
class Llama32Detector ( BaseFormatDetector ) :
"""
Detector for Llama 3.2 models .
Assumes function call format :
< | python_tag | > { " name " : " xxx " , " arguments " : { . . . } }
"""
def __init__ ( self ) :
super ( ) . __init__ ( )
self . bot_token = " <|python_tag|> "
def detect_and_parse ( self , text : str , tools : List [ Function ] ) - > List [ ToolCallItem ] :
2025-02-07 04:52:01 +00:00
""" Parse function calls from text, handling multiple JSON objects. """
2025-01-26 09:57:51 -08:00
if " <|python_tag|> " not in text :
return [ ]
2025-02-07 04:52:01 +00:00
_ , action_text = text . split ( " <|python_tag|> " )
# Split by semicolon and process each part
json_parts = [ part . strip ( ) for part in action_text . split ( " ; " ) if part . strip ( ) ]
all_actions = [ ]
for part in json_parts :
try :
# Parse each individual JSON object
action = json . loads ( part )
all_actions . append ( action )
except json . JSONDecodeError as e :
logger . warning ( f " Failed to parse JSON part: { part } " )
logger . warning ( f " JSON parse error: { str ( e ) } " )
continue
# Only process if we found valid JSON objects
if all_actions :
return self . parse_base_json ( all_actions , tools )
return [ ]
2025-01-26 09:57:51 -08:00
class MultiFormatParser :
def __init__ ( self , detectors : List [ BaseFormatDetector ] ) :
"""
: param detectors : A series of available Detector instances passed in
"""
self . detectors = detectors
def parse_once ( self , text : str , tools : List [ Function ] ) :
"""
One - time parsing : Loop through detectors until there are no new matches or text is exhausted
Return : ( final_text , all_calls )
- final_text : The remaining text after parsing that was not consumed by any Detector ( can be treated as normal text )
- all_calls : All calls parsed by the Detectors
"""
final_calls = [ ]
final_normal_text = text
for detector in self . detectors :
tool_call_list = detector . detect_and_parse ( text , tools )
if len ( tool_call_list ) > 0 : # parsed successfully
final_calls = tool_call_list
break
# leftover_text is the normal text not consumed by any Detector
return final_normal_text , final_calls
def parse_streaming_increment ( self , new_text : str , tools : List [ Function ] ) :
"""
Streaming incremental parsing : Feed new_text to each detector ' s parse_streaming_increment
and merge their produced normal_text / calls to return .
( The logic here can be " priority-based " or " parallel parsing " based on your needs )
"""
final_normal_text = " "
final_calls = [ ]
for detector in self . detectors :
sp_result = detector . parse_streaming_increment ( new_text , tools )
# Merge normal_text and calls
# If one sp_result contains result call, this should be a successful parse
# If one sp_result only contains normal_text, this can either be a successful
# parse or it is not using the desired parsing tool.
if sp_result . normal_text :
final_normal_text = sp_result . normal_text
if sp_result . calls :
final_calls . extend ( sp_result . calls )
final_normal_text = sp_result . normal_text
break
return final_normal_text , final_calls
class FunctionCallParser :
"""
In streaming scenarios , each time new_text is received , it calls multi_format_parser . parse_streaming_increment
and returns the resulting normal_text and calls to the upper layer ( or SSE ) .
"""
ToolCallParserEnum : Dict [ str , BaseFormatDetector ] = {
" llama3 " : Llama32Detector ,
" qwen25 " : Qwen25Detector ,
" mistral " : MistralDetector ,
}
def __init__ ( self , tools : List [ Function ] , tool_call_parser : str = None ) :
detectors = [ ]
if tool_call_parser :
detector_class = self . ToolCallParserEnum . get ( tool_call_parser )
if detector_class :
detectors . append ( detector_class ( ) )
else :
raise ValueError ( f " Unsupported tool_call_parser: { tool_call_parser } " )
else :
raise ValueError ( " Tool Call Parser Not Given! " )
self . multi_format_parser = MultiFormatParser ( detectors )
self . tools = tools
def parse_non_stream ( self , full_text : str ) :
"""
Non - streaming call : one - time parsing
"""
full_normal_text , calls = self . multi_format_parser . parse_once (
full_text , self . tools
)
return full_normal_text , calls
def parse_stream_chunk ( self , chunk_text : str ) :
"""
Streaming call : incremental parsing
"""
normal_text , calls = self . multi_format_parser . parse_streaming_increment (
chunk_text , self . tools
)
return normal_text , calls