diff --git a/sgl-router/src/tool_parser/parsers/json_parser.rs b/sgl-router/src/tool_parser/parsers/json_parser.rs index 117435b7f..104383582 100644 --- a/sgl-router/src/tool_parser/parsers/json_parser.rs +++ b/sgl-router/src/tool_parser/parsers/json_parser.rs @@ -242,13 +242,92 @@ impl Default for JsonParser { #[async_trait] impl ToolParser for JsonParser { async fn parse_complete(&self, text: &str) -> ToolParserResult> { + // Check if we have multiple start tokens (e.g., multiple <|python_tag|> markers) + if !self.token_config.start_tokens.is_empty() { + let start_token = &self.token_config.start_tokens[0]; + if !start_token.is_empty() && text.matches(start_token).count() > 1 { + // We have multiple occurrences of the start token + let mut all_tools = Vec::new(); + let mut remaining = text; + + while let Some(start_pos) = remaining.find(start_token.as_str()) { + // Extract content after this start token + let after_token = &remaining[start_pos + start_token.len()..]; + + // Find where this JSON ends (look for the next start token or end of string) + let end_pos = if let Some(next_start) = after_token.find(start_token.as_str()) { + next_start + } else { + after_token.len() + }; + + let json_content = &after_token[..end_pos]; + + // Try to extract and parse JSON from this segment + if let Some(extracted) = self.extract_json_from_text(json_content) { + if let Ok(value) = serde_json::from_str::(&extracted) { + if let Ok(tools) = self.parse_json_value(&value) { + all_tools.extend(tools); + } + } + } + + // Move to the next segment + remaining = &remaining[start_pos + start_token.len() + end_pos..]; + if remaining.is_empty() { + break; + } + } + + if !all_tools.is_empty() { + return Ok(all_tools); + } + } + } + // Extract JSON content from wrapper tokens if present let json_content = self.extract_json_content(text); - // Try to parse as JSON + // Try to parse as JSON first match serde_json::from_str::(json_content) { Ok(value) => self.parse_json_value(&value), Err(_) => { + // If parse failed, check if we have multiple JSON objects separated by the configured separator + // This handles cases like: {"name": "func1", ...};{"name": "func2", ...} + if !self.token_config.separator.is_empty() + && json_content.contains(&self.token_config.separator) + { + let mut all_tools = Vec::new(); + + // Split by separator and try to parse each part + let parts: Vec<&str> = + json_content.split(&self.token_config.separator).collect(); + for part in parts { + let trimmed = part.trim(); + if trimmed.is_empty() { + continue; + } + + // Try to parse this part as JSON + if let Ok(value) = serde_json::from_str::(trimmed) { + if let Ok(tools) = self.parse_json_value(&value) { + all_tools.extend(tools); + } + } else if let Some(extracted) = self.extract_json_from_text(trimmed) { + // Try extracting JSON from this part + if let Ok(value) = serde_json::from_str::(&extracted) { + if let Ok(tools) = self.parse_json_value(&value) { + all_tools.extend(tools); + } + } + } + } + + if !all_tools.is_empty() { + return Ok(all_tools); + } + } + // If no wrapper tokens configured and parse failed, // try to extract JSON from mixed text if self.token_config.start_tokens.is_empty() { @@ -350,9 +429,11 @@ impl ToolParser for JsonParser { Value::Array(ref arr) => { // Check if array contains tool-like objects arr.iter().any(|v| { - v.as_object().is_some_and(|o| { - o.contains_key("name") || o.contains_key("function") - }) + if let Some(obj) = v.as_object() { + obj.contains_key("name") || obj.contains_key("function") + } else { + false + } }) } _ => false, diff --git a/sgl-router/tests/tool_parser_llama.rs b/sgl-router/tests/tool_parser_llama.rs index d99b87638..6222150ad 100644 --- a/sgl-router/tests/tool_parser_llama.rs +++ b/sgl-router/tests/tool_parser_llama.rs @@ -141,3 +141,214 @@ async fn test_llama_json_array_format() { // Current implementation might handle this through JSON fallback assert!(!result.is_empty()); } + +#[tokio::test] +async fn test_single_json() { + // Test parsing plain JSON without python_tag + let parser = LlamaParser::new(); + let text = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#; + + let result = parser.parse_complete(text).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); + + let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap(); + assert_eq!(args["city"], "Paris"); +} + +#[tokio::test] +async fn test_multiple_json_with_separator() { + // Test multiple JSON objects with semicolon separator + let parser = LlamaParser::new(); + let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {"city": "Paris"}};{"name": "get_tourist_attractions", "arguments": {"city": "Paris"}}"#; + + let result = parser.parse_complete(text).await.unwrap(); + // Note: Current implementation may only parse the first one due to semicolon handling + assert!(!result.is_empty()); + assert_eq!(result[0].function.name, "get_weather"); +} + +#[tokio::test] +async fn test_multiple_json_with_separator_customized() { + // Test multiple JSON objects with python_tag repeated + let parser = LlamaParser::new(); + let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {}}<|python_tag|>{"name": "get_tourist_attractions", "arguments": {}}"#; + + let result = parser.parse_complete(text).await.unwrap(); + // Current implementation may handle this differently + assert!(!result.is_empty()); + assert_eq!(result[0].function.name, "get_weather"); +} + +#[tokio::test] +async fn test_json_with_trailing_text() { + // Test JSON with trailing text after + let parser = LlamaParser::new(); + let text = r#"{"name": "get_weather", "arguments": {}} Some follow-up text"#; + + let result = parser.parse_complete(text).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); +} + +#[tokio::test] +async fn test_invalid_then_valid_json() { + // Test error recovery - invalid JSON followed by valid JSON + let parser = LlamaParser::new(); + let text = r#"{"name": "get_weather", "arguments": {{"name": "get_weather", "arguments": {}}"#; + + let result = parser.parse_complete(text).await.unwrap(); + // Should parse at least one valid JSON + if !result.is_empty() { + assert_eq!(result[0].function.name, "get_weather"); + } +} + +#[tokio::test] +async fn test_plain_text_only() { + // Test plain text with no tool calls + let parser = LlamaParser::new(); + let text = "This is just plain explanation text."; + + let result = parser.parse_complete(text).await.unwrap(); + assert_eq!(result.len(), 0); +} + +#[tokio::test] +async fn test_with_python_tag_prefix() { + // Test text before python_tag + let parser = LlamaParser::new(); + let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "arguments": {}}"#; + + let result = parser.parse_complete(text).await.unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0].function.name, "get_weather"); +} + +// ============================================================================ +// STREAMING TESTS +// ============================================================================ + +#[tokio::test] +async fn test_llama_streaming_simple() { + let parser = LlamaParser::new(); + let mut state = sglang_router_rs::tool_parser::ParseState::new(); + + // Send complete JSON at once + let full_json = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#; + + let result = parser + .parse_incremental(full_json, &mut state) + .await + .unwrap(); + + match result { + sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => { + assert_eq!(tool.function.name, "search"); + } + _ => panic!("Expected ToolComplete for complete JSON input"), + } +} + +#[tokio::test] +async fn test_llama_streaming_partial() { + let parser = LlamaParser::new(); + let mut state = sglang_router_rs::tool_parser::ParseState::new(); + + // Stream in chunks + let chunks = vec![ + r#"<|python"#, + r#"_tag|>{"name": "#, + r#""calculate", "#, + r#""arguments": {"x": 10}"#, + r#"}"#, + ]; + + let mut got_complete = false; + + for chunk in chunks { + let result = parser.parse_incremental(chunk, &mut state).await.unwrap(); + if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result { + assert_eq!(tool.function.name, "calculate"); + got_complete = true; + } + } + + assert!(got_complete, "Should have completed parsing"); +} + +#[tokio::test] +async fn test_llama_streaming_plain_json() { + let parser = LlamaParser::new(); + let mut state = sglang_router_rs::tool_parser::ParseState::new(); + + // Stream plain JSON without python_tag + let chunks = vec![ + r#"{"name": "#, + r#""search", "#, + r#""arguments": "#, + r#"{"query": "#, + r#""test"}}"#, + ]; + + let mut got_complete = false; + + for chunk in chunks { + let result = parser.parse_incremental(chunk, &mut state).await.unwrap(); + if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result { + assert_eq!(tool.function.name, "search"); + got_complete = true; + } + } + + assert!(got_complete, "Should have completed parsing"); +} + +#[tokio::test] +async fn test_llama_streaming_with_text_before() { + let parser = LlamaParser::new(); + let mut state = sglang_router_rs::tool_parser::ParseState::new(); + + let chunks = vec![ + r#"Let me help you. "#, + r#"<|python_tag|>"#, + r#"{"name": "get_time","#, + r#" "arguments": {"#, + r#""timezone": "UTC"}}"#, + ]; + + let mut got_complete = false; + + for chunk in chunks { + let result = parser.parse_incremental(chunk, &mut state).await.unwrap(); + if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result { + assert_eq!(tool.function.name, "get_time"); + got_complete = true; + } + } + + assert!(got_complete, "Should have completed parsing"); +} + +#[tokio::test] +async fn test_llama_streaming_multiple_tools() { + // Test streaming multiple tool calls with semicolon separator + let parser = LlamaParser::new(); + let mut state = sglang_router_rs::tool_parser::ParseState::new(); + + let text = + r#"<|python_tag|>{"name": "func1", "arguments": {}};{"name": "func2", "arguments": {}}"#; + + let result = parser.parse_incremental(text, &mut state).await.unwrap(); + + // Current implementation may handle this differently + match result { + sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => { + // At minimum should get first tool + assert_eq!(tool.function.name, "func1"); + } + _ => { + // Also acceptable if waiting for more + } + } +}