[router][grpc] Fix streaming bugs: empty tool names, state pollution, and panics (#11373)

This commit is contained in:
Chang Su
2025-10-09 03:53:23 -07:00
committed by GitHub
parent a4b424c632
commit ab926dd697
33 changed files with 1145 additions and 534 deletions

View File

@@ -1,11 +1,11 @@
use thiserror::Error;
/// Result type for tool parser operations
pub type ToolParserResult<T> = Result<T, ToolParserError>;
pub type ParserResult<T> = Result<T, ParserError>;
/// Errors that can occur during tool parsing
#[derive(Debug, Error)]
pub enum ToolParserError {
pub enum ParserError {
#[error("Parsing failed: {0}")]
ParsingFailed(String),

View File

@@ -11,25 +11,25 @@ use crate::tool_parser::parsers::{
use crate::tool_parser::traits::ToolParser;
/// Type alias for pooled parser instances.
pub type PooledToolParser = Arc<Mutex<Box<dyn ToolParser>>>;
pub type PooledParser = Arc<Mutex<Box<dyn ToolParser>>>;
/// Type alias for parser creator functions.
type ParserCreator = Arc<dyn Fn() -> Box<dyn ToolParser> + Send + Sync>;
/// Registry for model-specific tool parsers with pooling support.
#[derive(Clone)]
pub struct ToolParserRegistry {
pub struct ParserRegistry {
/// Creator functions for parsers (used when pool is empty)
creators: Arc<RwLock<HashMap<String, ParserCreator>>>,
/// Pooled parser instances for reuse
pool: Arc<RwLock<HashMap<String, PooledToolParser>>>,
pool: Arc<RwLock<HashMap<String, PooledParser>>>,
/// Model pattern to parser name mappings
model_mapping: Arc<RwLock<HashMap<String, String>>>,
/// Default parser name
default_parser: Arc<RwLock<String>>,
}
impl ToolParserRegistry {
impl ParserRegistry {
/// Create a new empty registry.
pub fn new() -> Self {
Self {
@@ -57,7 +57,7 @@ impl ToolParserRegistry {
/// Get a pooled parser by exact name.
/// Returns a shared parser instance from the pool, creating one if needed.
pub fn get_pooled_parser(&self, name: &str) -> Option<PooledToolParser> {
pub fn get_pooled_parser(&self, name: &str) -> Option<PooledParser> {
// First check if we have a pooled instance
{
let pool = self.pool.read().unwrap();
@@ -81,8 +81,91 @@ impl ToolParserRegistry {
}
}
/// Check if a parser with the given name is registered.
pub fn has_parser(&self, name: &str) -> bool {
let creators = self.creators.read().unwrap();
creators.contains_key(name)
}
/// Create a fresh (non-pooled) parser instance by exact name.
/// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
pub fn create_parser(&self, name: &str) -> Option<Box<dyn ToolParser>> {
let creators = self.creators.read().unwrap();
creators.get(name).map(|creator| creator())
}
/// Check if a parser can be created for a specific model without actually creating it.
/// Returns true if a parser is available (registered) for this model.
pub fn has_parser_for_model(&self, model: &str) -> bool {
// Try exact match first
{
let mapping = self.model_mapping.read().unwrap();
if let Some(parser_name) = mapping.get(model) {
let creators = self.creators.read().unwrap();
if creators.contains_key(parser_name) {
return true;
}
}
}
// Try prefix matching
let model_mapping = self.model_mapping.read().unwrap();
let best_match = model_mapping
.iter()
.filter(|(pattern, _)| {
pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
})
.max_by_key(|(pattern, _)| pattern.len());
if let Some((_, parser_name)) = best_match {
let creators = self.creators.read().unwrap();
if creators.contains_key(parser_name) {
return true;
}
}
// Check if default parser exists
let default = self.default_parser.read().unwrap().clone();
let creators = self.creators.read().unwrap();
creators.contains_key(&default)
}
/// Create a fresh (non-pooled) parser instance for a specific model.
/// Returns a new parser instance for each call - useful for streaming where state isolation is needed.
pub fn create_for_model(&self, model: &str) -> Option<Box<dyn ToolParser>> {
// Try exact match first
{
let mapping = self.model_mapping.read().unwrap();
if let Some(parser_name) = mapping.get(model) {
if let Some(parser) = self.create_parser(parser_name) {
return Some(parser);
}
}
}
// Try prefix matching with more specific patterns first
let model_mapping = self.model_mapping.read().unwrap();
let best_match = model_mapping
.iter()
.filter(|(pattern, _)| {
pattern.ends_with('*') && model.starts_with(&pattern[..pattern.len() - 1])
})
.max_by_key(|(pattern, _)| pattern.len());
// Return the best matching parser
if let Some((_, parser_name)) = best_match {
if let Some(parser) = self.create_parser(parser_name) {
return Some(parser);
}
}
// Fall back to default parser
let default = self.default_parser.read().unwrap().clone();
self.create_parser(&default)
}
/// Get parser for a specific model
pub fn get_pooled_for_model(&self, model: &str) -> Option<PooledToolParser> {
pub fn get_pooled_for_model(&self, model: &str) -> Option<PooledParser> {
// Try exact match first
{
let mapping = self.model_mapping.read().unwrap();
@@ -127,7 +210,7 @@ impl ToolParserRegistry {
}
}
impl Default for ToolParserRegistry {
impl Default for ParserRegistry {
fn default() -> Self {
Self::new()
}
@@ -135,14 +218,14 @@ impl Default for ToolParserRegistry {
/// Factory for creating tool parsers based on model type.
#[derive(Clone)]
pub struct ToolParserFactory {
registry: ToolParserRegistry,
pub struct ParserFactory {
registry: ParserRegistry,
}
impl ToolParserFactory {
impl ParserFactory {
/// Create a new factory with default parsers registered.
pub fn new() -> Self {
let registry = ToolParserRegistry::new();
let registry = ParserRegistry::new();
// Register default parsers
registry.register_parser("json", || Box::new(JsonParser::new()));
@@ -172,7 +255,7 @@ impl ToolParserFactory {
Self { registry }
}
fn register_default_mappings(registry: &ToolParserRegistry) {
fn register_default_mappings(registry: &ParserRegistry) {
// OpenAI models
registry.map_model("gpt-4*", "json");
registry.map_model("gpt-3.5*", "json");
@@ -229,7 +312,7 @@ impl ToolParserFactory {
/// Get a pooled parser for the given model ID.
/// Returns a shared instance that can be used concurrently.
/// Falls back to JSON parser if model is not recognized.
pub fn get_pooled(&self, model_id: &str) -> PooledToolParser {
pub fn get_pooled(&self, model_id: &str) -> PooledParser {
self.registry
.get_pooled_for_model(model_id)
.unwrap_or_else(|| {
@@ -241,7 +324,7 @@ impl ToolParserFactory {
}
/// Get the internal registry for custom registration.
pub fn registry(&self) -> &ToolParserRegistry {
pub fn registry(&self) -> &ParserRegistry {
&self.registry
}
@@ -299,7 +382,7 @@ impl ToolParserFactory {
}
}
impl Default for ToolParserFactory {
impl Default for ParserFactory {
fn default() -> Self {
Self::new()
}

View File

@@ -16,8 +16,8 @@ pub mod parsers;
mod tests;
// Re-export commonly used types
pub use errors::{ToolParserError, ToolParserResult};
pub use factory::{PooledToolParser, ToolParserFactory, ToolParserRegistry};
pub use errors::{ParserError, ParserResult};
pub use factory::{ParserFactory, ParserRegistry, PooledParser};
pub use traits::{PartialJsonParser, ToolParser};
pub use types::{FunctionCall, PartialToolCall, StreamingParseResult, ToolCall};

View File

@@ -5,7 +5,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
errors::{ParserError, ParserResult},
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
@@ -78,15 +78,15 @@ impl DeepSeekParser {
}
/// Parse a single tool call block - throws error if parsing fails
fn parse_tool_call(&self, block: &str) -> ToolParserResult<ToolCall> {
fn parse_tool_call(&self, block: &str) -> ParserResult<ToolCall> {
let captures = self.func_detail_extractor.captures(block).ok_or_else(|| {
ToolParserError::ParsingFailed("Failed to match tool call pattern".to_string())
ParserError::ParsingFailed("Failed to match tool call pattern".to_string())
})?;
// Get function type (should be "function")
let func_type = captures.get(1).map_or("", |m| m.as_str());
if func_type != "function" {
return Err(ToolParserError::ParsingFailed(format!(
return Err(ParserError::ParsingFailed(format!(
"Invalid function type: {}",
func_type
)));
@@ -95,7 +95,7 @@ impl DeepSeekParser {
// Get function name
let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
if func_name.is_empty() {
return Err(ToolParserError::ParsingFailed(
return Err(ParserError::ParsingFailed(
"Empty function name".to_string(),
));
}
@@ -105,7 +105,7 @@ impl DeepSeekParser {
// Parse JSON arguments
let value = serde_json::from_str::<Value>(json_args)
.map_err(|e| ToolParserError::ParsingFailed(format!("Invalid JSON: {}", e)))?;
.map_err(|e| ParserError::ParsingFailed(format!("Invalid JSON: {}", e)))?;
// Create arguments object
let args = if value.is_object() {
@@ -115,8 +115,8 @@ impl DeepSeekParser {
serde_json::json!({ "value": value })
};
let arguments = serde_json::to_string(&args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
let arguments =
serde_json::to_string(&args).map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
Ok(ToolCall {
function: FunctionCall {
@@ -135,7 +135,7 @@ impl Default for DeepSeekParser {
#[async_trait]
impl ToolParser for DeepSeekParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![]));
}
@@ -168,7 +168,7 @@ impl ToolParser for DeepSeekParser {
&mut self,
chunk: &str,
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
) -> ParserResult<StreamingParseResult> {
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
@@ -314,4 +314,12 @@ impl ToolParser for DeepSeekParser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
fn reset(&mut self) {
self.buffer.clear();
self.prev_tool_call_arr.clear();
self.current_tool_id = -1;
self.current_tool_name_sent = false;
self.streamed_args_for_tool.clear();
}
}

View File

@@ -5,7 +5,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
errors::{ParserError, ParserResult},
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
@@ -72,7 +72,7 @@ impl Glm4MoeParser {
}
/// Parse arguments from key-value pairs
fn parse_arguments(&self, args_text: &str) -> ToolParserResult<serde_json::Map<String, Value>> {
fn parse_arguments(&self, args_text: &str) -> ParserResult<serde_json::Map<String, Value>> {
let mut arguments = serde_json::Map::new();
for capture in self.arg_extractor.captures_iter(args_text) {
@@ -110,7 +110,7 @@ impl Glm4MoeParser {
}
/// Parse a single tool call block
fn parse_tool_call(&self, block: &str) -> ToolParserResult<Option<ToolCall>> {
fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
if let Some(captures) = self.func_detail_extractor.captures(block) {
// Get function name
let func_name = captures.get(1).map_or("", |m| m.as_str()).trim();
@@ -122,7 +122,7 @@ impl Glm4MoeParser {
let arguments = self.parse_arguments(args_text)?;
let arguments_str = serde_json::to_string(&arguments)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
.map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
Ok(Some(ToolCall {
function: FunctionCall {
@@ -137,7 +137,7 @@ impl Glm4MoeParser {
/// Parse and return StreamingParseResult (mirrors Python's detect_and_parse)
/// Parse all tool calls from text (shared logic for complete and incremental parsing)
fn parse_tool_calls_from_text(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
fn parse_tool_calls_from_text(&self, text: &str) -> ParserResult<Vec<ToolCall>> {
let mut tools = Vec::new();
for mat in self.tool_call_extractor.find_iter(text) {
@@ -163,7 +163,7 @@ impl Default for Glm4MoeParser {
#[async_trait]
impl ToolParser for Glm4MoeParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
// Check if text contains GLM-4 MoE format
if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![]));
@@ -188,7 +188,7 @@ impl ToolParser for Glm4MoeParser {
&mut self,
chunk: &str,
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
) -> ParserResult<StreamingParseResult> {
// Python logic: Wait for complete tool call, then parse it all at once
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
@@ -315,4 +315,11 @@ impl ToolParser for Glm4MoeParser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
fn reset(&mut self) {
self.buffer.clear();
self.prev_tool_call_arr.clear();
self.current_tool_id = -1;
self.streamed_args_for_tool.clear();
}
}

View File

@@ -3,7 +3,7 @@ use async_trait::async_trait;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::ToolParserResult,
errors::ParserResult,
traits::{TokenToolParser, ToolParser},
types::{StreamingParseResult, ToolCall},
};
@@ -23,7 +23,7 @@ impl GptOssHarmonyParser {
#[async_trait]
impl ToolParser for GptOssHarmonyParser {
async fn parse_complete(&self, output: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
async fn parse_complete(&self, output: &str) -> ParserResult<(String, Vec<ToolCall>)> {
// Temporary stub: fall back to returning the raw text with no tool calls.
// Later phases will decode Harmony tokens into structured tool calls.
Ok((output.to_string(), Vec::new()))
@@ -33,7 +33,7 @@ impl ToolParser for GptOssHarmonyParser {
&mut self,
_chunk: &str,
_tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
) -> ParserResult<StreamingParseResult> {
// Temporary stub until the Harmony streaming pipeline is implemented.
Ok(StreamingParseResult::default())
}
@@ -54,7 +54,7 @@ impl TokenToolParser for GptOssHarmonyParser {
async fn parse_complete_tokens(
&self,
_tokens: &[u32],
) -> ToolParserResult<(String, Vec<ToolCall>)> {
) -> ParserResult<(String, Vec<ToolCall>)> {
// Placeholder until Harmony integration lands. Returning an empty tool list ensures
// that enabling the parser without full implementation results in a no-op rather
// than a runtime panic.
@@ -65,7 +65,7 @@ impl TokenToolParser for GptOssHarmonyParser {
&mut self,
_tokens: &[u32],
_tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
) -> ParserResult<StreamingParseResult> {
Ok(StreamingParseResult::default())
}
}

View File

@@ -5,7 +5,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
@@ -76,7 +76,7 @@ impl Default for GptOssParser {
#[async_trait]
impl ToolParser for GptOssParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
// Check if text contains GPT-OSS format
if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![]));
@@ -100,7 +100,7 @@ impl ToolParser for GptOssParser {
} else {
match serde_json::from_str::<Value>(args_content) {
Ok(value) => serde_json::to_string(&value)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?,
.map_err(|e| ParserError::ParsingFailed(e.to_string()))?,
Err(_) => {
// Skip malformed JSON
continue;
@@ -126,7 +126,7 @@ impl ToolParser for GptOssParser {
&mut self,
chunk: &str,
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
) -> ParserResult<StreamingParseResult> {
self.buffer.push_str(chunk);
// Check for tool markers
@@ -211,7 +211,7 @@ impl ToolParser for GptOssParser {
partial_args
};
match self.partial_json.parse_value(json_part) {
match self.partial_json.parse_value(json_part, true) {
Ok((value, _consumed)) => {
let args_str = serde_json::to_string(&value)
.unwrap_or_else(|_| "{}".to_string());

View File

@@ -2,7 +2,7 @@ use crate::protocols::spec::Tool;
use serde_json::Value;
use std::collections::HashMap;
use crate::tool_parser::errors::{ToolParserError, ToolParserResult};
use crate::tool_parser::errors::{ParserError, ParserResult};
use crate::tool_parser::types::{StreamingParseResult, ToolCallItem};
/// Get a mapping of tool names to their indices
@@ -14,6 +14,16 @@ pub fn get_tool_indices(tools: &[Tool]) -> HashMap<String, usize> {
.collect()
}
/// Find the common prefix of two strings
/// Used for incremental argument streaming when partial JSON returns different intermediate states
pub fn find_common_prefix(s1: &str, s2: &str) -> String {
s1.chars()
.zip(s2.chars())
.take_while(|(c1, c2)| c1 == c2)
.map(|(c1, _)| c1)
.collect()
}
/// Get unstreamed tool call arguments
/// Returns tool call items for arguments that have been parsed but not yet streamed
/// This ensures tool calls are properly completed even if the model generates final arguments in the last chunk
@@ -96,7 +106,7 @@ pub fn reset_parser_state(
) {
buffer.clear();
prev_tool_call_arr.clear();
*current_tool_id = 0;
*current_tool_id = -1;
*current_tool_name_sent = false;
streamed_args_for_tool.clear();
}
@@ -169,7 +179,7 @@ pub fn normalize_arguments_field(mut obj: Value) -> Value {
///
/// # Returns
/// - `Ok(StreamingParseResult)` with any tool call items to stream
/// - `Err(ToolParserError)` if JSON parsing or serialization fails
/// - `Err(ParserError)` if JSON parsing or serialization fails
#[allow(clippy::too_many_arguments)]
pub fn handle_json_tool_streaming(
current_text: &str,
@@ -181,7 +191,7 @@ pub fn handle_json_tool_streaming(
current_tool_name_sent: &mut bool,
streamed_args_for_tool: &mut Vec<String>,
prev_tool_call_arr: &mut Vec<Value>,
) -> ToolParserResult<StreamingParseResult> {
) -> ParserResult<StreamingParseResult> {
// Check if we have content to parse
if start_idx >= current_text.len() {
return Ok(StreamingParseResult::default());
@@ -190,8 +200,12 @@ pub fn handle_json_tool_streaming(
// Extract JSON string from current position
let json_str = &current_text[start_idx..];
// When current_tool_name_sent is false, don't allow partial strings to avoid
// parsing incomplete tool names as empty strings
let allow_partial_strings = *current_tool_name_sent;
// Parse partial JSON
let (obj, end_idx) = match partial_json.parse_value(json_str) {
let (obj, end_idx) = match partial_json.parse_value(json_str, allow_partial_strings) {
Ok(result) => result,
Err(_) => {
return Ok(StreamingParseResult::default());
@@ -252,49 +266,68 @@ pub fn handle_json_tool_streaming(
.map(|s| s.len())
.unwrap_or(0);
let cur_args_json = serde_json::to_string(cur_arguments)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
.map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
// Compute diff: everything after what we've already sent
let diff = cur_args_json[sent..].to_string();
// Get prev_arguments (matches Python's structure)
let prev_arguments = if tool_id < prev_tool_call_arr.len() {
prev_tool_call_arr[tool_id].get("arguments")
} else {
None
};
// Send diff if there's new content
if !diff.is_empty() {
// Only accumulate if not complete
if !is_complete && tool_id < streamed_args_for_tool.len() {
streamed_args_for_tool[tool_id].push_str(&diff);
}
// Calculate diff: everything after we've already sent
let mut argument_diff = None;
result.calls.push(ToolCallItem {
tool_index: tool_id,
name: None,
parameters: diff,
});
}
// If JSON is complete, advance to next tool
if is_complete {
// Remove processed portion, keep unprocessed content
*buffer = current_text[start_idx + end_idx..].to_string();
// Python: argument_diff = cur_args_json[sent:]
// Rust needs bounds check (Python returns "" automatically)
argument_diff = if sent < cur_args_json.len() {
Some(cur_args_json[sent..].to_string())
} else {
Some(String::new())
};
} else if let Some(prev_args) = prev_arguments {
let prev_args_json = serde_json::to_string(prev_args)
.map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
// Clear completed tool data
if tool_id < prev_tool_call_arr.len() {
prev_tool_call_arr[tool_id] = Value::Null;
if cur_args_json != prev_args_json {
let prefix = find_common_prefix(&prev_args_json, &cur_args_json);
argument_diff = if sent < prefix.len() {
Some(prefix[sent..].to_string())
} else {
Some(String::new())
};
}
*current_tool_name_sent = false;
if tool_id < streamed_args_for_tool.len() {
streamed_args_for_tool[tool_id].clear();
}
*current_tool_id += 1;
}
}
// Update prev_tool_call_arr with current state
if *current_tool_id >= 0 {
ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
let tool_id = *current_tool_id as usize;
// Send diff if present
if let Some(diff) = argument_diff {
if !diff.is_empty() {
if tool_id < streamed_args_for_tool.len() {
streamed_args_for_tool[tool_id].push_str(&diff);
}
result.calls.push(ToolCallItem {
tool_index: tool_id,
name: None,
parameters: diff,
});
}
}
if tool_id < prev_tool_call_arr.len() {
prev_tool_call_arr[tool_id] = current_tool_call;
// Update prev_tool_call_arr with current state
if *current_tool_id >= 0 {
ensure_capacity(*current_tool_id, prev_tool_call_arr, streamed_args_for_tool);
if tool_id < prev_tool_call_arr.len() {
prev_tool_call_arr[tool_id] = current_tool_call;
}
}
// If complete, advance to next tool
if is_complete {
*buffer = current_text[start_idx + end_idx..].to_string();
*current_tool_name_sent = false;
*current_tool_id += 1;
}
}
@@ -371,7 +404,7 @@ mod tests {
assert_eq!(buffer, "");
assert_eq!(prev_tools.len(), 0);
assert_eq!(current_tool_id, 0);
assert_eq!(current_tool_id, -1);
assert!(!current_tool_name_sent);
assert_eq!(streamed_args.len(), 0);
}

View File

@@ -4,7 +4,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
@@ -117,7 +117,7 @@ impl JsonParser {
}
/// Parse a single JSON object into a ToolCall
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
// Check if this looks like a tool call
let name = obj
.get("name")
@@ -134,7 +134,7 @@ impl JsonParser {
// Convert arguments to JSON string
let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
.map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
Ok(Some(ToolCall {
function: FunctionCall {
@@ -148,7 +148,7 @@ impl JsonParser {
}
/// Parse JSON value(s) into tool calls
fn parse_json_value(&self, value: &Value) -> ToolParserResult<Vec<ToolCall>> {
fn parse_json_value(&self, value: &Value) -> ParserResult<Vec<ToolCall>> {
let mut tools = Vec::new();
match value {
@@ -184,11 +184,11 @@ impl Default for JsonParser {
#[async_trait]
impl ToolParser for JsonParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
// Always use extract_json_from_text to handle both pure JSON and mixed content
if let Some((extracted_json, normal_text)) = self.extract_json_from_text(text) {
let parsed = serde_json::from_str::<Value>(&extracted_json)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))
.map_err(|e| ParserError::ParsingFailed(e.to_string()))
.and_then(|v| self.parse_json_value(&v));
match parsed {
@@ -205,7 +205,7 @@ impl ToolParser for JsonParser {
&mut self,
chunk: &str,
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
) -> ParserResult<StreamingParseResult> {
// Append new text to buffer
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
@@ -264,4 +264,14 @@ impl ToolParser for JsonParser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
fn reset(&mut self) {
helpers::reset_parser_state(
&mut self.buffer,
&mut self.prev_tool_call_arr,
&mut self.current_tool_id,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
);
}
}

View File

@@ -5,7 +5,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::ToolParserResult,
errors::ParserResult,
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
@@ -102,7 +102,7 @@ impl Default for KimiK2Parser {
#[async_trait]
impl ToolParser for KimiK2Parser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![]));
}
@@ -161,7 +161,7 @@ impl ToolParser for KimiK2Parser {
&mut self,
chunk: &str,
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
) -> ParserResult<StreamingParseResult> {
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
@@ -333,4 +333,13 @@ impl ToolParser for KimiK2Parser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
fn reset(&mut self) {
self.buffer.clear();
self.prev_tool_call_arr.clear();
self.current_tool_id = -1;
self.current_tool_name_sent = false;
self.streamed_args_for_tool.clear();
self.last_arguments.clear();
}
}

View File

@@ -4,7 +4,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
@@ -70,7 +70,7 @@ impl LlamaParser {
}
/// Parse a single JSON object into a ToolCall (Llama format: name + parameters)
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
// Llama format only: {"name": "function_name", "parameters": {...}}
let name = obj.get("name").and_then(|v| v.as_str());
@@ -81,7 +81,7 @@ impl LlamaParser {
// Convert parameters to JSON string
let arguments = serde_json::to_string(parameters)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
.map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
Ok(Some(ToolCall {
function: FunctionCall {
@@ -95,7 +95,7 @@ impl LlamaParser {
}
/// Parse semicolon-separated JSON objects
fn parse_semicolon_separated(&self, content: &str) -> ToolParserResult<Vec<ToolCall>> {
fn parse_semicolon_separated(&self, content: &str) -> ParserResult<Vec<ToolCall>> {
let mut all_tools = Vec::new();
// Split by semicolon and parse each JSON object
@@ -131,7 +131,7 @@ impl Default for LlamaParser {
#[async_trait]
impl ToolParser for LlamaParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
// Extract normal text and JSON content
let (normal_text, json_content) =
if let Some((normal, json)) = self.extract_content_after_python_tag(text) {
@@ -149,7 +149,7 @@ impl ToolParser for LlamaParser {
} else {
// Try single JSON object
let parsed = serde_json::from_str::<Value>(json_content.trim())
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))
.map_err(|e| ParserError::ParsingFailed(e.to_string()))
.and_then(|v| {
self.parse_single_object(&v)
.map(|opt| opt.map_or_else(Vec::new, |tool| vec![tool]))
@@ -173,7 +173,7 @@ impl ToolParser for LlamaParser {
&mut self,
chunk: &str,
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
) -> ParserResult<StreamingParseResult> {
// Append new text to buffer
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
@@ -231,4 +231,14 @@ impl ToolParser for LlamaParser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
fn reset(&mut self) {
helpers::reset_parser_state(
&mut self.buffer,
&mut self.prev_tool_call_arr,
&mut self.current_tool_id,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
);
}
}

View File

@@ -4,7 +4,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
@@ -111,9 +111,9 @@ impl MistralParser {
}
/// Parse tool calls from a JSON array
fn parse_json_array(&self, json_str: &str) -> ToolParserResult<Vec<ToolCall>> {
fn parse_json_array(&self, json_str: &str) -> ParserResult<Vec<ToolCall>> {
let value: Value = serde_json::from_str(json_str)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
.map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
let mut tools = Vec::new();
@@ -134,7 +134,7 @@ impl MistralParser {
}
/// Parse a single JSON object into a ToolCall
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
let name = obj.get("name").and_then(|v| v.as_str());
if let Some(name) = name {
@@ -144,7 +144,7 @@ impl MistralParser {
// Convert arguments to JSON string
let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
.map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
Ok(Some(ToolCall {
function: FunctionCall {
@@ -166,7 +166,7 @@ impl Default for MistralParser {
#[async_trait]
impl ToolParser for MistralParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
// Check if text contains Mistral format
if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![]));
@@ -199,7 +199,7 @@ impl ToolParser for MistralParser {
&mut self,
chunk: &str,
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
) -> ParserResult<StreamingParseResult> {
// Append new text to buffer
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
@@ -256,4 +256,14 @@ impl ToolParser for MistralParser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
fn reset(&mut self) {
helpers::reset_parser_state(
&mut self.buffer,
&mut self.prev_tool_call_arr,
&mut self.current_tool_id,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
);
}
}

View File

@@ -18,7 +18,7 @@ use std::sync::OnceLock;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
errors::{ParserError, ParserResult},
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
@@ -74,7 +74,7 @@ impl PythonicParser {
.replace("<|python_end|>", "")
}
fn parse_tool_call_block(&self, block: &str) -> ToolParserResult<Vec<ToolCall>> {
fn parse_tool_call_block(&self, block: &str) -> ParserResult<Vec<ToolCall>> {
let expr = parse_python_expression(block)?;
match expr {
Expr::List(list_expr) => list_expr
@@ -83,7 +83,7 @@ impl PythonicParser {
.enumerate()
.map(|(idx, call_expr)| build_tool_call(call_expr, idx))
.collect(),
_ => Err(ToolParserError::ParsingFailed(
_ => Err(ParserError::ParsingFailed(
"Expected a list of function calls in pythonic tool call".to_string(),
)),
}
@@ -92,7 +92,7 @@ impl PythonicParser {
#[async_trait]
impl ToolParser for PythonicParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
let cleaned = Self::strip_special_tokens(text);
if let Some((tool_calls_text, normal_text)) = self.extract_tool_calls(&cleaned) {
@@ -120,7 +120,7 @@ impl ToolParser for PythonicParser {
&mut self,
chunk: &str,
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
) -> ParserResult<StreamingParseResult> {
self.buffer.push_str(chunk);
let cleaned = Self::strip_special_tokens(&self.buffer);
@@ -232,23 +232,23 @@ fn find_matching_bracket(buffer: &str, start: usize) -> Option<usize> {
None // No matching bracket found
}
fn parse_python_expression(source: &str) -> ToolParserResult<Expr> {
fn parse_python_expression(source: &str) -> ParserResult<Expr> {
let module = parse(source, Mode::Expression, "<pythonic_tool_call>")
.map_err(|err| ToolParserError::ParsingFailed(err.to_string()))?;
.map_err(|err| ParserError::ParsingFailed(err.to_string()))?;
match module {
Mod::Expression(expr_mod) => Ok(*expr_mod.body),
_ => Err(ToolParserError::ParsingFailed(
_ => Err(ParserError::ParsingFailed(
"Expected a Python expression".to_string(),
)),
}
}
fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult<ToolCall> {
fn build_tool_call(expr: Expr, _index: usize) -> ParserResult<ToolCall> {
match expr {
Expr::Call(call_expr) => {
if !call_expr.args.is_empty() {
return Err(ToolParserError::ParsingFailed(
return Err(ParserError::ParsingFailed(
"Positional arguments are not supported in pythonic tool calls".to_string(),
));
}
@@ -256,7 +256,7 @@ fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult<ToolCall> {
let function_name = match *call_expr.func {
Expr::Name(name_expr) => name_expr.id.to_string(),
_ => {
return Err(ToolParserError::ParsingFailed(
return Err(ParserError::ParsingFailed(
"Unsupported function reference in pythonic tool call".to_string(),
))
}
@@ -265,7 +265,7 @@ fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult<ToolCall> {
let mut arguments_map = Map::with_capacity(call_expr.keywords.len());
for keyword in call_expr.keywords {
let arg_name = keyword.arg.ok_or_else(|| {
ToolParserError::ParsingFailed(
ParserError::ParsingFailed(
"pythonic tool calls do not support **kwargs".to_string(),
)
})?;
@@ -283,13 +283,13 @@ fn build_tool_call(expr: Expr, _index: usize) -> ToolParserResult<ToolCall> {
},
})
}
_ => Err(ToolParserError::ParsingFailed(
_ => Err(ParserError::ParsingFailed(
"Expected function calls inside pythonic tool call list".to_string(),
)),
}
}
fn expression_to_json(expr: &Expr) -> ToolParserResult<Value> {
fn expression_to_json(expr: &Expr) -> ParserResult<Value> {
match expr {
Expr::Constant(expr_constant) => constant_to_json(&expr_constant.value),
Expr::List(list_expr) => collect_sequence(&list_expr.elts).map(Value::Array),
@@ -300,81 +300,75 @@ fn expression_to_json(expr: &Expr) -> ToolParserResult<Value> {
Expr::UnaryOp(unary_expr) => match unary_expr.op {
UnaryOp::USub => match unary_expr.operand.as_ref() {
Expr::Constant(const_expr) => negate_constant(&const_expr.value),
_ => Err(ToolParserError::ParsingFailed(
_ => Err(ParserError::ParsingFailed(
"Unsupported unary operand in pythonic tool call".to_string(),
)),
},
UnaryOp::UAdd => expression_to_json(unary_expr.operand.as_ref()),
_ => Err(ToolParserError::ParsingFailed(format!(
_ => Err(ParserError::ParsingFailed(format!(
"Unsupported unary operator in pythonic tool call: {:?}",
unary_expr.op
))),
},
Expr::Name(name_expr) => Ok(Value::String(name_expr.id.to_string())),
_ => Err(ToolParserError::ParsingFailed(format!(
_ => Err(ParserError::ParsingFailed(format!(
"Unsupported expression in pythonic tool call: {:?}",
expr
))),
}
}
fn constant_to_json(constant: &Constant) -> ToolParserResult<Value> {
fn constant_to_json(constant: &Constant) -> ParserResult<Value> {
match constant {
Constant::None => Ok(Value::Null),
Constant::Bool(b) => Ok(Value::Bool(*b)),
Constant::Int(value) => Ok(integer_constant_to_value(value, false)),
Constant::Float(f) => Number::from_f64(*f).map(Value::Number).ok_or_else(|| {
ToolParserError::ParsingFailed(
"Invalid float literal in pythonic tool call".to_string(),
)
ParserError::ParsingFailed("Invalid float literal in pythonic tool call".to_string())
}),
Constant::Str(s) => Ok(Value::String(s.clone())),
Constant::Bytes(bytes) => Ok(Value::String(String::from_utf8_lossy(bytes).into_owned())),
Constant::Tuple(values) => constant_tuple_to_array(values).map(Value::Array),
Constant::Ellipsis | Constant::Complex { .. } => Err(ToolParserError::ParsingFailed(
Constant::Ellipsis | Constant::Complex { .. } => Err(ParserError::ParsingFailed(
"Unsupported literal in pythonic tool call".to_string(),
)),
}
}
fn negate_constant(constant: &Constant) -> ToolParserResult<Value> {
fn negate_constant(constant: &Constant) -> ParserResult<Value> {
match constant {
Constant::Int(value) => Ok(integer_constant_to_value(value, true)),
Constant::Float(f) => Number::from_f64(-f).map(Value::Number).ok_or_else(|| {
ToolParserError::ParsingFailed(
"Invalid float literal in pythonic tool call".to_string(),
)
ParserError::ParsingFailed("Invalid float literal in pythonic tool call".to_string())
}),
_ => Err(ToolParserError::ParsingFailed(
_ => Err(ParserError::ParsingFailed(
"Unsupported unary operand in pythonic tool call".to_string(),
)),
}
}
fn value_to_key_string(value: Value) -> ToolParserResult<String> {
fn value_to_key_string(value: Value) -> ParserResult<String> {
match value {
Value::String(s) => Ok(s),
Value::Number(num) => Ok(num.to_string()),
Value::Bool(b) => Ok(b.to_string()),
Value::Null => Ok("null".to_string()),
other => Err(ToolParserError::ParsingFailed(format!(
other => Err(ParserError::ParsingFailed(format!(
"Unsupported key type in pythonic tool call: {:?}",
other
))),
}
}
fn collect_sequence(elements: &[Expr]) -> ToolParserResult<Vec<Value>> {
fn collect_sequence(elements: &[Expr]) -> ParserResult<Vec<Value>> {
elements.iter().map(expression_to_json).collect()
}
fn collect_dict(keys: &[Option<Expr>], values: &[Expr]) -> ToolParserResult<Map<String, Value>> {
fn collect_dict(keys: &[Option<Expr>], values: &[Expr]) -> ParserResult<Map<String, Value>> {
let mut map = Map::with_capacity(keys.len());
for (key_expr, value_expr) in keys.iter().zip(values.iter()) {
let key_expr = key_expr.as_ref().ok_or_else(|| {
ToolParserError::ParsingFailed(
"pythonic tool calls do not support **kwargs".to_string(),
)
ParserError::ParsingFailed("pythonic tool calls do not support **kwargs".to_string())
})?;
let key_value = expression_to_json(key_expr)?;
let key = value_to_key_string(key_value)?;
@@ -384,7 +378,7 @@ fn collect_dict(keys: &[Option<Expr>], values: &[Expr]) -> ToolParserResult<Map<
Ok(map)
}
fn constant_tuple_to_array(values: &[Constant]) -> ToolParserResult<Vec<Value>> {
fn constant_tuple_to_array(values: &[Constant]) -> ParserResult<Vec<Value>> {
values.iter().map(constant_to_json).collect()
}

View File

@@ -5,7 +5,7 @@ use serde_json::Value;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
errors::{ParserError, ParserResult},
parsers::helpers,
partial_json::PartialJson,
traits::ToolParser,
@@ -76,7 +76,7 @@ impl QwenParser {
}
/// Parse a single JSON object into a ToolCall
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
fn parse_single_object(&self, obj: &Value) -> ParserResult<Option<ToolCall>> {
let name = obj.get("name").and_then(|v| v.as_str());
if let Some(name) = name {
@@ -86,7 +86,7 @@ impl QwenParser {
// Convert arguments to JSON string
let arguments = serde_json::to_string(args)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
.map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
Ok(Some(ToolCall {
function: FunctionCall {
@@ -108,7 +108,7 @@ impl Default for QwenParser {
#[async_trait]
impl ToolParser for QwenParser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
// Check if text contains Qwen format
if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![]));
@@ -123,7 +123,7 @@ impl ToolParser for QwenParser {
for captures in self.extractor.captures_iter(text) {
if let Some(json_str) = captures.get(1) {
let parsed = serde_json::from_str::<Value>(json_str.as_str().trim())
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))
.map_err(|e| ParserError::ParsingFailed(e.to_string()))
.and_then(|v| self.parse_single_object(&v));
match parsed {
@@ -149,7 +149,7 @@ impl ToolParser for QwenParser {
&mut self,
chunk: &str,
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
) -> ParserResult<StreamingParseResult> {
// Append new text to buffer
self.buffer.push_str(chunk);
let current_text = &self.buffer.clone();
@@ -240,4 +240,14 @@ impl ToolParser for QwenParser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
fn reset(&mut self) {
helpers::reset_parser_state(
&mut self.buffer,
&mut self.prev_tool_call_arr,
&mut self.current_tool_id,
&mut self.current_tool_name_sent,
&mut self.streamed_args_for_tool,
);
}
}

View File

@@ -6,7 +6,7 @@ use std::collections::HashMap;
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
errors::{ParserError, ParserResult},
parsers::helpers,
traits::ToolParser,
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
@@ -108,7 +108,7 @@ impl Step3Parser {
fn parse_partial_tool_call(
&mut self,
tool_indices: &HashMap<String, usize>,
) -> ToolParserResult<StreamingParseResult> {
) -> ParserResult<StreamingParseResult> {
let mut calls = Vec::new();
// Check if we have tool_sep (means we're past the type declaration)
@@ -321,7 +321,7 @@ impl Step3Parser {
fn parse_steptml_parameters(
&self,
params_text: &str,
) -> ToolParserResult<serde_json::Map<String, Value>> {
) -> ParserResult<serde_json::Map<String, Value>> {
let mut parameters = serde_json::Map::new();
for capture in self.param_extractor.captures_iter(params_text) {
@@ -359,7 +359,7 @@ impl Step3Parser {
}
/// Parse a single tool call block
fn parse_tool_call(&self, block: &str) -> ToolParserResult<Option<ToolCall>> {
fn parse_tool_call(&self, block: &str) -> ParserResult<Option<ToolCall>> {
// Check if it contains function marker and tool separator
if !block.contains("function") || !block.contains("<tool_sep>") {
return Ok(None);
@@ -393,7 +393,7 @@ impl Step3Parser {
let parameters = self.parse_steptml_parameters(params_text)?;
let arguments_str = serde_json::to_string(&parameters)
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
.map_err(|e| ParserError::ParsingFailed(e.to_string()))?;
Ok(Some(ToolCall {
function: FunctionCall {
@@ -415,7 +415,7 @@ impl Default for Step3Parser {
#[async_trait]
impl ToolParser for Step3Parser {
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
async fn parse_complete(&self, text: &str) -> ParserResult<(String, Vec<ToolCall>)> {
if !self.has_tool_markers(text) {
return Ok((text.to_string(), vec![]));
}
@@ -449,7 +449,7 @@ impl ToolParser for Step3Parser {
&mut self,
chunk: &str,
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult> {
) -> ParserResult<StreamingParseResult> {
self.buffer.push_str(chunk);
// Build tool indices for validation
@@ -555,4 +555,20 @@ impl ToolParser for Step3Parser {
fn get_unstreamed_tool_args(&self) -> Option<Vec<ToolCallItem>> {
helpers::get_unstreamed_args(&self.prev_tool_call_arr, &self.streamed_args_for_tool)
}
fn reset(&mut self) {
// Reset standard state
self.buffer.clear();
self.prev_tool_call_arr.clear();
self.current_tool_id = -1;
self.streamed_args_for_tool.clear();
// Reset Step3-specific fields
self.in_tool_block = false;
self.tool_block_finished = false;
self.current_function_name.clear();
self.current_parameters.clear();
self.in_tool_call = false;
self.function_name_sent = false;
}
}

View File

@@ -1,5 +1,5 @@
use crate::tool_parser::{
errors::{ToolParserError, ToolParserResult},
errors::{ParserError, ParserResult},
traits::PartialJsonParser,
};
use serde_json::{Map, Value};
@@ -22,8 +22,22 @@ impl PartialJson {
}
/// Parse potentially incomplete JSON, returning parsed value and consumed bytes
pub fn parse_value(&self, input: &str) -> ToolParserResult<(Value, usize)> {
let mut parser = Parser::new(input, self.max_depth, self.allow_incomplete);
///
/// # Arguments
/// * `input` - The JSON string to parse
/// * `allow_partial_strings` - When false, incomplete strings cause parsing to stop
/// (matches Python's Allow.ALL & ~Allow.STR behavior)
pub fn parse_value(
&self,
input: &str,
allow_partial_strings: bool,
) -> ParserResult<(Value, usize)> {
let mut parser = Parser::new(
input,
self.max_depth,
self.allow_incomplete,
allow_partial_strings,
);
let value = parser.parse_value(0)?;
Ok((value, parser.position))
}
@@ -36,8 +50,9 @@ impl Default for PartialJson {
}
impl PartialJsonParser for PartialJson {
fn parse(&self, input: &str) -> ToolParserResult<(Value, usize)> {
self.parse_value(input)
fn parse(&self, input: &str) -> ParserResult<(Value, usize)> {
// Default to allowing partial strings
self.parse_value(input, true)
}
fn is_complete(&self, input: &str) -> bool {
@@ -56,15 +71,22 @@ struct Parser<'a> {
position: usize,
max_depth: usize,
allow_incomplete: bool,
allow_partial_strings: bool,
}
impl<'a> Parser<'a> {
fn new(input: &'a str, max_depth: usize, allow_incomplete: bool) -> Self {
fn new(
input: &'a str,
max_depth: usize,
allow_incomplete: bool,
allow_partial_strings: bool,
) -> Self {
Self {
chars: input.chars().peekable(),
position: 0,
max_depth,
allow_incomplete,
allow_partial_strings,
}
}
@@ -88,9 +110,9 @@ impl<'a> Parser<'a> {
}
}
fn parse_value(&mut self, depth: usize) -> ToolParserResult<Value> {
fn parse_value(&mut self, depth: usize) -> ParserResult<Value> {
if depth > self.max_depth {
return Err(ToolParserError::DepthExceeded(self.max_depth));
return Err(ParserError::DepthExceeded(self.max_depth));
}
self.skip_whitespace();
@@ -106,17 +128,15 @@ impl<'a> Parser<'a> {
if self.allow_incomplete {
Ok(Value::Null)
} else {
Err(ToolParserError::ParsingFailed(
"Unexpected character".into(),
))
Err(ParserError::ParsingFailed("Unexpected character".into()))
}
}
}
}
fn parse_object(&mut self, depth: usize) -> ToolParserResult<Value> {
fn parse_object(&mut self, depth: usize) -> ParserResult<Value> {
if depth > self.max_depth {
return Err(ToolParserError::DepthExceeded(self.max_depth));
return Err(ParserError::DepthExceeded(self.max_depth));
}
let mut object = Map::new();
@@ -140,7 +160,7 @@ impl<'a> Parser<'a> {
return Ok(Value::Object(object));
}
Err(e) => return Err(e),
_ => return Err(ToolParserError::ParsingFailed("Expected string key".into())),
_ => return Err(ParserError::ParsingFailed("Expected string key".into())),
};
self.skip_whitespace();
@@ -152,7 +172,7 @@ impl<'a> Parser<'a> {
object.insert(key, Value::Null);
return Ok(Value::Object(object));
}
return Err(ToolParserError::ParsingFailed("Expected ':'".into()));
return Err(ParserError::ParsingFailed("Expected ':'".into()));
}
self.advance();
self.skip_whitespace();
@@ -161,8 +181,13 @@ impl<'a> Parser<'a> {
let value = match self.parse_value(depth) {
Ok(v) => v,
Err(_) if self.allow_incomplete => {
// Add null for incomplete value
object.insert(key, Value::Null);
// When allow_partial_strings is false, don't add the key with Null
// Just return the object without this incomplete key-value pair
// This matches Python's behavior: Allow.ALL & ~Allow.STR
if self.allow_partial_strings {
// Add null for incomplete value
object.insert(key, Value::Null);
}
return Ok(Value::Object(object));
}
Err(e) => return Err(e),
@@ -192,15 +217,15 @@ impl<'a> Parser<'a> {
if self.allow_incomplete {
return Ok(Value::Object(object));
}
return Err(ToolParserError::ParsingFailed("Expected ',' or '}'".into()));
return Err(ParserError::ParsingFailed("Expected ',' or '}'".into()));
}
}
}
}
fn parse_array(&mut self, depth: usize) -> ToolParserResult<Value> {
fn parse_array(&mut self, depth: usize) -> ParserResult<Value> {
if depth > self.max_depth {
return Err(ToolParserError::DepthExceeded(self.max_depth));
return Err(ParserError::DepthExceeded(self.max_depth));
}
let mut array = Vec::new();
@@ -249,15 +274,15 @@ impl<'a> Parser<'a> {
if self.allow_incomplete {
return Ok(Value::Array(array));
}
return Err(ToolParserError::ParsingFailed("Expected ',' or ']'".into()));
return Err(ParserError::ParsingFailed("Expected ',' or ']'".into()));
}
}
}
}
fn parse_string(&mut self) -> ToolParserResult<Value> {
fn parse_string(&mut self) -> ParserResult<Value> {
if self.peek() != Some('"') {
return Err(ToolParserError::ParsingFailed("Expected '\"'".into()));
return Err(ParserError::ParsingFailed("Expected '\"'".into()));
}
// Consume opening quote
@@ -301,14 +326,14 @@ impl<'a> Parser<'a> {
}
// Incomplete string
if self.allow_incomplete {
if self.allow_incomplete && self.allow_partial_strings {
Ok(Value::String(string))
} else {
Err(ToolParserError::ParsingFailed("Unterminated string".into()))
Err(ParserError::ParsingFailed("Unterminated string".into()))
}
}
fn parse_unicode_escape(&mut self) -> ToolParserResult<char> {
fn parse_unicode_escape(&mut self) -> ParserResult<char> {
let mut hex = String::new();
for _ in 0..4 {
if let Some(ch) = self.peek() {
@@ -327,17 +352,17 @@ impl<'a> Parser<'a> {
u32::from_str_radix(&hex, 16)
.ok()
.and_then(char::from_u32)
.ok_or_else(|| ToolParserError::ParsingFailed("Invalid unicode escape".into()))
.ok_or_else(|| ParserError::ParsingFailed("Invalid unicode escape".into()))
} else if self.allow_incomplete {
Ok('\u{FFFD}') // Replacement character
} else {
Err(ToolParserError::ParsingFailed(
Err(ParserError::ParsingFailed(
"Incomplete unicode escape".into(),
))
}
}
fn parse_number(&mut self) -> ToolParserResult<Value> {
fn parse_number(&mut self) -> ParserResult<Value> {
let mut number = String::new();
// Handle negative sign
@@ -410,11 +435,11 @@ impl<'a> Parser<'a> {
} else if self.allow_incomplete {
Ok(Value::Number(serde_json::Number::from(0)))
} else {
Err(ToolParserError::ParsingFailed("Invalid number".into()))
Err(ParserError::ParsingFailed("Invalid number".into()))
}
}
fn parse_bool(&mut self) -> ToolParserResult<Value> {
fn parse_bool(&mut self) -> ParserResult<Value> {
let mut word = String::new();
// Peek at upcoming characters to validate it looks like a boolean
@@ -435,7 +460,7 @@ impl<'a> Parser<'a> {
|| (self.allow_incomplete && ("true".starts_with(&word) || "false".starts_with(&word)));
if !is_valid {
return Err(ToolParserError::ParsingFailed("Invalid boolean".into()));
return Err(ParserError::ParsingFailed("Invalid boolean".into()));
}
// Now actually consume the characters
@@ -458,14 +483,14 @@ impl<'a> Parser<'a> {
} else if "false".starts_with(partial) {
Ok(Value::Bool(false))
} else {
Err(ToolParserError::ParsingFailed("Invalid boolean".into()))
Err(ParserError::ParsingFailed("Invalid boolean".into()))
}
}
_ => Err(ToolParserError::ParsingFailed("Invalid boolean".into())),
_ => Err(ParserError::ParsingFailed("Invalid boolean".into())),
}
}
fn parse_null(&mut self) -> ToolParserResult<Value> {
fn parse_null(&mut self) -> ParserResult<Value> {
let mut word = String::new();
// Peek at upcoming characters to validate it looks like "null"
@@ -484,7 +509,7 @@ impl<'a> Parser<'a> {
let is_valid = word == "null" || (self.allow_incomplete && "null".starts_with(&word));
if !is_valid {
return Err(ToolParserError::ParsingFailed("Invalid null".into()));
return Err(ParserError::ParsingFailed("Invalid null".into()));
}
// Now actually consume the characters
@@ -501,7 +526,7 @@ impl<'a> Parser<'a> {
if word == "null" || (self.allow_incomplete && "null".starts_with(&word)) {
Ok(Value::Null)
} else {
Err(ToolParserError::ParsingFailed("Invalid null".into()))
Err(ParserError::ParsingFailed("Invalid null".into()))
}
}
}

View File

@@ -7,7 +7,7 @@ use crate::tool_parser::traits::ToolParser;
#[tokio::test]
async fn test_tool_parser_factory() {
let factory = ToolParserFactory::new();
let factory = ParserFactory::new();
// Test that we can get a pooled parser
let pooled_parser = factory.get_pooled("gpt-4");
@@ -17,7 +17,7 @@ async fn test_tool_parser_factory() {
#[tokio::test]
async fn test_tool_parser_factory_model_mapping() {
let factory = ToolParserFactory::new();
let factory = ParserFactory::new();
// Test model mapping
factory.registry().map_model("test-model", "json");
@@ -54,22 +54,22 @@ fn test_partial_json_parser() {
let parser = PartialJson::default();
let input = r#"{"name": "test", "value": 42}"#;
let (value, consumed) = parser.parse_value(input).unwrap();
let (value, consumed) = parser.parse_value(input, true).unwrap();
assert_eq!(value["name"], "test");
assert_eq!(value["value"], 42);
assert_eq!(consumed, input.len());
let input = r#"{"name": "test", "value": "#;
let (value, _consumed) = parser.parse_value(input).unwrap();
let (value, _consumed) = parser.parse_value(input, true).unwrap();
assert_eq!(value["name"], "test");
assert!(value["value"].is_null());
let input = r#"{"name": "tes"#;
let (value, _consumed) = parser.parse_value(input).unwrap();
let (value, _consumed) = parser.parse_value(input, true).unwrap();
assert_eq!(value["name"], "tes");
let input = r#"[1, 2, "#;
let (value, _consumed) = parser.parse_value(input).unwrap();
let (value, _consumed) = parser.parse_value(input, true).unwrap();
assert!(value.is_array());
assert_eq!(value[0], 1);
assert_eq!(value[1], 2);
@@ -83,17 +83,17 @@ fn test_partial_json_depth_limit() {
// This should work (simple object)
let input = r#"{"a": 1}"#;
let result = parser.parse_value(input);
let result = parser.parse_value(input, true);
assert!(result.is_ok());
// This should work (nested to depth 3)
let input = r#"{"a": {"b": {"c": 1}}}"#;
let result = parser.parse_value(input);
let result = parser.parse_value(input, true);
assert!(result.is_ok());
// This should fail (nested to depth 4, exceeds limit)
let input = r#"{"a": {"b": {"c": {"d": 1}}}}"#;
let result = parser.parse_value(input);
let result = parser.parse_value(input, true);
assert!(result.is_err());
}
@@ -244,7 +244,7 @@ fn test_json_parser_format_detection() {
#[tokio::test]
async fn test_factory_with_json_parser() {
let factory = ToolParserFactory::new();
let factory = ParserFactory::new();
// Should get JSON parser for OpenAI models
let pooled_parser = factory.get_pooled("gpt-4-turbo");

View File

@@ -1,6 +1,6 @@
use crate::protocols::spec::Tool;
use crate::tool_parser::{
errors::ToolParserResult,
errors::ParserResult,
types::{StreamingParseResult, ToolCall},
};
use async_trait::async_trait;
@@ -10,7 +10,7 @@ use async_trait::async_trait;
pub trait ToolParser: Send + Sync {
/// Parse complete tool calls from final output
/// Returns (remaining_normal_text, tool_calls) tuple
async fn parse_complete(&self, output: &str) -> ToolParserResult<(String, Vec<ToolCall>)>;
async fn parse_complete(&self, output: &str) -> ParserResult<(String, Vec<ToolCall>)>;
/// Parse tool calls from model output (streaming)
/// Parsers now maintain internal state, so self is mutable
@@ -22,7 +22,7 @@ pub trait ToolParser: Send + Sync {
&mut self,
chunk: &str,
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult>;
) -> ParserResult<StreamingParseResult>;
/// Check if text contains tool calls in this parser's format
fn has_tool_markers(&self, text: &str) -> bool;
@@ -38,12 +38,18 @@ pub trait ToolParser: Send + Sync {
fn get_unstreamed_tool_args(&self) -> Option<Vec<crate::tool_parser::types::ToolCallItem>> {
None
}
/// Reset the parser state for reuse across requests.
/// This should clear all buffers and reset state to initial values.
fn reset(&mut self) {
// Default no-op implementation
}
}
/// Trait for partial JSON parsing
pub trait PartialJsonParser: Send + Sync {
/// Parse potentially incomplete JSON
fn parse(&self, input: &str) -> ToolParserResult<(serde_json::Value, usize)>;
fn parse(&self, input: &str) -> ParserResult<(serde_json::Value, usize)>;
/// Check if JSON is complete
fn is_complete(&self, input: &str) -> bool;
@@ -55,10 +61,7 @@ pub trait PartialJsonParser: Send + Sync {
#[async_trait]
pub trait TokenToolParser: ToolParser {
/// Parse complete tool calls when provided with raw token IDs.
async fn parse_complete_tokens(
&self,
tokens: &[u32],
) -> ToolParserResult<(String, Vec<ToolCall>)>;
async fn parse_complete_tokens(&self, tokens: &[u32]) -> ParserResult<(String, Vec<ToolCall>)>;
/// Streaming parser entrypoint for token chunks.
/// Parsers maintain internal state, so self is mutable
@@ -66,5 +69,5 @@ pub trait TokenToolParser: ToolParser {
&mut self,
tokens: &[u32],
tools: &[Tool],
) -> ToolParserResult<StreamingParseResult>;
) -> ParserResult<StreamingParseResult>;
}