[router][tool parser] Modify tool parser to return both normal text and tool calls (non-stream) (#10995)
This commit is contained in:
@@ -50,14 +50,6 @@ impl DeepSeekParser {
|
||||
text.contains("<|tool▁calls▁begin|>")
|
||||
}
|
||||
|
||||
/// Extract all tool call blocks from text
|
||||
fn extract_tool_calls<'a>(&self, text: &'a str) -> Vec<&'a str> {
|
||||
self.tool_call_extractor
|
||||
.find_iter(text)
|
||||
.map(|m| m.as_str())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Parse a single tool call block
|
||||
fn parse_tool_call(&self, block: &str) -> ToolParserResult<Option<ToolCall>> {
|
||||
if let Some(captures) = self.func_detail_extractor.captures(block) {
|
||||
@@ -115,23 +107,42 @@ impl Default for DeepSeekParser {
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for DeepSeekParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
||||
// Check if text contains DeepSeek format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
// Extract all tool call blocks
|
||||
let tool_blocks = self.extract_tool_calls(text);
|
||||
// Collect matches with positions and parse tools in one pass
|
||||
let matches: Vec<_> = self.tool_call_extractor.find_iter(text).collect();
|
||||
let mut tools = Vec::new();
|
||||
|
||||
for block in tool_blocks {
|
||||
if let Some(tool) = self.parse_tool_call(block)? {
|
||||
for mat in matches.iter() {
|
||||
if let Some(tool) = self.parse_tool_call(mat.as_str())? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
// Extract normal text using first and last match positions
|
||||
let normal_text = if tools.is_empty() || matches.is_empty() {
|
||||
text.to_string()
|
||||
} else {
|
||||
let first_start = matches[0].start();
|
||||
let last_end = matches.last().unwrap().end();
|
||||
let before = if first_start > 0 {
|
||||
&text[..first_start]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let after = if last_end < text.len() {
|
||||
&text[last_end..]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
format!("{}{}", before, after)
|
||||
};
|
||||
|
||||
Ok((normal_text, tools))
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
@@ -241,10 +252,10 @@ mod tests {
|
||||
{"location": "Tokyo", "units": "celsius"}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>More text"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Tokyo"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Tokyo"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -259,12 +270,12 @@ mod tests {
|
||||
{"location": "Paris"}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Tokyo"));
|
||||
assert!(result[1].function.arguments.contains("Paris"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert_eq!(tools[1].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Tokyo"));
|
||||
assert!(tools[1].function.arguments.contains("Paris"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -130,21 +130,42 @@ impl Default for Glm4MoeParser {
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for Glm4MoeParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
||||
// Check if text contains GLM-4 MoE format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
// Extract all tool call blocks
|
||||
// Collect matches with positions and parse tools in one pass
|
||||
let matches: Vec<_> = self.tool_call_extractor.find_iter(text).collect();
|
||||
let mut tools = Vec::new();
|
||||
for mat in self.tool_call_extractor.find_iter(text) {
|
||||
|
||||
for mat in matches.iter() {
|
||||
if let Some(tool) = self.parse_tool_call(mat.as_str())? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
// Extract normal text using first and last match positions
|
||||
let normal_text = if tools.is_empty() {
|
||||
text.to_string()
|
||||
} else {
|
||||
let first_start = matches[0].start();
|
||||
let last_end = matches.last().unwrap().end();
|
||||
let before = if first_start > 0 {
|
||||
&text[..first_start]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let after = if last_end < text.len() {
|
||||
&text[last_end..]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
format!("{}{}", before, after)
|
||||
};
|
||||
|
||||
Ok((normal_text, tools))
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
@@ -232,11 +253,12 @@ mod tests {
|
||||
<arg_value>2024-06-27</arg_value>
|
||||
</tool_call>More text"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Beijing"));
|
||||
assert!(result[0].function.arguments.contains("2024-06-27"));
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Beijing"));
|
||||
assert!(tools[0].function.arguments.contains("2024-06-27"));
|
||||
assert_eq!(normal_text, "Some text\nMore text"); // Text before and after tool call
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -251,12 +273,13 @@ mod tests {
|
||||
<arg_value>Shanghai</arg_value>
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Beijing"));
|
||||
assert!(result[1].function.arguments.contains("Shanghai"));
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert_eq!(tools[1].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Beijing"));
|
||||
assert!(tools[1].function.arguments.contains("Shanghai"));
|
||||
assert_eq!(normal_text, ""); // Pure tool calls, no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -271,12 +294,13 @@ mod tests {
|
||||
<arg_value>test</arg_value>
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process_data");
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(normal_text, ""); // Pure tool call, no normal text
|
||||
assert_eq!(tools[0].function.name, "process_data");
|
||||
|
||||
// Parse arguments to check types
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["count"], 42);
|
||||
assert_eq!(args["active"], true);
|
||||
assert_eq!(args["name"], "test");
|
||||
|
||||
@@ -71,10 +71,10 @@ impl Default for GptOssParser {
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for GptOssParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
||||
// Check if text contains GPT-OSS format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
let mut tools = Vec::new();
|
||||
@@ -119,7 +119,7 @@ impl ToolParser for GptOssParser {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
Ok((String::new(), tools)) // GPT-OSS parser returns empty normal text
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
@@ -239,10 +239,10 @@ mod tests {
|
||||
<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "San Francisco"}<|call|>
|
||||
More text"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("San Francisco"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("San Francisco"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -251,12 +251,12 @@ More text"#;
|
||||
let input = r#"<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "Paris"}<|call|>commentary
|
||||
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "Paris tourism"}<|call|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "search");
|
||||
assert!(result[0].function.arguments.contains("Paris"));
|
||||
assert!(result[1].function.arguments.contains("Paris tourism"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert_eq!(tools[1].function.name, "search");
|
||||
assert!(tools[0].function.arguments.contains("Paris"));
|
||||
assert!(tools[1].function.arguments.contains("Paris tourism"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -264,9 +264,9 @@ More text"#;
|
||||
let parser = GptOssParser::new();
|
||||
let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -275,10 +275,10 @@ More text"#;
|
||||
let input =
|
||||
r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_time");
|
||||
assert_eq!(result[0].function.arguments, "{}");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_time");
|
||||
assert_eq!(tools[0].function.arguments, "{}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -88,64 +88,65 @@ impl JsonParser {
|
||||
content.trim()
|
||||
}
|
||||
|
||||
/// Try to extract a JSON object or array from text that may contain other content
|
||||
fn extract_json_from_text(&self, text: &str) -> Option<String> {
|
||||
// Look for JSON object starting with {
|
||||
if let Some(start) = text.find('{') {
|
||||
let mut depth = 0;
|
||||
let mut in_string = false;
|
||||
let mut escape_next = false;
|
||||
/// Try to extract a first valid JSON object or array from text that may contain other content
|
||||
/// Returns (json_string, normal_text) where normal_text is text before and after the JSON
|
||||
fn extract_json_from_text(&self, text: &str) -> Option<(String, String)> {
|
||||
let mut in_string = false;
|
||||
let mut escape = false;
|
||||
let mut stack: Vec<char> = Vec::with_capacity(8);
|
||||
let mut start: Option<usize> = None;
|
||||
|
||||
for (i, ch) in text[start..].char_indices() {
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
for (i, ch) in text.char_indices() {
|
||||
if escape {
|
||||
escape = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
match ch {
|
||||
'\\' if in_string => escape = true,
|
||||
'"' => in_string = !in_string,
|
||||
_ if in_string => {}
|
||||
'{' | '[' => {
|
||||
if start.is_none() {
|
||||
start = Some(i);
|
||||
}
|
||||
stack.push(ch);
|
||||
}
|
||||
'}' | ']' => {
|
||||
let Some(open) = stack.pop() else {
|
||||
// Stray closer - reset and continue looking for next valid JSON
|
||||
start = None;
|
||||
continue;
|
||||
};
|
||||
|
||||
match ch {
|
||||
'\\' if in_string => escape_next = true,
|
||||
'"' if !in_string => in_string = true,
|
||||
'"' if in_string => in_string = false,
|
||||
'{' if !in_string => depth += 1,
|
||||
'}' if !in_string => {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
return Some(text[start..start + i + 1].to_string());
|
||||
let valid = (open == '{' && ch == '}') || (open == '[' && ch == ']');
|
||||
if !valid {
|
||||
// Mismatch - reset and continue looking
|
||||
start = None;
|
||||
stack.clear();
|
||||
continue;
|
||||
}
|
||||
|
||||
if stack.is_empty() {
|
||||
let s = start.unwrap();
|
||||
let e = i + ch.len_utf8();
|
||||
let potential_json = &text[s..e];
|
||||
|
||||
// Validate that this is actually valid JSON before returning
|
||||
if serde_json::from_str::<Value>(potential_json).is_ok() {
|
||||
let json = potential_json.to_string();
|
||||
let normal = format!("{}{}", &text[..s], &text[e..]);
|
||||
return Some((json, normal));
|
||||
} else {
|
||||
// Not valid JSON, reset and continue looking
|
||||
start = None;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
// Look for JSON array starting with [
|
||||
if let Some(start) = text.find('[') {
|
||||
let mut depth = 0;
|
||||
let mut in_string = false;
|
||||
let mut escape_next = false;
|
||||
|
||||
for (i, ch) in text[start..].char_indices() {
|
||||
if escape_next {
|
||||
escape_next = false;
|
||||
continue;
|
||||
}
|
||||
|
||||
match ch {
|
||||
'\\' if in_string => escape_next = true,
|
||||
'"' if !in_string => in_string = true,
|
||||
'"' if in_string => in_string = false,
|
||||
'[' if !in_string => depth += 1,
|
||||
']' if !in_string => {
|
||||
depth -= 1;
|
||||
if depth == 0 {
|
||||
return Some(text[start..start + i + 1].to_string());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
@@ -241,16 +242,20 @@ impl Default for JsonParser {
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for JsonParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
||||
// Check if we have multiple start tokens (e.g., multiple <|python_tag|> markers)
|
||||
if !self.token_config.start_tokens.is_empty() {
|
||||
let start_token = &self.token_config.start_tokens[0];
|
||||
if !start_token.is_empty() && text.matches(start_token).count() > 1 {
|
||||
// We have multiple occurrences of the start token
|
||||
let mut all_tools = Vec::new();
|
||||
let mut all_normal_text = String::new();
|
||||
let mut remaining = text;
|
||||
|
||||
while let Some(start_pos) = remaining.find(start_token.as_str()) {
|
||||
// Add text before this start token to normal text
|
||||
all_normal_text.push_str(&remaining[..start_pos]);
|
||||
|
||||
// Extract content after this start token
|
||||
let after_token = &remaining[start_pos + start_token.len()..];
|
||||
|
||||
@@ -264,12 +269,19 @@ impl ToolParser for JsonParser {
|
||||
let json_content = &after_token[..end_pos];
|
||||
|
||||
// Try to extract and parse JSON from this segment
|
||||
if let Some(extracted) = self.extract_json_from_text(json_content) {
|
||||
if let Some((extracted, segment_normal_text)) =
|
||||
self.extract_json_from_text(json_content)
|
||||
{
|
||||
if let Ok(value) = serde_json::from_str::<Value>(&extracted) {
|
||||
if let Ok(tools) = self.parse_json_value(&value) {
|
||||
all_tools.extend(tools);
|
||||
}
|
||||
}
|
||||
// Add the normal text from this segment
|
||||
all_normal_text.push_str(&segment_normal_text);
|
||||
} else {
|
||||
// If no JSON found, add the entire content as normal text
|
||||
all_normal_text.push_str(json_content);
|
||||
}
|
||||
|
||||
// Move to the next segment
|
||||
@@ -279,9 +291,10 @@ impl ToolParser for JsonParser {
|
||||
}
|
||||
}
|
||||
|
||||
if !all_tools.is_empty() {
|
||||
return Ok(all_tools);
|
||||
}
|
||||
// Add any remaining text
|
||||
all_normal_text.push_str(remaining);
|
||||
|
||||
return Ok((all_normal_text, all_tools));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -290,21 +303,30 @@ impl ToolParser for JsonParser {
|
||||
|
||||
// Try to parse as JSON first
|
||||
match serde_json::from_str::<Value>(json_content) {
|
||||
Ok(value) => self.parse_json_value(&value),
|
||||
Ok(value) => {
|
||||
let tools = self.parse_json_value(&value)?;
|
||||
Ok((String::new(), tools))
|
||||
}
|
||||
Err(_) => {
|
||||
// If parse failed, check if we have multiple JSON objects separated by the configured separator
|
||||
// This handles cases like: {"name": "func1", ...};{"name": "func2", ...}
|
||||
// Only do this if we can reasonably expect multiple complete JSON objects
|
||||
// (i.e., text starts and ends with JSON-like structure)
|
||||
if !self.token_config.separator.is_empty()
|
||||
&& json_content.contains(&self.token_config.separator)
|
||||
&& json_content.trim().starts_with('{')
|
||||
&& json_content.trim().ends_with('}')
|
||||
{
|
||||
let mut all_tools = Vec::new();
|
||||
|
||||
// Split by separator and try to parse each part
|
||||
let parts: Vec<&str> =
|
||||
json_content.split(&self.token_config.separator).collect();
|
||||
let mut normal_parts = Vec::new();
|
||||
|
||||
for part in parts {
|
||||
let trimmed = part.trim();
|
||||
if trimmed.is_empty() {
|
||||
normal_parts.push(trimmed.to_string());
|
||||
continue;
|
||||
}
|
||||
|
||||
@@ -313,32 +335,40 @@ impl ToolParser for JsonParser {
|
||||
if let Ok(tools) = self.parse_json_value(&value) {
|
||||
all_tools.extend(tools);
|
||||
}
|
||||
} else if let Some(extracted) = self.extract_json_from_text(trimmed) {
|
||||
normal_parts.push(trimmed.to_string());
|
||||
} else if let Some((extracted, part_normal_text)) =
|
||||
self.extract_json_from_text(trimmed)
|
||||
{
|
||||
// Try extracting JSON from this part
|
||||
if let Ok(value) = serde_json::from_str::<Value>(&extracted) {
|
||||
if let Ok(tools) = self.parse_json_value(&value) {
|
||||
all_tools.extend(tools);
|
||||
}
|
||||
}
|
||||
normal_parts.push(part_normal_text);
|
||||
} else {
|
||||
normal_parts.push(trimmed.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
if !all_tools.is_empty() {
|
||||
return Ok(all_tools);
|
||||
}
|
||||
// Rejoin with the original separator to preserve it
|
||||
let all_normal_text = normal_parts.join(&self.token_config.separator);
|
||||
|
||||
return Ok((all_normal_text, all_tools));
|
||||
}
|
||||
|
||||
// If no wrapper tokens configured and parse failed,
|
||||
// try to extract JSON from mixed text
|
||||
// If no wrapper tokens configured and parse failed, try to extract JSON from mixed text
|
||||
if self.token_config.start_tokens.is_empty() {
|
||||
if let Some(extracted) = self.extract_json_from_text(text) {
|
||||
if let Ok(value) = serde_json::from_str::<Value>(&extracted) {
|
||||
return self.parse_json_value(&value);
|
||||
if let Some((extracted_json, normal_text)) = self.extract_json_from_text(text) {
|
||||
if let Ok(value) = serde_json::from_str::<Value>(&extracted_json) {
|
||||
let tools = self.parse_json_value(&value)?;
|
||||
return Ok((normal_text, tools));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Not valid JSON, return empty
|
||||
Ok(vec![])
|
||||
|
||||
// No valid JSON found, return original text as normal text
|
||||
Ok((text.to_string(), vec![]))
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -538,9 +568,41 @@ mod tests {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "get_weather");
|
||||
assert_eq!(normal_text, ""); // Pure JSON should have no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_json_with_normal_text() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Test extraction of JSON from mixed text
|
||||
let input =
|
||||
r#"Here is some text before {"name": "test", "arguments": {}} and some text after."#;
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "test");
|
||||
assert_eq!(
|
||||
normal_text,
|
||||
"Here is some text before and some text after."
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_json_array_with_normal_text() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Test extraction of JSON array from mixed text
|
||||
let input = r#"Prefix text [{"name": "func1", "arguments": {}}, {"name": "func2", "arguments": {}}] suffix text"#;
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(tool_calls.len(), 2);
|
||||
assert_eq!(tool_calls[0].function.name, "func1");
|
||||
assert_eq!(tool_calls[1].function.name, "func2");
|
||||
assert_eq!(normal_text, "Prefix text suffix text");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -551,10 +613,11 @@ mod tests {
|
||||
{"name": "search", "arguments": {"query": "news"}}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "search");
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 2);
|
||||
assert_eq!(tool_calls[0].function.name, "get_weather");
|
||||
assert_eq!(tool_calls[1].function.name, "search");
|
||||
assert_eq!(normal_text, ""); // Pure JSON should have no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -562,10 +625,11 @@ mod tests {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
assert!(result[0].function.arguments.contains("10"));
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "calculate");
|
||||
assert!(tool_calls[0].function.arguments.contains("10"));
|
||||
assert_eq!(normal_text, ""); // Pure JSON should have no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -577,9 +641,38 @@ mod tests {
|
||||
});
|
||||
|
||||
let input = r#"<tool>{"name": "test", "arguments": {}}</tool>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "test");
|
||||
assert_eq!(normal_text, ""); // Wrapper tokens with no extra text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_start_token_invalid_json() {
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<|python_tag|>".to_string()],
|
||||
end_tokens: vec!["".to_string()],
|
||||
separator: ";".to_string(),
|
||||
});
|
||||
|
||||
let input = r#"Hello world <|python_tag|>this is not valid json at all"#;
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 0);
|
||||
assert_eq!(normal_text, input); // Should return entire original text when JSON parsing fails
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_normal_text() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"Here is the weather data: {"name": "get_weather", "arguments": {"location": "SF"}} Let me know if you need more info."#;
|
||||
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "get_weather");
|
||||
assert_eq!(
|
||||
normal_text,
|
||||
"Here is the weather data: Let me know if you need more info."
|
||||
); // Normal text is now extracted when JSON is found in mixed content
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -79,16 +79,18 @@ impl Default for KimiK2Parser {
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for KimiK2Parser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
||||
// Check if text contains Kimi K2 format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
// Collect matches with positions and parse tools in one pass
|
||||
let matches: Vec<_> = self.tool_call_extractor.captures_iter(text).collect();
|
||||
let mut tools = Vec::new();
|
||||
|
||||
// Extract all tool calls
|
||||
for captures in self.tool_call_extractor.captures_iter(text) {
|
||||
// Extract all tool calls using collected matches
|
||||
for captures in matches.iter() {
|
||||
if let (Some(id_match), Some(args_match)) = (
|
||||
captures.name("tool_call_id"),
|
||||
captures.name("function_arguments"),
|
||||
@@ -116,7 +118,26 @@ impl ToolParser for KimiK2Parser {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
// Extract normal text using first and last match positions
|
||||
let normal_text = if tools.is_empty() || matches.is_empty() {
|
||||
text.to_string()
|
||||
} else {
|
||||
let first_start = matches[0].get(0).unwrap().start();
|
||||
let last_end = matches.last().unwrap().get(0).unwrap().end();
|
||||
let before = if first_start > 0 {
|
||||
&text[..first_start]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let after = if last_end < text.len() {
|
||||
&text[last_end..]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
format!("{}{}", before, after)
|
||||
};
|
||||
|
||||
Ok((normal_text, tools))
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
@@ -227,10 +248,10 @@ mod tests {
|
||||
<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>More text"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Tokyo"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Tokyo"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -241,10 +262,10 @@ mod tests {
|
||||
<|tool_call_begin|>functions.calculate:1<|tool_call_argument_begin|>{"expression": "2+2"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "calculate");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "search");
|
||||
assert_eq!(tools[1].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -254,9 +275,9 @@ mod tests {
|
||||
<|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value"} <|tool_call_end|>
|
||||
<|tool_calls_section_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -42,22 +42,32 @@ impl Default for LlamaParser {
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for LlamaParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
||||
// First try with the configured python_tag parser
|
||||
let result = self.json_parser.parse_complete(text).await?;
|
||||
let (_json_normal_text, tools) = self.json_parser.parse_complete(text).await?;
|
||||
|
||||
if !result.is_empty() {
|
||||
return Ok(result);
|
||||
if !tools.is_empty() {
|
||||
// Extract normal text before the python tag
|
||||
// JsonParser doesn't preserve normal text for single start tokens, so we do it manually
|
||||
let normal_text = if let Some(tag_pos) = text.find("<|python_tag|>") {
|
||||
text[..tag_pos].to_string()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
return Ok((normal_text, tools));
|
||||
}
|
||||
|
||||
// If no results and text starts with '{', try plain JSON
|
||||
if text.trim_start().starts_with('{') {
|
||||
// Create a temporary plain JSON parser
|
||||
let plain_parser = JsonParser::new();
|
||||
return plain_parser.parse_complete(text).await;
|
||||
let (_json_normal_text, tools) = plain_parser.parse_complete(text).await?;
|
||||
// For plain JSON, don't extract normal text (consistent with JsonParser behavior)
|
||||
return Ok((String::new(), tools));
|
||||
}
|
||||
|
||||
Ok(vec![])
|
||||
// No tool calls found, return original text as normal text
|
||||
Ok((text.to_string(), vec![]))
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
@@ -99,10 +109,11 @@ mod tests {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert!(result[0].function.arguments.contains("weather"));
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "search");
|
||||
assert!(tool_calls[0].function.arguments.contains("weather"));
|
||||
assert_eq!(normal_text, ""); // Pure python_tag with JSON should have no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -110,9 +121,10 @@ mod tests {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "calculate");
|
||||
assert_eq!(normal_text, ""); // Pure JSON should have no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -120,9 +132,10 @@ mod tests {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_time");
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "get_time");
|
||||
assert_eq!(normal_text, "Let me help you with that. ");
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -141,15 +154,15 @@ mod tests {
|
||||
// Note: Llama 3.2 doesn't handle multiple calls well
|
||||
let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
let (_normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
// We expect this to either parse the first JSON object or fail gracefully
|
||||
// Since the semicolon makes it invalid JSON, it will likely return empty
|
||||
// This is acceptable as Llama 3.2 doesn't reliably support parallel calls
|
||||
|
||||
// If it parses anything, it should be func1
|
||||
if !result.is_empty() {
|
||||
assert_eq!(result[0].function.name, "func1");
|
||||
if !tool_calls.is_empty() {
|
||||
assert_eq!(tool_calls[0].function.name, "func1");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -38,6 +38,10 @@ impl MistralParser {
|
||||
/// - Escape sequences
|
||||
/// - Bracket depth
|
||||
fn extract_json_array<'a>(&self, text: &'a str) -> Option<&'a str> {
|
||||
self.extract_json_array_with_pos(text).map(|(_, json)| json)
|
||||
}
|
||||
|
||||
fn extract_json_array_with_pos<'a>(&self, text: &'a str) -> Option<(usize, &'a str)> {
|
||||
const BOT_TOKEN: &str = "[TOOL_CALLS] [";
|
||||
|
||||
// Find the start of the token
|
||||
@@ -78,7 +82,7 @@ impl MistralParser {
|
||||
bracket_count -= 1;
|
||||
if bracket_count == 0 {
|
||||
// Found the matching closing bracket
|
||||
return Some(&text[json_start..=i]);
|
||||
return Some((start_idx, &text[json_start..=i]));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -154,18 +158,31 @@ impl Default for MistralParser {
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for MistralParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
||||
// Check if text contains Mistral format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
// Extract JSON array from Mistral format
|
||||
if let Some(json_array) = self.extract_json_array(text) {
|
||||
self.parse_json_array(json_array)
|
||||
// Extract JSON array from Mistral format with position
|
||||
if let Some((start_idx, json_array)) = self.extract_json_array_with_pos(text) {
|
||||
// Extract normal text before BOT_TOKEN
|
||||
let normal_text_before = if start_idx > 0 {
|
||||
text[..start_idx].to_string()
|
||||
} else {
|
||||
String::new()
|
||||
};
|
||||
|
||||
match self.parse_json_array(json_array) {
|
||||
Ok(tools) => Ok((normal_text_before, tools)),
|
||||
Err(_) => {
|
||||
// If JSON parsing fails, return the original text as normal text
|
||||
Ok((text.to_string(), vec![]))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Markers present but no complete array found
|
||||
Ok(vec![])
|
||||
Ok((text.to_string(), vec![]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -291,10 +308,10 @@ mod tests {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "Paris", "units": "celsius"}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Paris"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Paris"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -305,10 +322,10 @@ mod tests {
|
||||
{"name": "calculate", "arguments": {"expression": "2 + 2"}}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "calculate");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "search");
|
||||
assert_eq!(tools[1].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -316,11 +333,11 @@ mod tests {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "process", "arguments": {"data": [1, 2, [3, 4]], "config": {"nested": [5, 6]}}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "process");
|
||||
// JSON serialization removes spaces, so check for [3,4] without spaces
|
||||
assert!(result[0].function.arguments.contains("[3,4]"));
|
||||
assert!(tools[0].function.arguments.contains("[3,4]"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -328,9 +345,9 @@ mod tests {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"message": "He said \"Hello [World]\""}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "echo");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "echo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -45,7 +45,8 @@ impl PythonicParser {
|
||||
}
|
||||
|
||||
/// Extract tool calls using bracket counting (similar to MistralParser)
|
||||
fn extract_tool_calls(&self, text: &str) -> Option<String> {
|
||||
/// Returns extracted tool call group with [] and normal content
|
||||
fn extract_tool_calls(&self, text: &str) -> Option<(String, String)> {
|
||||
// Find the start of a tool call list - look for [ followed by a function name
|
||||
let chars: Vec<char> = text.chars().collect();
|
||||
|
||||
@@ -103,7 +104,11 @@ impl PythonicParser {
|
||||
// Found the matching bracket
|
||||
let extracted: String = chars[start_idx..=i].iter().collect();
|
||||
if extracted.contains('(') && extracted.contains(')') {
|
||||
return Some(extracted);
|
||||
// Calculate normal text by removing the tool call portion
|
||||
let before = &text[..start_idx];
|
||||
let after = &text[(i + 1)..];
|
||||
let normal_text = format!("{}{}", before, after);
|
||||
return Some((extracted, normal_text));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -260,11 +265,11 @@ impl PythonicParser {
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for PythonicParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
||||
let cleaned = Self::strip_special_tokens(text);
|
||||
|
||||
// Extract tool calls using bracket counting
|
||||
if let Some(tool_calls_text) = self.extract_tool_calls(&cleaned) {
|
||||
if let Some((tool_calls_text, normal_text)) = self.extract_tool_calls(&cleaned) {
|
||||
// Remove the outer brackets
|
||||
let tool_calls_str = &tool_calls_text[1..tool_calls_text.len() - 1];
|
||||
|
||||
@@ -318,9 +323,9 @@ impl ToolParser for PythonicParser {
|
||||
}
|
||||
}
|
||||
|
||||
Ok(calls)
|
||||
Ok((normal_text, calls))
|
||||
} else {
|
||||
Ok(vec![])
|
||||
Ok((text.to_string(), vec![]))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -336,11 +341,11 @@ impl ToolParser for PythonicParser {
|
||||
// Try to parse if we have a complete tool call
|
||||
let cleaned = Self::strip_special_tokens(&state.buffer);
|
||||
if self.extract_tool_calls(&cleaned).is_some() {
|
||||
let result = self.parse_complete(&state.buffer).await?;
|
||||
if !result.is_empty() {
|
||||
let (_normal_text, tools) = self.parse_complete(&state.buffer).await?;
|
||||
if !tools.is_empty() {
|
||||
state.buffer.clear();
|
||||
return Ok(StreamResult::ToolComplete(
|
||||
result.into_iter().next().unwrap(),
|
||||
tools.into_iter().next().unwrap(),
|
||||
));
|
||||
}
|
||||
}
|
||||
@@ -369,11 +374,11 @@ mod tests {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[search_web(query="Rust programming", max_results=5)]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search_web");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "search_web");
|
||||
|
||||
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["query"], "Rust programming");
|
||||
assert_eq!(args["max_results"], 5);
|
||||
}
|
||||
@@ -383,10 +388,10 @@ mod tests {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[get_weather(city="Tokyo"), search(query="news")]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "search");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert_eq!(tools[1].function.name, "search");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -394,10 +399,10 @@ mod tests {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[test(flag=True, disabled=False, optional=None)]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
|
||||
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["flag"], true);
|
||||
assert_eq!(args["disabled"], false);
|
||||
assert_eq!(args["optional"], Value::Null);
|
||||
@@ -408,11 +413,11 @@ mod tests {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"<|python_start|>[calculate(x=10, y=20)]<|python_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "calculate");
|
||||
|
||||
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["x"], 10);
|
||||
assert_eq!(args["y"], 20);
|
||||
}
|
||||
@@ -422,12 +427,41 @@ mod tests {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[get_weather(city="London", units="celsius")]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
|
||||
let args: Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["city"], "London");
|
||||
assert_eq!(args["units"], "celsius");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_normal_text_extraction() {
|
||||
let parser = PythonicParser::new();
|
||||
|
||||
// Test with text before and after
|
||||
let input = r#"Please check the weather [get_weather(city="Tokyo")] and let me know."#;
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert_eq!(normal_text, "Please check the weather and let me know.");
|
||||
|
||||
// Test with only normal text (no tool calls)
|
||||
let input_no_tools = "This is just normal text without any tool calls.";
|
||||
let (normal_text, tools) = parser.parse_complete(input_no_tools).await.unwrap();
|
||||
|
||||
assert_eq!(tools.len(), 0);
|
||||
assert_eq!(normal_text, input_no_tools);
|
||||
|
||||
// Test with multiple tool calls in single bracket group and normal text
|
||||
let input_multiple = r#"First, [search(query="rust"), calculate(x=5, y=10)] please."#;
|
||||
let (normal_text, tools) = parser.parse_complete(input_multiple).await.unwrap();
|
||||
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "search");
|
||||
assert_eq!(tools[1].function.name, "calculate");
|
||||
assert_eq!(normal_text, "First, please.");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,32 +128,51 @@ impl Default for QwenParser {
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for QwenParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
||||
// Check if text contains Qwen format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
// Extract all tool call blocks
|
||||
let tool_blocks = self.extract_tool_calls(text);
|
||||
// Collect matches with positions and parse tools in one pass
|
||||
let matches: Vec<_> = self.extractor.captures_iter(text).collect();
|
||||
let mut tools = Vec::new();
|
||||
|
||||
for (index, json_str) in tool_blocks.iter().enumerate() {
|
||||
// Parse each JSON block
|
||||
match serde_json::from_str::<Value>(json_str.trim()) {
|
||||
Ok(value) => {
|
||||
if let Some(tool) = self.parse_single_object(&value, index)? {
|
||||
tools.push(tool);
|
||||
for (index, captures) in matches.iter().enumerate() {
|
||||
if let Some(json_str) = captures.get(1) {
|
||||
match serde_json::from_str::<Value>(json_str.as_str().trim()) {
|
||||
Ok(value) => {
|
||||
if let Some(tool) = self.parse_single_object(&value, index)? {
|
||||
tools.push(tool);
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// JSON parsing failed, might be incomplete
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// Skip malformed JSON blocks
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(tools)
|
||||
// Extract normal text using first and last match positions
|
||||
let normal_text = if tools.is_empty() {
|
||||
text.to_string()
|
||||
} else {
|
||||
let first_start = matches[0].get(0).unwrap().start();
|
||||
let last_end = matches.last().unwrap().get(0).unwrap().end();
|
||||
let before = if first_start > 0 {
|
||||
&text[..first_start]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let after = if last_end < text.len() {
|
||||
&text[last_end..]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
format!("{}{}", before, after)
|
||||
};
|
||||
|
||||
Ok((normal_text, tools))
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
@@ -276,10 +295,11 @@ mod tests {
|
||||
{"name": "get_weather", "arguments": {"location": "Beijing", "units": "celsius"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Beijing"));
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Beijing"));
|
||||
assert_eq!(normal_text, ""); // Pure tool call, no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -292,10 +312,11 @@ mod tests {
|
||||
{"name": "calculate", "arguments": {"expression": "2 + 2"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "calculate");
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "search");
|
||||
assert_eq!(tools[1].function.name, "calculate");
|
||||
assert_eq!(normal_text, ""); // Pure tool calls, no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -307,9 +328,13 @@ mod tests {
|
||||
</tool_call>
|
||||
Here are the results."#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_info");
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_info");
|
||||
assert_eq!(
|
||||
normal_text,
|
||||
"Let me help you with that.\n\nHere are the results."
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -329,10 +354,11 @@ Here are the results."#;
|
||||
}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process_data");
|
||||
assert!(result[0].function.arguments.contains("nested"));
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "process_data");
|
||||
assert!(tools[0].function.arguments.contains("nested"));
|
||||
assert_eq!(normal_text, ""); // Pure tool call, no normal text
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -157,10 +157,10 @@ impl Default for Step3Parser {
|
||||
|
||||
#[async_trait]
|
||||
impl ToolParser for Step3Parser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<Vec<ToolCall>> {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
||||
// Check if text contains Step3 format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok(vec![]);
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
// Find the tool calls section
|
||||
@@ -170,6 +170,7 @@ impl ToolParser for Step3Parser {
|
||||
// Find the end of tool calls section
|
||||
if let Some(end_pos) = text[search_from..].find("<|tool_calls_end|>") {
|
||||
let tool_section = &text[search_from..search_from + end_pos];
|
||||
let end_abs = search_from + end_pos + "<|tool_calls_end|>".len();
|
||||
|
||||
// Extract all tool call blocks
|
||||
let mut tools = Vec::new();
|
||||
@@ -179,11 +180,24 @@ impl ToolParser for Step3Parser {
|
||||
}
|
||||
}
|
||||
|
||||
return Ok(tools);
|
||||
// Extract normal text before start and after end
|
||||
let before = if start_pos > 0 {
|
||||
&text[..start_pos]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let after = if end_abs < text.len() {
|
||||
&text[end_abs..]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let normal_text = format!("{}{}", before, after);
|
||||
|
||||
return Ok((normal_text, tools));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(vec![])
|
||||
Ok((text.to_string(), vec![]))
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
@@ -289,11 +303,11 @@ mod tests {
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>More text"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("Tokyo"));
|
||||
assert!(result[0].function.arguments.contains("celsius"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Tokyo"));
|
||||
assert!(tools[0].function.arguments.contains("celsius"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -308,10 +322,10 @@ mod tests {
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "calculate");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "search");
|
||||
assert_eq!(tools[1].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -326,12 +340,12 @@ mod tests {
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process_data");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "process_data");
|
||||
|
||||
// Parse arguments to check types
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["count"], 42);
|
||||
assert_eq!(args["active"], true);
|
||||
assert_eq!(args["rate"], 1.5);
|
||||
|
||||
@@ -242,12 +242,12 @@ async fn test_json_parser_complete_single() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco", "units": "celsius"}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("San Francisco"));
|
||||
assert!(result[0].function.arguments.contains("celsius"));
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("San Francisco"));
|
||||
assert!(tools[0].function.arguments.contains("celsius"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -259,11 +259,11 @@ async fn test_json_parser_complete_array() {
|
||||
{"name": "get_news", "arguments": {"query": "technology"}}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "get_news");
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert_eq!(tools[1].function.name, "get_news");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -271,13 +271,13 @@ async fn test_json_parser_with_parameters() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20, "operation": "add"}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
assert!(result[0].function.arguments.contains("10"));
|
||||
assert!(result[0].function.arguments.contains("20"));
|
||||
assert!(result[0].function.arguments.contains("add"));
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "calculate");
|
||||
assert!(tools[0].function.arguments.contains("10"));
|
||||
assert!(tools[0].function.arguments.contains("20"));
|
||||
assert!(tools[0].function.arguments.contains("add"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -289,10 +289,10 @@ async fn test_json_parser_with_tokens() {
|
||||
});
|
||||
|
||||
let input = r#"[TOOL_CALLS] [{"name": "search", "arguments": {"query": "rust programming"}}]"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "search");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -313,12 +313,12 @@ async fn test_multiline_json_with_tokens() {
|
||||
}
|
||||
}</tool>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert!(result[0].function.arguments.contains("San Francisco"));
|
||||
assert!(result[0].function.arguments.contains("celsius"));
|
||||
assert!(result[0].function.arguments.contains("true"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("San Francisco"));
|
||||
assert!(tools[0].function.arguments.contains("celsius"));
|
||||
assert!(tools[0].function.arguments.contains("true"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -342,12 +342,12 @@ async fn test_multiline_json_array() {
|
||||
}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "function1");
|
||||
assert_eq!(result[1].function.name, "function2");
|
||||
assert!(result[0].function.arguments.contains("value1"));
|
||||
assert!(result[1].function.arguments.contains("[1,2,3]"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "function1");
|
||||
assert_eq!(tools[1].function.name, "function2");
|
||||
assert!(tools[0].function.arguments.contains("value1"));
|
||||
assert!(tools[1].function.arguments.contains("[1,2,3]"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -397,9 +397,9 @@ async fn test_registry_with_json_parser() {
|
||||
let parser = registry.get_parser("gpt-4-turbo").unwrap();
|
||||
|
||||
let input = r#"{"name": "test", "arguments": {"x": 1}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -407,9 +407,9 @@ async fn test_json_parser_invalid_input() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Invalid JSON should return empty results
|
||||
assert_eq!(parser.parse_complete("not json").await.unwrap().len(), 0);
|
||||
assert_eq!(parser.parse_complete("{invalid}").await.unwrap().len(), 0);
|
||||
assert_eq!(parser.parse_complete("").await.unwrap().len(), 0);
|
||||
assert_eq!(parser.parse_complete("not json").await.unwrap().1.len(), 0);
|
||||
assert_eq!(parser.parse_complete("{invalid}").await.unwrap().1.len(), 0);
|
||||
assert_eq!(parser.parse_complete("").await.unwrap().1.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -418,11 +418,11 @@ async fn test_json_parser_empty_arguments() {
|
||||
|
||||
// Tool call with no arguments
|
||||
let input = r#"{"name": "get_time"}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_time");
|
||||
assert_eq!(result[0].function.arguments, "{}");
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_time");
|
||||
assert_eq!(tools[0].function.arguments, "{}");
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -435,14 +435,14 @@ mod failure_cases {
|
||||
|
||||
// Missing name field
|
||||
let input = r#"{"arguments": {"x": 1}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0, "Should return empty for tool without name");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 0, "Should return empty for tool without name");
|
||||
|
||||
// Empty name
|
||||
let input = r#"{"name": "", "arguments": {"x": 1}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1, "Should accept empty name string");
|
||||
assert_eq!(result[0].function.name, "");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1, "Should accept empty name string");
|
||||
assert_eq!(tools[0].function.name, "");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -451,22 +451,22 @@ mod failure_cases {
|
||||
|
||||
// Arguments is a string instead of object
|
||||
let input = r#"{"name": "test", "arguments": "not an object"}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
// Should serialize the string as JSON
|
||||
assert!(result[0].function.arguments.contains("not an object"));
|
||||
assert!(tools[0].function.arguments.contains("not an object"));
|
||||
|
||||
// Arguments is a number
|
||||
let input = r#"{"name": "test", "arguments": 42}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.arguments, "42");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.arguments, "42");
|
||||
|
||||
// Arguments is null
|
||||
let input = r#"{"name": "test", "arguments": null}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.arguments, "null");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.arguments, "null");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -479,26 +479,26 @@ mod failure_cases {
|
||||
|
||||
// Missing end token
|
||||
let input = r#"<tool>{"name": "test", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(
|
||||
result.len(),
|
||||
tools.len(),
|
||||
0,
|
||||
"Should fail to parse without complete wrapper"
|
||||
);
|
||||
|
||||
// Missing start token - parser looks for complete wrapper, so this won't parse
|
||||
let input = r#"{"name": "test", "arguments": {}}</tool>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(
|
||||
result.len(),
|
||||
tools.len(),
|
||||
0,
|
||||
"Should not parse JSON with incomplete wrapper"
|
||||
);
|
||||
|
||||
// Mismatched tokens
|
||||
let input = r#"<tool>{"name": "test", "arguments": {}}</wrong>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0, "Should fail with mismatched tokens");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 0, "Should fail with mismatched tokens");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -507,18 +507,18 @@ mod failure_cases {
|
||||
|
||||
// Trailing comma
|
||||
let input = r#"{"name": "test", "arguments": {"x": 1,}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0, "Should reject JSON with trailing comma");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 0, "Should reject JSON with trailing comma");
|
||||
|
||||
// Missing quotes on keys
|
||||
let input = r#"{name: "test", arguments: {}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0, "Should reject invalid JSON syntax");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 0, "Should reject invalid JSON syntax");
|
||||
|
||||
// Unclosed object
|
||||
let input = r#"{"name": "test", "arguments": {"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0, "Should reject incomplete JSON");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 0, "Should reject incomplete JSON");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -532,17 +532,17 @@ mod edge_cases {
|
||||
|
||||
// Unicode in function name
|
||||
let input = r#"{"name": "获取天气", "arguments": {"location": "北京"}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "获取天气");
|
||||
assert!(result[0].function.arguments.contains("北京"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "获取天气");
|
||||
assert!(tools[0].function.arguments.contains("北京"));
|
||||
|
||||
// Emoji in arguments
|
||||
let input = r#"{"name": "send_message", "arguments": {"text": "Hello 👋 World 🌍"}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("👋"));
|
||||
assert!(result[0].function.arguments.contains("🌍"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert!(tools[0].function.arguments.contains("👋"));
|
||||
assert!(tools[0].function.arguments.contains("🌍"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -551,22 +551,22 @@ mod edge_cases {
|
||||
|
||||
// Escaped quotes in arguments
|
||||
let input = r#"{"name": "echo", "arguments": {"text": "He said \"hello\""}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains(r#"\"hello\""#));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert!(tools[0].function.arguments.contains(r#"\"hello\""#));
|
||||
|
||||
// Escaped backslashes
|
||||
let input = r#"{"name": "path", "arguments": {"dir": "C:\\Users\\test"}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("\\\\"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert!(tools[0].function.arguments.contains("\\\\"));
|
||||
|
||||
// Newlines and tabs
|
||||
let input = r#"{"name": "format", "arguments": {"text": "line1\nline2\ttabbed"}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("\\n"));
|
||||
assert!(result[0].function.arguments.contains("\\t"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert!(tools[0].function.arguments.contains("\\n"));
|
||||
assert!(tools[0].function.arguments.contains("\\t"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -580,10 +580,10 @@ mod edge_cases {
|
||||
}
|
||||
large_args.push_str(r#""final": "value"}}"#);
|
||||
|
||||
let result = parser.parse_complete(&large_args).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process");
|
||||
assert!(result[0].function.arguments.contains("field_999"));
|
||||
let (_normal_text, tools) = parser.parse_complete(&large_args).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "process");
|
||||
assert!(tools[0].function.arguments.contains("field_999"));
|
||||
|
||||
// Large array of tool calls
|
||||
let mut large_array = "[".to_string();
|
||||
@@ -595,9 +595,9 @@ mod edge_cases {
|
||||
}
|
||||
large_array.push(']');
|
||||
|
||||
let result = parser.parse_complete(&large_array).await.unwrap();
|
||||
assert_eq!(result.len(), 100);
|
||||
assert_eq!(result[99].function.name, "func_99");
|
||||
let (_normal_text, tools) = parser.parse_complete(&large_array).await.unwrap();
|
||||
assert_eq!(tools.len(), 100);
|
||||
assert_eq!(tools[99].function.name, "func_99");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -612,10 +612,10 @@ mod edge_cases {
|
||||
{"key": "value", "another": "field"}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2, "Should only parse valid tool calls");
|
||||
assert_eq!(result[0].function.name, "tool1");
|
||||
assert_eq!(result[1].function.name, "tool2");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2, "Should only parse valid tool calls");
|
||||
assert_eq!(tools[0].function.name, "tool1");
|
||||
assert_eq!(tools[1].function.name, "tool2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -624,14 +624,14 @@ mod edge_cases {
|
||||
|
||||
// JSON with duplicate keys (last one wins in most parsers)
|
||||
let input = r#"{"name": "first", "name": "second", "arguments": {"x": 1, "x": 2}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(
|
||||
result[0].function.name, "second",
|
||||
tools[0].function.name, "second",
|
||||
"Last duplicate key should win"
|
||||
);
|
||||
assert!(
|
||||
result[0].function.arguments.contains("2"),
|
||||
tools[0].function.arguments.contains("2"),
|
||||
"Last duplicate value should win"
|
||||
);
|
||||
}
|
||||
@@ -642,15 +642,15 @@ mod edge_cases {
|
||||
|
||||
// Null values in arguments
|
||||
let input = r#"{"name": "test", "arguments": {"required": "value", "optional": null}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("null"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert!(tools[0].function.arguments.contains("null"));
|
||||
|
||||
// Array with null
|
||||
let input = r#"{"name": "test", "arguments": {"items": [1, null, "three"]}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("null"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert!(tools[0].function.arguments.contains("null"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -663,22 +663,22 @@ mod edge_cases {
|
||||
|
||||
// First pattern
|
||||
let input = r#"<<{"name": "test1", "arguments": {}}>>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test1");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "test1");
|
||||
|
||||
// Second pattern
|
||||
let input = r#"<tool>{"name": "test2", "arguments": {}}</tool>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test2");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "test2");
|
||||
|
||||
// Nested patterns (should use first match)
|
||||
let input = r#"<<tool>{"name": "test3", "arguments": {}}</tool>>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
// This is tricky - depends on regex behavior
|
||||
// The parser should handle this gracefully
|
||||
assert!(result.len() <= 1, "Should not parse multiple times");
|
||||
assert!(tools.len() <= 1, "Should not parse multiple times");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -743,25 +743,25 @@ mod edge_cases {
|
||||
|
||||
// Boolean values
|
||||
let input = r#"{"name": "toggle", "arguments": {"enabled": true, "disabled": false}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("true"));
|
||||
assert!(result[0].function.arguments.contains("false"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert!(tools[0].function.arguments.contains("true"));
|
||||
assert!(tools[0].function.arguments.contains("false"));
|
||||
|
||||
// Numbers (including float and negative)
|
||||
let input = r#"{"name": "calc", "arguments": {"int": 42, "float": 3.14, "negative": -17}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("42"));
|
||||
assert!(result[0].function.arguments.contains("3.14"));
|
||||
assert!(result[0].function.arguments.contains("-17"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert!(tools[0].function.arguments.contains("42"));
|
||||
assert!(tools[0].function.arguments.contains("3.14"));
|
||||
assert!(tools[0].function.arguments.contains("-17"));
|
||||
|
||||
// Empty arrays and objects
|
||||
let input = r#"{"name": "test", "arguments": {"empty_arr": [], "empty_obj": {}}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("[]"));
|
||||
assert!(result[0].function.arguments.contains("{}"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert!(tools[0].function.arguments.contains("[]"));
|
||||
assert!(tools[0].function.arguments.contains("{}"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -770,15 +770,15 @@ mod edge_cases {
|
||||
|
||||
// Using "function" instead of "name"
|
||||
let input = r#"{"function": "test_func", "arguments": {"x": 1}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test_func");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "test_func");
|
||||
|
||||
// Both "name" and "function" present (name should take precedence)
|
||||
let input = r#"{"name": "primary", "function": "secondary", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "primary");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "primary");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -792,15 +792,15 @@ mod edge_cases {
|
||||
"key" : "value"
|
||||
}
|
||||
} "#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "test");
|
||||
|
||||
// Minified JSON (no whitespace)
|
||||
let input = r#"{"name":"compact","arguments":{"a":1,"b":2}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "compact");
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "compact");
|
||||
}
|
||||
}
|
||||
|
||||
@@ -830,9 +830,9 @@ mod stress_tests {
|
||||
}
|
||||
}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert!(result[0].function.arguments.contains("deep"));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert!(tools[0].function.arguments.contains("deep"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
@@ -845,9 +845,9 @@ mod stress_tests {
|
||||
let parser_clone = parser.clone();
|
||||
let handle = tokio::spawn(async move {
|
||||
let input = format!(r#"{{"name": "func_{}", "arguments": {{}}}}"#, i);
|
||||
let result = parser_clone.parse_complete(&input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, format!("func_{}", i));
|
||||
let (_normal_text, tools) = parser_clone.parse_complete(&input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, format!("func_{}", i));
|
||||
});
|
||||
handles.push(handle);
|
||||
}
|
||||
|
||||
@@ -9,7 +9,8 @@ use async_trait::async_trait;
|
||||
#[async_trait]
|
||||
pub trait ToolParser: Send + Sync {
|
||||
/// Parse complete tool calls from final output
|
||||
async fn parse_complete(&self, output: &str) -> ToolParserResult<Vec<ToolCall>>;
|
||||
/// Returns (remaining_normal_text, tool_calls) tuple
|
||||
async fn parse_complete(&self, output: &str) -> ToolParserResult<(String, Vec<ToolCall>)>;
|
||||
|
||||
/// Parse tool calls from model output (streaming)
|
||||
async fn parse_incremental(
|
||||
|
||||
Reference in New Issue
Block a user