[router][tool call] Improve normal content extraction and error handling (non-stream) (#11050)
This commit is contained in:
@@ -50,52 +50,58 @@ impl DeepSeekParser {
|
||||
text.contains("<|tool▁calls▁begin|>")
|
||||
}
|
||||
|
||||
/// Parse a single tool call block
|
||||
fn parse_tool_call(&self, block: &str) -> ToolParserResult<Option<ToolCall>> {
|
||||
if let Some(captures) = self.func_detail_extractor.captures(block) {
|
||||
// Get function type (should be "function")
|
||||
let func_type = captures.get(1).map_or("", |m| m.as_str());
|
||||
if func_type != "function" {
|
||||
return Ok(None);
|
||||
}
|
||||
/// Parse a single tool call block - throws error if parsing fails
|
||||
fn parse_tool_call(&self, block: &str) -> ToolParserResult<ToolCall> {
|
||||
let captures = self.func_detail_extractor.captures(block).ok_or_else(|| {
|
||||
ToolParserError::ParsingFailed("Failed to match tool call pattern".to_string())
|
||||
})?;
|
||||
|
||||
// Get function name
|
||||
let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
|
||||
|
||||
// Get JSON arguments
|
||||
let json_args = captures.get(3).map_or("{}", |m| m.as_str()).trim();
|
||||
|
||||
// Parse JSON arguments
|
||||
match serde_json::from_str::<Value>(json_args) {
|
||||
Ok(value) => {
|
||||
// Create arguments object
|
||||
let args = if value.is_object() {
|
||||
value
|
||||
} else {
|
||||
// If not an object, wrap it
|
||||
serde_json::json!({ "value": value })
|
||||
};
|
||||
|
||||
let arguments = serde_json::to_string(&args)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Generate ID
|
||||
let id = format!("deepseek_call_{}", uuid::Uuid::new_v4());
|
||||
|
||||
Ok(Some(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: func_name.to_string(),
|
||||
arguments,
|
||||
},
|
||||
}))
|
||||
}
|
||||
Err(_) => Ok(None),
|
||||
}
|
||||
} else {
|
||||
Ok(None)
|
||||
// Get function type (should be "function")
|
||||
let func_type = captures.get(1).map_or("", |m| m.as_str());
|
||||
if func_type != "function" {
|
||||
return Err(ToolParserError::ParsingFailed(format!(
|
||||
"Invalid function type: {}",
|
||||
func_type
|
||||
)));
|
||||
}
|
||||
|
||||
// Get function name
|
||||
let func_name = captures.get(2).map_or("", |m| m.as_str()).trim();
|
||||
if func_name.is_empty() {
|
||||
return Err(ToolParserError::ParsingFailed(
|
||||
"Empty function name".to_string(),
|
||||
));
|
||||
}
|
||||
|
||||
// Get JSON arguments
|
||||
let json_args = captures.get(3).map_or("{}", |m| m.as_str()).trim();
|
||||
|
||||
// Parse JSON arguments
|
||||
let value = serde_json::from_str::<Value>(json_args)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(format!("Invalid JSON: {}", e)))?;
|
||||
|
||||
// Create arguments object
|
||||
let args = if value.is_object() {
|
||||
value
|
||||
} else {
|
||||
// If not an object, wrap it
|
||||
serde_json::json!({ "value": value })
|
||||
};
|
||||
|
||||
let arguments = serde_json::to_string(&args)
|
||||
.map_err(|e| ToolParserError::ParsingFailed(e.to_string()))?;
|
||||
|
||||
// Generate ID
|
||||
let id = format!("deepseek_call_{}", uuid::Uuid::new_v4());
|
||||
|
||||
Ok(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: func_name.to_string(),
|
||||
arguments,
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -108,39 +114,30 @@ impl Default for DeepSeekParser {
|
||||
#[async_trait]
|
||||
impl ToolParser for DeepSeekParser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
||||
// Check if text contains DeepSeek format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
// Collect matches with positions and parse tools in one pass
|
||||
let matches: Vec<_> = self.tool_call_extractor.find_iter(text).collect();
|
||||
let mut tools = Vec::new();
|
||||
// Find where tool calls begin
|
||||
let idx = text.find("<|tool▁calls▁begin|>").unwrap();
|
||||
let normal_text = text[..idx].to_string();
|
||||
|
||||
for mat in matches.iter() {
|
||||
if let Some(tool) = self.parse_tool_call(mat.as_str())? {
|
||||
tools.push(tool);
|
||||
// Try to extract tool calls, log warnings for failures
|
||||
let mut tools = Vec::new();
|
||||
for mat in self.tool_call_extractor.find_iter(text) {
|
||||
match self.parse_tool_call(mat.as_str()) {
|
||||
Ok(tool) => tools.push(tool),
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse tool call: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract normal text using first and last match positions
|
||||
let normal_text = if tools.is_empty() || matches.is_empty() {
|
||||
text.to_string()
|
||||
} else {
|
||||
let first_start = matches[0].start();
|
||||
let last_end = matches.last().unwrap().end();
|
||||
let before = if first_start > 0 {
|
||||
&text[..first_start]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let after = if last_end < text.len() {
|
||||
&text[last_end..]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
format!("{}{}", before, after)
|
||||
};
|
||||
// If no tools were successfully parsed despite having markers, return entire text as fallback
|
||||
if tools.is_empty() {
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
Ok((normal_text, tools))
|
||||
}
|
||||
@@ -185,11 +182,16 @@ impl ToolParser for DeepSeekParser {
|
||||
// Extract and parse the complete tool call
|
||||
let tool_call_text = &state.buffer[call_start_abs..call_end_abs];
|
||||
|
||||
if let Some(tool) = self.parse_tool_call(tool_call_text)? {
|
||||
// Remove the processed part from buffer
|
||||
state.buffer.drain(..call_end_abs);
|
||||
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
match self.parse_tool_call(tool_call_text) {
|
||||
Ok(tool) => {
|
||||
// Remove the processed part from buffer
|
||||
state.buffer.drain(..call_end_abs);
|
||||
return Ok(StreamResult::ToolComplete(tool));
|
||||
}
|
||||
Err(_) => {
|
||||
// Parsing failed, skip this tool call
|
||||
state.buffer.drain(..call_end_abs);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Tool call not complete yet, try to extract partial info
|
||||
@@ -248,51 +250,3 @@ impl ToolParser for DeepSeekParser {
|
||||
self.has_tool_markers(text)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_deepseek_single_tool() {
|
||||
let parser = DeepSeekParser::new();
|
||||
let input = r#"Some text
|
||||
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"location": "Tokyo", "units": "celsius"}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>More text"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Tokyo"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_deepseek_multiple_tools() {
|
||||
let parser = DeepSeekParser::new();
|
||||
let input = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"location": "Tokyo"}
|
||||
```<|tool▁call▁end|>
|
||||
<|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"location": "Paris"}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert_eq!(tools[1].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Tokyo"));
|
||||
assert!(tools[1].function.arguments.contains("Paris"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = DeepSeekParser::new();
|
||||
assert!(parser.detect_format("<|tool▁calls▁begin|>"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,34 +136,27 @@ impl ToolParser for Glm4MoeParser {
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
// Collect matches with positions and parse tools in one pass
|
||||
let matches: Vec<_> = self.tool_call_extractor.find_iter(text).collect();
|
||||
let mut tools = Vec::new();
|
||||
// Find where tool calls begin
|
||||
let idx = text.find("<tool_call>").unwrap();
|
||||
let normal_text = text[..idx].to_string();
|
||||
|
||||
for mat in matches.iter() {
|
||||
if let Some(tool) = self.parse_tool_call(mat.as_str())? {
|
||||
tools.push(tool);
|
||||
// Extract tool calls
|
||||
let mut tools = Vec::new();
|
||||
for mat in self.tool_call_extractor.find_iter(text) {
|
||||
match self.parse_tool_call(mat.as_str()) {
|
||||
Ok(Some(tool)) => tools.push(tool),
|
||||
Ok(None) => continue,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse tool call: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract normal text using first and last match positions
|
||||
let normal_text = if tools.is_empty() {
|
||||
text.to_string()
|
||||
} else {
|
||||
let first_start = matches[0].start();
|
||||
let last_end = matches.last().unwrap().end();
|
||||
let before = if first_start > 0 {
|
||||
&text[..first_start]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let after = if last_end < text.len() {
|
||||
&text[last_end..]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
format!("{}{}", before, after)
|
||||
};
|
||||
// If no tools were successfully parsed despite having markers, return entire text as fallback
|
||||
if tools.is_empty() {
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
Ok((normal_text, tools))
|
||||
}
|
||||
@@ -247,80 +240,3 @@ impl ToolParser for Glm4MoeParser {
|
||||
self.has_tool_markers(text)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_glm4_single_tool() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
let input = r#"Some text
|
||||
<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2024-06-27</arg_value>
|
||||
</tool_call>More text"#;
|
||||
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Beijing"));
|
||||
assert!(tools[0].function.arguments.contains("2024-06-27"));
|
||||
assert_eq!(normal_text, "Some text\nMore text"); // Text before and after tool call
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_glm4_multiple_tools() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
let input = r#"<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
</tool_call>
|
||||
<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Shanghai</arg_value>
|
||||
</tool_call>"#;
|
||||
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert_eq!(tools[1].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Beijing"));
|
||||
assert!(tools[1].function.arguments.contains("Shanghai"));
|
||||
assert_eq!(normal_text, ""); // Pure tool calls, no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_glm4_mixed_types() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
let input = r#"<tool_call>process_data
|
||||
<arg_key>count</arg_key>
|
||||
<arg_value>42</arg_value>
|
||||
<arg_key>active</arg_key>
|
||||
<arg_value>true</arg_value>
|
||||
<arg_key>name</arg_key>
|
||||
<arg_value>test</arg_value>
|
||||
</tool_call>"#;
|
||||
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(normal_text, ""); // Pure tool call, no normal text
|
||||
assert_eq!(tools[0].function.name, "process_data");
|
||||
|
||||
// Parse arguments to check types
|
||||
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["count"], 42);
|
||||
assert_eq!(args["active"], true);
|
||||
assert_eq!(args["name"], "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
assert!(parser.detect_format("<tool_call>"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -227,66 +227,3 @@ impl ToolParser for GptOssParser {
|
||||
self.has_tool_markers(text) || text.contains("<|channel|>commentary")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_gpt_oss_single_tool() {
|
||||
let parser = GptOssParser::new();
|
||||
let input = r#"Some text
|
||||
<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "San Francisco"}<|call|>
|
||||
More text"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("San Francisco"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_gpt_oss_multiple_tools() {
|
||||
let parser = GptOssParser::new();
|
||||
let input = r#"<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "Paris"}<|call|>commentary
|
||||
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "Paris tourism"}<|call|>"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert_eq!(tools[1].function.name, "search");
|
||||
assert!(tools[0].function.arguments.contains("Paris"));
|
||||
assert!(tools[1].function.arguments.contains("Paris tourism"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_gpt_oss_with_prefix() {
|
||||
let parser = GptOssParser::new();
|
||||
let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_gpt_oss_empty_args() {
|
||||
let parser = GptOssParser::new();
|
||||
let input =
|
||||
r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_time");
|
||||
assert_eq!(tools[0].function.arguments, "{}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = GptOssParser::new();
|
||||
assert!(parser.detect_format("<|channel|>commentary to="));
|
||||
assert!(parser.detect_format("<|channel|>commentary"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -615,155 +615,3 @@ impl ToolParser for JsonParser {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_single_tool_call() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
|
||||
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "get_weather");
|
||||
assert_eq!(normal_text, ""); // Pure JSON should have no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_json_with_normal_text() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Test extraction of JSON from mixed text
|
||||
let input =
|
||||
r#"Here is some text before {"name": "test", "arguments": {}} and some text after."#;
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "test");
|
||||
assert_eq!(
|
||||
normal_text,
|
||||
"Here is some text before and some text after."
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extract_json_array_with_normal_text() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Test extraction of JSON array from mixed text
|
||||
let input = r#"Prefix text [{"name": "func1", "arguments": {}}, {"name": "func2", "arguments": {}}] suffix text"#;
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
assert_eq!(tool_calls.len(), 2);
|
||||
assert_eq!(tool_calls[0].function.name, "func1");
|
||||
assert_eq!(tool_calls[1].function.name, "func2");
|
||||
assert_eq!(normal_text, "Prefix text suffix text");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_multiple_tool_calls() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"[
|
||||
{"name": "get_weather", "arguments": {"location": "SF"}},
|
||||
{"name": "search", "arguments": {"query": "news"}}
|
||||
]"#;
|
||||
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 2);
|
||||
assert_eq!(tool_calls[0].function.name, "get_weather");
|
||||
assert_eq!(tool_calls[1].function.name, "search");
|
||||
assert_eq!(normal_text, ""); // Pure JSON should have no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_parameters_key() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
|
||||
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "calculate");
|
||||
assert!(tool_calls[0].function.arguments.contains("10"));
|
||||
assert_eq!(normal_text, ""); // Pure JSON should have no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_wrapper_tokens() {
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<tool>".to_string()],
|
||||
end_tokens: vec!["</tool>".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
let input = r#"<tool>{"name": "test", "arguments": {}}</tool>"#;
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "test");
|
||||
assert_eq!(normal_text, ""); // Wrapper tokens with no extra text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_start_token_invalid_json() {
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<|python_tag|>".to_string()],
|
||||
end_tokens: vec!["".to_string()],
|
||||
separator: ";".to_string(),
|
||||
});
|
||||
|
||||
let input = r#"Hello world <|python_tag|>this is not valid json at all"#;
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 0);
|
||||
assert_eq!(normal_text, input); // Should return entire original text when JSON parsing fails
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_normal_text() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"Here is the weather data: {"name": "get_weather", "arguments": {"location": "SF"}} Let me know if you need more info."#;
|
||||
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "get_weather");
|
||||
assert_eq!(
|
||||
normal_text,
|
||||
"Here is the weather data: Let me know if you need more info."
|
||||
); // Normal text is now extracted when JSON is found in mixed content
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(parser.detect_format(r#"[{"name": "test"}]"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format(r#"{"key": "value"}"#));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_parse() {
|
||||
// Just verify that streaming eventually produces a complete tool call
|
||||
let parser = JsonParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Send complete JSON in one go
|
||||
// TODO simplified version, address more complex version
|
||||
let full_json = r#"{"name": "get_weather", "arguments": {"location": "SF"}}"#;
|
||||
|
||||
let result = parser
|
||||
.parse_incremental(full_json, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should get a complete tool immediately with complete JSON
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
assert!(tool.function.arguments.contains("SF"));
|
||||
}
|
||||
_ => panic!("Expected ToolComplete for complete JSON input"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -80,17 +80,17 @@ impl Default for KimiK2Parser {
|
||||
#[async_trait]
|
||||
impl ToolParser for KimiK2Parser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
||||
// Check if text contains Kimi K2 format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
// Collect matches with positions and parse tools in one pass
|
||||
let matches: Vec<_> = self.tool_call_extractor.captures_iter(text).collect();
|
||||
let mut tools = Vec::new();
|
||||
// Find where tool calls begin
|
||||
let idx = text.find("<|tool_calls_section_begin|>").unwrap();
|
||||
let normal_text = text[..idx].to_string();
|
||||
|
||||
// Extract all tool calls using collected matches
|
||||
for captures in matches.iter() {
|
||||
// Try to extract tool calls
|
||||
let mut tools = Vec::new();
|
||||
for captures in self.tool_call_extractor.captures_iter(text) {
|
||||
if let (Some(id_match), Some(args_match)) = (
|
||||
captures.name("tool_call_id"),
|
||||
captures.name("function_arguments"),
|
||||
@@ -100,42 +100,41 @@ impl ToolParser for KimiK2Parser {
|
||||
|
||||
// Parse function ID
|
||||
if let Some((func_name, _index)) = self.parse_function_id(function_id) {
|
||||
// Validate JSON arguments
|
||||
if serde_json::from_str::<serde_json::Value>(function_args).is_ok() {
|
||||
// Generate unique ID
|
||||
let id = format!("kimi_call_{}", uuid::Uuid::new_v4());
|
||||
// Try to parse JSON arguments
|
||||
match serde_json::from_str::<serde_json::Value>(function_args) {
|
||||
Ok(_) => {
|
||||
// Generate unique ID
|
||||
let id = format!("kimi_call_{}", uuid::Uuid::new_v4());
|
||||
|
||||
tools.push(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: func_name,
|
||||
arguments: function_args.to_string(),
|
||||
},
|
||||
});
|
||||
tools.push(ToolCall {
|
||||
id,
|
||||
r#type: "function".to_string(),
|
||||
function: FunctionCall {
|
||||
name: func_name,
|
||||
arguments: function_args.to_string(),
|
||||
},
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
tracing::warn!(
|
||||
"Failed to parse JSON arguments for {}: {}",
|
||||
func_name,
|
||||
e
|
||||
);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tracing::warn!("Failed to parse function ID: {}", function_id);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract normal text using first and last match positions
|
||||
let normal_text = if tools.is_empty() || matches.is_empty() {
|
||||
text.to_string()
|
||||
} else {
|
||||
let first_start = matches[0].get(0).unwrap().start();
|
||||
let last_end = matches.last().unwrap().get(0).unwrap().end();
|
||||
let before = if first_start > 0 {
|
||||
&text[..first_start]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let after = if last_end < text.len() {
|
||||
&text[last_end..]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
format!("{}{}", before, after)
|
||||
};
|
||||
// If no tools were successfully parsed despite having markers, return entire text as fallback
|
||||
if tools.is_empty() {
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
Ok((normal_text, tools))
|
||||
}
|
||||
@@ -248,57 +247,3 @@ impl ToolParser for KimiK2Parser {
|
||||
self.has_tool_markers(text) || text.contains("<|tool_call_begin|>")
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_kimi_single_tool() {
|
||||
let parser = KimiK2Parser::new();
|
||||
let input = r#"Some text
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>More text"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Tokyo"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_kimi_multiple_tools() {
|
||||
let parser = KimiK2Parser::new();
|
||||
let input = r#"<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{"query": "rust"}<|tool_call_end|>
|
||||
<|tool_call_begin|>functions.calculate:1<|tool_call_argument_begin|>{"expression": "2+2"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "search");
|
||||
assert_eq!(tools[1].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_kimi_with_whitespace() {
|
||||
let parser = KimiK2Parser::new();
|
||||
let input = r#"<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value"} <|tool_call_end|>
|
||||
<|tool_calls_section_end|>"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = KimiK2Parser::new();
|
||||
assert!(parser.detect_format("<|tool_calls_section_begin|>"));
|
||||
assert!(parser.detect_format("<|tool_call_begin|>"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -101,70 +101,3 @@ impl ToolParser for LlamaParser {
|
||||
&& (text.contains(r#""name""#) || text.contains(r#""function""#)))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_python_tag() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;
|
||||
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "search");
|
||||
assert!(tool_calls[0].function.arguments.contains("weather"));
|
||||
assert_eq!(normal_text, ""); // Pure python_tag with JSON should have no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_plain_json() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;
|
||||
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "calculate");
|
||||
assert_eq!(normal_text, ""); // Pure JSON should have no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_with_text_before() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;
|
||||
|
||||
let (normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "get_time");
|
||||
assert_eq!(normal_text, "Let me help you with that. ");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = LlamaParser::new();
|
||||
|
||||
assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#));
|
||||
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_single_call_with_semicolon() {
|
||||
let parser = LlamaParser::new();
|
||||
// Note: Llama 3.2 doesn't handle multiple calls well
|
||||
let input = r#"<|python_tag|>{"name": "func1", "arguments": {"x": 1}};"#;
|
||||
|
||||
let (_normal_text, tool_calls) = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
// We expect this to either parse the first JSON object or fail gracefully
|
||||
// Since the semicolon makes it invalid JSON, it will likely return empty
|
||||
// This is acceptable as Llama 3.2 doesn't reliably support parallel calls
|
||||
|
||||
// If it parses anything, it should be func1
|
||||
if !tool_calls.is_empty() {
|
||||
assert_eq!(tool_calls[0].function.name, "func1");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -175,8 +175,9 @@ impl ToolParser for MistralParser {
|
||||
|
||||
match self.parse_json_array(json_array) {
|
||||
Ok(tools) => Ok((normal_text_before, tools)),
|
||||
Err(_) => {
|
||||
Err(e) => {
|
||||
// If JSON parsing fails, return the original text as normal text
|
||||
tracing::warn!("Failed to parse tool call: {}", e);
|
||||
Ok((text.to_string(), vec![]))
|
||||
}
|
||||
}
|
||||
@@ -309,67 +310,3 @@ impl ToolParser for MistralParser {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_mistral_format() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "get_weather", "arguments": {"location": "Paris", "units": "celsius"}}]"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Paris"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_multiple_tools() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [
|
||||
{"name": "search", "arguments": {"query": "rust programming"}},
|
||||
{"name": "calculate", "arguments": {"expression": "2 + 2"}}
|
||||
]"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "search");
|
||||
assert_eq!(tools[1].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_nested_brackets_in_json() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "process", "arguments": {"data": [1, 2, [3, 4]], "config": {"nested": [5, 6]}}}]"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "process");
|
||||
// JSON serialization removes spaces, so check for [3,4] without spaces
|
||||
assert!(tools[0].function.arguments.contains("[3,4]"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_escaped_quotes_in_strings() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"message": "He said \"Hello [World]\""}}]"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "echo");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = MistralParser::new();
|
||||
|
||||
assert!(parser.detect_format(r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]"#));
|
||||
assert!(
|
||||
parser.detect_format(r#"Some text [TOOL_CALLS] [{"name": "test", "arguments": {}}]"#)
|
||||
);
|
||||
assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
}
|
||||
}
|
||||
|
||||
@@ -84,8 +84,21 @@ impl ToolParser for PythonicParser {
|
||||
let cleaned = Self::strip_special_tokens(text);
|
||||
|
||||
if let Some((tool_calls_text, normal_text)) = self.extract_tool_calls(&cleaned) {
|
||||
let calls = self.parse_tool_call_block(&tool_calls_text)?;
|
||||
Ok((normal_text, calls))
|
||||
match self.parse_tool_call_block(&tool_calls_text) {
|
||||
Ok(calls) => {
|
||||
if calls.is_empty() {
|
||||
// No tools successfully parsed despite having markers
|
||||
Ok((text.to_string(), vec![]))
|
||||
} else {
|
||||
Ok((normal_text, calls))
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
// Log warning and return entire text as fallback
|
||||
tracing::warn!("Failed to parse pythonic tool calls: {}", e);
|
||||
Ok((text.to_string(), vec![]))
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Ok((text.to_string(), vec![]))
|
||||
}
|
||||
@@ -329,84 +342,3 @@ where
|
||||
Value::String(value.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_single_function_call() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[search_web(query="Rust programming", max_results=5)]"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "search_web");
|
||||
|
||||
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["query"], "Rust programming");
|
||||
assert_eq!(args["max_results"], 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_function_calls() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[get_weather(city="Tokyo"), search(query="news")]"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert_eq!(tools[1].function.name, "search");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_literals() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[test(flag=True, disabled=False, optional=None)]"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
|
||||
let args: Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["flag"], true);
|
||||
assert_eq!(args["disabled"], false);
|
||||
assert!(args["optional"].is_null());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_strip_special_tokens() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = "<|python_start|>[call(arg=1)]<|python_end|>";
|
||||
|
||||
assert!(parser.detect_format(input));
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_detect_format() {
|
||||
let parser = PythonicParser::new();
|
||||
assert!(parser.detect_format("[foo(bar=1)]"));
|
||||
assert!(!parser.detect_format("No python here"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_incremental() {
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
let chunk1 = "[call(arg=";
|
||||
let result1 = parser.parse_incremental(chunk1, &mut state).await.unwrap();
|
||||
assert!(matches!(result1, StreamResult::Incomplete));
|
||||
|
||||
let chunk2 = "1)]";
|
||||
let result2 = parser.parse_incremental(chunk2, &mut state).await.unwrap();
|
||||
|
||||
match result2 {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "call");
|
||||
}
|
||||
other => panic!("Expected ToolComplete, got {:?}", other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,43 +134,35 @@ impl ToolParser for QwenParser {
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
// Collect matches with positions and parse tools in one pass
|
||||
let matches: Vec<_> = self.extractor.captures_iter(text).collect();
|
||||
let mut tools = Vec::new();
|
||||
// Find where the first tool call begins
|
||||
let idx = text.find("<tool_call>").unwrap(); // Safe because has_tool_markers checked
|
||||
let normal_text = text[..idx].to_string();
|
||||
|
||||
for (index, captures) in matches.iter().enumerate() {
|
||||
// Extract tool calls
|
||||
let mut tools = Vec::new();
|
||||
for (index, captures) in self.extractor.captures_iter(text).enumerate() {
|
||||
if let Some(json_str) = captures.get(1) {
|
||||
match serde_json::from_str::<Value>(json_str.as_str().trim()) {
|
||||
Ok(value) => {
|
||||
if let Some(tool) = self.parse_single_object(&value, index)? {
|
||||
tools.push(tool);
|
||||
Ok(value) => match self.parse_single_object(&value, index) {
|
||||
Ok(Some(tool)) => tools.push(tool),
|
||||
Ok(None) => continue,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse tool call: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
// JSON parsing failed, might be incomplete
|
||||
},
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse JSON in tool call: {}", e);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Extract normal text using first and last match positions
|
||||
let normal_text = if tools.is_empty() {
|
||||
text.to_string()
|
||||
} else {
|
||||
let first_start = matches[0].get(0).unwrap().start();
|
||||
let last_end = matches.last().unwrap().get(0).unwrap().end();
|
||||
let before = if first_start > 0 {
|
||||
&text[..first_start]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let after = if last_end < text.len() {
|
||||
&text[last_end..]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
format!("{}{}", before, after)
|
||||
};
|
||||
// If no tools were successfully parsed despite having markers, return entire text as fallback
|
||||
if tools.is_empty() {
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
Ok((normal_text, tools))
|
||||
}
|
||||
@@ -299,140 +291,3 @@ impl ToolParser for QwenParser {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_qwen_format() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"<tool_call>
|
||||
{"name": "get_weather", "arguments": {"location": "Beijing", "units": "celsius"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Beijing"));
|
||||
assert_eq!(normal_text, ""); // Pure tool call, no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_multiple_tools() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"<tool_call>
|
||||
{"name": "search", "arguments": {"query": "rust programming"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "calculate", "arguments": {"expression": "2 + 2"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "search");
|
||||
assert_eq!(tools[1].function.name, "calculate");
|
||||
assert_eq!(normal_text, ""); // Pure tool calls, no normal text
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_with_normal_text() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"Let me help you with that.
|
||||
<tool_call>
|
||||
{"name": "get_info", "arguments": {"topic": "Rust"}}
|
||||
</tool_call>
|
||||
Here are the results."#;
|
||||
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_info");
|
||||
assert_eq!(
|
||||
normal_text,
|
||||
"Let me help you with that.\n\nHere are the results."
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_nested_json_structures() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"<tool_call>
|
||||
{
|
||||
"name": "process_data",
|
||||
"arguments": {
|
||||
"data": {
|
||||
"nested": {
|
||||
"array": [1, 2, 3],
|
||||
"object": {"key": "value"}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
</tool_call>"#;
|
||||
|
||||
let (normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "process_data");
|
||||
assert!(tools[0].function.arguments.contains("nested"));
|
||||
assert_eq!(normal_text, ""); // Pure tool call, no normal text
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = QwenParser::new();
|
||||
|
||||
assert!(parser.detect_format(
|
||||
r#"<tool_call>
|
||||
{"name": "test", "arguments": {}}
|
||||
</tool_call>"#
|
||||
));
|
||||
|
||||
assert!(parser.detect_format(
|
||||
r#"Text before <tool_call>
|
||||
{"name": "test", "arguments": {}}
|
||||
</tool_call> text after"#
|
||||
));
|
||||
|
||||
assert!(!parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
|
||||
// Partial format should still be detected
|
||||
assert!(parser.detect_format("<tool_call>"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_partial() {
|
||||
let parser = QwenParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Simulate streaming chunks
|
||||
let chunks = vec![
|
||||
"<tool_call>\n",
|
||||
r#"{"name": "search","#,
|
||||
r#" "arguments": {"query":"#,
|
||||
r#" "rust"}}"#,
|
||||
"\n</tool_call>",
|
||||
];
|
||||
|
||||
let mut found_name = false;
|
||||
let mut found_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolName { name, .. } => {
|
||||
assert_eq!(name, "search");
|
||||
found_name = true;
|
||||
}
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "search");
|
||||
found_complete = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(found_name || found_complete); // At least one should be found
|
||||
}
|
||||
}
|
||||
|
||||
@@ -158,46 +158,33 @@ impl Default for Step3Parser {
|
||||
#[async_trait]
|
||||
impl ToolParser for Step3Parser {
|
||||
async fn parse_complete(&self, text: &str) -> ToolParserResult<(String, Vec<ToolCall>)> {
|
||||
// Check if text contains Step3 format
|
||||
if !self.has_tool_markers(text) {
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
// Find the tool calls section
|
||||
if let Some(start_pos) = text.find("<|tool_calls_begin|>") {
|
||||
let search_from = start_pos + "<|tool_calls_begin|>".len();
|
||||
// Find where tool calls begin
|
||||
let idx = text.find("<|tool_calls_begin|>").unwrap();
|
||||
let normal_text = text[..idx].to_string();
|
||||
|
||||
// Find the end of tool calls section
|
||||
if let Some(end_pos) = text[search_from..].find("<|tool_calls_end|>") {
|
||||
let tool_section = &text[search_from..search_from + end_pos];
|
||||
let end_abs = search_from + end_pos + "<|tool_calls_end|>".len();
|
||||
|
||||
// Extract all tool call blocks
|
||||
let mut tools = Vec::new();
|
||||
for mat in self.tool_call_extractor.find_iter(tool_section) {
|
||||
if let Some(tool) = self.parse_tool_call(mat.as_str())? {
|
||||
tools.push(tool);
|
||||
}
|
||||
// Extract tool calls
|
||||
let mut tools = Vec::new();
|
||||
for mat in self.tool_call_extractor.find_iter(text) {
|
||||
match self.parse_tool_call(mat.as_str()) {
|
||||
Ok(Some(tool)) => tools.push(tool),
|
||||
Ok(None) => continue,
|
||||
Err(e) => {
|
||||
tracing::warn!("Failed to parse tool call: {}", e);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Extract normal text before start and after end
|
||||
let before = if start_pos > 0 {
|
||||
&text[..start_pos]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let after = if end_abs < text.len() {
|
||||
&text[end_abs..]
|
||||
} else {
|
||||
""
|
||||
};
|
||||
let normal_text = format!("{}{}", before, after);
|
||||
|
||||
return Ok((normal_text, tools));
|
||||
}
|
||||
}
|
||||
|
||||
Ok((text.to_string(), vec![]))
|
||||
// If no tools were successfully parsed despite having markers, return entire text as fallback
|
||||
if tools.is_empty() {
|
||||
return Ok((text.to_string(), vec![]));
|
||||
}
|
||||
|
||||
Ok((normal_text, tools))
|
||||
}
|
||||
|
||||
async fn parse_incremental(
|
||||
@@ -297,76 +284,3 @@ impl ToolParser for Step3Parser {
|
||||
self.has_tool_markers(text)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_step3_single_tool() {
|
||||
let parser = Step3Parser::new();
|
||||
let input = r#"Some text
|
||||
<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="get_weather">
|
||||
<steptml:parameter name="location">Tokyo</steptml:parameter>
|
||||
<steptml:parameter name="units">celsius</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>More text"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "get_weather");
|
||||
assert!(tools[0].function.arguments.contains("Tokyo"));
|
||||
assert!(tools[0].function.arguments.contains("celsius"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_step3_multiple_tools() {
|
||||
let parser = Step3Parser::new();
|
||||
let input = r#"<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="search">
|
||||
<steptml:parameter name="query">rust programming</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="calculate">
|
||||
<steptml:parameter name="expression">2 + 2</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 2);
|
||||
assert_eq!(tools[0].function.name, "search");
|
||||
assert_eq!(tools[1].function.name, "calculate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_step3_mixed_types() {
|
||||
let parser = Step3Parser::new();
|
||||
let input = r#"<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="process_data">
|
||||
<steptml:parameter name="count">42</steptml:parameter>
|
||||
<steptml:parameter name="active">true</steptml:parameter>
|
||||
<steptml:parameter name="rate">1.5</steptml:parameter>
|
||||
<steptml:parameter name="name">test</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>"#;
|
||||
|
||||
let (_normal_text, tools) = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(tools.len(), 1);
|
||||
assert_eq!(tools[0].function.name, "process_data");
|
||||
|
||||
// Parse arguments to check types
|
||||
let args: serde_json::Value = serde_json::from_str(&tools[0].function.arguments).unwrap();
|
||||
assert_eq!(args["count"], 42);
|
||||
assert_eq!(args["active"], true);
|
||||
assert_eq!(args["rate"], 1.5);
|
||||
assert_eq!(args["name"], "test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_detect_format() {
|
||||
let parser = Step3Parser::new();
|
||||
assert!(parser.detect_format("<|tool_calls_begin|>"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user