97 lines
3.6 KiB
Python
97 lines
3.6 KiB
Python
import json
|
|
import logging
|
|
from typing import List
|
|
|
|
from sglang.srt.entrypoints.openai.protocol import Tool
|
|
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
|
|
from sglang.srt.function_call.core_types import (
|
|
StreamingParseResult,
|
|
StructureInfo,
|
|
_GetInfoFunc,
|
|
)
|
|
from sglang.srt.function_call.ebnf_composer import EBNFComposer
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class Llama32Detector(BaseFormatDetector):
|
|
"""
|
|
Detector for Llama 3.2 models with json tool call format.
|
|
|
|
Format Structure:
|
|
```
|
|
<python_tag>{"name":"xxx", "arguments":{...}}
|
|
```
|
|
"""
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.bot_token = "<|python_tag|>"
|
|
# NOTE: technically Llama3.2 doesn't support well with parallel tool calls
|
|
# They need specific prompt engineering to support parallel tool calls
|
|
# Here we use ';' as the separator, which might have compatibility issues
|
|
# if users define to use a different separator in their prompt
|
|
self.tool_call_separator = ";"
|
|
|
|
def has_tool_call(self, text: str) -> bool:
|
|
"""Check if the text contains a Llama 3.2 format tool call."""
|
|
# depending on the prompt format the Llama model may or may not
|
|
# prefix the output with the <|python_tag|> token
|
|
return "<|python_tag|>" in text or text.startswith("{")
|
|
|
|
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
|
|
"""Parse function calls from text, handling multiple JSON objects."""
|
|
if "<|python_tag|>" not in text and not text.startswith("{"):
|
|
return StreamingParseResult(normal_text=text, calls=[])
|
|
|
|
if "<|python_tag|>" in text:
|
|
normal_text, action_text = text.split("<|python_tag|>", maxsplit=1)
|
|
else:
|
|
normal_text, action_text = "", text
|
|
|
|
decoder = json.JSONDecoder()
|
|
idx = 0
|
|
safe_idx = idx # the index of the last valid JSON object
|
|
all_actions = []
|
|
action_text_len = len(action_text)
|
|
while idx < action_text_len:
|
|
try:
|
|
obj, end = decoder.raw_decode(action_text[idx:])
|
|
all_actions.append(obj)
|
|
idx += end + len(self.tool_call_separator)
|
|
safe_idx = idx
|
|
except json.JSONDecodeError as e:
|
|
# Find where next `{"name"` appears and try again
|
|
logger.warning(
|
|
f"Failed to parse JSON part: {action_text[idx:]}, JSON parse error: {str(e)}"
|
|
)
|
|
next_obj_start = action_text.find('{"name":', idx + 1)
|
|
if next_obj_start == -1:
|
|
break
|
|
idx = next_obj_start
|
|
continue
|
|
|
|
# Only process if we found valid JSON objects
|
|
calls = self.parse_base_json(all_actions, tools) if all_actions else []
|
|
# Use safe_idx to avoid idx containing the last part of an invalid JSON object
|
|
trailing_text = (
|
|
action_text[safe_idx:].strip() if safe_idx < action_text_len else ""
|
|
)
|
|
return StreamingParseResult(
|
|
normal_text=normal_text + trailing_text, calls=calls
|
|
)
|
|
|
|
def structure_info(self) -> _GetInfoFunc:
|
|
return lambda name: StructureInfo(
|
|
begin='<|python_tag|>{"name":"' + name + '", "arguments":',
|
|
end="}",
|
|
trigger="<|python_tag|>",
|
|
)
|
|
|
|
def build_ebnf(self, tools: List[Tool]):
|
|
return EBNFComposer.build_ebnf(
|
|
tools,
|
|
function_format="json",
|
|
tool_call_separator=self.tool_call_separator,
|
|
)
|