[router][grpc] Support tool call parser in streaming (#11160)

This commit is contained in:
Chang Su
2025-10-02 03:18:50 -07:00
committed by GitHub
parent 5e786cca3a
commit b658be6f6a
38 changed files with 3086 additions and 2245 deletions

View File

@@ -4,6 +4,9 @@
use sglang_router_rs::tool_parser::{LlamaParser, ToolParser};
mod common;
use common::create_test_tools;
#[tokio::test]
async fn test_llama_python_tag_format() {
let parser = LlamaParser::new();
@@ -228,29 +231,27 @@ async fn test_with_python_tag_prefix() {
#[tokio::test]
async fn test_llama_streaming_simple() {
let parser = LlamaParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let tools = create_test_tools();
let mut parser = LlamaParser::new();
// Send complete JSON at once
let full_json = r#"<|python_tag|>{"name": "search", "parameters": {"query": "weather"}}"#;
let result = parser
.parse_incremental(full_json, &mut state)
.await
.unwrap();
let result = parser.parse_incremental(full_json, &tools).await.unwrap();
match result {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "search");
}
_ => panic!("Expected ToolComplete for complete JSON input"),
}
assert!(
!result.calls.is_empty(),
"Expected tool call for complete JSON input"
);
assert_eq!(result.calls[0].name.as_ref().unwrap(), "search");
}
#[tokio::test]
async fn test_llama_streaming_partial() {
let parser = LlamaParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let tools = create_test_tools();
let mut parser = LlamaParser::new();
// Stream in chunks
let chunks = vec![
@@ -264,10 +265,12 @@ async fn test_llama_streaming_partial() {
let mut got_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
assert_eq!(tool.function.name, "calculate");
got_complete = true;
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
if !result.calls.is_empty() {
if let Some(name) = &result.calls[0].name {
assert_eq!(name, "calculate");
got_complete = true;
}
}
}
@@ -276,8 +279,9 @@ async fn test_llama_streaming_partial() {
#[tokio::test]
async fn test_llama_streaming_plain_json() {
let parser = LlamaParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let tools = create_test_tools();
let mut parser = LlamaParser::new();
// Stream plain JSON without python_tag
let chunks = vec![
@@ -291,10 +295,12 @@ async fn test_llama_streaming_plain_json() {
let mut got_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
assert_eq!(tool.function.name, "search");
got_complete = true;
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
if !result.calls.is_empty() {
if let Some(name) = &result.calls[0].name {
assert_eq!(name, "search");
got_complete = true;
}
}
}
@@ -303,8 +309,9 @@ async fn test_llama_streaming_plain_json() {
#[tokio::test]
async fn test_llama_streaming_with_text_before() {
let parser = LlamaParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let tools = create_test_tools();
let mut parser = LlamaParser::new();
let chunks = vec![
r#"Let me help you. "#,
@@ -317,10 +324,12 @@ async fn test_llama_streaming_with_text_before() {
let mut got_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
assert_eq!(tool.function.name, "get_time");
got_complete = true;
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
if !result.calls.is_empty() {
if let Some(name) = &result.calls[0].name {
assert_eq!(name, "get_time");
got_complete = true;
}
}
}
@@ -329,74 +338,63 @@ async fn test_llama_streaming_with_text_before() {
#[tokio::test]
async fn test_llama_streaming_multiple_tools() {
let parser = LlamaParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let tools = create_test_tools();
let mut parser = LlamaParser::new();
let text =
r#"<|python_tag|>{"name": "func1", "parameters": {}};{"name": "func2", "parameters": {}}"#;
let result = parser.parse_incremental(text, &mut state).await.unwrap();
let result = parser.parse_incremental(text, &tools).await.unwrap();
// Should get first tool complete
match result {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "func1");
}
_ => panic!("Expected first tool to be complete, got: {:?}", result),
assert!(
!result.calls.is_empty(),
"Expected first tool to be complete"
);
if let Some(name) = &result.calls[0].name {
assert_eq!(name, "func1");
}
// Process remaining buffer to get second tool
let result2 = parser.parse_incremental("", &mut state).await.unwrap();
match result2 {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "func2");
let result2 = parser.parse_incremental("", &tools).await.unwrap();
if !result2.calls.is_empty() {
if let Some(name) = &result2.calls[0].name {
assert_eq!(name, "func2");
}
_ => panic!("Expected second tool to be complete"),
}
}
#[tokio::test]
async fn test_llama_streaming_multiple_tools_chunked() {
let parser = LlamaParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let mut parser = LlamaParser::new();
let tools = create_test_tools();
// First chunk - incomplete first JSON
let chunk1 = r#"<|python_tag|>{"name": "get_weather", "parameters""#;
let result1 = parser.parse_incremental(chunk1, &mut state).await.unwrap();
// Should be incomplete or have tool name
match result1 {
sglang_router_rs::tool_parser::StreamResult::Incomplete
| sglang_router_rs::tool_parser::StreamResult::ToolName { .. }
| sglang_router_rs::tool_parser::StreamResult::ToolArguments { .. } => {
// Expected - could get tool name or be incomplete or even partial args
let result1 = parser.parse_incremental(chunk1, &tools).await.unwrap();
if !result1.calls.is_empty() {
if let Some(name) = &result1.calls[0].name {
assert_eq!(name, "get_weather");
}
_ => panic!(
"Expected incomplete or tool name for partial JSON, got: {:?}",
result1
),
}
// Second chunk - complete first JSON and separator
let chunk2 = r#": {"city": "Paris"}};{"name": "#;
let result2 = parser.parse_incremental(chunk2, &mut state).await.unwrap();
let result2 = parser.parse_incremental(chunk2, &tools).await.unwrap();
// Should get first tool complete
match result2 {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["city"], "Paris");
}
_ => panic!("Expected first tool complete, got: {:?}", result2),
// Should get parameters for first tool (name already sent in result1)
if !result2.calls.is_empty() {
let args: serde_json::Value = serde_json::from_str(&result2.calls[0].parameters).unwrap();
assert_eq!(args["city"], "Paris");
}
let chunk3 = r#""get_time", "parameters": {"timezone": "UTC"}}"#;
let result3 = parser.parse_incremental(chunk3, &mut state).await.unwrap();
match result3 {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_time");
let result3 = parser.parse_incremental(chunk3, &tools).await.unwrap();
if !result3.calls.is_empty() {
if let Some(name) = &result3.calls[0].name {
assert_eq!(name, "get_time");
}
_ => panic!("Expected tool to be complete, got: {:?}", result3),
}
}