Files
sglang/python/sglang/srt/function_call/llama32_detector.py

75 lines
2.6 KiB
Python

import json
import logging
from typing import List
from sglang.srt.function_call.base_format_detector import BaseFormatDetector
from sglang.srt.function_call.core_types import (
StreamingParseResult,
StructureInfo,
_GetInfoFunc,
)
from sglang.srt.function_call.ebnf_composer import EBNFComposer
from sglang.srt.openai_api.protocol import Tool
logger = logging.getLogger(__name__)
class Llama32Detector(BaseFormatDetector):
"""
Detector for Llama 3.2 models.
Assumes function call format:
<|python_tag|>{"name":"xxx", "arguments":{...}}
"""
def __init__(self):
super().__init__()
self.bot_token = "<|python_tag|>"
def has_tool_call(self, text: str) -> bool:
"""Check if the text contains a Llama 3.2 format tool call."""
# depending on the prompt format the Llama model may or may not
# prefix the output with the <|python_tag|> token
return "<|python_tag|>" in text or text.startswith("{")
def detect_and_parse(self, text: str, tools: List[Tool]) -> StreamingParseResult:
"""Parse function calls from text, handling multiple JSON objects."""
if "<|python_tag|>" not in text and not text.startswith("{"):
return StreamingParseResult(normal_text=text, calls=[])
if "<|python_tag|>" in text:
normal_text, action_text = text.split("<|python_tag|>")
else:
normal_text, action_text = "", text
# Split by semicolon and process each part
json_parts = [part.strip() for part in action_text.split(";") if part.strip()]
all_actions = []
for part in json_parts:
try:
# Parse each individual JSON object
action = json.loads(part)
all_actions.append(action)
except json.JSONDecodeError as e:
logger.warning(f"Failed to parse JSON part: {part}")
logger.warning(f"JSON parse error: {str(e)}")
continue
calls = []
# Only process if we found valid JSON objects
if all_actions:
calls = self.parse_base_json(all_actions, tools)
return StreamingParseResult(normal_text=normal_text, calls=calls)
def structure_info(self) -> _GetInfoFunc:
return lambda name: StructureInfo(
begin='<|python_tag|>{"name":"' + name + '", "arguments":',
end="}",
trigger="<|python_tag|>",
)
def build_ebnf(self, tools: List[Tool]):
return EBNFComposer.build_ebnf(
tools,
function_format="json",
tool_call_separator=",",
)