[router][grpc] Support tool call parser in streaming (#11160)
This commit is contained in:
@@ -15,11 +15,13 @@ use rustpython_parser::{parse, Mode};
|
||||
use serde_json::{Map, Number, Value};
|
||||
use std::sync::OnceLock;
|
||||
|
||||
use crate::protocols::spec::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
state::ParseState,
|
||||
parsers::helpers,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall, ToolCallItem},
|
||||
};
|
||||
|
||||
static PYTHONIC_BLOCK_REGEX: OnceLock<Regex> = OnceLock::new();
|
||||
@@ -37,13 +39,23 @@ fn pythonic_block_regex() -> &'static Regex {
|
||||
}
|
||||
|
||||
/// Parser for Pythonic tool call format
|
||||
#[derive(Default)]
|
||||
pub struct PythonicParser;
|
||||
pub struct PythonicParser {
|
||||
/// Buffer for accumulating chunks
|
||||
buffer: String,
|
||||
}
|
||||
|
||||
impl Default for PythonicParser {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
impl PythonicParser {
|
||||
/// Create a new Pythonic parser
|
||||
pub fn new() -> Self {
|
||||
Self
|
||||
Self {
|
||||
buffer: String::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract the first pythonic tool call block and return it along with the
|
||||
@@ -105,23 +117,90 @@ impl ToolParser for PythonicParser {
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
&mut self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
self.buffer.push_str(chunk);
|
||||
|
||||
let cleaned = Self::strip_special_tokens(&state.buffer);
|
||||
if let Some((tool_calls_text, _)) = self.extract_tool_calls(&cleaned) {
|
||||
if let Ok(tools) = self.parse_tool_call_block(&tool_calls_text) {
|
||||
if let Some(tool) = tools.into_iter().next() {
|
||||
state.buffer.clear();
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
let cleaned = Self::strip_special_tokens(&self.buffer);
|
||||
|
||||
// Look for opening bracket
|
||||
if let Some(start) = cleaned.find('[') {
|
||||
let normal_text = if start > 0 {
|
||||
cleaned[..start].to_string()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
// Look for matching closing bracket
|
||||
if let Some(end) = find_matching_bracket(&cleaned, start) {
|
||||
// Found complete tool call - extract it and parse using parse_complete
|
||||
let call_text = &cleaned[start..=end];
|
||||
|
||||
match self.parse_complete(call_text).await {
|
||||
Ok((_, calls)) => {
|
||||
// Update buffer with remaining text after tool call
|
||||
let remaining_text = &cleaned[end + 1..];
|
||||
self.buffer = remaining_text.to_string();
|
||||
|
||||
// Validate tool names and convert ToolCall to ToolCallItem
|
||||
let tool_indices = helpers::get_tool_indices(tools);
|
||||
let items: Vec<ToolCallItem> = calls
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(idx, tool)| {
|
||||
if !tool_indices.contains_key(&tool.function.name) {
|
||||
tracing::warn!(
|
||||
"Invalid tool name '{}' - skipping",
|
||||
tool.function.name
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(ToolCallItem {
|
||||
tool_index: idx,
|
||||
name: Some(tool.function.name),
|
||||
parameters: tool.function.arguments,
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: items,
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse pythonic tool call: {}", e);
|
||||
// Clear buffer on error
|
||||
self.buffer.clear();
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// We have an opening bracket but no closing bracket yet
|
||||
// Put back everything from the bracket onwards
|
||||
self.buffer = cleaned[start..].to_string();
|
||||
|
||||
if !normal_text.is_empty() {
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: vec![],
|
||||
});
|
||||
}
|
||||
|
||||
// Still accumulating a potential tool call
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
// No tool call bracket found
|
||||
self.buffer.clear();
|
||||
Ok(StreamingParseResult {
|
||||
normal_text: cleaned,
|
||||
calls: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
@@ -134,6 +213,25 @@ impl ToolParser for PythonicParser {
|
||||
}
|
||||
}
|
||||
|
||||
/// Find the matching closing bracket for the opening bracket at start position.
|
||||
/// Properly handles nested brackets.
|
||||
fn find_matching_bracket(buffer: &str, start: usize) -> Option<usize> {
|
||||
let mut bracket_count = 0;
|
||||
let chars: Vec<char> = buffer.chars().collect();
|
||||
|
||||
for (i, &ch) in chars.iter().enumerate().skip(start) {
|
||||
if ch == '[' {
|
||||
bracket_count += 1;
|
||||
} else if ch == ']' {
|
||||
bracket_count -= 1;
|
||||
if bracket_count == 0 {
|
||||
return Some(i);
|
||||
}
|
||||
}
|
||||
}
|
||||
None // No matching bracket found
|
||||
}
|
||||
|
||||
fn parse_python_expression(source: &str) -> ToolParserResult<Expr> {
|
||||
let module = parse(source, Mode::Expression, "<pythonic_tool_call>")
|
||||
.map_err(|err| ToolParserError::ParsingFailed(err.to_string()))?;
|
||||
|
||||
Reference in New Issue
Block a user