[router][grpc] Support tool call parser in streaming (#11160)
This commit is contained in:
@@ -2,12 +2,14 @@ use async_trait::async_trait;
|
||||
use regex::Regex;
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::protocols::spec::Tool;
|
||||
|
||||
use crate::tool_parser::{
|
||||
errors::{ToolParserError, ToolParserResult},
|
||||
parsers::helpers,
|
||||
partial_json::PartialJson,
|
||||
state::ParseState,
|
||||
traits::ToolParser,
|
||||
types::{FunctionCall, StreamResult, ToolCall},
|
||||
types::{FunctionCall, StreamingParseResult, ToolCall},
|
||||
};
|
||||
|
||||
/// Qwen format parser for tool calls
|
||||
@@ -19,11 +21,36 @@ use crate::tool_parser::{
|
||||
/// - XML-style tags with JSON content
|
||||
/// - Support for multiple sequential tool calls
|
||||
/// - Newline-aware parsing
|
||||
/// - Buffering for partial end tokens
|
||||
pub struct QwenParser {
|
||||
/// Parser for handling incomplete JSON during streaming
|
||||
partial_json: PartialJson,
|
||||
/// Regex for extracting tool calls
|
||||
|
||||
/// Regex for extracting tool calls in parse_complete
|
||||
extractor: Regex,
|
||||
|
||||
/// Buffer for accumulating incomplete patterns across chunks
|
||||
buffer: String,
|
||||
|
||||
/// Stores complete tool call info (name and arguments) for each tool being parsed
|
||||
prev_tool_call_arr: Vec<Value>,
|
||||
|
||||
/// Index of currently streaming tool call (-1 means no active tool)
|
||||
current_tool_id: i32,
|
||||
|
||||
/// Flag for whether current tool's name has been sent to client
|
||||
current_tool_name_sent: bool,
|
||||
|
||||
/// Tracks raw JSON string content streamed to client for each tool's arguments
|
||||
streamed_args_for_tool: Vec<String>,
|
||||
|
||||
/// Buffer for normal text that might precede partial end tokens
|
||||
normal_text_buffer: String,
|
||||
|
||||
/// Token configuration
|
||||
bot_token: &'static str,
|
||||
eot_token: &'static str,
|
||||
tool_call_separator: &'static str,
|
||||
}
|
||||
|
||||
impl QwenParser {
|
||||
@@ -36,11 +63,20 @@ impl QwenParser {
|
||||
Self {
|
||||
partial_json: PartialJson::default(),
|
||||
extractor,
|
||||
buffer: String::new(),
|
||||
prev_tool_call_arr: Vec::new(),
|
||||
current_tool_id: -1,
|
||||
current_tool_name_sent: false,
|
||||
streamed_args_for_tool: Vec::new(),
|
||||
normal_text_buffer: String::new(),
|
||||
bot_token: "<tool_call>\n",
|
||||
eot_token: "\n</tool_call>",
|
||||
tool_call_separator: "\n",
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse a single JSON object into a ToolCall
|
||||
fn parse_single_object(&self, obj: &Value, index: usize) -> ToolParserResult<Option<ToolCall>> {
|
||||
fn parse_single_object(&self, obj: &Value) -> ToolParserResult<Option<ToolCall>> {
|
||||
let name = obj.get("name").and_then(|v| v.as_str());
|
||||
|
||||
if let Some(name) = name {
|
||||
@@ -52,8 +88,12 @@ impl QwenParser {
|
||||
let arguments = serde_json::to_string(args)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Generate ID with index for multiple tools
|
||||
let id = format!("qwen_call_{}", index);
|
||||
// Generate unique ID
|
||||
let id = obj
|
||||
.get("id")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from)
|
||||
.unwrap_or_else(|| format!("qwen_call_{}", uuid::Uuid::new_v4()));
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id,
|
||||
@@ -73,42 +113,9 @@ impl QwenParser {
|
||||
text.contains("<tool_call>")
|
||||
}
|
||||
|
||||
/// Find the start position of a tool call
|
||||
fn find_tool_start(&self, text: &str) -> Option<usize> {
|
||||
text.find("<tool_call>\n")
|
||||
}
|
||||
|
||||
/// Find the end position of a tool call
|
||||
fn find_tool_end(&self, text: &str, start_pos: usize) -> Option<usize> {
|
||||
let search_from = start_pos + "<tool_call>\n".len();
|
||||
text[search_from..]
|
||||
.find("\n</tool_call>")
|
||||
.map(|pos| search_from + pos + "\n</tool_call>".len())
|
||||
}
|
||||
|
||||
/// Check if buffer ends with a partial token
|
||||
fn ends_with_partial_token(&self, buffer: &str) -> Option<usize> {
|
||||
// Check for partial start token
|
||||
let start_token = "<tool_call>\n";
|
||||
// Use inclusive range to check if entire buffer could be a prefix
|
||||
for i in 1..=start_token.len().min(buffer.len()) {
|
||||
if start_token.starts_with(&buffer[buffer.len() - i..]) {
|
||||
return Some(i);
|
||||
}
|
||||
}
|
||||
|
||||
// Check for partial end token
|
||||
let end_token = "\n</tool_call>";
|
||||
// Only check if buffer ends with a partial match (not the complete token without newline)
|
||||
// If buffer ends with "</tool_call>", that's not a partial token - it's missing the newline
|
||||
if buffer.ends_with("</tool_call>") {
|
||||
// This is a complete end tag, just missing the leading newline
|
||||
// Not a partial token situation
|
||||
return None;
|
||||
}
|
||||
// Use inclusive range to check if entire buffer could be a prefix
|
||||
(1..=end_token.len().min(buffer.len()))
|
||||
.find(|&i| end_token.starts_with(&buffer[buffer.len() - i..]))
|
||||
/// Check if text has tool call
|
||||
fn has_tool_call(&self, text: &str) -> bool {
|
||||
text.contains("<tool_call>")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -132,17 +139,17 @@ impl ToolParser for QwenParser {
|
||||
|
||||
// Extract tool calls
|
||||
let mut tools = Vec::new();
|
||||
for (index, captures) in self.extractor.captures_iter(text).enumerate() {
|
||||
for captures in self.extractor.captures_iter(text) {
|
||||
if let Some(json_str) = captures.get(1) {
|
||||
let parsed = serde_json::from_str::<Value>(json_str.as_str().trim())
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))
|
||||
.and_then(|v| self.parse_single_object(&v, index));
|
||||
.and_then(|v| self.parse_single_object(&v));
|
||||
|
||||
match parsed {
|
||||
Ok(Some(tool)) => tools.push(tool),
|
||||
Ok(None) => continue,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse tool call {}: {:?}", index, e);
|
||||
tracing::warn!("Failed to parse tool call: {:?}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
@@ -158,103 +165,91 @@ impl ToolParser for QwenParser {
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
&self,
|
||||
&mut self,
|
||||
chunk: &str,
|
||||
state: &mut ParseState,
|
||||
) -> ToolParserResult<StreamResult> {
|
||||
state.buffer.push_str(chunk);
|
||||
tools: &[Tool],
|
||||
) -> ToolParserResult<StreamingParseResult> {
|
||||
// Append new text to buffer
|
||||
self.buffer.push_str(chunk);
|
||||
let current_text = &self.buffer.clone();
|
||||
|
||||
// Check for partial token at end of buffer
|
||||
if let Some(_partial_len) = self.ends_with_partial_token(&state.buffer) {
|
||||
// Hold back the partial token
|
||||
return Ok(StreamResult::Incomplete);
|
||||
}
|
||||
// Check if current_text has tool_call
|
||||
let has_tool_start = self.has_tool_call(current_text)
|
||||
|| (self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator));
|
||||
|
||||
// Check if we have the start marker
|
||||
if !self.has_tool_markers(&state.buffer) {
|
||||
// No tool markers detected - return all buffered content as normal text
|
||||
let normal_text = std::mem::take(&mut state.buffer);
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
}
|
||||
if !has_tool_start {
|
||||
// Only clear buffer if we're sure no tool call is starting
|
||||
if helpers::ends_with_partial_token(&self.buffer, self.bot_token).is_none() {
|
||||
let normal_text = self.buffer.clone();
|
||||
self.buffer.clear();
|
||||
|
||||
// Check for text before tool markers and extract it as normal text
|
||||
if let Some(marker_pos) = state.buffer.find("<tool_call>") {
|
||||
if marker_pos > 0 {
|
||||
// We have text before the tool marker - extract it as normal text
|
||||
let normal_text: String = state.buffer.drain(..marker_pos).collect();
|
||||
return Ok(StreamResult::NormalText(normal_text));
|
||||
}
|
||||
}
|
||||
|
||||
// Find start and end positions
|
||||
if let Some(start_pos) = self.find_tool_start(&state.buffer) {
|
||||
// Check if we have the complete tool call
|
||||
if let Some(end_pos) = self.find_tool_end(&state.buffer, start_pos) {
|
||||
// Extract the JSON content
|
||||
let json_start = start_pos + "<tool_call>\n".len();
|
||||
let json_end = end_pos - "\n</tool_call>".len();
|
||||
let json_str = &state.buffer[json_start..json_end];
|
||||
|
||||
// Parse the complete JSON
|
||||
match serde_json::from_str::<Value>(json_str.trim()) {
|
||||
Ok(value) => {
|
||||
if let Some(tool) = self.parse_single_object(&value, 0)? {
|
||||
// Clear the consumed part from buffer using drain for efficiency
|
||||
state.buffer.drain(..end_pos);
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// JSON parsing failed, might be incomplete or malformed
|
||||
// If we have what looks like a complete tool call block, treat as normal text
|
||||
if state.buffer[start_pos..end_pos].contains("\n</tool_call>") {
|
||||
let malformed_text: String = state.buffer.drain(..end_pos).collect();
|
||||
return Ok(StreamResult::NormalText(malformed_text));
|
||||
}
|
||||
}
|
||||
}
|
||||
return Ok(StreamingParseResult {
|
||||
normal_text,
|
||||
calls: vec![],
|
||||
});
|
||||
} else {
|
||||
// We have start but no end yet - try partial parsing
|
||||
let json_start = start_pos + "<tool_call>\n".len();
|
||||
let partial_json = &state.buffer[json_start..];
|
||||
// Might be partial bot_token, keep buffering
|
||||
return Ok(StreamingParseResult::default());
|
||||
}
|
||||
}
|
||||
|
||||
// Remove trailing newline if present (might be start of end token)
|
||||
let partial_json = partial_json.trim_end();
|
||||
// Build tool indices
|
||||
let tool_indices = helpers::get_tool_indices(tools);
|
||||
|
||||
// Try to parse with partial JSON parser
|
||||
match self.partial_json.parse_value(partial_json) {
|
||||
Ok((value, _consumed)) => {
|
||||
// Extract tool name if available
|
||||
if let Some(name) = value.get("name").and_then(|v| v.as_str()) {
|
||||
// Check if we've already sent the name
|
||||
if !state.in_string {
|
||||
state.in_string = true; // Use as flag for "name sent"
|
||||
return Ok(StreamResult::ToolName {
|
||||
index: 0,
|
||||
name: name.to_string(),
|
||||
});
|
||||
}
|
||||
// Determine start index for JSON parsing
|
||||
let start_idx = if let Some(pos) = current_text.find(self.bot_token) {
|
||||
pos + self.bot_token.len()
|
||||
} else if self.current_tool_id >= 0 && current_text.starts_with(self.tool_call_separator) {
|
||||
self.tool_call_separator.len()
|
||||
} else {
|
||||
0
|
||||
};
|
||||
|
||||
// Check for arguments
|
||||
if let Some(args) = value.get("arguments") {
|
||||
if let Ok(args_str) = serde_json::to_string(args) {
|
||||
return Ok(StreamResult::ToolArguments {
|
||||
index: 0,
|
||||
arguments: args_str,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Failed to parse even as partial JSON
|
||||
// Keep buffering
|
||||
}
|
||||
let mut result = helpers::handle_json_tool_streaming(
|
||||
current_text,
|
||||
start_idx,
|
||||
&mut self.partial_json,
|
||||
&tool_indices,
|
||||
&mut self.buffer,
|
||||
&mut self.current_tool_id,
|
||||
&mut self.current_tool_name_sent,
|
||||
&mut self.streamed_args_for_tool,
|
||||
&mut self.prev_tool_call_arr,
|
||||
)?;
|
||||
|
||||
// Qwen-specific: Handle partial end tokens in normal text
|
||||
// After tool calls complete, normal text might contain partial "</tool_call>" tags
|
||||
if !result.normal_text.is_empty() {
|
||||
self.normal_text_buffer.push_str(&result.normal_text);
|
||||
|
||||
// Check if buffer contains complete end token (without leading newline)
|
||||
let end_token_without_newline = &self.eot_token[1..]; // "</tool_call>"
|
||||
if self.normal_text_buffer.contains(end_token_without_newline) {
|
||||
// Complete end token found - clean it and return
|
||||
let cleaned_text = self
|
||||
.normal_text_buffer
|
||||
.replace(end_token_without_newline, "");
|
||||
self.normal_text_buffer.clear();
|
||||
result.normal_text = cleaned_text;
|
||||
} else {
|
||||
// Check if buffer might contain partial end token at the end
|
||||
if let Some(partial_match_len) = helpers::ends_with_partial_token(
|
||||
&self.normal_text_buffer,
|
||||
end_token_without_newline,
|
||||
) {
|
||||
// Keep potential partial match in buffer, return the rest
|
||||
let split_point = self.normal_text_buffer.len() - partial_match_len;
|
||||
result.normal_text = self.normal_text_buffer[..split_point].to_string();
|
||||
self.normal_text_buffer = self.normal_text_buffer[split_point..].to_string();
|
||||
} else {
|
||||
// No partial match, return all buffered text
|
||||
result.normal_text = self.normal_text_buffer.clone();
|
||||
self.normal_text_buffer.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(StreamResult::Incomplete)
|
||||
Ok(result)
|
||||
}
|
||||
|
||||
fn detect_format(&self, text: &str) -> bool {
|
||||
|
||||
Reference in New Issue
Block a user