327 lines
12 KiB
Rust
327 lines
12 KiB
Rust
use async_trait::async_trait;
|
|
use regex::Regex;
|
|
use serde_json::Value;
|
|
|
|
use crate::{
|
|
protocols::common::Tool,
|
|
tool_parser::{
|
|
errors::{ParserError, ParserResult},
|
|
parsers::helpers,
|
|
traits::ToolParser,
|
|
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
|
},
|
|
};
|
|
|
|
/// GLM-4 MoE format parser for tool calls
|
|
///
|
|
/// Handles the GLM-4 MoE specific format:
|
|
/// `<tool_call>{name}\n<arg_key>{key}</arg_key>\n<arg_value>{value}</arg_value>\n</tool_call>`
|
|
///
|
|
/// Features:
|
|
/// - XML-style tags for tool calls
|
|
/// - Key-value pairs for arguments
|
|
/// - Support for multiple sequential tool calls
|
|
pub struct Glm4MoeParser {
|
|
/// Regex for extracting complete tool calls
|
|
tool_call_extractor: Regex,
|
|
/// Regex for extracting function details
|
|
func_detail_extractor: Regex,
|
|
/// Regex for extracting argument key-value pairs
|
|
arg_extractor: Regex,
|
|
|
|
/// Buffer for accumulating incomplete patterns across chunks
|
|
buffer: String,
|
|
|
|
/// Stores complete tool call info (name and arguments) for each tool being parsed
|
|
prev_tool_call_arr: Vec<Value>,
|
|
|
|
/// Index of currently streaming tool call (-1 means no active tool)
|
|
current_tool_id: i32,
|
|
|
|
/// Tracks raw JSON string content streamed to client for each tool's arguments
|
|
streamed_args_for_tool: Vec<String>,
|
|
|
|
/// Token configuration
|
|
bot_token: &'static str,
|
|
eot_token: &'static str,
|
|
}
|
|
|
|
impl Glm4MoeParser {
|
|
/// Create a new GLM-4 MoE parser
|
|
pub fn new() -> Self {
|
|
// Use (?s) flag for DOTALL mode to handle newlines
|
|
let tool_call_pattern = r"(?s)<tool_call>.*?</tool_call>";
|
|
let tool_call_extractor = Regex::new(tool_call_pattern).expect("Valid regex pattern");
|
|
|
|
let func_detail_pattern = r"(?s)<tool_call>([^\n]*)\n(.*)</tool_call>";
|
|
let func_detail_extractor = Regex::new(func_detail_pattern).expect("Valid regex pattern");
|
|
|
|
let arg_pattern = r"(?s)<arg_key>(.*?)</arg_key>\s*<arg_value>(.*?)</arg_value>";
|
|
let arg_extractor = Regex::new(arg_pattern).expect("Valid regex pattern");
|
|
|
|
Self {
|
|
tool_call_extractor,
|
|
func_detail_extractor,
|
|
arg_extractor,
|
|
buffer: String::new(),
|
|
prev_tool_call_arr: Vec::new(),
|
|
current_tool_id: -1,
|
|
streamed_args_for_tool: Vec::new(),
|
|
bot_token: "<tool_call>",
|
|
eot_token: "</tool_call>",
|
|
}
|
|
}
|
|
|
|
/// Parse arguments from key-value pairs
|
|
fn parse_arguments(&self, args_text: &str) -> ParserResult<serde_json::Map<String, Value>> {
|
|
let mut arguments = serde_json::Map::new();
|
|
|
|
for capture in self.arg_extractor.captures_iter(args_text) {
|
|
let key = capture.get(1).map_or("", |m| m.as_str()).trim();
|
|
let value_str = capture.get(2).map_or("", |m| m.as_str()).trim();
|
|
|
|
// Try to parse the value as JSON first, fallback to string
|
|
let value = if let Ok(json_val) = serde_json::from_str::<Value>(value_str) {
|
|
json_val
|
|
} else {
|
|
// Try parsing as Python literal (similar to Python's ast.literal_eval)
|
|
if value_str == "true" || value_str == "True" {
|
|
Value::Bool(true)
|
|
} else if value_str == "false" || value_str == "False" {
|
|
Value::Bool(false)
|
|
} else if value_str == "null" || value_str == "None" {
|
|
Value::Null
|
|
} else if let Ok(num) = value_str.parse::<i64>() {
|
|
Value::Number(num.into())
|
|
} else if let Ok(num) = value_str.parse::<f64>() {
|
|
if let Some(n) = serde_json::Number::from_f64(num) {
|
|
Value::Number(n)
|
|
} else {
|
|
Value::String(value_str.to_string())
|
|
}
|
|
} else {
|
|
Value::String(value_str.to_string())
|
|
}
|
|
};
|
|
|
|
arguments.insert(key.to_string(), value);
|
|
}
|
|
|
|
Ok(arguments)
|
|
}
|
|
|
|
/// Parse a single tool call block
|
|
fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
|
|
if let Some(captures) = self.func_detail_extractor.captures(block) {
|
|
// Get function name
|
|
let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
|
|
|
|
// Get arguments text
|
|
let args_text = captures.get(2).map_or("", |m| m.as_str());
|
|
|
|
// Parse arguments
|
|
let arguments = self.parse_arguments(args_text)?;
|
|
|
|
let arguments_str = serde_json::to_string(&arguments)
|
|
.map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
|
|
|
|
Ok(Some(ToolCall {
|
|
function: FunctionCall {
|
|
name: func_name.to_string(),
|
|
arguments: arguments_str,
|
|
},
|
|
}))
|
|
} else {
|
|
Ok(None)
|
|
}
|
|
}
|
|
|
|
/// Parse and return StreamingParseResult (mirrors Python's detect_and_parse)
|
|
/// Parse all tool calls from text (shared logic for complete and incremental parsing)
|
|
fn parse_tool_calls_from_text(&self, text: &str) -> ParserResult<Vec<ToolCall>> {
|
|
let mut tools = Vec::new();
|
|
|
|
for mat in self.tool_call_extractor.find_iter(text) {
|
|
match self.parse_tool_call(mat.as_str()) {
|
|
Ok(Some(tool)) => tools.push(tool),
|
|
Ok(None) => continue,
|
|
Err(e) => {
|
|
tracing::warn!("Failed to parse tool call: {}", e);
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(tools)
|
|
}
|
|
}
|
|
|
|
impl Default for Glm4MoeParser {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl ToolParser for Glm4MoeParser {
|
|
async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
|
|
// Check if text contains GLM-4 MoE format
|
|
if !self.has_tool_markers(text) {
|
|
return Ok((text.to_string(), vec![]));
|
|
}
|
|
|
|
// Find where tool calls begin
|
|
let idx = text.find("<tool_call>").unwrap();
|
|
let normal_text = text[..idx].to_string();
|
|
|
|
// Parse all tool calls using shared helper
|
|
let tools = self.parse_tool_calls_from_text(text)?;
|
|
|
|
// If no tools were successfully parsed despite having markers, return entire text as fallback
|
|
if tools.is_empty() {
|
|
return Ok((text.to_string(), vec![]));
|
|
}
|
|
|
|
Ok((normal_text, tools))
|
|
}
|
|
|
|
async fn parse_incremental(
|
|
&mut self,
|
|
chunk: &str,
|
|
tools: &[Tool],
|
|
) -> ParserResult<StreamingParseResult> {
|
|
// Python logic: Wait for complete tool call, then parse it all at once
|
|
self.buffer.push_str(chunk);
|
|
let current_text = &self.buffer.clone();
|
|
|
|
// Check if we have bot_token
|
|
let start = current_text.find(self.bot_token);
|
|
if start.is_none() {
|
|
self.buffer.clear();
|
|
// If we're in the middle of streaming (current_tool_id > 0), don't return text
|
|
let normal_text = if self.current_tool_id > 0 {
|
|
String::new()
|
|
} else {
|
|
current_text.clone()
|
|
};
|
|
return Ok(StreamingParseResult {
|
|
normal_text,
|
|
calls: vec![],
|
|
});
|
|
}
|
|
|
|
// Check if we have eot_token (end of tool call)
|
|
let end = current_text.find(self.eot_token);
|
|
if let Some(end_pos) = end {
|
|
// We have a complete tool call!
|
|
|
|
// Initialize state if this is the first tool call
|
|
if self.current_tool_id == -1 {
|
|
self.current_tool_id = 0;
|
|
self.prev_tool_call_arr = Vec::new();
|
|
self.streamed_args_for_tool = vec![String::new()];
|
|
}
|
|
|
|
// Ensure we have enough entries in our tracking arrays
|
|
helpers::ensure_capacity(
|
|
self.current_tool_id,
|
|
&mut self.prev_tool_call_arr,
|
|
&mut self.streamed_args_for_tool,
|
|
);
|
|
|
|
// Parse the complete block using shared helper
|
|
let block_end = end_pos + self.eot_token.len();
|
|
let parsed_tools = self.parse_tool_calls_from_text(¤t_text[..block_end])?;
|
|
|
|
// Extract normal text before tool calls
|
|
let idx = current_text.find(self.bot_token);
|
|
let normal_text = if let Some(pos) = idx {
|
|
current_text[..pos].trim().to_string()
|
|
} else {
|
|
String::new()
|
|
};
|
|
|
|
// Build tool indices for validation
|
|
let tool_indices = helpers::get_tool_indices(tools);
|
|
|
|
let mut calls = Vec::new();
|
|
|
|
if !parsed_tools.is_empty() {
|
|
// Take the first tool and convert to ToolCallItem
|
|
let tool_call = &parsed_tools[0];
|
|
let tool_id = self.current_tool_id as usize;
|
|
|
|
// Validate tool name
|
|
if !tool_indices.contains_key(&tool_call.function.name) {
|
|
// Invalid tool name - skip this tool, preserve indexing for next tool
|
|
tracing::warn!("Invalid tool name '{}' - skipping", tool_call.function.name);
|
|
helpers::reset_current_tool_state(
|
|
&mut self.buffer,
|
|
&mut false, // glm4_moe doesn't track name_sent per tool
|
|
&mut self.streamed_args_for_tool,
|
|
&self.prev_tool_call_arr,
|
|
);
|
|
return Ok(StreamingParseResult::default());
|
|
}
|
|
|
|
calls.push(ToolCallItem {
|
|
tool_index: tool_id,
|
|
name: Some(tool_call.function.name.clone()),
|
|
parameters: tool_call.function.arguments.clone(),
|
|
});
|
|
|
|
// Store in tracking arrays
|
|
if self.prev_tool_call_arr.len() <= tool_id {
|
|
self.prev_tool_call_arr
|
|
.resize_with(tool_id + 1, || Value::Null);
|
|
}
|
|
|
|
// Parse parameters as JSON and store
|
|
if let Ok(args) = serde_json::from_str::<Value>(&tool_call.function.arguments) {
|
|
self.prev_tool_call_arr[tool_id] = serde_json::json!({
|
|
"name": tool_call.function.name,
|
|
"arguments": args,
|
|
});
|
|
}
|
|
|
|
if self.streamed_args_for_tool.len() <= tool_id {
|
|
self.streamed_args_for_tool
|
|
.resize_with(tool_id + 1, String::new);
|
|
}
|
|
self.streamed_args_for_tool[tool_id] = tool_call.function.arguments.clone();
|
|
|
|
self.current_tool_id += 1;
|
|
}
|
|
|
|
// Remove processed portion from buffer
|
|
self.buffer = current_text[block_end..].to_string();
|
|
return Ok(StreamingParseResult { normal_text, calls });
|
|
}
|
|
|
|
// No complete tool call yet - return normal text before start token
|
|
let start_pos = start.unwrap();
|
|
let normal_text = current_text[..start_pos].to_string();
|
|
self.buffer = current_text[start_pos..].to_string();
|
|
|
|
Ok(StreamingParseResult {
|
|
normal_text,
|
|
calls: vec![],
|
|
})
|
|
}
|
|
|
|
fn has_tool_markers(&self, text: &str) -> bool {
|
|
text.contains(self.bot_token)
|
|
}
|
|
|
|
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
|
|
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
|
|
}
|
|
|
|
fn reset(&mut self) {
|
|
self.buffer.clear();
|
|
self.prev_tool_call_arr.clear();
|
|
self.current_tool_id = -1;
|
|
self.streamed_args_for_tool.clear();
|
|
}
|
|
}
|