[router][grpc] Support tool call parser in streaming (#11160)
This commit is contained in:
@@ -4,6 +4,9 @@
|
||||
|
||||
use sglang_router_rs::tool_parser::{LlamaParser, ToolParser};
|
||||
|
||||
mod common;
|
||||
use common::create_test_tools;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_python_tag_format() {
|
||||
let parser = LlamaParser::new();
|
||||
@@ -228,29 +231,27 @@ async fn test_with_python_tag_prefix() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_streaming_simple() {
|
||||
let parser = LlamaParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
let tools = create_test_tools();
|
||||
|
||||
let mut parser = LlamaParser::new();
|
||||
|
||||
// Send complete JSON at once
|
||||
let full_json = r#"<|python_tag|>{"name": "search", "parameters": {"query": "weather"}}"#;
|
||||
|
||||
let result = parser
|
||||
.parse_incremental(full_json, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
let result = parser.parse_incremental(full_json, &tools).await.unwrap();
|
||||
|
||||
match result {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "search");
|
||||
}
|
||||
_ => panic!("Expected ToolComplete for complete JSON input"),
|
||||
}
|
||||
assert!(
|
||||
!result.calls.is_empty(),
|
||||
"Expected tool call for complete JSON input"
|
||||
);
|
||||
assert_eq!(result.calls[0].name.as_ref().unwrap(), "search");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_streaming_partial() {
|
||||
let parser = LlamaParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
let tools = create_test_tools();
|
||||
|
||||
let mut parser = LlamaParser::new();
|
||||
|
||||
// Stream in chunks
|
||||
let chunks = vec![
|
||||
@@ -264,10 +265,12 @@ async fn test_llama_streaming_partial() {
|
||||
let mut got_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
|
||||
assert_eq!(tool.function.name, "calculate");
|
||||
got_complete = true;
|
||||
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
|
||||
if !result.calls.is_empty() {
|
||||
if let Some(name) = &result.calls[0].name {
|
||||
assert_eq!(name, "calculate");
|
||||
got_complete = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -276,8 +279,9 @@ async fn test_llama_streaming_partial() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_streaming_plain_json() {
|
||||
let parser = LlamaParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
let tools = create_test_tools();
|
||||
|
||||
let mut parser = LlamaParser::new();
|
||||
|
||||
// Stream plain JSON without python_tag
|
||||
let chunks = vec![
|
||||
@@ -291,10 +295,12 @@ async fn test_llama_streaming_plain_json() {
|
||||
let mut got_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
|
||||
assert_eq!(tool.function.name, "search");
|
||||
got_complete = true;
|
||||
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
|
||||
if !result.calls.is_empty() {
|
||||
if let Some(name) = &result.calls[0].name {
|
||||
assert_eq!(name, "search");
|
||||
got_complete = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -303,8 +309,9 @@ async fn test_llama_streaming_plain_json() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_streaming_with_text_before() {
|
||||
let parser = LlamaParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
let tools = create_test_tools();
|
||||
|
||||
let mut parser = LlamaParser::new();
|
||||
|
||||
let chunks = vec![
|
||||
r#"Let me help you. "#,
|
||||
@@ -317,10 +324,12 @@ async fn test_llama_streaming_with_text_before() {
|
||||
let mut got_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
|
||||
assert_eq!(tool.function.name, "get_time");
|
||||
got_complete = true;
|
||||
let result = parser.parse_incremental(chunk, &tools).await.unwrap();
|
||||
if !result.calls.is_empty() {
|
||||
if let Some(name) = &result.calls[0].name {
|
||||
assert_eq!(name, "get_time");
|
||||
got_complete = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -329,74 +338,63 @@ async fn test_llama_streaming_with_text_before() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_streaming_multiple_tools() {
|
||||
let parser = LlamaParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
let tools = create_test_tools();
|
||||
|
||||
let mut parser = LlamaParser::new();
|
||||
|
||||
let text =
|
||||
r#"<|python_tag|>{"name": "func1", "parameters": {}};{"name": "func2", "parameters": {}}"#;
|
||||
|
||||
let result = parser.parse_incremental(text, &mut state).await.unwrap();
|
||||
let result = parser.parse_incremental(text, &tools).await.unwrap();
|
||||
|
||||
// Should get first tool complete
|
||||
match result {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "func1");
|
||||
}
|
||||
_ => panic!("Expected first tool to be complete, got: {:?}", result),
|
||||
assert!(
|
||||
!result.calls.is_empty(),
|
||||
"Expected first tool to be complete"
|
||||
);
|
||||
if let Some(name) = &result.calls[0].name {
|
||||
assert_eq!(name, "func1");
|
||||
}
|
||||
|
||||
// Process remaining buffer to get second tool
|
||||
let result2 = parser.parse_incremental("", &mut state).await.unwrap();
|
||||
match result2 {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "func2");
|
||||
let result2 = parser.parse_incremental("", &tools).await.unwrap();
|
||||
if !result2.calls.is_empty() {
|
||||
if let Some(name) = &result2.calls[0].name {
|
||||
assert_eq!(name, "func2");
|
||||
}
|
||||
_ => panic!("Expected second tool to be complete"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_streaming_multiple_tools_chunked() {
|
||||
let parser = LlamaParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
let mut parser = LlamaParser::new();
|
||||
|
||||
let tools = create_test_tools();
|
||||
|
||||
// First chunk - incomplete first JSON
|
||||
let chunk1 = r#"<|python_tag|>{"name": "get_weather", "parameters""#;
|
||||
let result1 = parser.parse_incremental(chunk1, &mut state).await.unwrap();
|
||||
|
||||
// Should be incomplete or have tool name
|
||||
match result1 {
|
||||
sglang_router_rs::tool_parser::StreamResult::Incomplete
|
||||
| sglang_router_rs::tool_parser::StreamResult::ToolName { .. }
|
||||
| sglang_router_rs::tool_parser::StreamResult::ToolArguments { .. } => {
|
||||
// Expected - could get tool name or be incomplete or even partial args
|
||||
let result1 = parser.parse_incremental(chunk1, &tools).await.unwrap();
|
||||
if !result1.calls.is_empty() {
|
||||
if let Some(name) = &result1.calls[0].name {
|
||||
assert_eq!(name, "get_weather");
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected incomplete or tool name for partial JSON, got: {:?}",
|
||||
result1
|
||||
),
|
||||
}
|
||||
|
||||
// Second chunk - complete first JSON and separator
|
||||
let chunk2 = r#": {"city": "Paris"}};{"name": "#;
|
||||
let result2 = parser.parse_incremental(chunk2, &mut state).await.unwrap();
|
||||
let result2 = parser.parse_incremental(chunk2, &tools).await.unwrap();
|
||||
|
||||
// Should get first tool complete
|
||||
match result2 {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["city"], "Paris");
|
||||
}
|
||||
_ => panic!("Expected first tool complete, got: {:?}", result2),
|
||||
// Should get parameters for first tool (name already sent in result1)
|
||||
if !result2.calls.is_empty() {
|
||||
let args: serde_json::Value = serde_json::from_str(&result2.calls[0].parameters).unwrap();
|
||||
assert_eq!(args["city"], "Paris");
|
||||
}
|
||||
|
||||
let chunk3 = r#""get_time", "parameters": {"timezone": "UTC"}}"#;
|
||||
let result3 = parser.parse_incremental(chunk3, &mut state).await.unwrap();
|
||||
match result3 {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_time");
|
||||
let result3 = parser.parse_incremental(chunk3, &tools).await.unwrap();
|
||||
if !result3.calls.is_empty() {
|
||||
if let Some(name) = &result3.calls[0].name {
|
||||
assert_eq!(name, "get_time");
|
||||
}
|
||||
_ => panic!("Expected tool to be complete, got: {:?}", result3),
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user