adapt to sglang v0.5.2rc1 on dcu
This commit is contained in:
1669
sgl-router/tests/api_endpoints_test.rs
Normal file
1669
sgl-router/tests/api_endpoints_test.rs
Normal file
File diff suppressed because it is too large
Load Diff
237
sgl-router/tests/common/mock_mcp_server.rs
Normal file
237
sgl-router/tests/common/mock_mcp_server.rs
Normal file
@@ -0,0 +1,237 @@
|
||||
// tests/common/mock_mcp_server.rs - Mock MCP server for testing
|
||||
|
||||
use axum::{
|
||||
extract::Json, http::StatusCode, response::Json as ResponseJson, routing::post, Router,
|
||||
};
|
||||
use serde_json::{json, Value};
|
||||
use tokio::net::TcpListener;
|
||||
|
||||
/// Mock MCP server that returns hardcoded responses for testing
|
||||
pub struct MockMCPServer {
|
||||
pub port: u16,
|
||||
pub server_handle: Option<tokio::task::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
impl MockMCPServer {
|
||||
/// Start a mock MCP server on an available port
|
||||
pub async fn start() -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
|
||||
// Find an available port
|
||||
let listener = TcpListener::bind("127.0.0.1:0").await?;
|
||||
let port = listener.local_addr()?.port();
|
||||
|
||||
let app = Router::new().route("/mcp", post(handle_mcp_request));
|
||||
|
||||
let server_handle = tokio::spawn(async move {
|
||||
axum::serve(listener, app)
|
||||
.await
|
||||
.expect("Mock MCP server failed to start");
|
||||
});
|
||||
|
||||
// Give the server a moment to start
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
Ok(MockMCPServer {
|
||||
port,
|
||||
server_handle: Some(server_handle),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the full URL for this mock server
|
||||
pub fn url(&self) -> String {
|
||||
format!("http://127.0.0.1:{}/mcp", self.port)
|
||||
}
|
||||
|
||||
/// Stop the mock server
|
||||
pub async fn stop(&mut self) {
|
||||
if let Some(handle) = self.server_handle.take() {
|
||||
handle.abort();
|
||||
// Wait a moment for cleanup
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for MockMCPServer {
|
||||
fn drop(&mut self) {
|
||||
if let Some(handle) = self.server_handle.take() {
|
||||
handle.abort();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Handle MCP requests and return mock responses
|
||||
async fn handle_mcp_request(Json(request): Json<Value>) -> Result<ResponseJson<Value>, StatusCode> {
|
||||
// Parse the JSON-RPC request
|
||||
let method = request.get("method").and_then(|m| m.as_str()).unwrap_or("");
|
||||
|
||||
let id = request
|
||||
.get("id")
|
||||
.and_then(|i| i.as_str())
|
||||
.unwrap_or("unknown");
|
||||
|
||||
let response = match method {
|
||||
"initialize" => {
|
||||
// Mock initialize response
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": {
|
||||
"serverInfo": {
|
||||
"name": "Mock MCP Server",
|
||||
"version": "1.0.0"
|
||||
},
|
||||
"instructions": "Mock server for testing"
|
||||
}
|
||||
})
|
||||
}
|
||||
"tools/list" => {
|
||||
// Mock tools list response
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": {
|
||||
"tools": [
|
||||
{
|
||||
"name": "brave_web_search",
|
||||
"description": "Mock web search tool",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"count": {"type": "integer"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "brave_local_search",
|
||||
"description": "Mock local search tool",
|
||||
"inputSchema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"}
|
||||
},
|
||||
"required": ["query"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
}
|
||||
"tools/call" => {
|
||||
// Mock tool call response
|
||||
let empty_json = json!({});
|
||||
let params = request.get("params").unwrap_or(&empty_json);
|
||||
let tool_name = params.get("name").and_then(|n| n.as_str()).unwrap_or("");
|
||||
let empty_args = json!({});
|
||||
let arguments = params.get("arguments").unwrap_or(&empty_args);
|
||||
|
||||
match tool_name {
|
||||
"brave_web_search" => {
|
||||
let query = arguments
|
||||
.get("query")
|
||||
.and_then(|q| q.as_str())
|
||||
.unwrap_or("test");
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": format!("Mock search results for: {}", query)
|
||||
}
|
||||
],
|
||||
"isError": false
|
||||
}
|
||||
})
|
||||
}
|
||||
"brave_local_search" => {
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"result": {
|
||||
"content": [
|
||||
{
|
||||
"type": "text",
|
||||
"text": "Mock local search results"
|
||||
}
|
||||
],
|
||||
"isError": false
|
||||
}
|
||||
})
|
||||
}
|
||||
_ => {
|
||||
// Unknown tool
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"error": {
|
||||
"code": -1,
|
||||
"message": format!("Unknown tool: {}", tool_name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// Unknown method
|
||||
json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": id,
|
||||
"error": {
|
||||
"code": -32601,
|
||||
"message": format!("Method not found: {}", method)
|
||||
}
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
Ok(ResponseJson(response))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(unused_imports)]
|
||||
mod tests {
|
||||
use super::MockMCPServer;
|
||||
use serde_json::{json, Value};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_server_startup() {
|
||||
let mut server = MockMCPServer::start().await.unwrap();
|
||||
assert!(server.port > 0);
|
||||
assert!(server.url().contains(&server.port.to_string()));
|
||||
server.stop().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mock_server_responses() {
|
||||
let mut server = MockMCPServer::start().await.unwrap();
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Test initialize
|
||||
let init_request = json!({
|
||||
"jsonrpc": "2.0",
|
||||
"id": "1",
|
||||
"method": "initialize",
|
||||
"params": {
|
||||
"protocolVersion": "2024-11-05",
|
||||
"capabilities": {}
|
||||
}
|
||||
});
|
||||
|
||||
let response = client
|
||||
.post(server.url())
|
||||
.json(&init_request)
|
||||
.send()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(response.status().is_success());
|
||||
let json: Value = response.json().await.unwrap();
|
||||
assert_eq!(json["jsonrpc"], "2.0");
|
||||
assert_eq!(json["result"]["serverInfo"]["name"], "Mock MCP Server");
|
||||
|
||||
server.stop().await;
|
||||
}
|
||||
}
|
||||
614
sgl-router/tests/common/mock_worker.rs
Normal file
614
sgl-router/tests/common/mock_worker.rs
Normal file
@@ -0,0 +1,614 @@
|
||||
// Mock worker for testing - these functions are used by integration tests
|
||||
#![allow(dead_code)]
|
||||
|
||||
use axum::{
|
||||
extract::{Json, State},
|
||||
http::StatusCode,
|
||||
response::sse::{Event, KeepAlive},
|
||||
response::{IntoResponse, Response, Sse},
|
||||
routing::{get, post},
|
||||
Router,
|
||||
};
|
||||
use futures_util::stream::{self, StreamExt};
|
||||
use serde_json::json;
|
||||
use std::convert::Infallible;
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
use tokio::sync::RwLock;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Configuration for mock worker behavior
|
||||
#[derive(Clone)]
|
||||
pub struct MockWorkerConfig {
|
||||
pub port: u16,
|
||||
pub worker_type: WorkerType,
|
||||
pub health_status: HealthStatus,
|
||||
pub response_delay_ms: u64,
|
||||
pub fail_rate: f32,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum WorkerType {
|
||||
Regular,
|
||||
Prefill,
|
||||
Decode,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub enum HealthStatus {
|
||||
Healthy,
|
||||
Unhealthy,
|
||||
Degraded,
|
||||
}
|
||||
|
||||
/// Mock worker server for testing
|
||||
pub struct MockWorker {
|
||||
config: Arc<RwLock<MockWorkerConfig>>,
|
||||
shutdown_handle: Option<tokio::task::JoinHandle<()>>,
|
||||
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
|
||||
}
|
||||
|
||||
impl MockWorker {
|
||||
pub fn new(config: MockWorkerConfig) -> Self {
|
||||
Self {
|
||||
config: Arc::new(RwLock::new(config)),
|
||||
shutdown_handle: None,
|
||||
shutdown_tx: None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start the mock worker server
|
||||
pub async fn start(&mut self) -> Result<String, Box<dyn std::error::Error>> {
|
||||
let config = self.config.clone();
|
||||
let port = config.read().await.port;
|
||||
|
||||
// If port is 0, find an available port
|
||||
let port = if port == 0 {
|
||||
let listener = std::net::TcpListener::bind("127.0.0.1:0")?;
|
||||
let port = listener.local_addr()?.port();
|
||||
drop(listener);
|
||||
config.write().await.port = port;
|
||||
port
|
||||
} else {
|
||||
port
|
||||
};
|
||||
|
||||
let app = Router::new()
|
||||
.route("/health", get(health_handler))
|
||||
.route("/health_generate", get(health_generate_handler))
|
||||
.route("/get_server_info", get(server_info_handler))
|
||||
.route("/get_model_info", get(model_info_handler))
|
||||
.route("/generate", post(generate_handler))
|
||||
.route("/v1/chat/completions", post(chat_completions_handler))
|
||||
.route("/v1/completions", post(completions_handler))
|
||||
.route("/flush_cache", post(flush_cache_handler))
|
||||
.route("/v1/models", get(v1_models_handler))
|
||||
.with_state(config);
|
||||
|
||||
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
|
||||
self.shutdown_tx = Some(shutdown_tx);
|
||||
|
||||
// Spawn the server in a separate task
|
||||
let handle = tokio::spawn(async move {
|
||||
let listener = match tokio::net::TcpListener::bind(("127.0.0.1", port)).await {
|
||||
Ok(l) => l,
|
||||
Err(e) => {
|
||||
eprintln!("Failed to bind to port {}: {}", port, e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let server = axum::serve(listener, app).with_graceful_shutdown(async move {
|
||||
let _ = shutdown_rx.await;
|
||||
});
|
||||
|
||||
if let Err(e) = server.await {
|
||||
eprintln!("Server error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
self.shutdown_handle = Some(handle);
|
||||
|
||||
// Wait for the server to start
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
let url = format!("http://127.0.0.1:{}", port);
|
||||
Ok(url)
|
||||
}
|
||||
|
||||
/// Stop the mock worker server
|
||||
pub async fn stop(&mut self) {
|
||||
if let Some(shutdown_tx) = self.shutdown_tx.take() {
|
||||
let _ = shutdown_tx.send(());
|
||||
}
|
||||
|
||||
if let Some(handle) = self.shutdown_handle.take() {
|
||||
// Wait for the server to shut down
|
||||
let _ = tokio::time::timeout(tokio::time::Duration::from_secs(5), handle).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for MockWorker {
|
||||
fn drop(&mut self) {
|
||||
// Clean shutdown when dropped
|
||||
if let Some(shutdown_tx) = self.shutdown_tx.take() {
|
||||
let _ = shutdown_tx.send(());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Handler implementations
|
||||
|
||||
/// Check if request should fail based on configured fail_rate
|
||||
async fn should_fail(config: &MockWorkerConfig) -> bool {
|
||||
rand::random::<f32>() < config.fail_rate
|
||||
}
|
||||
|
||||
async fn health_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
|
||||
let config = config.read().await;
|
||||
|
||||
match config.health_status {
|
||||
HealthStatus::Healthy => Json(json!({
|
||||
"status": "healthy",
|
||||
"timestamp": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(),
|
||||
"worker_type": format!("{:?}", config.worker_type),
|
||||
}))
|
||||
.into_response(),
|
||||
HealthStatus::Unhealthy => (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"status": "unhealthy",
|
||||
"error": "Worker is not responding"
|
||||
})),
|
||||
)
|
||||
.into_response(),
|
||||
HealthStatus::Degraded => Json(json!({
|
||||
"status": "degraded",
|
||||
"warning": "High load detected"
|
||||
}))
|
||||
.into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_generate_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
|
||||
let config = config.read().await;
|
||||
|
||||
if should_fail(&config).await {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": "Random failure for testing"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
if matches!(config.health_status, HealthStatus::Healthy) {
|
||||
Json(json!({
|
||||
"status": "ok",
|
||||
"queue_length": 0,
|
||||
"processing_time_ms": config.response_delay_ms
|
||||
}))
|
||||
.into_response()
|
||||
} else {
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
Json(json!({
|
||||
"error": "Generation service unavailable"
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn server_info_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
|
||||
let config = config.read().await;
|
||||
|
||||
if should_fail(&config).await {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": "Random failure for testing"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
Json(json!({
|
||||
"model_path": "mock-model-path",
|
||||
"tokenizer_path": "mock-tokenizer-path",
|
||||
"port": config.port,
|
||||
"host": "127.0.0.1",
|
||||
"max_num_batched_tokens": 32768,
|
||||
"max_prefill_tokens": 16384,
|
||||
"mem_fraction_static": 0.88,
|
||||
"tp_size": 1,
|
||||
"dp_size": 1,
|
||||
"stream_interval": 8,
|
||||
"dtype": "float16",
|
||||
"device": "cuda",
|
||||
"enable_flashinfer": true,
|
||||
"enable_p2p_check": true,
|
||||
"context_length": 32768,
|
||||
"chat_template": null,
|
||||
"disable_radix_cache": false,
|
||||
"enable_torch_compile": false,
|
||||
"trust_remote_code": false,
|
||||
"show_time_cost": false,
|
||||
"waiting_queue_size": 0,
|
||||
"running_queue_size": 0,
|
||||
"req_to_token_ratio": 1.2,
|
||||
"min_running_requests": 0,
|
||||
"max_running_requests": 2048,
|
||||
"max_req_num": 8192,
|
||||
"max_batch_tokens": 32768,
|
||||
"schedule_policy": "lpm",
|
||||
"schedule_conservativeness": 1.0,
|
||||
"version": "0.3.0",
|
||||
"internal_states": [{
|
||||
"waiting_queue_size": 0,
|
||||
"running_queue_size": 0
|
||||
}]
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn model_info_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
|
||||
let config = config.read().await;
|
||||
|
||||
if should_fail(&config).await {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": "Random failure for testing"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
Json(json!({
|
||||
"model_path": "mock-model-path",
|
||||
"tokenizer_path": "mock-tokenizer-path",
|
||||
"is_generation": true,
|
||||
"preferred_sampling_params": {
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"max_tokens": 2048
|
||||
}
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn generate_handler(
|
||||
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
||||
Json(payload): Json<serde_json::Value>,
|
||||
) -> Response {
|
||||
let config = config.read().await;
|
||||
|
||||
if should_fail(&config).await {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": "Random failure for testing"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
if config.response_delay_ms > 0 {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await;
|
||||
}
|
||||
|
||||
let is_stream = payload
|
||||
.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
if is_stream {
|
||||
let stream_delay = config.response_delay_ms;
|
||||
|
||||
// Check if it's a batch request
|
||||
let is_batch = payload.get("text").and_then(|t| t.as_array()).is_some();
|
||||
|
||||
let batch_size = if is_batch {
|
||||
payload
|
||||
.get("text")
|
||||
.and_then(|t| t.as_array())
|
||||
.map(|arr| arr.len())
|
||||
.unwrap_or(1)
|
||||
} else {
|
||||
1
|
||||
};
|
||||
|
||||
let mut events = Vec::new();
|
||||
|
||||
// Generate events for each item in batch
|
||||
for i in 0..batch_size {
|
||||
let timestamp_start = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs_f64();
|
||||
|
||||
let data = json!({
|
||||
"text": format!("Mock response {}", i + 1),
|
||||
"meta_info": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"completion_tokens_wo_jump_forward": 5,
|
||||
"input_token_logprobs": null,
|
||||
"output_token_logprobs": null,
|
||||
"first_token_latency": stream_delay as f64 / 1000.0,
|
||||
"time_to_first_token": stream_delay as f64 / 1000.0,
|
||||
"time_per_output_token": 0.01,
|
||||
"end_time": timestamp_start + (stream_delay as f64 / 1000.0),
|
||||
"start_time": timestamp_start,
|
||||
"finish_reason": {
|
||||
"type": "stop",
|
||||
"reason": "length"
|
||||
}
|
||||
},
|
||||
"stage": "mid"
|
||||
});
|
||||
|
||||
events.push(Ok::<_, Infallible>(Event::default().data(data.to_string())));
|
||||
}
|
||||
|
||||
// Add [DONE] event
|
||||
events.push(Ok(Event::default().data("[DONE]")));
|
||||
|
||||
let stream = stream::iter(events);
|
||||
|
||||
Sse::new(stream)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
} else {
|
||||
Json(json!({
|
||||
"text": "This is a mock response.",
|
||||
"meta_info": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"completion_tokens_wo_jump_forward": 5,
|
||||
"input_token_logprobs": null,
|
||||
"output_token_logprobs": null,
|
||||
"first_token_latency": config.response_delay_ms as f64 / 1000.0,
|
||||
"time_to_first_token": config.response_delay_ms as f64 / 1000.0,
|
||||
"time_per_output_token": 0.01,
|
||||
"finish_reason": {
|
||||
"type": "stop",
|
||||
"reason": "length"
|
||||
}
|
||||
}
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_completions_handler(
|
||||
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
||||
Json(payload): Json<serde_json::Value>,
|
||||
) -> Response {
|
||||
let config = config.read().await;
|
||||
|
||||
if should_fail(&config).await {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": {
|
||||
"message": "Random failure for testing",
|
||||
"type": "internal_error",
|
||||
"code": "internal_error"
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
if config.response_delay_ms > 0 {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await;
|
||||
}
|
||||
|
||||
let is_stream = payload
|
||||
.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
if is_stream {
|
||||
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
|
||||
|
||||
let stream = stream::once(async move {
|
||||
let chunk = json!({
|
||||
"id": request_id,
|
||||
"object": "chat.completion.chunk",
|
||||
"created": timestamp,
|
||||
"model": "mock-model",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"delta": {
|
||||
"content": "This is a mock chat response."
|
||||
},
|
||||
"finish_reason": null
|
||||
}]
|
||||
});
|
||||
|
||||
Ok::<_, Infallible>(Event::default().data(chunk.to_string()))
|
||||
})
|
||||
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
|
||||
|
||||
Sse::new(stream)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
} else {
|
||||
Json(json!({
|
||||
"id": format!("chatcmpl-{}", Uuid::new_v4()),
|
||||
"object": "chat.completion",
|
||||
"created": timestamp,
|
||||
"model": "mock-model",
|
||||
"choices": [{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "This is a mock chat response."
|
||||
},
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15
|
||||
}
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn completions_handler(
|
||||
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
|
||||
Json(payload): Json<serde_json::Value>,
|
||||
) -> Response {
|
||||
let config = config.read().await;
|
||||
|
||||
if should_fail(&config).await {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": {
|
||||
"message": "Random failure for testing",
|
||||
"type": "internal_error",
|
||||
"code": "internal_error"
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
if config.response_delay_ms > 0 {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await;
|
||||
}
|
||||
|
||||
let is_stream = payload
|
||||
.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
if is_stream {
|
||||
let request_id = format!("cmpl-{}", Uuid::new_v4());
|
||||
|
||||
let stream = stream::once(async move {
|
||||
let chunk = json!({
|
||||
"id": request_id,
|
||||
"object": "text_completion",
|
||||
"created": timestamp,
|
||||
"model": "mock-model",
|
||||
"choices": [{
|
||||
"text": "This is a mock completion.",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"finish_reason": null
|
||||
}]
|
||||
});
|
||||
|
||||
Ok::<_, Infallible>(Event::default().data(chunk.to_string()))
|
||||
})
|
||||
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
|
||||
|
||||
Sse::new(stream)
|
||||
.keep_alive(KeepAlive::default())
|
||||
.into_response()
|
||||
} else {
|
||||
Json(json!({
|
||||
"id": format!("cmpl-{}", Uuid::new_v4()),
|
||||
"object": "text_completion",
|
||||
"created": timestamp,
|
||||
"model": "mock-model",
|
||||
"choices": [{
|
||||
"text": "This is a mock completion.",
|
||||
"index": 0,
|
||||
"logprobs": null,
|
||||
"finish_reason": "stop"
|
||||
}],
|
||||
"usage": {
|
||||
"prompt_tokens": 10,
|
||||
"completion_tokens": 5,
|
||||
"total_tokens": 15
|
||||
}
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn flush_cache_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
|
||||
let config = config.read().await;
|
||||
|
||||
if should_fail(&config).await {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": "Random failure for testing"
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
Json(json!({
|
||||
"message": "Cache flushed successfully"
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn v1_models_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
|
||||
let config = config.read().await;
|
||||
|
||||
if should_fail(&config).await {
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(json!({
|
||||
"error": {
|
||||
"message": "Random failure for testing",
|
||||
"type": "internal_error",
|
||||
"code": "internal_error"
|
||||
}
|
||||
})),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let timestamp = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.unwrap()
|
||||
.as_secs();
|
||||
|
||||
Json(json!({
|
||||
"object": "list",
|
||||
"data": [{
|
||||
"id": "mock-model",
|
||||
"object": "model",
|
||||
"created": timestamp,
|
||||
"owned_by": "organization-owner"
|
||||
}]
|
||||
}))
|
||||
.into_response()
|
||||
}
|
||||
|
||||
impl Default for MockWorkerConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
port: 0,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
99
sgl-router/tests/common/mod.rs
Normal file
99
sgl-router/tests/common/mod.rs
Normal file
@@ -0,0 +1,99 @@
|
||||
// These modules are used by tests and benchmarks
|
||||
#![allow(dead_code)]
|
||||
|
||||
pub mod mock_mcp_server;
|
||||
pub mod mock_worker;
|
||||
pub mod test_app;
|
||||
|
||||
use sglang_router_rs::config::RouterConfig;
|
||||
use sglang_router_rs::server::AppContext;
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::{Arc, Mutex, OnceLock};
|
||||
|
||||
/// Helper function to create AppContext for tests
|
||||
pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
|
||||
Arc::new(AppContext::new(
|
||||
config.clone(),
|
||||
reqwest::Client::new(),
|
||||
config.max_concurrent_requests,
|
||||
config.rate_limit_tokens_per_second,
|
||||
))
|
||||
}
|
||||
|
||||
// Tokenizer download configuration
|
||||
const TINYLLAMA_TOKENIZER_URL: &str =
|
||||
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json";
|
||||
const CACHE_DIR: &str = ".tokenizer_cache";
|
||||
const TINYLLAMA_TOKENIZER_FILENAME: &str = "tinyllama_tokenizer.json";
|
||||
|
||||
// Global mutex to prevent concurrent downloads
|
||||
static DOWNLOAD_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
|
||||
|
||||
/// Downloads the TinyLlama tokenizer from HuggingFace if not already cached.
|
||||
/// Returns the path to the cached tokenizer file.
|
||||
///
|
||||
/// This function is thread-safe and will only download the tokenizer once
|
||||
/// even if called from multiple threads concurrently.
|
||||
pub fn ensure_tokenizer_cached() -> PathBuf {
|
||||
// Get or initialize the mutex
|
||||
let mutex = DOWNLOAD_MUTEX.get_or_init(|| Mutex::new(()));
|
||||
|
||||
// Lock to ensure only one thread downloads at a time
|
||||
let _guard = mutex.lock().unwrap();
|
||||
|
||||
let cache_dir = PathBuf::from(CACHE_DIR);
|
||||
let tokenizer_path = cache_dir.join(TINYLLAMA_TOKENIZER_FILENAME);
|
||||
|
||||
// Create cache directory if it doesn't exist
|
||||
if !cache_dir.exists() {
|
||||
fs::create_dir_all(&cache_dir).expect("Failed to create cache directory");
|
||||
}
|
||||
|
||||
// Download tokenizer if not already cached
|
||||
if !tokenizer_path.exists() {
|
||||
println!("Downloading TinyLlama tokenizer from HuggingFace...");
|
||||
|
||||
// Use blocking reqwest client since we're in tests/benchmarks
|
||||
let client = reqwest::blocking::Client::new();
|
||||
let response = client
|
||||
.get(TINYLLAMA_TOKENIZER_URL)
|
||||
.send()
|
||||
.expect("Failed to download tokenizer");
|
||||
|
||||
if !response.status().is_success() {
|
||||
panic!("Failed to download tokenizer: HTTP {}", response.status());
|
||||
}
|
||||
|
||||
let content = response.bytes().expect("Failed to read tokenizer content");
|
||||
|
||||
// Verify we got actual JSON content
|
||||
if content.len() < 100 {
|
||||
panic!("Downloaded content too small: {} bytes", content.len());
|
||||
}
|
||||
|
||||
fs::write(&tokenizer_path, content).expect("Failed to write tokenizer to cache");
|
||||
println!(
|
||||
"Tokenizer downloaded and cached successfully ({} bytes)",
|
||||
tokenizer_path.metadata().unwrap().len()
|
||||
);
|
||||
}
|
||||
|
||||
tokenizer_path
|
||||
}
|
||||
|
||||
/// Common test prompts for consistency across tests
|
||||
pub const TEST_PROMPTS: [&str; 4] = [
|
||||
"deep learning is",
|
||||
"Deep learning is",
|
||||
"has anyone seen nemo lately",
|
||||
"another prompt",
|
||||
];
|
||||
|
||||
/// Pre-computed hashes for verification
|
||||
pub const EXPECTED_HASHES: [u64; 4] = [
|
||||
1209591529327510910,
|
||||
4181375434596349981,
|
||||
6245658446118930933,
|
||||
5097285695902185237,
|
||||
];
|
||||
49
sgl-router/tests/common/test_app.rs
Normal file
49
sgl-router/tests/common/test_app.rs
Normal file
@@ -0,0 +1,49 @@
|
||||
use axum::Router;
|
||||
use reqwest::Client;
|
||||
use sglang_router_rs::{
|
||||
config::RouterConfig,
|
||||
routers::RouterTrait,
|
||||
server::{build_app, AppContext, AppState},
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Create a test Axum application using the actual server's build_app function
|
||||
#[allow(dead_code)]
|
||||
pub fn create_test_app(
|
||||
router: Arc<dyn RouterTrait>,
|
||||
client: Client,
|
||||
router_config: &RouterConfig,
|
||||
) -> Router {
|
||||
// Create AppContext
|
||||
let app_context = Arc::new(AppContext::new(
|
||||
router_config.clone(),
|
||||
client,
|
||||
router_config.max_concurrent_requests,
|
||||
router_config.rate_limit_tokens_per_second,
|
||||
));
|
||||
|
||||
// Create AppState with the test router and context
|
||||
let app_state = Arc::new(AppState {
|
||||
router,
|
||||
context: app_context,
|
||||
concurrency_queue_tx: None, // No queue for tests
|
||||
});
|
||||
|
||||
// Configure request ID headers (use defaults if not specified)
|
||||
let request_id_headers = router_config.request_id_headers.clone().unwrap_or_else(|| {
|
||||
vec![
|
||||
"x-request-id".to_string(),
|
||||
"x-correlation-id".to_string(),
|
||||
"x-trace-id".to_string(),
|
||||
"request-id".to_string(),
|
||||
]
|
||||
});
|
||||
|
||||
// Use the actual server's build_app function
|
||||
build_app(
|
||||
app_state,
|
||||
router_config.max_payload_size,
|
||||
request_id_headers,
|
||||
router_config.cors_allowed_origins.clone(),
|
||||
)
|
||||
}
|
||||
458
sgl-router/tests/mcp_test.rs
Normal file
458
sgl-router/tests/mcp_test.rs
Normal file
@@ -0,0 +1,458 @@
|
||||
// This test suite validates the complete MCP implementation against the
|
||||
// functionality required for SGLang responses API integration.
|
||||
//
|
||||
// Test Coverage:
|
||||
// - Core MCP server functionality (Python tool_server.py parity)
|
||||
// - Tool session management (individual and multi-tool)
|
||||
// - Tool execution and error handling
|
||||
// - Schema adaptation and validation
|
||||
// - SSE parsing and protocol compliance
|
||||
// - Mock server integration for reliable testing
|
||||
|
||||
mod common;
|
||||
|
||||
use common::mock_mcp_server::MockMCPServer;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::mcp::{parse_sse_event, MCPToolServer, MultiToolSessionManager, ToolSession};
|
||||
/// Create a new mock server for testing (each test gets its own)
|
||||
async fn create_mock_server() -> MockMCPServer {
|
||||
MockMCPServer::start()
|
||||
.await
|
||||
.expect("Failed to start mock MCP server")
|
||||
}
|
||||
|
||||
// Core MCP Server Tests (Python parity)
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mcp_server_initialization() {
|
||||
let server = MCPToolServer::new();
|
||||
|
||||
assert!(!server.has_tool("any_tool"));
|
||||
assert_eq!(server.list_tools().len(), 0);
|
||||
assert_eq!(server.list_servers().len(), 0);
|
||||
|
||||
let stats = server.get_tool_stats();
|
||||
assert_eq!(stats.total_tools, 0);
|
||||
assert_eq!(stats.total_servers, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_server_connection_with_mock() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
let result = mcp_server.add_tool_server(mock_server.url()).await;
|
||||
assert!(result.is_ok(), "Should connect to mock server");
|
||||
|
||||
let stats = mcp_server.get_tool_stats();
|
||||
assert_eq!(stats.total_tools, 2);
|
||||
assert_eq!(stats.total_servers, 1);
|
||||
|
||||
assert!(mcp_server.has_tool("brave_web_search"));
|
||||
assert!(mcp_server.has_tool("brave_local_search"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_availability_checking() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
assert!(!mcp_server.has_tool("brave_web_search"));
|
||||
|
||||
mcp_server.add_tool_server(mock_server.url()).await.unwrap();
|
||||
|
||||
let test_tools = vec!["brave_web_search", "brave_local_search", "calculator"];
|
||||
for tool in test_tools {
|
||||
let available = mcp_server.has_tool(tool);
|
||||
match tool {
|
||||
"brave_web_search" | "brave_local_search" => {
|
||||
assert!(
|
||||
available,
|
||||
"Tool {} should be available from mock server",
|
||||
tool
|
||||
);
|
||||
}
|
||||
"calculator" => {
|
||||
assert!(
|
||||
!available,
|
||||
"Tool {} should not be available from mock server",
|
||||
tool
|
||||
);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multi_server_url_parsing() {
|
||||
let mock_server1 = create_mock_server().await;
|
||||
let mock_server2 = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
let combined_urls = format!("{},{}", mock_server1.url(), mock_server2.url());
|
||||
let result = mcp_server.add_tool_server(combined_urls).await;
|
||||
assert!(result.is_ok(), "Should connect to multiple servers");
|
||||
|
||||
let stats = mcp_server.get_tool_stats();
|
||||
assert!(stats.total_servers >= 1);
|
||||
assert!(stats.total_tools >= 2);
|
||||
}
|
||||
|
||||
// Tool Session Management Tests
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_individual_tool_session_creation() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
mcp_server.add_tool_server(mock_server.url()).await.unwrap();
|
||||
|
||||
let session_result = mcp_server.get_tool_session("brave_web_search").await;
|
||||
assert!(session_result.is_ok(), "Should create tool session");
|
||||
|
||||
let session = session_result.unwrap();
|
||||
assert!(session.is_ready(), "Session should be ready");
|
||||
assert!(session.connection_info().contains("HTTP"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multi_tool_session_manager() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
mcp_server.add_tool_server(mock_server.url()).await.unwrap();
|
||||
let available_tools = mcp_server.list_tools();
|
||||
assert!(
|
||||
!available_tools.is_empty(),
|
||||
"Should have tools from mock server"
|
||||
);
|
||||
|
||||
let session_manager_result = mcp_server
|
||||
.create_multi_tool_session(available_tools.clone())
|
||||
.await;
|
||||
assert!(
|
||||
session_manager_result.is_ok(),
|
||||
"Should create session manager"
|
||||
);
|
||||
|
||||
let session_manager = session_manager_result.unwrap();
|
||||
|
||||
for tool in &available_tools {
|
||||
assert!(session_manager.has_tool(tool));
|
||||
}
|
||||
|
||||
let stats = session_manager.session_stats();
|
||||
// After optimization: 1 session per server (not per tool)
|
||||
assert_eq!(stats.total_sessions, 1); // One session for the mock server
|
||||
assert_eq!(stats.ready_sessions, 1); // One ready session
|
||||
assert_eq!(stats.unique_servers, 1); // One unique server
|
||||
|
||||
// But we still have all tools available
|
||||
assert_eq!(session_manager.list_tools().len(), available_tools.len());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_execution_with_mock() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
mcp_server.add_tool_server(mock_server.url()).await.unwrap();
|
||||
|
||||
let result = mcp_server
|
||||
.call_tool(
|
||||
"brave_web_search",
|
||||
json!({
|
||||
"query": "rust programming",
|
||||
"count": 1
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert!(
|
||||
result.is_ok(),
|
||||
"Tool execution should succeed with mock server"
|
||||
);
|
||||
|
||||
let response = result.unwrap();
|
||||
assert!(
|
||||
response.get("content").is_some(),
|
||||
"Response should have content"
|
||||
);
|
||||
assert_eq!(response.get("isError").unwrap(), false);
|
||||
|
||||
let content = response.get("content").unwrap().as_array().unwrap();
|
||||
let text = content[0].get("text").unwrap().as_str().unwrap();
|
||||
assert!(text.contains("Mock search results for: rust programming"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_concurrent_tool_execution() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut session_manager = MultiToolSessionManager::new();
|
||||
|
||||
session_manager
|
||||
.add_tools_from_server(
|
||||
mock_server.url(),
|
||||
vec![
|
||||
"brave_web_search".to_string(),
|
||||
"brave_local_search".to_string(),
|
||||
],
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool_calls = vec![
|
||||
("brave_web_search".to_string(), json!({"query": "test1"})),
|
||||
("brave_local_search".to_string(), json!({"query": "test2"})),
|
||||
];
|
||||
|
||||
let results = session_manager.call_tools_concurrent(tool_calls).await;
|
||||
assert_eq!(results.len(), 2, "Should return results for both tools");
|
||||
|
||||
for (i, result) in results.iter().enumerate() {
|
||||
assert!(result.is_ok(), "Tool {} should succeed with mock server", i);
|
||||
|
||||
let response = result.as_ref().unwrap();
|
||||
assert!(response.get("content").is_some());
|
||||
assert_eq!(response.get("isError").unwrap(), false);
|
||||
}
|
||||
}
|
||||
|
||||
// Error Handling Tests
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_tool_execution_errors() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
mcp_server.add_tool_server(mock_server.url()).await.unwrap();
|
||||
|
||||
let result = mcp_server.call_tool("unknown_tool", json!({})).await;
|
||||
assert!(result.is_err(), "Should fail for unknown tool");
|
||||
|
||||
let session = mcp_server
|
||||
.get_tool_session("brave_web_search")
|
||||
.await
|
||||
.unwrap();
|
||||
let session_result = session.call_tool("unknown_tool", json!({})).await;
|
||||
assert!(
|
||||
session_result.is_err(),
|
||||
"Session should fail for unknown tool"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connection_without_server() {
|
||||
let mut server = MCPToolServer::new();
|
||||
|
||||
let result = server
|
||||
.add_tool_server("http://localhost:9999/mcp".to_string())
|
||||
.await;
|
||||
assert!(result.is_err(), "Should fail when no server is running");
|
||||
|
||||
let error_msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
error_msg.contains("Failed to connect") || error_msg.contains("Connection"),
|
||||
"Error should be connection-related: {}",
|
||||
error_msg
|
||||
);
|
||||
}
|
||||
|
||||
// Schema Adaptation Tests
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_schema_validation() {
|
||||
let mock_server = create_mock_server().await;
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
mcp_server.add_tool_server(mock_server.url()).await.unwrap();
|
||||
|
||||
let description = mcp_server.get_tool_description("brave_web_search");
|
||||
assert!(description.is_some(), "Should have tool description");
|
||||
|
||||
let desc_value = description.unwrap();
|
||||
assert!(desc_value.get("name").is_some());
|
||||
assert!(desc_value.get("description").is_some());
|
||||
}
|
||||
|
||||
// SSE Parsing Tests
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sse_event_parsing_success() {
|
||||
let valid_event = "data: {\"jsonrpc\": \"2.0\", \"id\": \"1\", \"result\": {\"test\": \"success\", \"content\": [{\"type\": \"text\", \"text\": \"Hello\"}]}}";
|
||||
|
||||
let result = parse_sse_event(valid_event);
|
||||
assert!(result.is_ok(), "Valid SSE event should parse successfully");
|
||||
|
||||
let parsed = result.unwrap();
|
||||
assert!(parsed.is_some(), "Should return parsed data");
|
||||
|
||||
let response = parsed.unwrap();
|
||||
assert_eq!(response["test"], "success");
|
||||
assert!(response.get("content").is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sse_event_parsing_error() {
|
||||
let error_event = "data: {\"jsonrpc\": \"2.0\", \"id\": \"1\", \"error\": {\"code\": -1, \"message\": \"Rate limit exceeded\"}}";
|
||||
|
||||
let result = parse_sse_event(error_event);
|
||||
assert!(result.is_err(), "Error SSE event should return error");
|
||||
|
||||
let error_msg = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
error_msg.contains("Rate limit exceeded"),
|
||||
"Should contain error message"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sse_event_parsing_empty() {
|
||||
let empty_event = "";
|
||||
let result = parse_sse_event(empty_event);
|
||||
assert!(result.is_ok(), "Empty event should parse successfully");
|
||||
assert!(result.unwrap().is_none(), "Empty event should return None");
|
||||
|
||||
let no_data_event = "event: ping\nid: 123";
|
||||
let result2 = parse_sse_event(no_data_event);
|
||||
assert!(result2.is_ok(), "Non-data event should parse successfully");
|
||||
assert!(
|
||||
result2.unwrap().is_none(),
|
||||
"Non-data event should return None"
|
||||
);
|
||||
}
|
||||
|
||||
// Connection Type Tests
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_connection_type_detection() {
|
||||
let mock_server = create_mock_server().await;
|
||||
|
||||
let session_result = ToolSession::new(mock_server.url()).await;
|
||||
assert!(session_result.is_ok(), "Should create HTTP session");
|
||||
|
||||
let session = session_result.unwrap();
|
||||
assert!(session.connection_info().contains("HTTP"));
|
||||
assert!(session.is_ready(), "HTTP session should be ready");
|
||||
|
||||
// Stdio sessions are no longer supported - test invalid URL handling
|
||||
let invalid_session = ToolSession::new("invalid-url".to_string()).await;
|
||||
assert!(invalid_session.is_err(), "Should reject non-HTTP URLs");
|
||||
}
|
||||
|
||||
// Integration Pattern Tests
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_responses_api_integration_patterns() {
|
||||
let mock_server = create_mock_server().await;
|
||||
|
||||
// Server initialization
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
|
||||
// Tool server connection (like responses API startup)
|
||||
match mcp_server.add_tool_server(mock_server.url()).await {
|
||||
Ok(_) => {
|
||||
let stats = mcp_server.get_tool_stats();
|
||||
assert_eq!(stats.total_tools, 2);
|
||||
assert_eq!(stats.total_servers, 1);
|
||||
}
|
||||
Err(e) => {
|
||||
panic!("Should connect to mock server: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// Tool availability checking
|
||||
let test_tools = vec!["brave_web_search", "brave_local_search", "calculator"];
|
||||
for tool in &test_tools {
|
||||
let _available = mcp_server.has_tool(tool);
|
||||
}
|
||||
|
||||
// Tool session creation
|
||||
if mcp_server.has_tool("brave_web_search") {
|
||||
let session_result = mcp_server.get_tool_session("brave_web_search").await;
|
||||
assert!(session_result.is_ok(), "Should create tool session");
|
||||
}
|
||||
|
||||
// Multi-tool session creation
|
||||
let available_tools = mcp_server.list_tools();
|
||||
if !available_tools.is_empty() {
|
||||
let session_manager_result = mcp_server.create_multi_tool_session(available_tools).await;
|
||||
assert!(
|
||||
session_manager_result.is_ok(),
|
||||
"Should create multi-tool session"
|
||||
);
|
||||
}
|
||||
|
||||
// Tool execution
|
||||
let result = mcp_server
|
||||
.call_tool(
|
||||
"brave_web_search",
|
||||
json!({
|
||||
"query": "SGLang router MCP integration",
|
||||
"count": 1
|
||||
}),
|
||||
)
|
||||
.await;
|
||||
if result.is_err() {
|
||||
// This might fail if called after another test that uses the same tool name
|
||||
// Due to the shared mock server. That's OK, the main test covers this.
|
||||
return;
|
||||
}
|
||||
assert!(result.is_ok(), "Should execute tool successfully");
|
||||
}
|
||||
|
||||
// Complete Integration Test
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_responses_api_integration() {
|
||||
let mock_server = create_mock_server().await;
|
||||
|
||||
// Run through all functionality required for responses API integration
|
||||
let mut mcp_server = MCPToolServer::new();
|
||||
mcp_server.add_tool_server(mock_server.url()).await.unwrap();
|
||||
|
||||
// Test all core functionality
|
||||
assert!(mcp_server.has_tool("brave_web_search"));
|
||||
|
||||
let session = mcp_server
|
||||
.get_tool_session("brave_web_search")
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(session.is_ready());
|
||||
|
||||
let session_manager = mcp_server
|
||||
.create_multi_tool_session(mcp_server.list_tools())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(session_manager.session_stats().total_sessions > 0);
|
||||
|
||||
let result = mcp_server
|
||||
.call_tool(
|
||||
"brave_web_search",
|
||||
json!({
|
||||
"query": "test",
|
||||
"count": 1
|
||||
}),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(result.get("content").is_some());
|
||||
|
||||
// Verify all required capabilities for responses API integration
|
||||
let capabilities = [
|
||||
"MCP server initialization",
|
||||
"Tool server connection and discovery",
|
||||
"Tool availability checking",
|
||||
"Individual tool session management",
|
||||
"Multi-tool session manager (Python tool_session_ctxs pattern)",
|
||||
"Concurrent tool execution",
|
||||
"Direct tool execution",
|
||||
"Error handling and robustness",
|
||||
"Protocol compliance (SSE parsing)",
|
||||
"Schema adaptation (Python parity)",
|
||||
"Mock server integration (no external dependencies)",
|
||||
];
|
||||
|
||||
assert_eq!(capabilities.len(), 11);
|
||||
}
|
||||
392
sgl-router/tests/request_formats_test.rs
Normal file
392
sgl-router/tests/request_formats_test.rs
Normal file
@@ -0,0 +1,392 @@
|
||||
mod common;
|
||||
|
||||
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
||||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::config::{
|
||||
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||
};
|
||||
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Test context that manages mock workers
|
||||
struct TestContext {
|
||||
workers: Vec<MockWorker>,
|
||||
router: Arc<dyn RouterTrait>,
|
||||
}
|
||||
|
||||
impl TestContext {
|
||||
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
|
||||
let mut config = RouterConfig {
|
||||
mode: RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
},
|
||||
policy: PolicyConfig::Random,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 3003,
|
||||
max_payload_size: 256 * 1024 * 1024,
|
||||
request_timeout_secs: 600,
|
||||
worker_startup_timeout_secs: 1,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
log_level: None,
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
};
|
||||
|
||||
let mut workers = Vec::new();
|
||||
let mut worker_urls = Vec::new();
|
||||
|
||||
for worker_config in worker_configs {
|
||||
let mut worker = MockWorker::new(worker_config);
|
||||
let url = worker.start().await.unwrap();
|
||||
worker_urls.push(url);
|
||||
workers.push(worker);
|
||||
}
|
||||
|
||||
if !workers.is_empty() {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
||||
}
|
||||
|
||||
config.mode = RoutingMode::Regular { worker_urls };
|
||||
|
||||
let app_context = common::create_test_context(config);
|
||||
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
||||
let router = Arc::from(router);
|
||||
|
||||
if !workers.is_empty() {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||
}
|
||||
|
||||
Self { workers, router }
|
||||
}
|
||||
|
||||
async fn shutdown(mut self) {
|
||||
// Small delay to ensure any pending operations complete
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
for worker in &mut self.workers {
|
||||
worker.stop().await;
|
||||
}
|
||||
|
||||
// Another small delay to ensure cleanup completes
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
|
||||
async fn make_request(
|
||||
&self,
|
||||
endpoint: &str,
|
||||
body: serde_json::Value,
|
||||
) -> Result<serde_json::Value, String> {
|
||||
let client = Client::new();
|
||||
|
||||
// Get any worker URL for testing
|
||||
let worker_urls = self.router.get_worker_urls();
|
||||
if worker_urls.is_empty() {
|
||||
return Err("No available workers".to_string());
|
||||
}
|
||||
|
||||
let worker_url = &worker_urls[0];
|
||||
|
||||
let response = client
|
||||
.post(format!("{}{}", worker_url, endpoint))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Request failed: {}", e))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(format!("Request failed with status: {}", response.status()));
|
||||
}
|
||||
|
||||
response
|
||||
.json::<serde_json::Value>()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse response: {}", e))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod request_format_tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_request_formats() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 19001,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
// Test 1: Basic text request
|
||||
let payload = json!({
|
||||
"text": "Hello, world!",
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test 2: Request with sampling parameters
|
||||
let payload = json!({
|
||||
"text": "Tell me a story",
|
||||
"sampling_params": {
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 100,
|
||||
"top_p": 0.9
|
||||
},
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test 3: Request with input_ids
|
||||
let payload = json!({
|
||||
"input_ids": [1, 2, 3, 4, 5],
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 50
|
||||
},
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_v1_chat_completions_formats() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 19002,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
// Test 1: Basic chat completion
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{"role": "user", "content": "Hello!"}
|
||||
],
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/v1/chat/completions", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let response = result.unwrap();
|
||||
assert!(response.get("choices").is_some());
|
||||
assert!(response.get("id").is_some());
|
||||
assert_eq!(
|
||||
response.get("object").and_then(|v| v.as_str()),
|
||||
Some("chat.completion")
|
||||
);
|
||||
|
||||
// Test 2: Chat completion with parameters
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Tell me a joke"}
|
||||
],
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 150,
|
||||
"top_p": 0.95,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/v1/chat/completions", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_v1_completions_formats() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 19003,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
// Test 1: Basic completion
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"prompt": "Once upon a time",
|
||||
"max_tokens": 50,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/v1/completions", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let response = result.unwrap();
|
||||
assert!(response.get("choices").is_some());
|
||||
assert_eq!(
|
||||
response.get("object").and_then(|v| v.as_str()),
|
||||
Some("text_completion")
|
||||
);
|
||||
|
||||
// Test 2: Completion with array prompt
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"prompt": ["First prompt", "Second prompt"],
|
||||
"temperature": 0.5,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/v1/completions", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test 3: Completion with logprobs
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"prompt": "The capital of France is",
|
||||
"max_tokens": 10,
|
||||
"logprobs": 5,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/v1/completions", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_requests() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 19004,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
// Test batch text generation
|
||||
let payload = json!({
|
||||
"text": ["First text", "Second text", "Third text"],
|
||||
"sampling_params": {
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 50
|
||||
},
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test batch with input_ids
|
||||
let payload = json!({
|
||||
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_special_parameters() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 19005,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
// Test with return_logprob
|
||||
let payload = json!({
|
||||
"text": "Test",
|
||||
"return_logprob": true,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test with json_schema
|
||||
let payload = json!({
|
||||
"text": "Generate JSON",
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"json_schema": "$$ANY$$"
|
||||
},
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
// Test with ignore_eos
|
||||
let payload = json!({
|
||||
"text": "Continue forever",
|
||||
"sampling_params": {
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 100,
|
||||
"ignore_eos": true
|
||||
},
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_error_handling() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 19006,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
// Test with empty body - should still work with mock worker
|
||||
let payload = json!({});
|
||||
|
||||
let result = ctx.make_request("/generate", payload).await;
|
||||
// Mock worker accepts empty body
|
||||
assert!(result.is_ok());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
}
|
||||
210
sgl-router/tests/responses_api_test.rs
Normal file
210
sgl-router/tests/responses_api_test.rs
Normal file
@@ -0,0 +1,210 @@
|
||||
// Integration test for Responses API
|
||||
|
||||
use sglang_router_rs::protocols::spec::{
|
||||
GenerationRequest, ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseStatus,
|
||||
ResponseTool, ResponseToolType, ResponsesRequest, ResponsesResponse, ServiceTier, ToolChoice,
|
||||
ToolChoiceValue, Truncation, UsageInfo,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn test_responses_request_creation() {
|
||||
let request = ResponsesRequest {
|
||||
background: false,
|
||||
include: None,
|
||||
input: ResponseInput::Text("Hello, world!".to_string()),
|
||||
instructions: Some("Be helpful".to_string()),
|
||||
max_output_tokens: Some(100),
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: Some("test-model".to_string()),
|
||||
parallel_tool_calls: true,
|
||||
previous_response_id: None,
|
||||
reasoning: Some(ResponseReasoningParam {
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
}),
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true,
|
||||
stream: false,
|
||||
temperature: Some(0.7),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::WebSearchPreview,
|
||||
}],
|
||||
top_logprobs: 5,
|
||||
top_p: Some(0.9),
|
||||
truncation: Truncation::Disabled,
|
||||
user: Some("test-user".to_string()),
|
||||
request_id: "resp_test123".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.0,
|
||||
presence_penalty: 0.0,
|
||||
stop: None,
|
||||
top_k: -1,
|
||||
min_p: 0.0,
|
||||
repetition_penalty: 1.0,
|
||||
};
|
||||
|
||||
// Test GenerationRequest trait implementation
|
||||
assert!(!request.is_stream());
|
||||
assert_eq!(request.get_model(), Some("test-model"));
|
||||
let routing_text = request.extract_text_for_routing();
|
||||
assert_eq!(routing_text, "Hello, world!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sampling_params_conversion() {
|
||||
let request = ResponsesRequest {
|
||||
background: false,
|
||||
include: None,
|
||||
input: ResponseInput::Text("Test".to_string()),
|
||||
instructions: None,
|
||||
max_output_tokens: Some(50),
|
||||
max_tool_calls: None,
|
||||
metadata: None,
|
||||
model: Some("test-model".to_string()),
|
||||
parallel_tool_calls: true, // Use default true
|
||||
previous_response_id: None,
|
||||
reasoning: None,
|
||||
service_tier: ServiceTier::Auto,
|
||||
store: true, // Use default true
|
||||
stream: false,
|
||||
temperature: Some(0.8),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
|
||||
tools: vec![],
|
||||
top_logprobs: 0, // Use default 0
|
||||
top_p: Some(0.95),
|
||||
truncation: Truncation::Auto,
|
||||
user: None,
|
||||
request_id: "resp_test456".to_string(),
|
||||
priority: 0,
|
||||
frequency_penalty: 0.1,
|
||||
presence_penalty: 0.2,
|
||||
stop: None,
|
||||
top_k: 10,
|
||||
min_p: 0.05,
|
||||
repetition_penalty: 1.1,
|
||||
};
|
||||
|
||||
let params = request.to_sampling_params(1000, None);
|
||||
|
||||
// Check that parameters are converted correctly
|
||||
assert!(params.contains_key("temperature"));
|
||||
assert!(params.contains_key("top_p"));
|
||||
assert!(params.contains_key("frequency_penalty"));
|
||||
assert!(params.contains_key("max_new_tokens"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_responses_response_creation() {
|
||||
let response = ResponsesResponse::new(
|
||||
"resp_test789".to_string(),
|
||||
"test-model".to_string(),
|
||||
ResponseStatus::Completed,
|
||||
);
|
||||
|
||||
assert_eq!(response.id, "resp_test789");
|
||||
assert_eq!(response.model, "test-model");
|
||||
assert!(response.is_complete());
|
||||
assert!(!response.is_in_progress());
|
||||
assert!(!response.is_failed());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_usage_conversion() {
|
||||
let usage_info = UsageInfo::new_with_cached(15, 25, Some(8), 3);
|
||||
let response_usage = usage_info.to_response_usage();
|
||||
|
||||
assert_eq!(response_usage.input_tokens, 15);
|
||||
assert_eq!(response_usage.output_tokens, 25);
|
||||
assert_eq!(response_usage.total_tokens, 40);
|
||||
|
||||
// Check details are converted correctly
|
||||
assert!(response_usage.input_tokens_details.is_some());
|
||||
assert_eq!(
|
||||
response_usage
|
||||
.input_tokens_details
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.cached_tokens,
|
||||
3
|
||||
);
|
||||
|
||||
assert!(response_usage.output_tokens_details.is_some());
|
||||
assert_eq!(
|
||||
response_usage
|
||||
.output_tokens_details
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.reasoning_tokens,
|
||||
8
|
||||
);
|
||||
|
||||
// Test reverse conversion
|
||||
let back_to_usage = response_usage.to_usage_info();
|
||||
assert_eq!(back_to_usage.prompt_tokens, 15);
|
||||
assert_eq!(back_to_usage.completion_tokens, 25);
|
||||
assert_eq!(back_to_usage.reasoning_tokens, Some(8));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_reasoning_param_default() {
|
||||
let param = ResponseReasoningParam {
|
||||
effort: Some(ReasoningEffort::Medium),
|
||||
};
|
||||
|
||||
// Test JSON serialization/deserialization preserves default
|
||||
let json = serde_json::to_string(¶m).unwrap();
|
||||
let parsed: ResponseReasoningParam = serde_json::from_str(&json).unwrap();
|
||||
|
||||
assert!(matches!(parsed.effort, Some(ReasoningEffort::Medium)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_json_serialization() {
|
||||
let request = ResponsesRequest {
|
||||
background: true,
|
||||
include: None,
|
||||
input: ResponseInput::Text("Test input".to_string()),
|
||||
instructions: Some("Test instructions".to_string()),
|
||||
max_output_tokens: Some(200),
|
||||
max_tool_calls: Some(5),
|
||||
metadata: None,
|
||||
model: Some("gpt-4".to_string()),
|
||||
parallel_tool_calls: false,
|
||||
previous_response_id: None,
|
||||
reasoning: Some(ResponseReasoningParam {
|
||||
effort: Some(ReasoningEffort::High),
|
||||
}),
|
||||
service_tier: ServiceTier::Priority,
|
||||
store: false,
|
||||
stream: true,
|
||||
temperature: Some(0.9),
|
||||
tool_choice: ToolChoice::Value(ToolChoiceValue::Required),
|
||||
tools: vec![ResponseTool {
|
||||
r#type: ResponseToolType::CodeInterpreter,
|
||||
}],
|
||||
top_logprobs: 10,
|
||||
top_p: Some(0.8),
|
||||
truncation: Truncation::Auto,
|
||||
user: Some("test_user".to_string()),
|
||||
request_id: "resp_comprehensive_test".to_string(),
|
||||
priority: 1,
|
||||
frequency_penalty: 0.3,
|
||||
presence_penalty: 0.4,
|
||||
stop: None,
|
||||
top_k: 50,
|
||||
min_p: 0.1,
|
||||
repetition_penalty: 1.2,
|
||||
};
|
||||
|
||||
// Test that everything can be serialized to JSON and back
|
||||
let json = serde_json::to_string(&request).expect("Serialization should work");
|
||||
let parsed: ResponsesRequest =
|
||||
serde_json::from_str(&json).expect("Deserialization should work");
|
||||
|
||||
assert_eq!(parsed.request_id, "resp_comprehensive_test");
|
||||
assert_eq!(parsed.model, Some("gpt-4".to_string()));
|
||||
assert!(parsed.background);
|
||||
assert!(parsed.stream);
|
||||
assert_eq!(parsed.tools.len(), 1);
|
||||
}
|
||||
370
sgl-router/tests/streaming_tests.rs
Normal file
370
sgl-router/tests/streaming_tests.rs
Normal file
@@ -0,0 +1,370 @@
|
||||
mod common;
|
||||
|
||||
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::config::{
|
||||
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||
};
|
||||
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Test context that manages mock workers
|
||||
struct TestContext {
|
||||
workers: Vec<MockWorker>,
|
||||
router: Arc<dyn RouterTrait>,
|
||||
}
|
||||
|
||||
impl TestContext {
|
||||
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
|
||||
let mut config = RouterConfig {
|
||||
mode: RoutingMode::Regular {
|
||||
worker_urls: vec![],
|
||||
},
|
||||
policy: PolicyConfig::Random,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 3004,
|
||||
max_payload_size: 256 * 1024 * 1024,
|
||||
request_timeout_secs: 600,
|
||||
worker_startup_timeout_secs: 1,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
log_level: None,
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 60,
|
||||
rate_limit_tokens_per_second: None,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
};
|
||||
|
||||
let mut workers = Vec::new();
|
||||
let mut worker_urls = Vec::new();
|
||||
|
||||
for worker_config in worker_configs {
|
||||
let mut worker = MockWorker::new(worker_config);
|
||||
let url = worker.start().await.unwrap();
|
||||
worker_urls.push(url);
|
||||
workers.push(worker);
|
||||
}
|
||||
|
||||
if !workers.is_empty() {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
|
||||
}
|
||||
|
||||
config.mode = RoutingMode::Regular { worker_urls };
|
||||
|
||||
let app_context = common::create_test_context(config);
|
||||
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
||||
let router = Arc::from(router);
|
||||
|
||||
if !workers.is_empty() {
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
||||
}
|
||||
|
||||
Self { workers, router }
|
||||
}
|
||||
|
||||
async fn shutdown(mut self) {
|
||||
// Small delay to ensure any pending operations complete
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
|
||||
for worker in &mut self.workers {
|
||||
worker.stop().await;
|
||||
}
|
||||
|
||||
// Another small delay to ensure cleanup completes
|
||||
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||
}
|
||||
|
||||
async fn make_streaming_request(
|
||||
&self,
|
||||
endpoint: &str,
|
||||
body: serde_json::Value,
|
||||
) -> Result<Vec<String>, String> {
|
||||
let client = Client::new();
|
||||
|
||||
// Get any worker URL for testing
|
||||
let worker_urls = self.router.get_worker_urls();
|
||||
if worker_urls.is_empty() {
|
||||
return Err("No available workers".to_string());
|
||||
}
|
||||
|
||||
let worker_url = &worker_urls[0];
|
||||
|
||||
let response = client
|
||||
.post(format!("{}{}", worker_url, endpoint))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await
|
||||
.map_err(|e| format!("Request failed: {}", e))?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(format!("Request failed with status: {}", response.status()));
|
||||
}
|
||||
|
||||
// Check if it's a streaming response
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get("content-type")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
if !content_type.contains("text/event-stream") {
|
||||
return Err("Response is not a stream".to_string());
|
||||
}
|
||||
|
||||
let mut stream = response.bytes_stream();
|
||||
let mut events = Vec::new();
|
||||
|
||||
while let Some(chunk) = stream.next().await {
|
||||
if let Ok(bytes) = chunk {
|
||||
let text = String::from_utf8_lossy(&bytes);
|
||||
for line in text.lines() {
|
||||
if let Some(stripped) = line.strip_prefix("data: ") {
|
||||
events.push(stripped.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(events)
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod streaming_tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_generate_streaming() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 20001,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 10,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let payload = json!({
|
||||
"text": "Stream test",
|
||||
"stream": true,
|
||||
"sampling_params": {
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 10
|
||||
}
|
||||
});
|
||||
|
||||
let result = ctx.make_streaming_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let events = result.unwrap();
|
||||
// Should have at least one data chunk and [DONE]
|
||||
assert!(events.len() >= 2);
|
||||
assert_eq!(events.last().unwrap(), "[DONE]");
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_v1_chat_completions_streaming() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 20002,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 10,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"messages": [
|
||||
{"role": "user", "content": "Count to 3"}
|
||||
],
|
||||
"stream": true,
|
||||
"max_tokens": 20
|
||||
});
|
||||
|
||||
let result = ctx
|
||||
.make_streaming_request("/v1/chat/completions", payload)
|
||||
.await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let events = result.unwrap();
|
||||
assert!(events.len() >= 2); // At least one chunk + [DONE]
|
||||
|
||||
// Verify events are valid JSON (except [DONE])
|
||||
for event in &events {
|
||||
if event != "[DONE]" {
|
||||
let parsed: Result<serde_json::Value, _> = serde_json::from_str(event);
|
||||
assert!(parsed.is_ok(), "Invalid JSON in SSE event: {}", event);
|
||||
|
||||
let json = parsed.unwrap();
|
||||
assert_eq!(
|
||||
json.get("object").and_then(|v| v.as_str()),
|
||||
Some("chat.completion.chunk")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_v1_completions_streaming() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 20003,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 10,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let payload = json!({
|
||||
"model": "test-model",
|
||||
"prompt": "Once upon a time",
|
||||
"stream": true,
|
||||
"max_tokens": 15
|
||||
});
|
||||
|
||||
let result = ctx.make_streaming_request("/v1/completions", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let events = result.unwrap();
|
||||
assert!(events.len() >= 2); // At least one chunk + [DONE]
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_with_error() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 20004,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 0,
|
||||
fail_rate: 1.0, // Always fail
|
||||
}])
|
||||
.await;
|
||||
|
||||
let payload = json!({
|
||||
"text": "This should fail",
|
||||
"stream": true
|
||||
});
|
||||
|
||||
let result = ctx.make_streaming_request("/generate", payload).await;
|
||||
// With fail_rate: 1.0, the request should fail
|
||||
assert!(result.is_err());
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_timeouts() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 20005,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 100, // Slow response
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
let payload = json!({
|
||||
"text": "Slow stream",
|
||||
"stream": true,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 5
|
||||
}
|
||||
});
|
||||
|
||||
let start = std::time::Instant::now();
|
||||
let result = ctx.make_streaming_request("/generate", payload).await;
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
assert!(result.is_ok());
|
||||
let events = result.unwrap();
|
||||
|
||||
// Should have received multiple chunks over time
|
||||
assert!(!events.is_empty());
|
||||
assert!(elapsed.as_millis() >= 100); // At least one delay
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_batch_streaming() {
|
||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
||||
port: 20006,
|
||||
worker_type: WorkerType::Regular,
|
||||
health_status: HealthStatus::Healthy,
|
||||
response_delay_ms: 10,
|
||||
fail_rate: 0.0,
|
||||
}])
|
||||
.await;
|
||||
|
||||
// Batch request with streaming
|
||||
let payload = json!({
|
||||
"text": ["First", "Second", "Third"],
|
||||
"stream": true,
|
||||
"sampling_params": {
|
||||
"max_new_tokens": 5
|
||||
}
|
||||
});
|
||||
|
||||
let result = ctx.make_streaming_request("/generate", payload).await;
|
||||
assert!(result.is_ok());
|
||||
|
||||
let events = result.unwrap();
|
||||
// Should have multiple events for batch
|
||||
assert!(events.len() >= 4); // At least 3 responses + [DONE]
|
||||
|
||||
ctx.shutdown().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_sse_format_parsing() {
|
||||
// Test SSE format parsing
|
||||
let parse_sse_chunk = |chunk: &[u8]| -> Vec<String> {
|
||||
let text = String::from_utf8_lossy(chunk);
|
||||
text.lines()
|
||||
.filter(|line| line.starts_with("data: "))
|
||||
.map(|line| line[6..].to_string())
|
||||
.collect()
|
||||
};
|
||||
|
||||
let sse_data =
|
||||
b"data: {\"text\":\"Hello\"}\n\ndata: {\"text\":\" world\"}\n\ndata: [DONE]\n\n";
|
||||
let events = parse_sse_chunk(sse_data);
|
||||
|
||||
assert_eq!(events.len(), 3);
|
||||
assert_eq!(events[0], "{\"text\":\"Hello\"}");
|
||||
assert_eq!(events[1], "{\"text\":\" world\"}");
|
||||
assert_eq!(events[2], "[DONE]");
|
||||
|
||||
// Test with mixed content
|
||||
let mixed = b"event: message\ndata: {\"test\":true}\n\n: comment\ndata: [DONE]\n\n";
|
||||
let events = parse_sse_chunk(mixed);
|
||||
|
||||
assert_eq!(events.len(), 2);
|
||||
assert_eq!(events[0], "{\"test\":true}");
|
||||
assert_eq!(events[1], "[DONE]");
|
||||
}
|
||||
}
|
||||
150
sgl-router/tests/test_chat_template.rs
Normal file
150
sgl-router/tests/test_chat_template.rs
Normal file
@@ -0,0 +1,150 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor};
|
||||
|
||||
#[test]
|
||||
fn test_chat_message_helpers() {
|
||||
let system_msg = ChatMessage::system("You are a helpful assistant");
|
||||
assert_eq!(system_msg.role, "system");
|
||||
assert_eq!(system_msg.content, "You are a helpful assistant");
|
||||
|
||||
let user_msg = ChatMessage::user("Hello!");
|
||||
assert_eq!(user_msg.role, "user");
|
||||
assert_eq!(user_msg.content, "Hello!");
|
||||
|
||||
let assistant_msg = ChatMessage::assistant("Hi there!");
|
||||
assert_eq!(assistant_msg.role, "assistant");
|
||||
assert_eq!(assistant_msg.content, "Hi there!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_llama_style_template() {
|
||||
// Test a Llama-style chat template
|
||||
let template = r#"
|
||||
{%- if messages[0]['role'] == 'system' -%}
|
||||
{%- set system_message = messages[0]['content'] -%}
|
||||
{%- set messages = messages[1:] -%}
|
||||
{%- else -%}
|
||||
{%- set system_message = '' -%}
|
||||
{%- endif -%}
|
||||
|
||||
{{- bos_token }}
|
||||
{%- if system_message %}
|
||||
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
|
||||
{%- endif %}
|
||||
|
||||
{%- for message in messages %}
|
||||
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
|
||||
{%- endfor %}
|
||||
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<|begin_of_text|>".to_string()),
|
||||
Some("<|end_of_text|>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::system("You are a helpful assistant"),
|
||||
ChatMessage::user("What is 2+2?"),
|
||||
];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||
|
||||
// Check that the result contains expected markers
|
||||
assert!(result.contains("<|begin_of_text|>"));
|
||||
assert!(result.contains("<|start_header_id|>system<|end_header_id|>"));
|
||||
assert!(result.contains("You are a helpful assistant"));
|
||||
assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
|
||||
assert!(result.contains("What is 2+2?"));
|
||||
assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_chatml_template() {
|
||||
// Test a ChatML-style template
|
||||
let template = r#"
|
||||
{%- for message in messages %}
|
||||
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|im_start|>assistant\n' }}
|
||||
{%- endif %}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there!"),
|
||||
ChatMessage::user("How are you?"),
|
||||
];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, true).unwrap();
|
||||
|
||||
// Check ChatML format
|
||||
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
|
||||
assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>"));
|
||||
assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>"));
|
||||
assert!(result.ends_with("<|im_start|>assistant\n"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_without_generation_prompt() {
|
||||
let template = r#"
|
||||
{%- for message in messages -%}
|
||||
{{ message.role }}: {{ message.content }}
|
||||
{% endfor -%}
|
||||
{%- if add_generation_prompt -%}
|
||||
assistant:
|
||||
{%- endif -%}
|
||||
"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![ChatMessage::user("Test")];
|
||||
|
||||
// Test without generation prompt
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert_eq!(result.trim(), "user: Test");
|
||||
|
||||
// Test with generation prompt
|
||||
let result_with_prompt = processor.apply_chat_template(&messages, true).unwrap();
|
||||
assert!(result_with_prompt.contains("assistant:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_template_with_special_tokens() {
|
||||
let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(
|
||||
template.to_string(),
|
||||
Some("<s>".to_string()),
|
||||
Some("</s>".to_string()),
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::user("Hello")];
|
||||
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert_eq!(result, "<s>Hello</s>");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_empty_messages() {
|
||||
let template =
|
||||
r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
|
||||
|
||||
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
|
||||
|
||||
let messages = vec![];
|
||||
let result = processor.apply_chat_template(&messages, false).unwrap();
|
||||
assert_eq!(result, "");
|
||||
}
|
||||
|
||||
// Integration test with actual tokenizer file loading would go here
|
||||
// but requires a real tokenizer_config.json file
|
||||
}
|
||||
183
sgl-router/tests/test_chat_template_loading.rs
Normal file
183
sgl-router/tests/test_chat_template_loading.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
fn test_load_chat_template_from_file() {
|
||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||
|
||||
// Create temporary directory
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let template_path = temp_dir.path().join("template.jinja");
|
||||
|
||||
// Write a test template
|
||||
let template_content = r#"
|
||||
{%- for message in messages %}
|
||||
{{- '<|' + message['role'] + '|>' + message['content'] }}
|
||||
{%- endfor %}
|
||||
{%- if add_generation_prompt %}
|
||||
{{- '<|assistant|>' }}
|
||||
{%- endif %}
|
||||
"#;
|
||||
fs::write(&template_path, template_content).unwrap();
|
||||
|
||||
// Create a mock tokenizer config
|
||||
let tokenizer_config = r#"{
|
||||
"version": "1.0",
|
||||
"truncation": null,
|
||||
"padding": null,
|
||||
"added_tokens": [],
|
||||
"normalizer": null,
|
||||
"pre_tokenizer": {
|
||||
"type": "Whitespace"
|
||||
},
|
||||
"post_processor": null,
|
||||
"decoder": null,
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {
|
||||
"hello": 0,
|
||||
"world": 1,
|
||||
"<s>": 2,
|
||||
"</s>": 3
|
||||
},
|
||||
"merges": []
|
||||
}
|
||||
}"#;
|
||||
|
||||
let tokenizer_path = temp_dir.path().join("tokenizer.json");
|
||||
fs::write(&tokenizer_path, tokenizer_config).unwrap();
|
||||
|
||||
// Load tokenizer with custom chat template
|
||||
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
|
||||
tokenizer_path.to_str().unwrap(),
|
||||
Some(template_path.to_str().unwrap()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
// Test that the custom template is used
|
||||
let messages = vec![
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::assistant("Hi there"),
|
||||
];
|
||||
|
||||
let result = tokenizer.apply_chat_template(&messages, true).unwrap();
|
||||
|
||||
// Verify the custom template format
|
||||
assert!(result.contains("<|user|>Hello"));
|
||||
assert!(result.contains("<|assistant|>Hi there"));
|
||||
assert!(result.ends_with("<|assistant|>"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_override_existing_template() {
|
||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||
|
||||
// Create temporary directory
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
|
||||
// Create tokenizer config with a built-in template
|
||||
let tokenizer_config_path = temp_dir.path().join("tokenizer_config.json");
|
||||
let config_with_template = r#"{
|
||||
"chat_template": "built-in: {% for msg in messages %}{{ msg.content }}{% endfor %}"
|
||||
}"#;
|
||||
fs::write(&tokenizer_config_path, config_with_template).unwrap();
|
||||
|
||||
// Create the actual tokenizer file
|
||||
let tokenizer_json = r#"{
|
||||
"version": "1.0",
|
||||
"truncation": null,
|
||||
"padding": null,
|
||||
"added_tokens": [],
|
||||
"normalizer": null,
|
||||
"pre_tokenizer": {
|
||||
"type": "Whitespace"
|
||||
},
|
||||
"post_processor": null,
|
||||
"decoder": null,
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {
|
||||
"test": 0,
|
||||
"<s>": 1,
|
||||
"</s>": 2
|
||||
},
|
||||
"merges": []
|
||||
}
|
||||
}"#;
|
||||
let tokenizer_path = temp_dir.path().join("tokenizer.json");
|
||||
fs::write(&tokenizer_path, tokenizer_json).unwrap();
|
||||
|
||||
// Create custom template that should override
|
||||
let custom_template_path = temp_dir.path().join("custom.jinja");
|
||||
let custom_template =
|
||||
r#"CUSTOM: {% for msg in messages %}[{{ msg.role }}]: {{ msg.content }}{% endfor %}"#;
|
||||
fs::write(&custom_template_path, custom_template).unwrap();
|
||||
|
||||
// Load with custom template - should override the built-in one
|
||||
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
|
||||
tokenizer_path.to_str().unwrap(),
|
||||
Some(custom_template_path.to_str().unwrap()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let messages = vec![ChatMessage::user("Test")];
|
||||
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
|
||||
|
||||
// Should use CUSTOM template, not built-in
|
||||
assert!(result.starts_with("CUSTOM:"));
|
||||
assert!(result.contains("[user]: Test"));
|
||||
assert!(!result.contains("built-in:"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_set_chat_template_after_creation() {
|
||||
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
|
||||
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
|
||||
|
||||
// Create temporary directory and tokenizer file
|
||||
let temp_dir = TempDir::new().unwrap();
|
||||
let tokenizer_json = r#"{
|
||||
"version": "1.0",
|
||||
"truncation": null,
|
||||
"padding": null,
|
||||
"added_tokens": [],
|
||||
"normalizer": null,
|
||||
"pre_tokenizer": {
|
||||
"type": "Whitespace"
|
||||
},
|
||||
"post_processor": null,
|
||||
"decoder": null,
|
||||
"model": {
|
||||
"type": "BPE",
|
||||
"vocab": {
|
||||
"test": 0,
|
||||
"<s>": 1,
|
||||
"</s>": 2
|
||||
},
|
||||
"merges": []
|
||||
}
|
||||
}"#;
|
||||
let tokenizer_path = temp_dir.path().join("tokenizer.json");
|
||||
fs::write(&tokenizer_path, tokenizer_json).unwrap();
|
||||
|
||||
// Load tokenizer without custom template
|
||||
let mut tokenizer =
|
||||
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()).unwrap();
|
||||
|
||||
// Set a template after creation (mimics Python's behavior)
|
||||
let new_template =
|
||||
"NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}";
|
||||
tokenizer.set_chat_template(new_template.to_string());
|
||||
|
||||
let messages = vec![ChatMessage::user("Hello"), ChatMessage::assistant("World")];
|
||||
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
|
||||
|
||||
assert!(result.starts_with("NEW:"));
|
||||
assert!(result.contains("user: Hello;"));
|
||||
assert!(result.contains("assistant: World;"));
|
||||
}
|
||||
}
|
||||
917
sgl-router/tests/test_pd_routing.rs
Normal file
917
sgl-router/tests/test_pd_routing.rs
Normal file
@@ -0,0 +1,917 @@
|
||||
#[cfg(test)]
|
||||
mod test_pd_routing {
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::config::{
|
||||
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||
};
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::routers::http::pd_types::get_hostname;
|
||||
use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy;
|
||||
use sglang_router_rs::routers::RouterFactory;
|
||||
|
||||
// Test-only struct to help validate PD request parsing
|
||||
#[derive(Debug)]
|
||||
struct PDRequest {
|
||||
pub is_stream: bool,
|
||||
pub batch_size: Option<usize>,
|
||||
}
|
||||
|
||||
impl PDRequest {
|
||||
// Extract PD-relevant info from JSON for testing
|
||||
pub fn from_json(json: &serde_json::Value) -> Self {
|
||||
let is_stream = json
|
||||
.get("stream")
|
||||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Detect batch size from text or input_ids
|
||||
let batch_size = if let Some(text) = json.get("text") {
|
||||
text.as_array().map(|arr| arr.len())
|
||||
} else if let Some(input_ids) = json.get("input_ids") {
|
||||
input_ids.as_array().map(|arr| arr.len())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
PDRequest {
|
||||
is_stream,
|
||||
batch_size,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Phase 1: Basic PD Components and Router Creation
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_worker_types() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
|
||||
// Test worker creation for prefill servers
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
assert_eq!(prefill_worker.url(), "http://prefill:8080");
|
||||
match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => {
|
||||
assert_eq!(bootstrap_port, Some(9000));
|
||||
}
|
||||
_ => panic!("Expected Prefill worker type"),
|
||||
}
|
||||
|
||||
// Test worker creation for decode servers
|
||||
let decode_worker = WorkerFactory::create_decode("http://decode:8080".to_string());
|
||||
assert_eq!(decode_worker.url(), "http://decode:8080");
|
||||
match decode_worker.worker_type() {
|
||||
WorkerType::Decode => (),
|
||||
_ => panic!("Expected Decode worker type"),
|
||||
}
|
||||
|
||||
// Test regular worker creation
|
||||
let regular_worker = WorkerFactory::create_regular("http://regular:8080".to_string());
|
||||
assert_eq!(regular_worker.url(), "http://regular:8080");
|
||||
match regular_worker.worker_type() {
|
||||
WorkerType::Regular => (),
|
||||
_ => panic!("Expected Regular worker type"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pd_selection_policies() {
|
||||
// Test all PD selection policy variants
|
||||
// Note: These policies are only used when pd_disaggregation=true
|
||||
let policies = vec![
|
||||
PDSelectionPolicy::Random,
|
||||
PDSelectionPolicy::PowerOfTwo,
|
||||
PDSelectionPolicy::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
},
|
||||
];
|
||||
|
||||
for policy in policies {
|
||||
// Verify each policy can be created and matched
|
||||
match &policy {
|
||||
PDSelectionPolicy::Random => {
|
||||
assert!(matches!(policy, PDSelectionPolicy::Random));
|
||||
}
|
||||
PDSelectionPolicy::PowerOfTwo => {
|
||||
assert!(matches!(policy, PDSelectionPolicy::PowerOfTwo));
|
||||
}
|
||||
PDSelectionPolicy::CacheAware {
|
||||
cache_threshold, ..
|
||||
} => {
|
||||
assert!(*cache_threshold >= 0.0 && *cache_threshold <= 1.0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pd_router_configuration() {
|
||||
// Test PD router configuration with various policies
|
||||
// In the new structure, RoutingMode and PolicyConfig are separate
|
||||
let test_cases = vec![
|
||||
(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![
|
||||
("http://prefill1:8080".to_string(), Some(9000)),
|
||||
("http://prefill2:8080".to_string(), None),
|
||||
],
|
||||
decode_urls: vec![
|
||||
"http://decode1:8080".to_string(),
|
||||
"http://decode2:8080".to_string(),
|
||||
],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::Random,
|
||||
),
|
||||
(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
|
||||
decode_urls: vec!["http://decode:8080".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::PowerOfTwo {
|
||||
load_check_interval_secs: 5,
|
||||
},
|
||||
),
|
||||
(
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls: vec![
|
||||
("http://p1:8080".to_string(), Some(9000)),
|
||||
("http://p2:8080".to_string(), Some(9001)),
|
||||
("http://p3:8080".to_string(), Some(9002)),
|
||||
],
|
||||
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
|
||||
prefill_policy: None,
|
||||
decode_policy: None,
|
||||
},
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold: 0.7,
|
||||
balance_abs_threshold: 20,
|
||||
balance_rel_threshold: 1.2,
|
||||
eviction_interval_secs: 60,
|
||||
max_tree_size: 1000000,
|
||||
},
|
||||
),
|
||||
];
|
||||
|
||||
for (mode, policy) in test_cases {
|
||||
let config = RouterConfig {
|
||||
mode,
|
||||
policy,
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 3001,
|
||||
max_payload_size: 1024 * 1024,
|
||||
request_timeout_secs: 60,
|
||||
worker_startup_timeout_secs: 10,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
log_level: None,
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
queue_size: 0,
|
||||
queue_timeout_secs: 60,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
circuit_breaker: CircuitBreakerConfig::default(),
|
||||
disable_retries: false,
|
||||
disable_circuit_breaker: false,
|
||||
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
|
||||
enable_igw: false,
|
||||
rate_limit_tokens_per_second: None,
|
||||
connection_mode: ConnectionMode::Http,
|
||||
model_path: None,
|
||||
tokenizer_path: None,
|
||||
};
|
||||
|
||||
// Router creation will fail due to health checks, but config should be valid
|
||||
let app_context =
|
||||
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None);
|
||||
let app_context = std::sync::Arc::new(app_context);
|
||||
let result = RouterFactory::create_router(&app_context).await;
|
||||
assert!(result.is_err());
|
||||
let error_msg = result.unwrap_err();
|
||||
// Error should be about health/timeout, not configuration
|
||||
assert!(
|
||||
error_msg.contains("healthy") || error_msg.contains("timeout"),
|
||||
"Unexpected error: {}",
|
||||
error_msg
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Phase 2: Bootstrap Injection and Request Handling
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_pd_request_from_json() {
|
||||
// Test PDRequest parsing from single text request
|
||||
let single_json = json!({
|
||||
"text": "Hello world",
|
||||
"stream": false,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100
|
||||
});
|
||||
|
||||
let pd_req = PDRequest::from_json(&single_json);
|
||||
assert!(!pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, None);
|
||||
|
||||
// Test PDRequest parsing from batch text request
|
||||
let batch_json = json!({
|
||||
"text": ["Hello", "World", "Test"],
|
||||
"stream": true,
|
||||
"temperature": 0.5
|
||||
});
|
||||
|
||||
let pd_req = PDRequest::from_json(&batch_json);
|
||||
assert!(pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, Some(3));
|
||||
|
||||
// Test PDRequest parsing from input_ids request
|
||||
let ids_json = json!({
|
||||
"input_ids": [[1, 2, 3], [4, 5, 6]],
|
||||
"stream": false
|
||||
});
|
||||
|
||||
let pd_req = PDRequest::from_json(&ids_json);
|
||||
assert!(!pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, Some(2));
|
||||
|
||||
// Test PDRequest parsing from chat request
|
||||
let chat_json = json!({
|
||||
"messages": [
|
||||
{"role": "system", "content": "You are a helpful assistant"},
|
||||
{"role": "user", "content": "Hello"}
|
||||
],
|
||||
"stream": true
|
||||
});
|
||||
|
||||
let pd_req = PDRequest::from_json(&chat_json);
|
||||
assert!(pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_injection_simulation() {
|
||||
// Since we can't test the actual inject_bootstrap_fields function here
|
||||
// (it's private in the router module), we'll test the expected behavior
|
||||
|
||||
// Simulate bootstrap injection for single request
|
||||
let mut single_json = json!({
|
||||
"text": "Hello world",
|
||||
"stream": false,
|
||||
"temperature": 0.7
|
||||
});
|
||||
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill1:8080".to_string(), Some(9000));
|
||||
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
|
||||
// Simulate what inject_bootstrap_fields would do
|
||||
single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url()));
|
||||
single_json["bootstrap_port"] = json!(bootstrap_port);
|
||||
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
|
||||
|
||||
// Verify bootstrap fields are added correctly
|
||||
assert_eq!(single_json["bootstrap_host"], "prefill1");
|
||||
assert_eq!(single_json["bootstrap_port"], json!(Some(9000)));
|
||||
assert!(single_json["bootstrap_room"].is_u64());
|
||||
assert_eq!(single_json["temperature"], 0.7); // Original field preserved
|
||||
|
||||
// Simulate bootstrap injection for batch request
|
||||
let mut batch_json = json!({
|
||||
"text": ["Hello", "World", "Test"],
|
||||
"stream": true
|
||||
});
|
||||
|
||||
let batch_size = 3;
|
||||
let hostname = get_hostname(prefill_worker.url());
|
||||
batch_json["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
|
||||
|
||||
// Verify batch bootstrap fields
|
||||
assert!(batch_json["bootstrap_host"].is_array());
|
||||
assert_eq!(
|
||||
batch_json["bootstrap_host"].as_array().unwrap().len(),
|
||||
batch_size
|
||||
);
|
||||
assert!(batch_json["bootstrap_port"].is_array());
|
||||
assert!(batch_json["bootstrap_room"].is_array());
|
||||
assert_eq!(batch_json["stream"], true); // Original field preserved
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_request_serialization() {
|
||||
// Test that requests can be properly serialized and deserialized
|
||||
let request = json!({
|
||||
"text": "Test prompt",
|
||||
"stream": false,
|
||||
"temperature": 0.7,
|
||||
"max_tokens": 100,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0.5,
|
||||
"bootstrap_host": "prefill1",
|
||||
"bootstrap_port": 9000,
|
||||
"bootstrap_room": 12345u64
|
||||
});
|
||||
|
||||
// Convert to bytes (as would happen in the router)
|
||||
let bytes = serde_json::to_vec(&request).unwrap();
|
||||
|
||||
// Parse back from bytes
|
||||
let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
|
||||
|
||||
// Verify all fields are preserved
|
||||
assert_eq!(parsed["text"], "Test prompt");
|
||||
assert_eq!(parsed["stream"], false);
|
||||
assert_eq!(parsed["temperature"], 0.7);
|
||||
assert_eq!(parsed["max_tokens"], 100);
|
||||
assert_eq!(parsed["bootstrap_host"], "prefill1");
|
||||
assert_eq!(parsed["bootstrap_port"], 9000);
|
||||
assert_eq!(parsed["bootstrap_room"], 12345);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_hostname_extraction() {
|
||||
// Test various URL formats
|
||||
let test_cases = vec![
|
||||
("http://localhost:8080", "localhost"),
|
||||
("http://10.0.0.1:8080", "10.0.0.1"),
|
||||
("https://api.example.com:443", "api.example.com"),
|
||||
("http://prefill-server", "prefill-server"),
|
||||
("http://[::1]:8080", "["), // IPv6 edge case
|
||||
("prefill:8080", "prefill"), // No protocol
|
||||
];
|
||||
|
||||
for (url, expected_hostname) in test_cases {
|
||||
assert_eq!(get_hostname(url), expected_hostname);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_pd_request_edge_cases() {
|
||||
// Test empty request
|
||||
let empty_json = json!({});
|
||||
let pd_req = PDRequest::from_json(&empty_json);
|
||||
assert!(!pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, None);
|
||||
|
||||
// Test request with only stream field
|
||||
let stream_only = json!({
|
||||
"stream": true
|
||||
});
|
||||
let pd_req = PDRequest::from_json(&stream_only);
|
||||
assert!(pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, None);
|
||||
|
||||
// Test request with empty text array
|
||||
let empty_batch = json!({
|
||||
"text": []
|
||||
});
|
||||
let pd_req = PDRequest::from_json(&empty_batch);
|
||||
assert_eq!(pd_req.batch_size, Some(0));
|
||||
|
||||
// Test request with non-array text (should be None)
|
||||
let non_array_text = json!({
|
||||
"text": "single string"
|
||||
});
|
||||
let pd_req = PDRequest::from_json(&non_array_text);
|
||||
assert_eq!(pd_req.batch_size, None);
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Phase 2: Background Load Monitoring Tests
|
||||
// ========================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_background_load_monitoring() {
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::watch;
|
||||
|
||||
// Create a watch channel for testing
|
||||
let (tx, rx) = watch::channel(HashMap::new());
|
||||
|
||||
// Simulate load updates
|
||||
let mut loads = HashMap::new();
|
||||
loads.insert("http://prefill1:8080".to_string(), 10);
|
||||
loads.insert("http://prefill2:8080".to_string(), 20);
|
||||
loads.insert("http://decode1:8080".to_string(), 5);
|
||||
loads.insert("http://decode2:8080".to_string(), 15);
|
||||
|
||||
// Send the loads
|
||||
tx.send(loads.clone()).unwrap();
|
||||
|
||||
// Verify receiver gets the update
|
||||
let received_loads = rx.borrow();
|
||||
assert_eq!(received_loads.get("http://prefill1:8080"), Some(&10));
|
||||
assert_eq!(received_loads.get("http://prefill2:8080"), Some(&20));
|
||||
assert_eq!(received_loads.get("http://decode1:8080"), Some(&5));
|
||||
assert_eq!(received_loads.get("http://decode2:8080"), Some(&15));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_load_monitoring_configuration() {
|
||||
// Test that load monitoring is only enabled for PowerOfTwo policy
|
||||
let policies = vec![
|
||||
(PDSelectionPolicy::Random, false),
|
||||
(PDSelectionPolicy::PowerOfTwo, true),
|
||||
(
|
||||
PDSelectionPolicy::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
},
|
||||
false,
|
||||
),
|
||||
];
|
||||
|
||||
for (policy, should_monitor) in policies {
|
||||
match policy {
|
||||
PDSelectionPolicy::PowerOfTwo => assert!(should_monitor),
|
||||
_ => assert!(!should_monitor),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_watch_channel_behavior() {
|
||||
use std::collections::HashMap;
|
||||
use tokio::sync::watch;
|
||||
|
||||
// Test watch channel's broadcast behavior
|
||||
let (tx, rx1) = watch::channel(HashMap::new());
|
||||
let rx2 = rx1.clone();
|
||||
|
||||
// Initial state - empty map
|
||||
assert!(rx1.borrow().is_empty());
|
||||
assert!(rx2.borrow().is_empty());
|
||||
|
||||
// Update 1
|
||||
let mut loads = HashMap::new();
|
||||
loads.insert("worker1".to_string(), 10);
|
||||
tx.send(loads.clone()).unwrap();
|
||||
|
||||
// Both receivers see the update
|
||||
assert_eq!(rx1.borrow().get("worker1"), Some(&10));
|
||||
assert_eq!(rx2.borrow().get("worker1"), Some(&10));
|
||||
|
||||
// Update 2 - overwrites previous
|
||||
loads.insert("worker1".to_string(), 20);
|
||||
loads.insert("worker2".to_string(), 30);
|
||||
tx.send(loads).unwrap();
|
||||
|
||||
// Both receivers see the latest state
|
||||
assert_eq!(rx1.borrow().get("worker1"), Some(&20));
|
||||
assert_eq!(rx2.borrow().get("worker2"), Some(&30));
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Tests based on bench_one_batch_server.py patterns
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_generate_request_formats() {
|
||||
// Based on bench_one_batch_server.py request patterns
|
||||
|
||||
// Test 1: Batch request with input_ids (most common in benchmarks)
|
||||
let batch_request = json!({
|
||||
"input_ids": [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 16,
|
||||
"ignore_eos": true,
|
||||
},
|
||||
"return_logprob": false,
|
||||
"stream": true
|
||||
});
|
||||
|
||||
let pd_req = PDRequest::from_json(&batch_request);
|
||||
assert!(pd_req.is_stream);
|
||||
assert_eq!(pd_req.batch_size, Some(3));
|
||||
|
||||
// Test 2: Request with return_logprob (critical for PD)
|
||||
let logprob_request = json!({
|
||||
"input_ids": [[1, 2, 3]],
|
||||
"sampling_params": {
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 8,
|
||||
},
|
||||
"return_logprob": true,
|
||||
"stream": false
|
||||
});
|
||||
|
||||
assert_eq!(logprob_request["return_logprob"], true);
|
||||
assert_eq!(logprob_request["stream"], false);
|
||||
|
||||
// Test 3: Large batch sizes from benchmark
|
||||
let batch_sizes = vec![1, 16, 64]; // From bench_one_batch_server.py
|
||||
for bs in batch_sizes {
|
||||
let request = json!({
|
||||
"input_ids": vec![vec![1, 2, 3]; bs],
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
"stream": true
|
||||
});
|
||||
|
||||
let pd_req = PDRequest::from_json(&request);
|
||||
assert_eq!(pd_req.batch_size, Some(bs));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sampling_params_handling() {
|
||||
// Test various sampling parameters from bench_one_batch_server.py
|
||||
let sampling_params_variations = vec![
|
||||
json!({
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 8,
|
||||
"ignore_eos": true
|
||||
}),
|
||||
json!({
|
||||
"temperature": 0.7,
|
||||
"max_new_tokens": 16,
|
||||
"ignore_eos": false,
|
||||
"top_p": 0.9,
|
||||
"frequency_penalty": 0.5
|
||||
}),
|
||||
json!({
|
||||
"temperature": 1.0,
|
||||
"max_new_tokens": 64,
|
||||
"json_schema": "$$ANY$$" // Structured output
|
||||
}),
|
||||
];
|
||||
|
||||
for params in sampling_params_variations {
|
||||
let request = json!({
|
||||
"input_ids": [[1, 2, 3]],
|
||||
"sampling_params": params.clone(),
|
||||
"stream": false
|
||||
});
|
||||
|
||||
// Verify params are preserved
|
||||
assert_eq!(request["sampling_params"], params);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_streaming_response_parsing() {
|
||||
// Test SSE format parsing from streaming responses
|
||||
let sse_chunks = ["data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}",
|
||||
"data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}",
|
||||
"data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}",
|
||||
"data: [DONE]"];
|
||||
|
||||
for chunk in &sse_chunks[..3] {
|
||||
assert!(chunk.starts_with("data: "));
|
||||
let json_str = &chunk[6..]; // Skip "data: "
|
||||
let parsed: serde_json::Value = serde_json::from_str(json_str).unwrap();
|
||||
assert!(parsed["meta_info"]["completion_tokens"].is_u64());
|
||||
}
|
||||
|
||||
// Test [DONE] detection
|
||||
assert_eq!(sse_chunks[3], "data: [DONE]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ttft_calculation() {
|
||||
// Test Time To First Token calculation pattern
|
||||
let first_token_response = json!({
|
||||
"text": "Hello",
|
||||
"meta_info": {
|
||||
"completion_tokens": 1,
|
||||
"finish_reason": null
|
||||
}
|
||||
});
|
||||
|
||||
// TTFT is calculated when completion_tokens == 1
|
||||
assert_eq!(first_token_response["meta_info"]["completion_tokens"], 1);
|
||||
assert!(first_token_response["meta_info"]["finish_reason"].is_null());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_throughput_metrics() {
|
||||
// Test throughput calculation patterns from bench_one_batch_server.py
|
||||
let batch_size = 16;
|
||||
let input_len = 1024;
|
||||
let output_len = 16;
|
||||
let ttft = 0.5; // seconds
|
||||
let total_latency = 2.0; // seconds
|
||||
|
||||
// Input throughput = batch_size * input_len / ttft
|
||||
let input_throughput = (batch_size as f64) * (input_len as f64) / ttft;
|
||||
assert!((input_throughput - 32768.0).abs() < 0.01);
|
||||
|
||||
// Output throughput = batch_size * output_len / (latency - ttft)
|
||||
let output_throughput = (batch_size as f64) * (output_len as f64) / (total_latency - ttft);
|
||||
assert!((output_throughput - 170.67).abs() < 0.01);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_error_response_handling() {
|
||||
// Test error response format from bench_one_batch_server.py
|
||||
let error_response = json!({
|
||||
"error": "Request has failed. Invalid input format."
|
||||
});
|
||||
|
||||
assert!(error_response.get("error").is_some());
|
||||
assert!(error_response["error"].as_str().unwrap().contains("failed"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_structured_output_request() {
|
||||
// Test structured output format (json_schema)
|
||||
let structured_request = json!({
|
||||
"text": "What is the capital of France? Answer in JSON.",
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 64,
|
||||
"json_schema": "$$ANY$$"
|
||||
},
|
||||
"stream": false
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
structured_request["sampling_params"]["json_schema"],
|
||||
"$$ANY$$"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bootstrap_injection_with_benchmark_requests() {
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
|
||||
// Test bootstrap injection with actual benchmark request patterns
|
||||
let mut benchmark_request = json!({
|
||||
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 8,
|
||||
"ignore_eos": true
|
||||
},
|
||||
"return_logprob": true,
|
||||
"stream": true
|
||||
});
|
||||
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
let batch_size = 16;
|
||||
let hostname = get_hostname(prefill_worker.url());
|
||||
|
||||
benchmark_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||
benchmark_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||
benchmark_request["bootstrap_room"] =
|
||||
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
|
||||
|
||||
// Verify bootstrap fields match batch size
|
||||
assert_eq!(
|
||||
benchmark_request["bootstrap_host"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.len(),
|
||||
batch_size
|
||||
);
|
||||
assert_eq!(
|
||||
benchmark_request["bootstrap_port"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.len(),
|
||||
batch_size
|
||||
);
|
||||
assert_eq!(
|
||||
benchmark_request["bootstrap_room"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.len(),
|
||||
batch_size
|
||||
);
|
||||
|
||||
// Verify original fields are preserved
|
||||
assert_eq!(benchmark_request["return_logprob"], true);
|
||||
assert_eq!(benchmark_request["stream"], true);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_server_info_response_format() {
|
||||
// Test server info format expected by bench_one_batch_server.py
|
||||
let server_info = json!({
|
||||
"internal_states": [{
|
||||
"avg_spec_accept_length": 3.5,
|
||||
"last_gen_throughput": 2048.5,
|
||||
"load": 16
|
||||
}],
|
||||
"prefill": [
|
||||
{"url": "http://prefill1:8080", "load": 10},
|
||||
{"url": "http://prefill2:8080", "load": 20}
|
||||
],
|
||||
"decode": [
|
||||
{"url": "http://decode1:8080", "load": 5},
|
||||
{"url": "http://decode2:8080", "load": 15}
|
||||
]
|
||||
});
|
||||
|
||||
// Verify structure matches what benchmark expects
|
||||
assert!(server_info["internal_states"][0]["avg_spec_accept_length"].is_f64());
|
||||
assert!(server_info["internal_states"][0]["last_gen_throughput"].is_f64());
|
||||
assert!(server_info["prefill"].is_array());
|
||||
assert!(server_info["decode"].is_array());
|
||||
}
|
||||
|
||||
// ========================================================================
|
||||
// Comprehensive Endpoint Coverage Test
|
||||
// ========================================================================
|
||||
|
||||
#[test]
|
||||
fn test_pd_endpoints_coverage() {
|
||||
// Document all endpoints from Python mini_lb.py and verify implementation status
|
||||
let implemented_endpoints = vec![
|
||||
("/health", "GET", true),
|
||||
("/health_generate", "GET", true), // Note: Python uses POST, we use GET
|
||||
("/get_server_info", "GET", true),
|
||||
("/v1/models", "GET", true),
|
||||
("/get_model_info", "GET", true),
|
||||
("/generate", "POST", true),
|
||||
("/v1/chat/completions", "POST", true),
|
||||
("/v1/completions", "POST", true),
|
||||
("/flush_cache", "POST", true),
|
||||
("/get_loads", "GET", true),
|
||||
("/register", "POST", false), // NOT IMPLEMENTED - needs dynamic worker management
|
||||
];
|
||||
|
||||
let implemented_count = implemented_endpoints
|
||||
.iter()
|
||||
.filter(|(_, _, impl_status)| *impl_status)
|
||||
.count();
|
||||
let total_count = implemented_endpoints.len();
|
||||
|
||||
// We've implemented 10 out of 11 endpoints (register is not needed for Phase 1/2)
|
||||
assert_eq!(implemented_count, 10);
|
||||
assert_eq!(total_count, 11);
|
||||
|
||||
// Document the missing endpoint
|
||||
let missing: Vec<_> = implemented_endpoints
|
||||
.iter()
|
||||
.filter(|(_, _, impl_status)| !impl_status)
|
||||
.map(|(endpoint, method, _)| format!("{} {}", method, endpoint))
|
||||
.collect();
|
||||
|
||||
assert_eq!(missing, vec!["POST /register"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_large_batch_bootstrap_injection() {
|
||||
// Test bootstrap injection performance with very large batches
|
||||
// This simulates the bench_one_batch_server.py scenario
|
||||
let large_batch_sizes = vec![1024, 4096, 8192];
|
||||
|
||||
for batch_size in large_batch_sizes {
|
||||
let start = std::time::Instant::now();
|
||||
|
||||
// Simulate a large batch request
|
||||
let mut large_batch_request = json!({
|
||||
"input_ids": vec![vec![1, 2, 3, 4]; batch_size],
|
||||
"sampling_params": {
|
||||
"temperature": 0.0,
|
||||
"max_new_tokens": 16,
|
||||
},
|
||||
"stream": true
|
||||
});
|
||||
|
||||
// Create a prefill worker to simulate injection
|
||||
let prefill_worker =
|
||||
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
|
||||
|
||||
// Extract bootstrap port from worker type
|
||||
let bootstrap_port = match prefill_worker.worker_type() {
|
||||
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
|
||||
_ => None,
|
||||
};
|
||||
let hostname = get_hostname(prefill_worker.url());
|
||||
|
||||
large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
|
||||
large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
|
||||
large_batch_request["bootstrap_room"] = json!((0..batch_size)
|
||||
.map(|_| rand::random::<u64>())
|
||||
.collect::<Vec<_>>());
|
||||
|
||||
let elapsed = start.elapsed();
|
||||
|
||||
// Verify bootstrap fields are correctly sized
|
||||
assert_eq!(
|
||||
large_batch_request["bootstrap_host"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.len(),
|
||||
batch_size
|
||||
);
|
||||
assert_eq!(
|
||||
large_batch_request["bootstrap_port"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.len(),
|
||||
batch_size
|
||||
);
|
||||
assert_eq!(
|
||||
large_batch_request["bootstrap_room"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.len(),
|
||||
batch_size
|
||||
);
|
||||
|
||||
// Bootstrap injection should be reasonably fast even for large batches
|
||||
println!(
|
||||
"Bootstrap injection for batch_size {} took {:?}",
|
||||
batch_size, elapsed
|
||||
);
|
||||
assert!(
|
||||
elapsed.as_millis() < 1000,
|
||||
"Bootstrap injection took too long for batch size {}",
|
||||
batch_size
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_payload_size_calculation() {
|
||||
// Test payload size estimation for bench_one_batch_server.py scenarios
|
||||
let test_cases = vec![
|
||||
(1, 1024, 16), // Small batch
|
||||
(16, 1024, 16), // Medium batch
|
||||
(64, 1024, 16), // Large batch
|
||||
(8192, 4096, 5), // Benchmark scenario
|
||||
];
|
||||
|
||||
for (batch_size, input_len, _output_len) in test_cases {
|
||||
// Estimate payload size (rough calculation)
|
||||
// Each token is ~4 bytes (i32), plus JSON overhead
|
||||
let tokens_size = batch_size * input_len * 4; // 4 bytes per token
|
||||
let json_overhead = batch_size * 100; // ~100 bytes overhead per request
|
||||
let total_size = tokens_size + json_overhead;
|
||||
|
||||
println!(
|
||||
"Batch size: {}, Input len: {}, Estimated payload: {} MB",
|
||||
batch_size,
|
||||
input_len,
|
||||
total_size / (1024 * 1024)
|
||||
);
|
||||
|
||||
// For the benchmark case (8192, 4096), this should be ~134 MB
|
||||
if batch_size == 8192 && input_len == 4096 {
|
||||
assert!(
|
||||
total_size > 100 * 1024 * 1024,
|
||||
"Benchmark payload should be > 100MB"
|
||||
);
|
||||
assert!(
|
||||
total_size < 200 * 1024 * 1024,
|
||||
"Benchmark payload should be < 200MB"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_type_to_pd_selection_policy_mapping() {
|
||||
// Test that PDSelectionPolicy doesn't include RoundRobin
|
||||
let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware
|
||||
assert_eq!(
|
||||
pd_policy_count, 3,
|
||||
"PDSelectionPolicy should have exactly 3 variants"
|
||||
);
|
||||
|
||||
// Verify that each PDSelectionPolicy variant can be created
|
||||
let _random = PDSelectionPolicy::Random;
|
||||
let _po2 = PDSelectionPolicy::PowerOfTwo;
|
||||
let _cache_aware = PDSelectionPolicy::CacheAware {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 32,
|
||||
balance_rel_threshold: 1.1,
|
||||
};
|
||||
}
|
||||
}
|
||||
320
sgl-router/tests/tokenizer_integration.rs
Normal file
320
sgl-router/tests/tokenizer_integration.rs
Normal file
@@ -0,0 +1,320 @@
|
||||
//! Integration tests for tokenizers using real tokenizer data
|
||||
//!
|
||||
//! These tests download the TinyLlama tokenizer from HuggingFace to verify our tokenizer
|
||||
//! implementation works correctly with real-world tokenizer files.
|
||||
|
||||
mod common;
|
||||
use common::{ensure_tokenizer_cached, EXPECTED_HASHES, TEST_PROMPTS};
|
||||
|
||||
use sglang_router_rs::tokenizer::{
|
||||
factory, huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream,
|
||||
traits::*,
|
||||
};
|
||||
use std::sync::Arc;
|
||||
|
||||
const LONG_TEST_PROMPTS: [(&str, &str); 6] = [
|
||||
("Tell me about the following text.", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat."),
|
||||
("Tell me about the following text.", "Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."),
|
||||
("Tell me about the following text.", "Sed ut perspiciatis unde omnis iste natus error sit voluptatem accusantium doloremque laudantium, totam rem aperiam, eaque ipsa quae ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo. Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit, sed quia consequuntur magni dolores eos qui ratione voluptatem sequi nesciunt."),
|
||||
("Tell me about the following text.", "Neque porro quisquam est, qui dolorem ipsum quia dolor sit amet, consectetur, adipisci velit, sed quia non numquam eius modi tempora incidunt ut labore et dolore magnam aliquam quaerat voluptatem."),
|
||||
// Tennis-themed prompt for variety
|
||||
("Tell me about the following text.", "In the ancient realm of Tennisia, the very magic of the land is drawn from the sport itself. Forehands light the skies, backhands carve the earth, and serves rumble like thunder across kingdoms. At the center of this balance lie four sacred Grand Slam relics: the Sapphire Trophy of Melbourne, the Emerald Chalice of Paris, the Ruby Crown of London, and the Diamond Orb of New York. Together, they keep the game's spirit alive.
|
||||
But the relics are scattered, guarded by champions of legendary skill. The first is the Fire King of Clay, ruler of the crimson courts, whose topspin arcs blaze high and heavy, scorching all who dare stand across from him. The second is the Tempest Trickster, master of the baseline fortress, whose footwork and precision can turn back any storm, and whose returns arrive as if pulled by invisible strings. The third is the Shadow-Dancer of the Highlands, a tactician who thrives in the long rallies of twilight, changing pace and spin until opponents lose their rhythm. The fourth and final guardian is a towering Diamond Titan, a net-charging colossus whose volleys shatter the air itself.
|
||||
Into this arena of gods steps the Silver-Wristed Knight — a player of impossible grace, whose game is an art form. His quest: to claim each relic not for glory, but to restore harmony to the rankings of the realm.
|
||||
He travels across the Kingdom of Clay, where the points stretch like marathons and the air tastes of iron; through the Grasslands of London, where the ball skids low and the margins are razor-thin; over the Hard Courts of the East, where rallies turn into duels of endurance; and finally to the Cathedral of Lights in New York, where night matches burn with fevered energy.
|
||||
Each battle is played under enchanted floodlights, the lines patrolled by spectral line judges whose calls are final. The crowd's roar swells with every break point, and the Silver-Wristed Knight's racket glows brightest when the match teeters at deuce. There are moments when doubt grips him — when his serve falters or his touch deserts him — but each challenge teaches a new stroke, culminating in the legendary Forehand of Dawn.
|
||||
When the last relic is claimed, he stands not as a conqueror but as a custodian of the game, knowing that rivalries forge the very magic he protects. The balance is restored — until the next season begins."),
|
||||
// Emoji stress test
|
||||
("Tell me about the following text.", "😀😃😄😁😆🥹😅😂🤣🥲☺️😊😇🙂🙃😉🤩😎 🤪🥳🤓🙄🤪😵👻")
|
||||
];
|
||||
|
||||
fn compute_hashes_for_tokenizer<E: Encoder>(tokenizer: &E, prompts: &[&str]) -> Vec<u64> {
|
||||
prompts
|
||||
.iter()
|
||||
.map(|&prompt| {
|
||||
tokenizer
|
||||
.encode(prompt)
|
||||
.expect("Failed to encode prompt")
|
||||
.get_hash()
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_huggingface_tokenizer_hashes() {
|
||||
let tokenizer_path = ensure_tokenizer_cached();
|
||||
let tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
|
||||
.expect("Failed to load HuggingFace tokenizer");
|
||||
|
||||
let prompt_hashes = compute_hashes_for_tokenizer(&tokenizer, &TEST_PROMPTS);
|
||||
|
||||
println!(
|
||||
"HF Tokenizer: {:?}\nComputed Hashes: {:?}\nExpected Hashes: {:?}",
|
||||
tokenizer_path, prompt_hashes, EXPECTED_HASHES
|
||||
);
|
||||
|
||||
assert_eq!(prompt_hashes, EXPECTED_HASHES);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_tokenizer_encode_decode_lifecycle() {
|
||||
let tokenizer_path = ensure_tokenizer_cached();
|
||||
let tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
|
||||
.expect("Failed to load HuggingFace tokenizer");
|
||||
|
||||
for prompt in TEST_PROMPTS.iter() {
|
||||
let encoding = tokenizer.encode(prompt).expect("Failed to encode prompt");
|
||||
|
||||
let decoded = tokenizer
|
||||
.decode(encoding.token_ids(), false)
|
||||
.expect("Failed to decode token_ids");
|
||||
|
||||
assert_eq!(decoded, *prompt, "Encode-decode mismatch for: {}", prompt);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_sequence_operations() {
|
||||
let tokenizer_path = ensure_tokenizer_cached();
|
||||
let tokenizer = Arc::new(
|
||||
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
|
||||
.expect("Failed to load tokenizer"),
|
||||
);
|
||||
|
||||
for prompt in TEST_PROMPTS.iter() {
|
||||
let encoding = tokenizer.encode(prompt).expect("Failed to encode prompt");
|
||||
|
||||
// Test Sequence with append_text
|
||||
let mut sequence = Sequence::new(tokenizer.clone());
|
||||
sequence.append_text(prompt).expect("Failed to append text");
|
||||
|
||||
assert_eq!(
|
||||
sequence.len(),
|
||||
encoding.token_ids().len(),
|
||||
"Sequence length mismatch"
|
||||
);
|
||||
assert_eq!(sequence.text().unwrap(), *prompt, "Sequence text mismatch");
|
||||
|
||||
// Test incremental decoding with append_token
|
||||
let mut decoder = Sequence::new(tokenizer.clone());
|
||||
let mut output = String::new();
|
||||
|
||||
for token_id in encoding.token_ids() {
|
||||
let text = decoder
|
||||
.append_token(*token_id)
|
||||
.expect("Failed to append token");
|
||||
output.push_str(&text);
|
||||
}
|
||||
|
||||
assert_eq!(decoder.len(), sequence.len(), "Decoder length mismatch");
|
||||
assert_eq!(
|
||||
decoder.token_ids(),
|
||||
sequence.token_ids(),
|
||||
"Token IDs mismatch"
|
||||
);
|
||||
assert_eq!(output, *prompt, "Incremental decode mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_decode_stream() {
|
||||
let tokenizer_path = ensure_tokenizer_cached();
|
||||
let tokenizer = Arc::new(
|
||||
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
|
||||
.expect("Failed to load tokenizer"),
|
||||
);
|
||||
|
||||
for prompt in TEST_PROMPTS.iter() {
|
||||
let encoding = tokenizer.encode(prompt).expect("Failed to encode prompt");
|
||||
|
||||
let mut decoder = DecodeStream::new(tokenizer.clone(), &[], false);
|
||||
let mut output = String::new();
|
||||
|
||||
for token_id in encoding.token_ids() {
|
||||
if let Some(text) = decoder.step(*token_id).expect("Failed to decode token") {
|
||||
output.push_str(&text);
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(output, *prompt, "DecodeStream output mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_long_sequence_incremental_decode_with_prefill() {
|
||||
let tokenizer_path = ensure_tokenizer_cached();
|
||||
let tokenizer = Arc::new(
|
||||
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
|
||||
.expect("Failed to load tokenizer"),
|
||||
);
|
||||
|
||||
for (input_text, output_text) in LONG_TEST_PROMPTS.iter() {
|
||||
let input_encoding = tokenizer
|
||||
.encode(input_text)
|
||||
.expect("Failed to encode input");
|
||||
|
||||
let output_encoding = tokenizer
|
||||
.encode(output_text)
|
||||
.expect("Failed to encode output");
|
||||
|
||||
let mut decoder = DecodeStream::new(tokenizer.clone(), input_encoding.token_ids(), false);
|
||||
|
||||
let mut output = String::new();
|
||||
for token_id in output_encoding.token_ids() {
|
||||
if let Some(text) = decoder.step(*token_id).expect("Failed to decode token") {
|
||||
output.push_str(&text);
|
||||
}
|
||||
}
|
||||
|
||||
assert_eq!(output.trim(), *output_text, "Long sequence decode mismatch");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_stop_sequence_decoder() {
|
||||
let tokenizer_path = ensure_tokenizer_cached();
|
||||
let tokenizer = Arc::new(
|
||||
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
|
||||
.expect("Failed to load tokenizer"),
|
||||
);
|
||||
|
||||
// Test with various stop sequences
|
||||
let test_cases = vec![
|
||||
(
|
||||
"Hello world! Stop here. Continue after.",
|
||||
"Stop",
|
||||
"Hello world! ",
|
||||
),
|
||||
("Testing stop sequences.", ".", "Testing stop sequences"),
|
||||
("No stop sequence here", "xyz", "No stop sequence here"),
|
||||
];
|
||||
|
||||
for (input, stop_seq, expected) in test_cases {
|
||||
let config = StopSequenceConfig::default().with_stop_sequence(stop_seq);
|
||||
|
||||
let mut decoder = StopSequenceDecoder::new(tokenizer.clone(), config, false);
|
||||
|
||||
let encoding = tokenizer.encode(input).expect("Failed to encode");
|
||||
let mut output = String::new();
|
||||
let mut stopped = false;
|
||||
|
||||
for token_id in encoding.token_ids() {
|
||||
match decoder.process_token(*token_id).unwrap() {
|
||||
SequenceDecoderOutput::Text(text) => output.push_str(&text),
|
||||
SequenceDecoderOutput::StoppedWithText(text) => {
|
||||
output.push_str(&text);
|
||||
stopped = true;
|
||||
break;
|
||||
}
|
||||
SequenceDecoderOutput::Stopped => {
|
||||
stopped = true;
|
||||
break;
|
||||
}
|
||||
SequenceDecoderOutput::Held => {}
|
||||
}
|
||||
}
|
||||
|
||||
if !stopped {
|
||||
// Flush any remaining text
|
||||
if let SequenceDecoderOutput::Text(text) = decoder.flush() {
|
||||
output.push_str(&text);
|
||||
}
|
||||
}
|
||||
|
||||
println!(
|
||||
"Input: '{}', Stop: '{}', Output: '{}', Expected: '{}'",
|
||||
input, stop_seq, output, expected
|
||||
);
|
||||
|
||||
// The test should check if output starts with expected
|
||||
// since stop sequences might not be perfectly aligned with token boundaries
|
||||
assert!(
|
||||
output.starts_with(expected) || output == input,
|
||||
"Stop sequence test failed"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_factory_creation() {
|
||||
// Test factory creation method
|
||||
let tokenizer_path = ensure_tokenizer_cached();
|
||||
let tokenizer = factory::create_tokenizer(tokenizer_path.to_str().unwrap())
|
||||
.expect("Failed to create tokenizer via factory");
|
||||
|
||||
let encoding = tokenizer.encode(TEST_PROMPTS[0]).expect("Failed to encode");
|
||||
|
||||
let decoded = tokenizer
|
||||
.decode(encoding.token_ids(), false)
|
||||
.expect("Failed to decode");
|
||||
|
||||
assert_eq!(decoded, TEST_PROMPTS[0]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_batch_encoding() {
|
||||
let tokenizer_path = ensure_tokenizer_cached();
|
||||
let tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
|
||||
.expect("Failed to load tokenizer");
|
||||
|
||||
let encodings = tokenizer
|
||||
.encode_batch(&TEST_PROMPTS)
|
||||
.expect("Failed to batch encode");
|
||||
|
||||
assert_eq!(encodings.len(), TEST_PROMPTS.len());
|
||||
|
||||
for (i, encoding) in encodings.iter().enumerate() {
|
||||
let decoded = tokenizer
|
||||
.decode(encoding.token_ids(), false)
|
||||
.expect("Failed to decode");
|
||||
assert_eq!(decoded, TEST_PROMPTS[i]);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_special_tokens() {
|
||||
use sglang_router_rs::tokenizer::traits::Tokenizer as TokenizerTrait;
|
||||
|
||||
let tokenizer_path = ensure_tokenizer_cached();
|
||||
let tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
|
||||
.expect("Failed to load tokenizer");
|
||||
|
||||
let special_tokens = tokenizer.get_special_tokens();
|
||||
|
||||
// TinyLlama should have at least BOS and EOS tokens
|
||||
assert!(special_tokens.bos_token.is_some());
|
||||
assert!(special_tokens.eos_token.is_some());
|
||||
|
||||
println!("Special tokens: {:?}", special_tokens);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_thread_safety() {
|
||||
use std::thread;
|
||||
|
||||
let tokenizer_path = ensure_tokenizer_cached();
|
||||
let tokenizer = Arc::new(
|
||||
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
|
||||
.expect("Failed to load tokenizer"),
|
||||
);
|
||||
|
||||
let handles: Vec<_> = TEST_PROMPTS
|
||||
.iter()
|
||||
.map(|&prompt| {
|
||||
let tokenizer_clone = tokenizer.clone();
|
||||
thread::spawn(move || {
|
||||
let encoding = tokenizer_clone
|
||||
.encode(prompt)
|
||||
.expect("Failed to encode in thread");
|
||||
let decoded = tokenizer_clone
|
||||
.decode(encoding.token_ids(), false)
|
||||
.expect("Failed to decode in thread");
|
||||
assert_eq!(decoded, prompt);
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
for handle in handles {
|
||||
handle.join().expect("Thread panicked");
|
||||
}
|
||||
}
|
||||
183
sgl-router/tests/tool_parser_deepseek.rs
Normal file
183
sgl-router/tests/tool_parser_deepseek.rs
Normal file
@@ -0,0 +1,183 @@
|
||||
//! DeepSeek V3 Parser Integration Tests
|
||||
|
||||
use sglang_router_rs::tool_parser::{DeepSeekParser, ParseState, StreamResult, ToolParser};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deepseek_complete_parsing() {
|
||||
let parser = DeepSeekParser::new();
|
||||
|
||||
// Test single tool call
|
||||
let input = r#"Let me help you with that.
|
||||
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"location": "Tokyo", "units": "celsius"}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>
|
||||
The weather in Tokyo is..."#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
|
||||
// Verify arguments
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["location"], "Tokyo");
|
||||
assert_eq!(args["units"], "celsius");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deepseek_multiple_tools() {
|
||||
let parser = DeepSeekParser::new();
|
||||
|
||||
let input = r#"<|tool▁calls▁begin|>
|
||||
<|tool▁call▁begin|>function<|tool▁sep|>search
|
||||
```json
|
||||
{"query": "rust programming"}
|
||||
```<|tool▁call▁end|>
|
||||
<|tool▁call▁begin|>function<|tool▁sep|>translate
|
||||
```json
|
||||
{"text": "Hello World", "to": "ja"}
|
||||
```<|tool▁call▁end|>
|
||||
<|tool▁calls▁end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "translate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deepseek_streaming() {
|
||||
let parser = DeepSeekParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Simulate streaming chunks
|
||||
let chunks = vec![
|
||||
"<|tool▁calls▁begin|><|tool▁call▁begin|>",
|
||||
"function<|tool▁sep|>get_weather\n",
|
||||
"```json\n",
|
||||
r#"{"location": "#,
|
||||
r#""Beijing", "#,
|
||||
r#""units": "metric"}"#,
|
||||
"\n```<|tool▁call▁end|><|tool▁calls▁end|>",
|
||||
];
|
||||
|
||||
let mut found_name = false;
|
||||
let mut found_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolName { name, .. } => {
|
||||
assert_eq!(name, "get_weather");
|
||||
found_name = true;
|
||||
}
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
found_complete = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(found_name || found_complete);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deepseek_nested_json() {
|
||||
let parser = DeepSeekParser::new();
|
||||
|
||||
let input = r#"<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>process
|
||||
```json
|
||||
{
|
||||
"data": {
|
||||
"nested": {
|
||||
"deep": [1, 2, 3]
|
||||
}
|
||||
}
|
||||
}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert!(args["data"]["nested"]["deep"].is_array());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_deepseek_format_detection() {
|
||||
let parser = DeepSeekParser::new();
|
||||
|
||||
// Should detect DeepSeek format
|
||||
assert!(parser.detect_format("<|tool▁calls▁begin|>"));
|
||||
assert!(parser.detect_format("text with <|tool▁calls▁begin|> marker"));
|
||||
|
||||
// Should not detect other formats
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
assert!(!parser.detect_format("<tool_call>"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deepseek_malformed_json_handling() {
|
||||
let parser = DeepSeekParser::new();
|
||||
|
||||
// Malformed JSON should be skipped
|
||||
let input = r#"<|tool▁calls▁begin|>
|
||||
<|tool▁call▁begin|>function<|tool▁sep|>broken
|
||||
```json
|
||||
{invalid json}
|
||||
```<|tool▁call▁end|>
|
||||
<|tool▁call▁begin|>function<|tool▁sep|>valid
|
||||
```json
|
||||
{"key": "value"}
|
||||
```<|tool▁call▁end|>
|
||||
<|tool▁calls▁end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
// Only the valid tool call should be parsed
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "valid");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_normal_text_extraction() {
|
||||
let parser = DeepSeekParser::new();
|
||||
|
||||
// Python extracts text before tool calls as normal_text
|
||||
let input = r#"Let me help you with that.
|
||||
<|tool▁calls▁begin|><|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"location": "Tokyo"}
|
||||
```<|tool▁call▁end|><|tool▁calls▁end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
|
||||
// TODO: Verify normal text extraction when parser returns it
|
||||
// In Python: normal_text = "Let me help you with that."
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_tool_calls() {
|
||||
let parser = DeepSeekParser::new();
|
||||
|
||||
let input = r#"<|tool▁calls▁begin|>
|
||||
<|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"location": "Tokyo"}
|
||||
```<|tool▁call▁end|>
|
||||
<|tool▁call▁begin|>function<|tool▁sep|>get_weather
|
||||
```json
|
||||
{"location": "Paris"}
|
||||
```<|tool▁call▁end|>
|
||||
<|tool▁calls▁end|><|end▁of▁sentence|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "get_weather");
|
||||
}
|
||||
330
sgl-router/tests/tool_parser_edge_cases.rs
Normal file
330
sgl-router/tests/tool_parser_edge_cases.rs
Normal file
@@ -0,0 +1,330 @@
|
||||
//! Edge Cases and Error Handling Tests
|
||||
//!
|
||||
//! Tests for malformed input, edge cases, and error recovery
|
||||
|
||||
use sglang_router_rs::tool_parser::{
|
||||
JsonParser, MistralParser, ParseState, ParserRegistry, PythonicParser, QwenParser,
|
||||
StreamResult, ToolParser,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_empty_input() {
|
||||
let registry = ParserRegistry::new();
|
||||
let parsers = vec!["json", "mistral", "qwen", "pythonic", "llama"];
|
||||
|
||||
for parser_name in parsers {
|
||||
let parser = registry
|
||||
.get_parser(&format!("test-{}", parser_name))
|
||||
.unwrap();
|
||||
let result = parser.parse_complete("").await.unwrap();
|
||||
assert_eq!(
|
||||
result.len(),
|
||||
0,
|
||||
"Parser {} should return empty for empty input",
|
||||
parser_name
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_plain_text_no_tools() {
|
||||
let plain_text = "This is just a regular response with no tool calls whatsoever.";
|
||||
|
||||
let json_parser = JsonParser::new();
|
||||
assert_eq!(
|
||||
json_parser.parse_complete(plain_text).await.unwrap().len(),
|
||||
0
|
||||
);
|
||||
|
||||
let mistral_parser = MistralParser::new();
|
||||
assert_eq!(
|
||||
mistral_parser
|
||||
.parse_complete(plain_text)
|
||||
.await
|
||||
.unwrap()
|
||||
.len(),
|
||||
0
|
||||
);
|
||||
|
||||
let qwen_parser = QwenParser::new();
|
||||
assert_eq!(
|
||||
qwen_parser.parse_complete(plain_text).await.unwrap().len(),
|
||||
0
|
||||
);
|
||||
|
||||
let pythonic_parser = PythonicParser::new();
|
||||
assert_eq!(
|
||||
pythonic_parser
|
||||
.parse_complete(plain_text)
|
||||
.await
|
||||
.unwrap()
|
||||
.len(),
|
||||
0
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_incomplete_json() {
|
||||
let json_parser = JsonParser::new();
|
||||
|
||||
let incomplete_cases = vec![
|
||||
r#"{"name": "test""#, // Missing closing brace
|
||||
r#"{"name": "test", "arguments":"#, // Incomplete arguments
|
||||
r#"{"name": "test", "arguments": {"#, // Incomplete nested object
|
||||
];
|
||||
|
||||
for input in incomplete_cases {
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(
|
||||
result.len(),
|
||||
0,
|
||||
"Should not parse incomplete JSON: {}",
|
||||
input
|
||||
);
|
||||
}
|
||||
|
||||
// This case might actually parse because [{"name": "test"}] is complete
|
||||
// The trailing comma suggests more items but the first item is valid
|
||||
let _result = json_parser
|
||||
.parse_complete(r#"[{"name": "test"},"#)
|
||||
.await
|
||||
.unwrap();
|
||||
// This could parse the first element or return empty - implementation dependent
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_malformed_mistral() {
|
||||
let parser = MistralParser::new();
|
||||
|
||||
let malformed_cases = vec![
|
||||
"[TOOL_CALLS]", // Missing array
|
||||
"[TOOL_CALLS] {", // Not an array
|
||||
"[TOOL_CALLS] [", // Incomplete array
|
||||
"[TOOL_CALLS] [{]", // Invalid JSON in array
|
||||
"[TOOL_CALLS] [{\"name\": }]", // Invalid value
|
||||
];
|
||||
|
||||
for input in malformed_cases {
|
||||
// Parser might return error or empty vec for malformed input
|
||||
if let Ok(result) = parser.parse_complete(input).await {
|
||||
assert_eq!(
|
||||
result.len(),
|
||||
0,
|
||||
"Should not parse malformed Mistral: {}",
|
||||
input
|
||||
);
|
||||
}
|
||||
// Error is also acceptable for malformed input
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_missing_required_fields() {
|
||||
let json_parser = JsonParser::new();
|
||||
|
||||
// Missing name field
|
||||
let input = r#"{"arguments": {"x": 1}}"#;
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0, "Should not parse without name field");
|
||||
|
||||
// Name is not a string
|
||||
let input = r#"{"name": 123, "arguments": {}}"#;
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0, "Should not parse with non-string name");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_very_long_strings() {
|
||||
let json_parser = JsonParser::new();
|
||||
|
||||
let long_string = "x".repeat(10000);
|
||||
let input = format!(
|
||||
r#"{{"name": "test", "arguments": {{"data": "{}"}}}}"#,
|
||||
long_string
|
||||
);
|
||||
|
||||
let result = json_parser.parse_complete(&input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["data"].as_str().unwrap().len(), 10000);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unicode_edge_cases() {
|
||||
let json_parser = JsonParser::new();
|
||||
|
||||
// Various Unicode characters including emojis, CJK, RTL text
|
||||
let input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍 مرحبا עולם"}}"#;
|
||||
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["text"], "Hello 世界 🌍 مرحبا עולם");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_nested_brackets_in_strings() {
|
||||
// Test that parsers correctly handle brackets within string literals
|
||||
|
||||
let mistral_parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"text": "Array: [1, 2, 3]"}}]"#;
|
||||
let result = mistral_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["text"], "Array: [1, 2, 3]");
|
||||
|
||||
let pythonic_parser = PythonicParser::new();
|
||||
let input = r#"[echo(text="List: [a, b, c]")]"#;
|
||||
let result = pythonic_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["text"], "List: [a, b, c]");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_formats_in_text() {
|
||||
// Test that parsers don't get confused by other formats in the text
|
||||
|
||||
let json_parser = JsonParser::new();
|
||||
let input = r#"
|
||||
Here's some text with [TOOL_CALLS] that shouldn't trigger.
|
||||
{"name": "actual_tool", "arguments": {}}
|
||||
And some more text with <tool_call> tags.
|
||||
"#;
|
||||
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "actual_tool");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_escaped_characters() {
|
||||
let json_parser = JsonParser::new();
|
||||
|
||||
let input = r#"{"name": "write", "arguments": {"content": "Line 1\nLine 2\r\nLine 3\tTabbed\\Backslash\"Quote"}}"#;
|
||||
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
let content = args["content"].as_str().unwrap();
|
||||
assert!(content.contains('\n'));
|
||||
assert!(content.contains('\t'));
|
||||
assert!(content.contains('\\'));
|
||||
assert!(content.contains('"'));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_numeric_edge_cases() {
|
||||
let json_parser = JsonParser::new();
|
||||
|
||||
let input = r#"{
|
||||
"name": "calculate",
|
||||
"arguments": {
|
||||
"int": 42,
|
||||
"float": 123.456,
|
||||
"scientific": 1.23e-4,
|
||||
"negative": -999,
|
||||
"zero": 0,
|
||||
"large": 9007199254740991
|
||||
}
|
||||
}"#;
|
||||
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["int"], 42);
|
||||
assert_eq!(args["float"], 123.456);
|
||||
assert_eq!(args["scientific"], 0.000123);
|
||||
assert_eq!(args["negative"], -999);
|
||||
assert_eq!(args["zero"], 0);
|
||||
assert_eq!(args["large"], 9007199254740991i64);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_null_and_boolean_values() {
|
||||
let json_parser = JsonParser::new();
|
||||
|
||||
let input = r#"{
|
||||
"name": "configure",
|
||||
"arguments": {
|
||||
"enabled": true,
|
||||
"disabled": false,
|
||||
"optional": null
|
||||
}
|
||||
}"#;
|
||||
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["enabled"], true);
|
||||
assert_eq!(args["disabled"], false);
|
||||
assert_eq!(args["optional"], serde_json::Value::Null);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_partial_token_at_buffer_boundary() {
|
||||
let parser = QwenParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Test case that would fail with the bug:
|
||||
// Send exactly "<tool" which is a 5-character prefix of "<tool_call>\n"
|
||||
let result = parser.parse_incremental("<tool", &mut state).await.unwrap();
|
||||
assert!(matches!(result, StreamResult::Incomplete));
|
||||
assert_eq!(state.buffer, "<tool");
|
||||
|
||||
// Complete the token
|
||||
let result = parser
|
||||
.parse_incremental(
|
||||
"_call>\n{\"name\": \"test\", \"arguments\": {}}\n</tool_call>",
|
||||
&mut state,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should successfully parse after completing
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "test");
|
||||
}
|
||||
_ => {
|
||||
// In Phase 2 simplified streaming, might get Incomplete
|
||||
// The important thing is it didn't fail to recognize the partial token
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_exact_prefix_lengths() {
|
||||
let parser = QwenParser::new();
|
||||
|
||||
// Test various exact prefix lengths that would be missed by exclusive range
|
||||
let test_cases = vec![
|
||||
("<", 1), // 1-char prefix
|
||||
("<t", 2), // 2-char prefix
|
||||
("<tool", 5), // 5-char prefix (the main bug case)
|
||||
("<tool_call", 10), // 10-char prefix
|
||||
("<tool_call>", 11), // 11-char prefix (full start without \n)
|
||||
];
|
||||
|
||||
for (prefix, expected_len) in test_cases {
|
||||
let mut state = ParseState::new();
|
||||
let result = parser.parse_incremental(prefix, &mut state).await.unwrap();
|
||||
assert!(
|
||||
matches!(result, StreamResult::Incomplete),
|
||||
"Prefix '{}' (len {}) should be incomplete",
|
||||
prefix,
|
||||
expected_len
|
||||
);
|
||||
assert_eq!(
|
||||
state.buffer, prefix,
|
||||
"Buffer should contain the prefix '{}'",
|
||||
prefix
|
||||
);
|
||||
}
|
||||
}
|
||||
194
sgl-router/tests/tool_parser_glm4_moe.rs
Normal file
194
sgl-router/tests/tool_parser_glm4_moe.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
//! GLM-4 MoE Parser Integration Tests
|
||||
|
||||
use sglang_router_rs::tool_parser::{Glm4MoeParser, ParseState, StreamResult, ToolParser};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_glm4_complete_parsing() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
|
||||
// Test single tool call
|
||||
let input = r#"Let me search for that.
|
||||
<tool_call>get_weather
|
||||
<arg_key>city</arg_key>
|
||||
<arg_value>Beijing</arg_value>
|
||||
<arg_key>date</arg_key>
|
||||
<arg_value>2024-12-25</arg_value>
|
||||
</tool_call>
|
||||
The weather will be..."#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
|
||||
// Verify arguments
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["city"], "Beijing");
|
||||
assert_eq!(args["date"], "2024-12-25");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_glm4_multiple_tools() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
|
||||
let input = r#"<tool_call>search
|
||||
<arg_key>query</arg_key>
|
||||
<arg_value>rust tutorials</arg_value>
|
||||
</tool_call>
|
||||
<tool_call>translate
|
||||
<arg_key>text</arg_key>
|
||||
<arg_value>Hello World</arg_value>
|
||||
<arg_key>target_lang</arg_key>
|
||||
<arg_value>zh</arg_value>
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "translate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_glm4_type_conversion() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
|
||||
// Test various value types
|
||||
let input = r#"<tool_call>process
|
||||
<arg_key>count</arg_key>
|
||||
<arg_value>42</arg_value>
|
||||
<arg_key>rate</arg_key>
|
||||
<arg_value>1.5</arg_value>
|
||||
<arg_key>enabled</arg_key>
|
||||
<arg_value>true</arg_value>
|
||||
<arg_key>data</arg_key>
|
||||
<arg_value>null</arg_value>
|
||||
<arg_key>text</arg_key>
|
||||
<arg_value>string value</arg_value>
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["count"], 42);
|
||||
assert_eq!(args["rate"], 1.5);
|
||||
assert_eq!(args["enabled"], true);
|
||||
assert_eq!(args["data"], serde_json::Value::Null);
|
||||
assert_eq!(args["text"], "string value");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_glm4_streaming() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Simulate streaming chunks
|
||||
let chunks = vec![
|
||||
"<tool_call>",
|
||||
"get_weather\n",
|
||||
"<arg_key>city</arg_key>\n",
|
||||
"<arg_value>Shanghai</arg_value>\n",
|
||||
"<arg_key>units</arg_key>\n",
|
||||
"<arg_value>celsius</arg_value>\n",
|
||||
"</tool_call>",
|
||||
];
|
||||
|
||||
let mut found_name = false;
|
||||
let mut found_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolName { name, .. } => {
|
||||
assert_eq!(name, "get_weather");
|
||||
found_name = true;
|
||||
}
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
found_complete = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(found_name || found_complete);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_glm4_format_detection() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
|
||||
// Should detect GLM-4 format
|
||||
assert!(parser.detect_format("<tool_call>"));
|
||||
assert!(parser.detect_format("text with <tool_call> marker"));
|
||||
|
||||
// Should not detect other formats
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
assert!(!parser.detect_format("<|tool▁calls▁begin|>"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_glm4_python_literal_values() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
|
||||
// Test Python-style boolean values
|
||||
let input = r#"<tool_call>config
|
||||
<arg_key>debug</arg_key>
|
||||
<arg_value>True</arg_value>
|
||||
<arg_key>verbose</arg_key>
|
||||
<arg_value>False</arg_value>
|
||||
<arg_key>optional</arg_key>
|
||||
<arg_value>None</arg_value>
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["debug"], true);
|
||||
assert_eq!(args["verbose"], false);
|
||||
assert_eq!(args["optional"], serde_json::Value::Null);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_python_literals() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
|
||||
let input = r#"<tool_call>test_func
|
||||
<arg_key>bool_true</arg_key>
|
||||
<arg_value>True</arg_value>
|
||||
<arg_key>bool_false</arg_key>
|
||||
<arg_value>False</arg_value>
|
||||
<arg_key>none_val</arg_key>
|
||||
<arg_value>None</arg_value>
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test_func");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["bool_true"], true);
|
||||
assert_eq!(args["bool_false"], false);
|
||||
assert_eq!(args["none_val"], serde_json::Value::Null);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_nested_values() {
|
||||
let parser = Glm4MoeParser::new();
|
||||
|
||||
let input = r#"<tool_call>process
|
||||
<arg_key>data</arg_key>
|
||||
<arg_value>{"nested": {"key": "value"}}</arg_value>
|
||||
<arg_key>list</arg_key>
|
||||
<arg_value>[1, 2, 3]</arg_value>
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert!(args["data"].is_object());
|
||||
assert!(args["list"].is_array());
|
||||
}
|
||||
201
sgl-router/tests/tool_parser_gpt_oss.rs
Normal file
201
sgl-router/tests/tool_parser_gpt_oss.rs
Normal file
@@ -0,0 +1,201 @@
|
||||
//! GPT-OSS Parser Integration Tests
|
||||
|
||||
use sglang_router_rs::tool_parser::{GptOssParser, ParseState, StreamResult, ToolParser};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_gpt_oss_complete_parsing() {
|
||||
let parser = GptOssParser::new();
|
||||
|
||||
// Test single tool call
|
||||
let input = r#"Let me search for that information.
|
||||
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "rust programming", "limit": 10}<|call|>
|
||||
Here are the results..."#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
|
||||
// Verify arguments
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["query"], "rust programming");
|
||||
assert_eq!(args["limit"], 10);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_gpt_oss_multiple_tools() {
|
||||
let parser = GptOssParser::new();
|
||||
|
||||
let input = r#"<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "Paris"}<|call|>commentary
|
||||
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "Paris tourism"}<|call|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "search");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_gpt_oss_with_namespace() {
|
||||
let parser = GptOssParser::new();
|
||||
|
||||
// Test with different namespace patterns
|
||||
let input = r#"<|channel|>commentary to=api.users.create<|constrain|>json<|message|>{"name": "John", "email": "john@example.com"}<|call|>
|
||||
<|channel|>commentary to=tools.calculator.add<|constrain|>json<|message|>{"x": 10, "y": 20}<|call|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "create"); // Should extract last part
|
||||
assert_eq!(result[1].function.name, "add");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_gpt_oss_with_assistant_prefix() {
|
||||
let parser = GptOssParser::new();
|
||||
|
||||
// Test with <|start|>assistant prefix
|
||||
let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_gpt_oss_empty_args() {
|
||||
let parser = GptOssParser::new();
|
||||
|
||||
// Test with empty arguments
|
||||
let input =
|
||||
r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_time");
|
||||
assert_eq!(result[0].function.arguments, "{}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_gpt_oss_streaming() {
|
||||
let parser = GptOssParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Simulate streaming chunks
|
||||
let chunks = vec![
|
||||
"<|channel|>commentary to=",
|
||||
"functions.calculate",
|
||||
"<|constrain|>json<|message|>",
|
||||
r#"{"x": 10"#,
|
||||
r#", "y": 20}"#,
|
||||
"<|call|>",
|
||||
];
|
||||
|
||||
let mut found_name = false;
|
||||
let mut found_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolName { name, .. } => {
|
||||
assert_eq!(name, "calculate");
|
||||
found_name = true;
|
||||
}
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "calculate");
|
||||
found_complete = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(found_name || found_complete);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_gpt_oss_format_detection() {
|
||||
let parser = GptOssParser::new();
|
||||
|
||||
// Should detect GPT-OSS format
|
||||
assert!(parser.detect_format("<|channel|>commentary to="));
|
||||
assert!(parser.detect_format("<|channel|>commentary"));
|
||||
assert!(parser.detect_format("text with <|channel|>commentary to= marker"));
|
||||
|
||||
// Should not detect other formats
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
assert!(!parser.detect_format("<tool_call>"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_gpt_oss_with_whitespace() {
|
||||
let parser = GptOssParser::new();
|
||||
|
||||
// Test with whitespace after function name
|
||||
let input = r#"<|channel|>commentary to=functions.test <|constrain|>json<|message|>{"key": "value"}<|call|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_gpt_oss_complex_json() {
|
||||
let parser = GptOssParser::new();
|
||||
|
||||
// Test with complex nested JSON
|
||||
let input = r#"<|channel|>commentary to=functions.process<|constrain|>json<|message|>{
|
||||
"nested": {
|
||||
"data": [1, 2, 3],
|
||||
"config": {
|
||||
"enabled": true
|
||||
}
|
||||
}
|
||||
}<|call|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert!(args["nested"]["data"].is_array());
|
||||
assert_eq!(args["nested"]["config"]["enabled"], true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_commentary_without_function() {
|
||||
let parser = GptOssParser::new();
|
||||
|
||||
// Python should extract commentary as normal text
|
||||
let input = r#"<|channel|>commentary<|message|>**Action plan**: 1. Do X 2. Do Y<|end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0); // No tool calls
|
||||
// TODO: Verify normal text = "**Action plan**: 1. Do X 2. Do Y"
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_final_channel() {
|
||||
let parser = GptOssParser::new();
|
||||
|
||||
let input = r#"<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"x": 1}<|call|>
|
||||
<|channel|>final<|message|>The result is calculated.<|return|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
// TODO: Verify normal text = "The result is calculated."
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mixed_commentary_and_calls() {
|
||||
let parser = GptOssParser::new();
|
||||
|
||||
let input = r#"<|channel|>commentary<|message|>Let me think<|end|>
|
||||
<|channel|>commentary to=functions.calc<|constrain|>json<|message|>{"x": 5}<|call|>
|
||||
<|channel|>commentary<|message|>Processing...<|end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calc");
|
||||
// TODO: Verify normal text = "Let me think Processing..."
|
||||
}
|
||||
147
sgl-router/tests/tool_parser_json.rs
Normal file
147
sgl-router/tests/tool_parser_json.rs
Normal file
@@ -0,0 +1,147 @@
|
||||
//! JSON Parser Integration Tests
|
||||
//!
|
||||
//! Tests for the JSON parser which handles OpenAI, Claude, and generic JSON formats
|
||||
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::tool_parser::{JsonParser, ToolParser};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_simple_json_tool_call() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["location"], "San Francisco");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_array_of_tools() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"[
|
||||
{"name": "get_weather", "arguments": {"location": "SF"}},
|
||||
{"name": "search", "arguments": {"query": "news"}}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "search");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_with_parameters_key() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["x"], 10);
|
||||
assert_eq!(args["y"], 20);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_extraction_from_text() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"I'll help you with that. {"name": "search", "arguments": {"query": "rust"}} Let me search for that."#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_with_nested_objects() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"{
|
||||
"name": "update_config",
|
||||
"arguments": {
|
||||
"settings": {
|
||||
"theme": "dark",
|
||||
"language": "en",
|
||||
"notifications": {
|
||||
"email": true,
|
||||
"push": false
|
||||
}
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "update_config");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["settings"]["theme"], "dark");
|
||||
assert_eq!(args["settings"]["notifications"]["email"], true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_with_special_characters() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"{"name": "echo", "arguments": {"text": "Line 1\nLine 2\tTabbed", "path": "C:\\Users\\test"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["text"], "Line 1\nLine 2\tTabbed");
|
||||
assert_eq!(args["path"], "C:\\Users\\test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_with_unicode() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍", "emoji": "😊"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["text"], "Hello 世界 🌍");
|
||||
assert_eq!(args["emoji"], "😊");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_empty_arguments() {
|
||||
let parser = JsonParser::new();
|
||||
let input = r#"{"name": "ping", "arguments": {}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "ping");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args, json!({}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_invalid_format() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Missing closing brace
|
||||
let input = r#"{"name": "test", "arguments": {"key": "value""#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0);
|
||||
|
||||
// Not JSON at all
|
||||
let input = "This is just plain text";
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_format_detection() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(parser.detect_format(r#"[{"name": "test"}]"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
|
||||
}
|
||||
160
sgl-router/tests/tool_parser_kimik2.rs
Normal file
160
sgl-router/tests/tool_parser_kimik2.rs
Normal file
@@ -0,0 +1,160 @@
|
||||
//! Kimi K2 Parser Integration Tests
|
||||
|
||||
use sglang_router_rs::tool_parser::{KimiK2Parser, ParseState, StreamResult, ToolParser};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_kimik2_complete_parsing() {
|
||||
let parser = KimiK2Parser::new();
|
||||
|
||||
// Test single tool call
|
||||
let input = r#"Let me help you with that.
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>
|
||||
The weather in Tokyo is..."#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
|
||||
// Verify arguments
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["location"], "Tokyo");
|
||||
assert_eq!(args["units"], "celsius");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_kimik2_multiple_tools() {
|
||||
let parser = KimiK2Parser::new();
|
||||
|
||||
let input = r#"<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{"query": "rust tutorials"}<|tool_call_end|>
|
||||
<|tool_call_begin|>functions.translate:1<|tool_call_argument_begin|>{"text": "Hello", "to": "ja"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "translate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_kimik2_with_whitespace() {
|
||||
let parser = KimiK2Parser::new();
|
||||
|
||||
// Test with extra whitespace
|
||||
let input = r#"<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value", "num": 42} <|tool_call_end|>
|
||||
<|tool_calls_section_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["key"], "value");
|
||||
assert_eq!(args["num"], 42);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_kimik2_streaming() {
|
||||
let parser = KimiK2Parser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Simulate streaming chunks
|
||||
let chunks = vec![
|
||||
"<|tool_calls_section_begin|>\n",
|
||||
"<|tool_call_begin|>functions.",
|
||||
"calculate:0",
|
||||
"<|tool_call_argument_begin|>",
|
||||
r#"{"x": 10, "#,
|
||||
r#""y": 20}"#,
|
||||
"<|tool_call_end|>\n",
|
||||
"<|tool_calls_section_end|>",
|
||||
];
|
||||
|
||||
let mut found_name = false;
|
||||
let mut found_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolName { name, .. } => {
|
||||
assert_eq!(name, "calculate");
|
||||
found_name = true;
|
||||
}
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "calculate");
|
||||
found_complete = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(found_name || found_complete);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_kimik2_format_detection() {
|
||||
let parser = KimiK2Parser::new();
|
||||
|
||||
// Should detect Kimi K2 format
|
||||
assert!(parser.detect_format("<|tool_calls_section_begin|>"));
|
||||
assert!(parser.detect_format("<|tool_call_begin|>"));
|
||||
assert!(parser.detect_format("text with <|tool_calls_section_begin|> marker"));
|
||||
|
||||
// Should not detect other formats
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
assert!(!parser.detect_format("<tool_call>"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_kimik2_sequential_indices() {
|
||||
let parser = KimiK2Parser::new();
|
||||
|
||||
// Test with proper sequential indexing
|
||||
let input = r#"<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>functions.first:0<|tool_call_argument_begin|>{"param": "a"}<|tool_call_end|>
|
||||
<|tool_call_begin|>functions.second:1<|tool_call_argument_begin|>{"param": "b"}<|tool_call_end|>
|
||||
<|tool_call_begin|>functions.third:2<|tool_call_argument_begin|>{"param": "c"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 3);
|
||||
assert_eq!(result[0].function.name, "first");
|
||||
assert_eq!(result[1].function.name, "second");
|
||||
assert_eq!(result[2].function.name, "third");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_function_index_extraction() {
|
||||
let parser = KimiK2Parser::new();
|
||||
|
||||
let input = r#"Text before tool calls.
|
||||
<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{"query": "rust"}<|tool_call_end|>
|
||||
<|tool_call_begin|>functions.calc:1<|tool_call_argument_begin|>{"x": 10}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "calc");
|
||||
// TODO: Verify indices are preserved: 0 and 1
|
||||
// TODO: Verify normal text = "Text before tool calls."
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_namespace_extraction() {
|
||||
let parser = KimiK2Parser::new();
|
||||
|
||||
let input = r#"<|tool_calls_section_begin|>
|
||||
<|tool_call_begin|>api.tools.search:0<|tool_call_argument_begin|>{"q": "test"}<|tool_call_end|>
|
||||
<|tool_calls_section_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search"); // Should extract after last dot
|
||||
}
|
||||
424
sgl-router/tests/tool_parser_llama.rs
Normal file
424
sgl-router/tests/tool_parser_llama.rs
Normal file
@@ -0,0 +1,424 @@
|
||||
//! Llama Parser Integration Tests
|
||||
//!
|
||||
//! Tests for the Llama parser which handles <|python_tag|> format and plain JSON
|
||||
|
||||
use sglang_router_rs::tool_parser::{LlamaParser, ToolParser};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_python_tag_format() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["query"], "weather");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_plain_json_fallback() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["x"], 5);
|
||||
assert_eq!(args["y"], 10);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_with_text_before() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_time");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["timezone"], "UTC");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_with_nested_json() {
|
||||
let parser = LlamaParser::new();
|
||||
let input = r#"<|python_tag|>{
|
||||
"name": "update_settings",
|
||||
"arguments": {
|
||||
"preferences": {
|
||||
"theme": "dark",
|
||||
"language": "en"
|
||||
},
|
||||
"notifications": true
|
||||
}
|
||||
}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "update_settings");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["preferences"]["theme"], "dark");
|
||||
assert_eq!(args["notifications"], true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_empty_arguments() {
|
||||
let parser = LlamaParser::new();
|
||||
|
||||
// With python_tag
|
||||
let input = r#"<|python_tag|>{"name": "ping", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "ping");
|
||||
|
||||
// Plain JSON
|
||||
let input = r#"{"name": "ping", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "ping");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_format_detection() {
|
||||
let parser = LlamaParser::new();
|
||||
|
||||
assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#));
|
||||
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_invalid_json_after_tag() {
|
||||
let parser = LlamaParser::new();
|
||||
|
||||
let input = r#"<|python_tag|>{"name": invalid}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_real_world_output() {
|
||||
let parser = LlamaParser::new();
|
||||
|
||||
// Actual output from Llama 3.2 model - simplified for testing
|
||||
let input = r#"I'll search for that information for you.
|
||||
|
||||
<|python_tag|>{"name": "web_search", "arguments": {"query": "Llama 3.2 model capabilities", "num_results": 5, "search_type": "recent"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "web_search");
|
||||
|
||||
// Test with nicely formatted JSON
|
||||
let formatted_input = r#"<|python_tag|>{
|
||||
"name": "get_current_time",
|
||||
"arguments": {
|
||||
"timezone": "America/New_York",
|
||||
"format": "ISO8601"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let result2 = parser.parse_complete(formatted_input).await.unwrap();
|
||||
assert_eq!(result2.len(), 1);
|
||||
assert_eq!(result2[0].function.name, "get_current_time");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_json_array_format() {
|
||||
let parser = LlamaParser::new();
|
||||
|
||||
// Plain JSON array (should work as fallback)
|
||||
let input = r#"[{"name": "func1", "arguments": {}}, {"name": "func2", "arguments": {}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
// Current implementation might handle this through JSON fallback
|
||||
assert!(!result.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_single_json() {
|
||||
// Test parsing plain JSON without python_tag
|
||||
let parser = LlamaParser::new();
|
||||
let text = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#;
|
||||
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["city"], "Paris");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_json_with_separator() {
|
||||
// Test multiple JSON objects with semicolon separator
|
||||
let parser = LlamaParser::new();
|
||||
let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {"city": "Paris"}};{"name": "get_tourist_attractions", "arguments": {"city": "Paris"}}"#;
|
||||
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
// Note: Current implementation may only parse the first one due to semicolon handling
|
||||
assert!(!result.is_empty());
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_json_with_separator_customized() {
|
||||
// Test multiple JSON objects with python_tag repeated
|
||||
let parser = LlamaParser::new();
|
||||
let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {}}<|python_tag|>{"name": "get_tourist_attractions", "arguments": {}}"#;
|
||||
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
// Current implementation may handle this differently
|
||||
assert!(!result.is_empty());
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_with_trailing_text() {
|
||||
// Test JSON with trailing text after
|
||||
let parser = LlamaParser::new();
|
||||
let text = r#"{"name": "get_weather", "arguments": {}} Some follow-up text"#;
|
||||
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_invalid_then_valid_json() {
|
||||
// Test error recovery - invalid JSON followed by valid JSON
|
||||
let parser = LlamaParser::new();
|
||||
let text = r#"{"name": "get_weather", "arguments": {{"name": "get_weather", "arguments": {}}"#;
|
||||
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
// Should parse at least one valid JSON
|
||||
if !result.is_empty() {
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_plain_text_only() {
|
||||
// Test plain text with no tool calls
|
||||
let parser = LlamaParser::new();
|
||||
let text = "This is just plain explanation text.";
|
||||
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
assert_eq!(result.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_with_python_tag_prefix() {
|
||||
// Test text before python_tag
|
||||
let parser = LlamaParser::new();
|
||||
let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "arguments": {}}"#;
|
||||
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// STREAMING TESTS
|
||||
// ============================================================================
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_streaming_simple() {
|
||||
let parser = LlamaParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
// Send complete JSON at once
|
||||
let full_json = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;
|
||||
|
||||
let result = parser
|
||||
.parse_incremental(full_json, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match result {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "search");
|
||||
}
|
||||
_ => panic!("Expected ToolComplete for complete JSON input"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_streaming_partial() {
|
||||
let parser = LlamaParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
// Stream in chunks
|
||||
let chunks = vec![
|
||||
r#"<|python"#,
|
||||
r#"_tag|>{"name": "#,
|
||||
r#""calculate", "#,
|
||||
r#""arguments": {"x": 10}"#,
|
||||
r#"}"#,
|
||||
];
|
||||
|
||||
let mut got_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
|
||||
assert_eq!(tool.function.name, "calculate");
|
||||
got_complete = true;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(got_complete, "Should have completed parsing");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_streaming_plain_json() {
|
||||
let parser = LlamaParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
// Stream plain JSON without python_tag
|
||||
let chunks = vec![
|
||||
r#"{"name": "#,
|
||||
r#""search", "#,
|
||||
r#""arguments": "#,
|
||||
r#"{"query": "#,
|
||||
r#""test"}}"#,
|
||||
];
|
||||
|
||||
let mut got_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
|
||||
assert_eq!(tool.function.name, "search");
|
||||
got_complete = true;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(got_complete, "Should have completed parsing");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_streaming_with_text_before() {
|
||||
let parser = LlamaParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
let chunks = vec![
|
||||
r#"Let me help you. "#,
|
||||
r#"<|python_tag|>"#,
|
||||
r#"{"name": "get_time","#,
|
||||
r#" "arguments": {"#,
|
||||
r#""timezone": "UTC"}}"#,
|
||||
];
|
||||
|
||||
let mut got_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
|
||||
assert_eq!(tool.function.name, "get_time");
|
||||
got_complete = true;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(got_complete, "Should have completed parsing");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_streaming_multiple_tools() {
|
||||
// Test streaming multiple tool calls with semicolon separator
|
||||
let parser = LlamaParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
let text =
|
||||
r#"<|python_tag|>{"name": "func1", "arguments": {}};{"name": "func2", "arguments": {}}"#;
|
||||
|
||||
let result = parser.parse_incremental(text, &mut state).await.unwrap();
|
||||
|
||||
// Should get first tool complete
|
||||
match result {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "func1");
|
||||
}
|
||||
_ => panic!("Expected first tool to be complete"),
|
||||
}
|
||||
|
||||
// Process remaining buffer to get second tool
|
||||
let result2 = parser.parse_incremental("", &mut state).await.unwrap();
|
||||
match result2 {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "func2");
|
||||
}
|
||||
_ => panic!("Expected second tool to be complete"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_streaming_multiple_tools_chunked() {
|
||||
// Test streaming multiple tool calls arriving in chunks
|
||||
let parser = LlamaParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
// First chunk - incomplete first JSON
|
||||
let chunk1 = r#"<|python_tag|>{"name": "get_weather", "arguments""#;
|
||||
let result1 = parser.parse_incremental(chunk1, &mut state).await.unwrap();
|
||||
|
||||
// Should be incomplete or have tool name
|
||||
match result1 {
|
||||
sglang_router_rs::tool_parser::StreamResult::Incomplete
|
||||
| sglang_router_rs::tool_parser::StreamResult::ToolName { .. }
|
||||
| sglang_router_rs::tool_parser::StreamResult::ToolArguments { .. } => {
|
||||
// Expected - could get tool name or be incomplete or even partial args
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected incomplete or tool name for partial JSON, got: {:?}",
|
||||
result1
|
||||
),
|
||||
}
|
||||
|
||||
// Second chunk - complete first JSON and separator
|
||||
let chunk2 = r#": {"city": "Paris"}};{"name": "#;
|
||||
let result2 = parser.parse_incremental(chunk2, &mut state).await.unwrap();
|
||||
|
||||
// Should get first tool complete
|
||||
match result2 {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["city"], "Paris");
|
||||
}
|
||||
_ => panic!("Expected first tool to be complete after separator"),
|
||||
}
|
||||
|
||||
// Third chunk - complete second JSON
|
||||
let chunk3 = r#""get_time", "arguments": {"timezone": "UTC"}}"#;
|
||||
let result3 = parser.parse_incremental(chunk3, &mut state).await.unwrap();
|
||||
|
||||
// Should get second tool complete
|
||||
match result3 {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_time");
|
||||
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["timezone"], "UTC");
|
||||
}
|
||||
_ => {
|
||||
// If not complete yet, try one more empty chunk
|
||||
let result4 = parser.parse_incremental("", &mut state).await.unwrap();
|
||||
match result4 {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_time");
|
||||
let args: serde_json::Value =
|
||||
serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["timezone"], "UTC");
|
||||
}
|
||||
_ => panic!("Expected second tool to be complete"),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
153
sgl-router/tests/tool_parser_mistral.rs
Normal file
153
sgl-router/tests/tool_parser_mistral.rs
Normal file
@@ -0,0 +1,153 @@
|
||||
//! Mistral Parser Integration Tests
|
||||
//!
|
||||
//! Tests for the Mistral parser which handles [TOOL_CALLS] format
|
||||
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::tool_parser::{MistralParser, ToolParser};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mistral_single_tool() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"Let me search for that.
|
||||
[TOOL_CALLS] [{"name": "search_web", "arguments": {"query": "latest news", "max_results": 5}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search_web");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["query"], "latest news");
|
||||
assert_eq!(args["max_results"], 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mistral_multiple_tools() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"I'll help you with both tasks.
|
||||
[TOOL_CALLS] [
|
||||
{"name": "get_weather", "arguments": {"city": "Tokyo", "units": "celsius"}},
|
||||
{"name": "search_news", "arguments": {"query": "AI developments", "limit": 10}}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
let args0: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args0["city"], "Tokyo");
|
||||
|
||||
assert_eq!(result[1].function.name, "search_news");
|
||||
let args1: serde_json::Value = serde_json::from_str(&result[1].function.arguments).unwrap();
|
||||
assert_eq!(args1["query"], "AI developments");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mistral_nested_json() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"Processing complex data.
|
||||
[TOOL_CALLS] [{"name": "process_data", "arguments": {"config": {"nested": {"value": [1, 2, 3]}}, "enabled": true}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["config"]["nested"]["value"], json!([1, 2, 3]));
|
||||
assert_eq!(args["enabled"], true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mistral_with_text_after() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]
|
||||
|
||||
And here's some text after the tool call that should be ignored."#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mistral_empty_arguments() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "ping", "arguments": {}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "ping");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mistral_with_brackets_in_strings() {
|
||||
let parser = MistralParser::new();
|
||||
let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"text": "Array notation: arr[0] = value[1]"}}]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["text"], "Array notation: arr[0] = value[1]");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mistral_format_detection() {
|
||||
let parser = MistralParser::new();
|
||||
|
||||
assert!(parser.detect_format("[TOOL_CALLS] ["));
|
||||
assert!(parser.detect_format("Some text [TOOL_CALLS] ["));
|
||||
assert!(!parser.detect_format("Just plain text"));
|
||||
assert!(!parser.detect_format("[{\"name\": \"test\"}]")); // JSON array without TOOL_CALLS
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mistral_malformed_json() {
|
||||
let parser = MistralParser::new();
|
||||
|
||||
// Missing closing bracket
|
||||
let input = r#"[TOOL_CALLS] [{"name": "test", "arguments": {}"#;
|
||||
if let Ok(result) = parser.parse_complete(input).await {
|
||||
assert_eq!(result.len(), 0);
|
||||
}
|
||||
// Error is also acceptable for malformed input
|
||||
|
||||
// Invalid JSON inside
|
||||
let input = r#"[TOOL_CALLS] [{"name": invalid}]"#;
|
||||
if let Ok(result) = parser.parse_complete(input).await {
|
||||
assert_eq!(result.len(), 0);
|
||||
}
|
||||
// Error is also acceptable for malformed input
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mistral_real_world_output() {
|
||||
let parser = MistralParser::new();
|
||||
|
||||
// Actual output from Mistral model
|
||||
let input = r#"I'll search for information about Rust programming and check the weather in San Francisco.
|
||||
|
||||
[TOOL_CALLS] [
|
||||
{
|
||||
"name": "web_search",
|
||||
"arguments": {
|
||||
"query": "Rust programming language features 2024",
|
||||
"max_results": 3,
|
||||
"include_snippets": true
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"location": "San Francisco, CA",
|
||||
"units": "fahrenheit",
|
||||
"include_forecast": false
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
Let me execute these searches for you."#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "web_search");
|
||||
assert_eq!(result[1].function.name, "get_weather");
|
||||
}
|
||||
301
sgl-router/tests/tool_parser_mixed_edge_cases.rs
Normal file
301
sgl-router/tests/tool_parser_mixed_edge_cases.rs
Normal file
@@ -0,0 +1,301 @@
|
||||
//! Mixed Format and Additional Edge Case Tests
|
||||
//!
|
||||
//! Tests for edge cases across parsers and mixed format scenarios
|
||||
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::tool_parser::{
|
||||
JsonParser, LlamaParser, MistralParser, ParseState, PythonicParser, QwenParser, StreamResult,
|
||||
ToolParser,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mixed_formats_in_text() {
|
||||
// Test that parsers correctly ignore other formats' markers
|
||||
|
||||
let json_parser = JsonParser::new();
|
||||
let input = r#"
|
||||
Some text with [TOOL_CALLS] marker that shouldn't trigger.
|
||||
Also has <tool_call> tags and [function()] syntax.
|
||||
But here's the actual JSON: {"name": "test", "arguments": {}}
|
||||
"#;
|
||||
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
|
||||
// Mistral parser should ignore JSON and other formats
|
||||
let mistral_parser = MistralParser::new();
|
||||
let input = r#"
|
||||
{"name": "fake"} [function()] <tool_call>
|
||||
[TOOL_CALLS] [{"name": "real", "arguments": {}}]
|
||||
"#;
|
||||
|
||||
let result = mistral_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "real");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_format_markers_in_string_content() {
|
||||
// Test that format markers inside string content don't interfere
|
||||
|
||||
let pythonic_parser = PythonicParser::new();
|
||||
let input = r#"[echo(text="Use [TOOL_CALLS] and <tool_call> in text")]"#;
|
||||
|
||||
let result = pythonic_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["text"], "Use [TOOL_CALLS] and <tool_call> in text");
|
||||
|
||||
let qwen_parser = QwenParser::new();
|
||||
let input = r#"<tool_call>
|
||||
{"name": "log", "arguments": {"msg": "Found [function()] pattern"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = qwen_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["msg"], "Found [function()] pattern");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deeply_nested_json_structures() {
|
||||
let json_parser = JsonParser::new();
|
||||
|
||||
let input = r#"{
|
||||
"name": "deep_process",
|
||||
"arguments": {
|
||||
"level1": {
|
||||
"level2": {
|
||||
"level3": {
|
||||
"level4": {
|
||||
"level5": {
|
||||
"data": [1, 2, [3, [4, 5]]]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}"#;
|
||||
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "deep_process");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert!(args["level1"]["level2"]["level3"]["level4"]["level5"]["data"].is_array());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_multiple_sequential_calls_different_formats() {
|
||||
// Simulate a scenario where different parts of text have different formats
|
||||
// (though each parser will only recognize its own format)
|
||||
|
||||
let llama_parser = LlamaParser::new();
|
||||
|
||||
// Llama parser currently only returns the first tool found
|
||||
let input = r#"First call: <|python_tag|>{"name": "call1", "arguments": {}}"#;
|
||||
|
||||
let result = llama_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "call1");
|
||||
|
||||
// Test plain JSON separately
|
||||
let input2 = r#"{"name": "call2", "arguments": {"x": 1}}"#;
|
||||
let result2 = llama_parser.parse_complete(input2).await.unwrap();
|
||||
assert_eq!(result2.len(), 1);
|
||||
assert_eq!(result2[0].function.name, "call2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_empty_and_whitespace_variations() {
|
||||
let json_parser = JsonParser::new();
|
||||
|
||||
// Various whitespace scenarios
|
||||
let cases = vec![
|
||||
r#" {"name":"compact","arguments":{}} "#,
|
||||
r#"
|
||||
|
||||
{"name": "spaced", "arguments": {}}
|
||||
|
||||
"#,
|
||||
r#" {"name": "tabbed", "arguments": {}} "#, // tabs
|
||||
];
|
||||
|
||||
for input in cases {
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1, "Should parse regardless of whitespace");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_special_json_values() {
|
||||
let json_parser = JsonParser::new();
|
||||
|
||||
// Test various special JSON values
|
||||
let input = r#"{
|
||||
"name": "test_special",
|
||||
"arguments": {
|
||||
"float_e": 1.23e10,
|
||||
"float_neg_e": 1.23e-10,
|
||||
"hex_like": "0x1234",
|
||||
"very_long_num": 99999999999999999999,
|
||||
"special_strings": ["", " ", "\u0000", "\u001f"],
|
||||
"escaped": "\\n\\r\\t\\\"\\\\",
|
||||
"unicode": "\u4e2d\u6587"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test_special");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert!(args["special_strings"].is_array());
|
||||
assert!(args["escaped"].is_string());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parser_recovery_after_invalid_input() {
|
||||
let mut state = ParseState::new();
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Send invalid JSON first
|
||||
let _ = parser.parse_incremental(r#"{"broken": "#, &mut state).await;
|
||||
|
||||
// Clear state and try valid JSON
|
||||
state.buffer.clear();
|
||||
let result = parser
|
||||
.parse_incremental(r#"{"name": "valid", "arguments": {}}"#, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "valid");
|
||||
}
|
||||
_ => {
|
||||
// Might be incomplete depending on implementation
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_boundary_cases_for_extraction() {
|
||||
// Test edge cases in JSON extraction from text
|
||||
|
||||
let json_parser = JsonParser::new();
|
||||
|
||||
// JSON at the very beginning
|
||||
let input = r#"{"name": "start", "arguments": {}} and then text"#;
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "start");
|
||||
|
||||
// JSON at the very end
|
||||
let input = r#"Some text first {"name": "end", "arguments": {}}"#;
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "end");
|
||||
|
||||
// Multiple JSON objects in text (should find first valid one)
|
||||
let input =
|
||||
r#"Text {"name": "first", "arguments": {}} more {"name": "second", "arguments": {}}"#;
|
||||
let result = json_parser.parse_complete(input).await.unwrap();
|
||||
assert!(!result.is_empty());
|
||||
assert_eq!(result[0].function.name, "first");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_edge_cases() {
|
||||
let parser = PythonicParser::new();
|
||||
|
||||
// Function name with underscores and numbers
|
||||
let input = r#"[func_name_2(param_1="value")]"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "func_name_2");
|
||||
|
||||
// Empty string argument
|
||||
let input = r#"[process(text="")]"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["text"], "");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mistral_with_pretty_json() {
|
||||
let parser = MistralParser::new();
|
||||
|
||||
// Pretty-printed JSON in Mistral format
|
||||
let input = r#"[TOOL_CALLS] [
|
||||
{
|
||||
"name": "formatted",
|
||||
"arguments": {
|
||||
"nested": {
|
||||
"key": "value"
|
||||
},
|
||||
"array": [
|
||||
1,
|
||||
2,
|
||||
3
|
||||
]
|
||||
}
|
||||
}
|
||||
]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "formatted");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["nested"]["key"], "value");
|
||||
assert_eq!(args["array"], json!([1, 2, 3]));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_qwen_with_cdata_like_content() {
|
||||
let parser = QwenParser::new();
|
||||
|
||||
// Test with content that looks like CDATA but isn't
|
||||
// Note: QwenParser expects exactly "<tool_call>\n" with the newline
|
||||
let input = r#"<tool_call>
|
||||
{"name": "process", "arguments": {"xml": "<![CDATA[some data]]>"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["xml"], "<![CDATA[some data]]>");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_extremely_long_function_names() {
|
||||
let parser = PythonicParser::new();
|
||||
|
||||
let long_name = "very_long_function_name_that_might_appear_in_generated_code_somewhere";
|
||||
let input = format!(r#"[{}(param="value")]"#, long_name);
|
||||
|
||||
let result = parser.parse_complete(&input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, long_name);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_with_duplicate_keys() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// JSON with duplicate keys (last one should win per JSON spec)
|
||||
let input = r#"{"name": "test", "arguments": {"key": "first", "key": "second"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
// JSON parsers typically keep the last value for duplicate keys
|
||||
assert_eq!(args["key"], "second");
|
||||
}
|
||||
559
sgl-router/tests/tool_parser_pythonic.rs
Normal file
559
sgl-router/tests/tool_parser_pythonic.rs
Normal file
@@ -0,0 +1,559 @@
|
||||
//! Pythonic Parser Integration Tests
|
||||
//!
|
||||
//! Tests for the Pythonic parser which handles Python function call syntax
|
||||
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::tool_parser::{PythonicParser, ToolParser};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_single_function() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[get_weather(city="London", units="celsius")]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["city"], "London");
|
||||
assert_eq!(args["units"], "celsius");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_multiple_functions() {
|
||||
let parser = PythonicParser::new();
|
||||
let input =
|
||||
r#"[search_web(query="Rust programming", max_results=5), get_time(timezone="UTC")]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search_web");
|
||||
assert_eq!(result[1].function.name, "get_time");
|
||||
|
||||
let args0: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args0["query"], "Rust programming");
|
||||
assert_eq!(args0["max_results"], 5);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_with_python_literals() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[configure(enabled=True, disabled=False, optional=None)]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["enabled"], true);
|
||||
assert_eq!(args["disabled"], false);
|
||||
assert_eq!(args["optional"], json!(null));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_with_lists_and_dicts() {
|
||||
let parser = PythonicParser::new();
|
||||
let input =
|
||||
r#"[process_data(items=[1, 2, 3], config={"key": "value", "nested": {"deep": True}})]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["items"], json!([1, 2, 3]));
|
||||
assert_eq!(args["config"]["key"], "value");
|
||||
assert_eq!(args["config"]["nested"]["deep"], true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_with_special_tokens() {
|
||||
let parser = PythonicParser::new();
|
||||
|
||||
// Llama 4 sometimes outputs these tokens
|
||||
let input = r#"<|python_start|>[calculate(x=10, y=20)]<|python_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "calculate");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["x"], 10);
|
||||
assert_eq!(args["y"], 20);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_with_nested_parentheses() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[math_eval(expression="(2 + 3) * (4 - 1)", round_to=2)]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["expression"], "(2 + 3) * (4 - 1)");
|
||||
assert_eq!(args["round_to"], 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_with_escaped_quotes() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[echo(text="She said \"Hello\" to him")]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["text"], "She said \"Hello\" to him");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_empty_arguments() {
|
||||
let parser = PythonicParser::new();
|
||||
let input = r#"[ping()]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "ping");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args, json!({}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_format_detection() {
|
||||
let parser = PythonicParser::new();
|
||||
|
||||
assert!(parser.detect_format("[function_name("));
|
||||
assert!(parser.detect_format("[get_weather(city=\"NYC\")]"));
|
||||
assert!(!parser.detect_format("Just plain text"));
|
||||
assert!(!parser.detect_format("[1, 2, 3]")); // Plain list
|
||||
assert!(!parser.detect_format("{\"name\": \"test\"}")); // JSON
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_invalid_syntax() {
|
||||
let parser = PythonicParser::new();
|
||||
|
||||
// Missing closing bracket
|
||||
let input = r#"[function(arg=value"#;
|
||||
if let Ok(result) = parser.parse_complete(input).await {
|
||||
assert_eq!(result.len(), 0);
|
||||
}
|
||||
// Error is also acceptable for invalid syntax
|
||||
|
||||
// Invalid Python syntax - empty parameter name
|
||||
// Note: The parser currently accepts this invalid syntax and returns a result
|
||||
// This is a known limitation of the current implementation
|
||||
let input = r#"[function(=value)]"#;
|
||||
if let Ok(result) = parser.parse_complete(input).await {
|
||||
// The parser incorrectly accepts this, returning 1 result
|
||||
// We'll accept this behavior for now but note it's not ideal
|
||||
assert!(result.len() <= 1, "Should parse at most one function");
|
||||
}
|
||||
// Error would be the correct behavior
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_real_world_llama4() {
|
||||
let parser = PythonicParser::new();
|
||||
|
||||
// Actual output from Llama 4 model
|
||||
let input = r#"I'll help you with multiple tasks. Let me search for information and perform calculations.
|
||||
|
||||
[web_search(query="latest Rust features", max_results=3, safe_search=True),
|
||||
calculate(expression="42 * 3.14159", precision=2),
|
||||
get_weather(city="San Francisco", units="fahrenheit", include_forecast=False)]
|
||||
|
||||
These functions will provide the information you need."#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 3);
|
||||
assert_eq!(result[0].function.name, "web_search");
|
||||
assert_eq!(result[1].function.name, "calculate");
|
||||
assert_eq!(result[2].function.name, "get_weather");
|
||||
|
||||
let args0: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args0["query"], "latest Rust features");
|
||||
assert_eq!(args0["safe_search"], true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_nested_brackets_in_lists() {
|
||||
let parser = PythonicParser::new();
|
||||
|
||||
// Test nested brackets within list arguments
|
||||
let input = r#"[process_matrix(data=[[1, 2], [3, 4]], labels=["row[0]", "row[1]"])]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "process_matrix");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["data"], json!([[1, 2], [3, 4]]));
|
||||
assert_eq!(args["labels"], json!(["row[0]", "row[1]"]));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_nested_brackets_in_dicts() {
|
||||
let parser = PythonicParser::new();
|
||||
|
||||
// Test nested brackets within dictionary arguments
|
||||
let input =
|
||||
r#"[analyze(config={"patterns": ["[a-z]+", "[0-9]+"], "nested": {"list": [1, [2, 3]]}})]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "analyze");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["config"]["patterns"], json!(["[a-z]+", "[0-9]+"]));
|
||||
assert_eq!(args["config"]["nested"]["list"], json!([1, [2, 3]]));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_mixed_quotes() {
|
||||
let parser = PythonicParser::new();
|
||||
|
||||
// Test mixed quote types in arguments
|
||||
let input = r#"[format_text(single='Hello', double="World", mixed="It's \"quoted\"")]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "format_text");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["single"], "Hello");
|
||||
assert_eq!(args["double"], "World");
|
||||
assert_eq!(args["mixed"], "It's \"quoted\"");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_complex_nesting() {
|
||||
let parser = PythonicParser::new();
|
||||
|
||||
// Test complex nested structures
|
||||
let input = r#"[transform(
|
||||
matrix=[[1, [2, 3]], [4, [5, [6, 7]]]],
|
||||
operations=[{"type": "scale", "factor": [2, 3]}, {"type": "rotate", "angle": 90}],
|
||||
metadata={"tags": ["nested[0]", "nested[1]"], "config": {"depth": [1, 2, 3]}}
|
||||
)]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "transform");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert!(args["matrix"].is_array());
|
||||
assert!(args["operations"].is_array());
|
||||
assert_eq!(args["operations"][0]["type"], "scale");
|
||||
assert_eq!(args["metadata"]["config"]["depth"], json!([1, 2, 3]));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_streaming_no_brackets() {
|
||||
// Test parsing text with no brackets (no tool calls)
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
let text = "This is just normal text without any tool calls.";
|
||||
let result = parser.parse_incremental(text, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
sglang_router_rs::tool_parser::StreamResult::Incomplete => {
|
||||
// Expected - no tool calls found
|
||||
assert_eq!(state.buffer, text);
|
||||
}
|
||||
_ => panic!("Should return Incomplete for text without tool calls"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_streaming_complete_tool_call() {
|
||||
// Test parsing a complete tool call
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
let text = "Here's a tool call: [get_weather(location='New York', unit='celsius')]";
|
||||
let result = parser.parse_incremental(text, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["location"], "New York");
|
||||
assert_eq!(args["unit"], "celsius");
|
||||
assert_eq!(state.buffer, "");
|
||||
}
|
||||
_ => panic!("Should return ToolComplete for complete tool call"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_streaming_text_before_tool_call() {
|
||||
// Test parsing text that appears before a tool call
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
let text = "This is some text before [get_weather(location='London')]";
|
||||
let result = parser.parse_incremental(text, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["location"], "London");
|
||||
}
|
||||
_ => panic!("Should return ToolComplete"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_streaming_partial_tool_call() {
|
||||
// Test parsing a partial tool call that spans multiple chunks
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
// First chunk with opening bracket but no closing bracket
|
||||
let text1 = "Let me check the weather: [get_weather(location=";
|
||||
let result1 = parser.parse_incremental(text1, &mut state).await.unwrap();
|
||||
|
||||
match result1 {
|
||||
sglang_router_rs::tool_parser::StreamResult::Incomplete => {
|
||||
assert!(state.buffer.contains("[get_weather(location="));
|
||||
}
|
||||
_ => panic!("First chunk should return Incomplete"),
|
||||
}
|
||||
|
||||
// Second chunk completing the tool call
|
||||
let text2 = "'Paris')]";
|
||||
let result2 = parser.parse_incremental(text2, &mut state).await.unwrap();
|
||||
|
||||
match result2 {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["location"], "Paris");
|
||||
assert_eq!(state.buffer, "");
|
||||
}
|
||||
_ => panic!("Second chunk should return ToolComplete"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_streaming_bracket_without_text_before() {
|
||||
// Test parsing a tool call that starts at the beginning of the text
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
let text = "[search(query='python programming')]";
|
||||
let result = parser.parse_incremental(text, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "search");
|
||||
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["query"], "python programming");
|
||||
}
|
||||
_ => panic!("Should return ToolComplete"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_streaming_text_after_tool_call() {
|
||||
// Test parsing text that appears after a tool call
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
// First chunk with complete tool call and some text after
|
||||
let text = "[get_weather(location='Tokyo')] Here's the forecast:";
|
||||
let result = parser.parse_incremental(text, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
// Text after tool call should remain in buffer
|
||||
// Note: Current implementation may clear buffer, this behavior needs verification
|
||||
}
|
||||
_ => panic!("Should return ToolComplete"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_streaming_multiple_tool_calls() {
|
||||
// Test parsing multiple tool calls in sequence
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
let text = "[get_weather(location='Berlin'), search(query='restaurants')]";
|
||||
|
||||
// Current implementation may handle this as a single parse
|
||||
let result = parser.parse_incremental(text, &mut state).await.unwrap();
|
||||
|
||||
// The parser should handle multiple tools in one bracket pair
|
||||
match result {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(_) => {
|
||||
// Expected behavior - parses first tool
|
||||
}
|
||||
_ => {
|
||||
// Also acceptable if it returns Incomplete waiting for more
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_streaming_opening_bracket_only() {
|
||||
// Test parsing text with only an opening bracket but no closing bracket
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
let text = "Let's try this: [";
|
||||
let result = parser.parse_incremental(text, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
sglang_router_rs::tool_parser::StreamResult::Incomplete => {
|
||||
assert!(state.buffer.ends_with("["));
|
||||
}
|
||||
_ => panic!("Should return Incomplete for partial bracket"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_streaming_nested_brackets() {
|
||||
// Test parsing tool calls with nested brackets in arguments
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
let text = "[get_weather(location='New York', unit='celsius', data=[1, 2, 3])]";
|
||||
let result = parser.parse_incremental(text, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["location"], "New York");
|
||||
assert_eq!(args["unit"], "celsius");
|
||||
assert_eq!(args["data"], json!([1, 2, 3]));
|
||||
}
|
||||
_ => panic!("Should return ToolComplete"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_streaming_nested_brackets_dict() {
|
||||
// Test parsing tool calls with nested dictionaries and lists
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
let text = r#"[search(query='test', config={'options': [1, 2], 'nested': {'key': 'value'}})]"#;
|
||||
let result = parser.parse_incremental(text, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "search");
|
||||
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["query"], "test");
|
||||
assert_eq!(args["config"]["options"], json!([1, 2]));
|
||||
assert_eq!(args["config"]["nested"]["key"], "value");
|
||||
}
|
||||
_ => panic!("Should return ToolComplete"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_streaming_multiple_tools_with_nested_brackets() {
|
||||
// Test parsing multiple tool calls with nested brackets
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
let text =
|
||||
"[get_weather(location='Paris', data=[10, 20]), search(query='test', filters=['a', 'b'])]";
|
||||
let result = parser.parse_incremental(text, &mut state).await.unwrap();
|
||||
|
||||
// Should parse both tools successfully
|
||||
match result {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
// At least gets the first tool
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
}
|
||||
_ => panic!("Should return ToolComplete"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_streaming_partial_nested_brackets() {
|
||||
// Test parsing partial tool calls with nested brackets across chunks
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
// First chunk with nested brackets but incomplete
|
||||
let text1 = "Here's a call: [get_weather(location='Tokyo', data=[1, 2";
|
||||
let result1 = parser.parse_incremental(text1, &mut state).await.unwrap();
|
||||
|
||||
match result1 {
|
||||
sglang_router_rs::tool_parser::StreamResult::Incomplete => {
|
||||
assert!(state
|
||||
.buffer
|
||||
.contains("[get_weather(location='Tokyo', data=[1, 2"));
|
||||
}
|
||||
_ => panic!("First chunk should return Incomplete"),
|
||||
}
|
||||
|
||||
// Second chunk completing the nested brackets
|
||||
let text2 = ", 3])]";
|
||||
let result2 = parser.parse_incremental(text2, &mut state).await.unwrap();
|
||||
|
||||
match result2 {
|
||||
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["location"], "Tokyo");
|
||||
assert_eq!(args["data"], json!([1, 2, 3]));
|
||||
}
|
||||
_ => panic!("Second chunk should return ToolComplete"),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_parse_streaming_with_python_start_and_end_token() {
|
||||
// Test parsing a message that starts with <|python_start|> and <|python_end|> across chunks
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = sglang_router_rs::tool_parser::ParseState::new();
|
||||
|
||||
let chunks = vec![
|
||||
"Here's a call: ",
|
||||
"<|python_",
|
||||
"start|>[get_weather(location=",
|
||||
"'Tokyo', data=[1, 2",
|
||||
", 3])]<|python_end|>",
|
||||
];
|
||||
|
||||
let mut got_tool = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["location"], "Tokyo");
|
||||
assert_eq!(args["data"], json!([1, 2, 3]));
|
||||
got_tool = true;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(got_tool, "Should have parsed the tool call");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_detect_and_parse_with_python_start_and_end_token() {
|
||||
// Test parsing a message that starts with <|python_start|> and contains a valid tool call
|
||||
let parser = PythonicParser::new();
|
||||
|
||||
let text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars.";
|
||||
let result = parser.parse_complete(text).await.unwrap();
|
||||
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["location"], "Mars");
|
||||
assert_eq!(args["unit"], "celsius");
|
||||
}
|
||||
259
sgl-router/tests/tool_parser_qwen.rs
Normal file
259
sgl-router/tests/tool_parser_qwen.rs
Normal file
@@ -0,0 +1,259 @@
|
||||
//! Qwen Parser Integration Tests
|
||||
//!
|
||||
//! Tests for the Qwen parser which handles <tool_call>...</tool_call> format
|
||||
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::tool_parser::{ParseState, QwenParser, StreamResult, ToolParser};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_qwen_single_tool() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"<tool_call>
|
||||
{"name": "get_weather", "arguments": {"city": "Beijing", "units": "celsius"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["city"], "Beijing");
|
||||
assert_eq!(args["units"], "celsius");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_qwen_multiple_sequential_tools() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"Let me help you with that.
|
||||
<tool_call>
|
||||
{"name": "search", "arguments": {"query": "Qwen model"}}
|
||||
</tool_call>
|
||||
<tool_call>
|
||||
{"name": "translate", "arguments": {"text": "Hello", "to": "zh"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "translate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_qwen_pretty_printed_json() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"<tool_call>
|
||||
{
|
||||
"name": "create_document",
|
||||
"arguments": {
|
||||
"title": "Test Document",
|
||||
"content": "This is a test",
|
||||
"metadata": {
|
||||
"author": "Qwen",
|
||||
"tags": ["test", "example"]
|
||||
}
|
||||
}
|
||||
}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "create_document");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["metadata"]["author"], "Qwen");
|
||||
assert_eq!(args["metadata"]["tags"], json!(["test", "example"]));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_qwen_with_text_between() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"First, let me search for information.
|
||||
<tool_call>
|
||||
{"name": "search", "arguments": {"query": "test"}}
|
||||
</tool_call>
|
||||
|
||||
Now I'll translate something.
|
||||
|
||||
<tool_call>
|
||||
{"name": "translate", "arguments": {"text": "world", "to": "es"}}
|
||||
</tool_call>
|
||||
Done!"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
assert_eq!(result[1].function.name, "translate");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_qwen_empty_arguments() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"<tool_call>
|
||||
{"name": "get_time", "arguments": {}}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_time");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_qwen_with_newlines_in_strings() {
|
||||
let parser = QwenParser::new();
|
||||
let input = r#"<tool_call>
|
||||
{"name": "write_file", "arguments": {"content": "Line 1\nLine 2\nLine 3", "path": "/tmp/test.txt"}}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["content"], "Line 1\nLine 2\nLine 3");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_qwen_format_detection() {
|
||||
let parser = QwenParser::new();
|
||||
|
||||
assert!(parser.detect_format("<tool_call>"));
|
||||
assert!(parser.detect_format("Some text <tool_call>\n{"));
|
||||
assert!(!parser.detect_format("Just plain text"));
|
||||
assert!(!parser.detect_format("{\"name\": \"test\"}")); // Plain JSON
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_qwen_incomplete_tags() {
|
||||
let parser = QwenParser::new();
|
||||
|
||||
// Missing closing tag
|
||||
let input = r#"<tool_call>
|
||||
{"name": "test", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0);
|
||||
|
||||
// Missing opening tag
|
||||
let input = r#"{"name": "test", "arguments": {}}
|
||||
</tool_call>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_qwen_real_world_output() {
|
||||
let parser = QwenParser::new();
|
||||
|
||||
// Actual output from Qwen model
|
||||
let input = r#"I'll help you search for information and perform calculations.
|
||||
|
||||
<tool_call>
|
||||
{
|
||||
"name": "web_search",
|
||||
"arguments": {
|
||||
"query": "quantum computing breakthroughs 2024",
|
||||
"language": "en",
|
||||
"region": "us",
|
||||
"safe_search": true
|
||||
}
|
||||
}
|
||||
</tool_call>
|
||||
|
||||
Let me also calculate something for you:
|
||||
|
||||
<tool_call>
|
||||
{
|
||||
"name": "calculator",
|
||||
"arguments": {
|
||||
"expression": "sqrt(144) + 3^2",
|
||||
"precision": 2
|
||||
}
|
||||
}
|
||||
</tool_call>
|
||||
|
||||
These tools will provide the information you need."#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "web_search");
|
||||
assert_eq!(result[1].function.name, "calculator");
|
||||
|
||||
let args0: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args0["query"], "quantum computing breakthroughs 2024");
|
||||
assert_eq!(args0["safe_search"], true);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_buffer_drain_optimization() {
|
||||
let parser = QwenParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// First chunk - incomplete tool call
|
||||
let chunk1 = "<tool_call>\n{\"name\": \"test1\", ";
|
||||
let _result = parser.parse_incremental(chunk1, &mut state).await.unwrap();
|
||||
// Phase 2 simplified streaming might not handle partial JSON correctly
|
||||
// The important thing is buffer accumulation works
|
||||
assert!(!state.buffer.is_empty());
|
||||
|
||||
// Complete first tool and start second
|
||||
let chunk2 = "\"arguments\": {}}\n</tool_call><tool_call>\n{\"name\": \"test2\", ";
|
||||
let result = parser.parse_incremental(chunk2, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "test1");
|
||||
// After consuming the first tool, buffer should contain only the second tool start
|
||||
assert!(state.buffer.starts_with("<tool_call>"));
|
||||
assert!(state.buffer.contains("test2"));
|
||||
}
|
||||
_ => {
|
||||
// Phase 2 simplified streaming might return Incomplete
|
||||
// The important thing is the buffer is managed correctly
|
||||
}
|
||||
}
|
||||
|
||||
// Complete the second tool
|
||||
let chunk3 = "\"arguments\": {\"x\": 1}}\n</tool_call>";
|
||||
let result = parser.parse_incremental(chunk3, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "test2");
|
||||
// Buffer should be empty after consuming all tools
|
||||
assert!(state.buffer.is_empty() || !state.buffer.contains("</tool_call>"));
|
||||
}
|
||||
_ => {
|
||||
// Phase 2 simplified streaming might handle this differently
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_buffer_efficiency_with_multiple_tools() {
|
||||
let parser = QwenParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Send multiple complete tools at once
|
||||
let input = r#"<tool_call>
|
||||
{"name": "tool1", "arguments": {"a": 1}}
|
||||
</tool_call><tool_call>
|
||||
{"name": "tool2", "arguments": {"b": 2}}
|
||||
</tool_call><tool_call>
|
||||
{"name": "tool3", "arguments": {"c": 3}}
|
||||
</tool_call>"#;
|
||||
|
||||
// This should efficiently process tools using drain() without creating new strings
|
||||
let result = parser.parse_incremental(input, &mut state).await.unwrap();
|
||||
|
||||
// In Phase 2, this will likely parse only the first tool
|
||||
// The important thing is that drain() doesn't cause any issues
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert!(["tool1", "tool2", "tool3"].contains(&tool.function.name.as_str()));
|
||||
}
|
||||
_ => {
|
||||
// Simplified streaming might return Incomplete
|
||||
}
|
||||
}
|
||||
|
||||
// Verify no memory issues or panics occurred with drain()
|
||||
// Test passes if we reach this point without panic
|
||||
}
|
||||
194
sgl-router/tests/tool_parser_registry.rs
Normal file
194
sgl-router/tests/tool_parser_registry.rs
Normal file
@@ -0,0 +1,194 @@
|
||||
//! Parser Registry Integration Tests
|
||||
//!
|
||||
//! Tests for model-to-parser mappings and registry functionality
|
||||
|
||||
use sglang_router_rs::tool_parser::ParserRegistry;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_registry_has_all_parsers() {
|
||||
let registry = ParserRegistry::new();
|
||||
let parsers = registry.list_parsers();
|
||||
|
||||
assert!(parsers.contains(&"json"));
|
||||
assert!(parsers.contains(&"mistral"));
|
||||
assert!(parsers.contains(&"qwen"));
|
||||
assert!(parsers.contains(&"pythonic"));
|
||||
assert!(parsers.contains(&"llama"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_openai_models_use_json() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
let models = vec!["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo", "gpt-4o"];
|
||||
for model in models {
|
||||
let parser = registry.get_parser(model).unwrap();
|
||||
let test_input = r#"{"name": "test", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(test_input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_anthropic_models_use_json() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
let models = vec!["claude-3-opus", "claude-3-sonnet", "claude-2.1"];
|
||||
for model in models {
|
||||
let parser = registry.get_parser(model).unwrap();
|
||||
let test_input = r#"{"name": "test", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(test_input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mistral_models() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
let models = vec!["mistral-large", "mistral-medium", "mixtral-8x7b"];
|
||||
for model in models {
|
||||
let parser = registry.get_parser(model).unwrap();
|
||||
let test_input = r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]"#;
|
||||
let result = parser.parse_complete(test_input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_qwen_models() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
let models = vec!["qwen2.5-72b", "Qwen2-7B", "qwen-max"];
|
||||
for model in models {
|
||||
let parser = registry.get_parser(model).unwrap();
|
||||
let test_input = r#"<tool_call>
|
||||
{"name": "test", "arguments": {}}
|
||||
</tool_call>"#;
|
||||
let result = parser.parse_complete(test_input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_model_variants() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
// Llama 4 uses pythonic
|
||||
let parser = registry.get_parser("llama-4-70b").unwrap();
|
||||
let test_input = r#"[get_weather(city="NYC")]"#;
|
||||
let result = parser.parse_complete(test_input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
|
||||
// Llama 3.2 uses python_tag
|
||||
let parser = registry.get_parser("llama-3.2-8b").unwrap();
|
||||
let test_input = r#"<|python_tag|>{"name": "test", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(test_input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
|
||||
// Other Llama models use JSON
|
||||
let parser = registry.get_parser("llama-2-70b").unwrap();
|
||||
let test_input = r#"{"name": "test", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(test_input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_deepseek_models() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
// DeepSeek uses pythonic format (simplified, v3 would need custom parser)
|
||||
let parser = registry.get_parser("deepseek-coder").unwrap();
|
||||
let test_input = r#"[function(arg="value")]"#;
|
||||
let result = parser.parse_complete(test_input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "function");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_unknown_model_fallback() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
// Unknown models should fall back to JSON parser
|
||||
let parser = registry.get_parser("unknown-model-xyz").unwrap();
|
||||
let test_input = r#"{"name": "fallback", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(test_input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "fallback");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pattern_specificity() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
// Test that more specific patterns take precedence
|
||||
// llama-4* should match before llama-*
|
||||
let parser = registry.get_parser("llama-4-70b").unwrap();
|
||||
assert!(parser.detect_format(r#"[test_function(x=1)]"#)); // Pythonic format
|
||||
|
||||
let parser = registry.get_parser("llama-3-70b").unwrap();
|
||||
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#)); // JSON format
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_real_world_model_outputs() {
|
||||
let registry = ParserRegistry::new();
|
||||
|
||||
// Test with realistic outputs from different models
|
||||
let test_cases = vec![
|
||||
(
|
||||
"gpt-4",
|
||||
r#"I'll help you with that.
|
||||
|
||||
{"name": "search_web", "arguments": {"query": "latest AI news", "max_results": 5}}
|
||||
|
||||
Let me search for that information."#,
|
||||
"search_web",
|
||||
),
|
||||
(
|
||||
"mistral-large",
|
||||
r#"Let me search for information about Rust.
|
||||
|
||||
[TOOL_CALLS] [
|
||||
{"name": "search", "arguments": {"query": "Rust programming"}},
|
||||
{"name": "get_weather", "arguments": {"city": "San Francisco"}}
|
||||
]
|
||||
|
||||
I've initiated the search."#,
|
||||
"search",
|
||||
),
|
||||
(
|
||||
"qwen2.5",
|
||||
r#"I'll check the weather for you.
|
||||
|
||||
<tool_call>
|
||||
{
|
||||
"name": "get_weather",
|
||||
"arguments": {
|
||||
"location": "Tokyo",
|
||||
"units": "celsius"
|
||||
}
|
||||
}
|
||||
</tool_call>
|
||||
|
||||
The weather information has been requested."#,
|
||||
"get_weather",
|
||||
),
|
||||
];
|
||||
|
||||
for (model, output, expected_name) in test_cases {
|
||||
let parser = registry.get_parser(model).unwrap();
|
||||
let result = parser.parse_complete(output).await.unwrap();
|
||||
assert!(!result.is_empty(), "No tools parsed for model {}", model);
|
||||
assert_eq!(
|
||||
result[0].function.name, expected_name,
|
||||
"Wrong function name for model {}",
|
||||
model
|
||||
);
|
||||
}
|
||||
}
|
||||
245
sgl-router/tests/tool_parser_step3.rs
Normal file
245
sgl-router/tests/tool_parser_step3.rs
Normal file
@@ -0,0 +1,245 @@
|
||||
//! Step3 Parser Integration Tests
|
||||
|
||||
use sglang_router_rs::tool_parser::{ParseState, Step3Parser, StreamResult, ToolParser};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_step3_complete_parsing() {
|
||||
let parser = Step3Parser::new();
|
||||
|
||||
// Test single tool call
|
||||
let input = r#"Let me help you.
|
||||
<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="search">
|
||||
<steptml:parameter name="query">rust programming</steptml:parameter>
|
||||
<steptml:parameter name="limit">10</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>
|
||||
Here are the results..."#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
|
||||
// Verify arguments
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["query"], "rust programming");
|
||||
assert_eq!(args["limit"], 10);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_step3_multiple_tools() {
|
||||
let parser = Step3Parser::new();
|
||||
|
||||
let input = r#"<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="get_weather">
|
||||
<steptml:parameter name="location">Tokyo</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="get_news">
|
||||
<steptml:parameter name="category">tech</steptml:parameter>
|
||||
<steptml:parameter name="limit">5</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "get_weather");
|
||||
assert_eq!(result[1].function.name, "get_news");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_step3_type_conversion() {
|
||||
let parser = Step3Parser::new();
|
||||
|
||||
let input = r#"<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="process">
|
||||
<steptml:parameter name="count">100</steptml:parameter>
|
||||
<steptml:parameter name="rate">2.5</steptml:parameter>
|
||||
<steptml:parameter name="active">true</steptml:parameter>
|
||||
<steptml:parameter name="optional">null</steptml:parameter>
|
||||
<steptml:parameter name="text">hello world</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["count"], 100);
|
||||
assert_eq!(args["rate"], 2.5);
|
||||
assert_eq!(args["active"], true);
|
||||
assert_eq!(args["optional"], serde_json::Value::Null);
|
||||
assert_eq!(args["text"], "hello world");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_step3_streaming() {
|
||||
let parser = Step3Parser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Simulate streaming chunks
|
||||
let chunks = vec![
|
||||
"<|tool_calls_begin|>\n",
|
||||
"<|tool_call_begin|>function",
|
||||
"<|tool_sep|><steptml:invoke name=\"calc\">",
|
||||
"\n<steptml:parameter name=\"x\">10</steptml:parameter>",
|
||||
"\n<steptml:parameter name=\"y\">20</steptml:parameter>",
|
||||
"\n</steptml:invoke><|tool_call_end|>",
|
||||
"\n<|tool_calls_end|>",
|
||||
];
|
||||
|
||||
let mut found_name = false;
|
||||
let mut found_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolName { name, .. } => {
|
||||
assert_eq!(name, "calc");
|
||||
found_name = true;
|
||||
}
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "calc");
|
||||
found_complete = true;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
assert!(found_name || found_complete);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_step3_format_detection() {
|
||||
let parser = Step3Parser::new();
|
||||
|
||||
// Should detect Step3 format
|
||||
assert!(parser.detect_format("<|tool_calls_begin|>"));
|
||||
assert!(parser.detect_format("text with <|tool_calls_begin|> marker"));
|
||||
|
||||
// Should not detect other formats
|
||||
assert!(!parser.detect_format("[TOOL_CALLS]"));
|
||||
assert!(!parser.detect_format("<tool_call>"));
|
||||
assert!(!parser.detect_format("plain text"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_step3_nested_steptml() {
|
||||
let parser = Step3Parser::new();
|
||||
|
||||
// Test with complex parameter values
|
||||
let input = r#"<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="config">
|
||||
<steptml:parameter name="settings">{"nested": {"key": "value"}}</steptml:parameter>
|
||||
<steptml:parameter name="array">[1, 2, 3]</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "config");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert!(args["settings"].is_object());
|
||||
assert!(args["array"].is_array());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_step3_python_literals() {
|
||||
let parser = Step3Parser::new();
|
||||
|
||||
// Test Python-style literals
|
||||
let input = r#"<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="test">
|
||||
<steptml:parameter name="bool_true">True</steptml:parameter>
|
||||
<steptml:parameter name="bool_false">False</steptml:parameter>
|
||||
<steptml:parameter name="none_value">None</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["bool_true"], true);
|
||||
assert_eq!(args["bool_false"], false);
|
||||
assert_eq!(args["none_value"], serde_json::Value::Null);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_steptml_format() {
|
||||
let parser = Step3Parser::new();
|
||||
|
||||
let input = r#"Text before.
|
||||
<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="search">
|
||||
<steptml:parameter name="query">rust lang</steptml:parameter>
|
||||
<steptml:parameter name="limit">10</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>Text after."#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["query"], "rust lang");
|
||||
assert_eq!(args["limit"], 10);
|
||||
// TODO: Verify normal text extraction
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_parameter_values() {
|
||||
let parser = Step3Parser::new();
|
||||
|
||||
let input = r#"<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="config">
|
||||
<steptml:parameter name="settings">{"nested": {"value": true}}</steptml:parameter>
|
||||
<steptml:parameter name="items">[1, 2, 3]</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert!(args["settings"].is_object());
|
||||
assert!(args["items"].is_array());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_step3_parameter_with_angle_brackets() {
|
||||
let parser = Step3Parser::new();
|
||||
|
||||
// Test parameter value containing < character
|
||||
let input = r#"<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="compare">
|
||||
<steptml:parameter name="expression">a < b && b > c</steptml:parameter>
|
||||
<steptml:parameter name="context">comparison test</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "compare");
|
||||
|
||||
// Verify the parameter value was parsed correctly
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["expression"], "a < b && b > c");
|
||||
assert_eq!(args["context"], "comparison test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_step3_empty_function_name() {
|
||||
let parser = Step3Parser::new();
|
||||
|
||||
// Test empty function name
|
||||
let input = r#"<|tool_calls_begin|>
|
||||
<|tool_call_begin|>function<|tool_sep|><steptml:invoke name="">
|
||||
<steptml:parameter name="param">value</steptml:parameter>
|
||||
</steptml:invoke><|tool_call_end|>
|
||||
<|tool_calls_end|>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0); // Should reject empty function name
|
||||
}
|
||||
341
sgl-router/tests/tool_parser_streaming.rs
Normal file
341
sgl-router/tests/tool_parser_streaming.rs
Normal file
@@ -0,0 +1,341 @@
|
||||
//! Streaming Parser Tests
|
||||
//!
|
||||
//! Tests for incremental/streaming parsing capabilities across all parsers
|
||||
|
||||
use sglang_router_rs::tool_parser::{
|
||||
JsonParser, LlamaParser, MistralParser, ParseState, PythonicParser, QwenParser, StreamResult,
|
||||
ToolParser,
|
||||
};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_streaming_simple() {
|
||||
let parser = JsonParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Phase 2 note: This test sends the full JSON at once in the last chunk
|
||||
// In real streaming, chunks would be smaller
|
||||
let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
|
||||
|
||||
let result = parser
|
||||
.parse_incremental(full_json, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// With complete JSON sent at once, we should get ToolComplete
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
}
|
||||
_ => {
|
||||
panic!("Expected ToolComplete for complete JSON input");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_streaming_array() {
|
||||
let parser = JsonParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Stream a JSON array of tools
|
||||
let chunks = vec![
|
||||
r#"["#,
|
||||
r#"{"name": "tool1", "#,
|
||||
r#""arguments": {}}, "#,
|
||||
r#"{"name": "tool2", "#,
|
||||
r#""arguments": {"x": 1"#,
|
||||
r#"}}]"#,
|
||||
];
|
||||
|
||||
let mut tool_count = 0;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
if let StreamResult::ToolComplete(_) = result {
|
||||
tool_count += 1;
|
||||
}
|
||||
}
|
||||
|
||||
// Current implementation may handle this differently
|
||||
// We're mainly testing that it doesn't crash
|
||||
assert!(tool_count <= 2, "Should parse at most 2 tools");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_mistral_streaming() {
|
||||
let parser = MistralParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
let chunks = vec![
|
||||
r#"Here is the result: "#,
|
||||
r#"[TOOL_CALLS] ["#,
|
||||
r#"{"name": "#,
|
||||
r#""search", "#,
|
||||
r#""arguments": "#,
|
||||
r#"{"query": "#,
|
||||
r#""rust lang""#,
|
||||
r#"}}]"#,
|
||||
];
|
||||
|
||||
let mut got_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
if let StreamResult::ToolComplete(tool) = result {
|
||||
assert_eq!(tool.function.name, "search");
|
||||
got_complete = true;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(got_complete, "Should have completed parsing");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_pythonic_streaming() {
|
||||
let parser = PythonicParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Send complete pythonic format at once
|
||||
let full_input = r#"[get_weather(city="London", units="celsius")]"#;
|
||||
|
||||
let result = parser
|
||||
.parse_incremental(full_input, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "get_weather");
|
||||
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert_eq!(args["city"], "London");
|
||||
}
|
||||
_ => {
|
||||
panic!("Expected ToolComplete for complete pythonic input");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_llama_streaming_with_python_tag() {
|
||||
let parser = LlamaParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
let chunks = vec![
|
||||
r#"Let me help. "#,
|
||||
r#"<|python"#,
|
||||
r#"_tag|>"#,
|
||||
r#"{"name": "#,
|
||||
r#""calculate", "#,
|
||||
r#""arguments": "#,
|
||||
r#"{"x": 10}"#,
|
||||
r#"}"#,
|
||||
];
|
||||
|
||||
let mut got_complete = false;
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
if let StreamResult::ToolComplete(tool) = result {
|
||||
assert_eq!(tool.function.name, "calculate");
|
||||
got_complete = true;
|
||||
}
|
||||
}
|
||||
|
||||
assert!(got_complete, "Should have completed parsing");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_qwen_streaming() {
|
||||
let parser = QwenParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Send complete Qwen format at once (with exact format expected by parser)
|
||||
// Note: Parser expects newline after both tags
|
||||
let full_input = "<tool_call>\n{\"name\": \"translate\", \"arguments\": {\"text\": \"hello\", \"to\": \"zh\"}}\n</tool_call>";
|
||||
|
||||
let result = parser
|
||||
.parse_incremental(full_input, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "translate");
|
||||
}
|
||||
other => {
|
||||
panic!(
|
||||
"Expected ToolComplete for complete Qwen input, got: {:?}",
|
||||
other
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_incomplete_stays_incomplete() {
|
||||
let parser = JsonParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Send truly incomplete JSON that can't be auto-completed
|
||||
let chunks = vec![r#"{"na"#, r#"me": "#];
|
||||
|
||||
for chunk in chunks {
|
||||
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
|
||||
// Should return Incomplete for partial JSON that can't be auto-completed
|
||||
assert!(
|
||||
matches!(result, StreamResult::Incomplete),
|
||||
"Should return Incomplete for partial JSON, got: {:?}",
|
||||
result
|
||||
);
|
||||
}
|
||||
|
||||
// Buffer should contain the accumulated incomplete JSON
|
||||
assert!(!state.buffer.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_with_text_before_tool() {
|
||||
let parser = JsonParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// For streaming, the parser expects clean JSON
|
||||
// Mixed text extraction only works in parse_complete, not parse_incremental
|
||||
let full_input = r#"{"name": "test", "arguments": {}}"#;
|
||||
|
||||
let result = parser
|
||||
.parse_incremental(full_input, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "test");
|
||||
}
|
||||
other => {
|
||||
panic!("Expected ToolComplete, got: {:?}", other);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_buffer_accumulation() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// Test: Complete JSON should clear buffer after parsing
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Send partial JSON that can't be interpreted as complete
|
||||
let result1 = parser
|
||||
.parse_incremental(r#"{"na"#, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(matches!(result1, StreamResult::Incomplete));
|
||||
assert!(
|
||||
!state.buffer.is_empty(),
|
||||
"Buffer should accumulate incomplete JSON"
|
||||
);
|
||||
|
||||
// Send rest of JSON
|
||||
let result2 = parser
|
||||
.parse_incremental(r#"me": "test", "arguments": {}}"#, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match result2 {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "test");
|
||||
assert!(
|
||||
state.buffer.is_empty(),
|
||||
"Buffer should be cleared after complete parse"
|
||||
);
|
||||
}
|
||||
_ => panic!(
|
||||
"Expected ToolComplete for complete JSON, got: {:?}",
|
||||
result2
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_multiple_tools_sequential() {
|
||||
let parser = QwenParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Send complete Qwen format with newlines
|
||||
let full_input = r#"<tool_call>
|
||||
{"name": "tool1", "arguments": {}}
|
||||
</tool_call>"#;
|
||||
|
||||
let result = parser
|
||||
.parse_incremental(full_input, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "tool1");
|
||||
}
|
||||
_ => {
|
||||
panic!("Expected ToolComplete for first tool");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_reset_after_error() {
|
||||
let parser = JsonParser::new();
|
||||
|
||||
// First attempt with invalid JSON
|
||||
let mut state1 = ParseState::new();
|
||||
let _ = parser
|
||||
.parse_incremental(r#"{"name": invalid}"#, &mut state1)
|
||||
.await;
|
||||
|
||||
// Second attempt with valid JSON should work with fresh state
|
||||
let mut state2 = ParseState::new();
|
||||
let result = parser
|
||||
.parse_incremental(r#"{"name": "test", "arguments": {}}"#, &mut state2)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
if let StreamResult::ToolComplete(tool) = result {
|
||||
assert_eq!(tool.function.name, "test");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_streaming_with_unicode_chunks() {
|
||||
let parser = JsonParser::new();
|
||||
let mut state = ParseState::new();
|
||||
|
||||
// Send complete JSON with unicode
|
||||
let full_input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍"}}"#;
|
||||
|
||||
let result = parser
|
||||
.parse_incremental(full_input, &mut state)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Phase 2 may return partial results even with complete JSON
|
||||
// The important thing is that unicode is handled without crashes
|
||||
match result {
|
||||
StreamResult::ToolComplete(tool) => {
|
||||
assert_eq!(tool.function.name, "translate");
|
||||
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
|
||||
assert!(args["text"].as_str().unwrap().contains("世界"));
|
||||
}
|
||||
StreamResult::ToolName { name, .. } => {
|
||||
assert_eq!(name, "translate");
|
||||
// Phase 2 partial streaming behavior - acceptable
|
||||
}
|
||||
StreamResult::ToolArguments { arguments, .. } => {
|
||||
// Verify unicode was preserved
|
||||
let args: serde_json::Value = serde_json::from_str(&arguments).unwrap();
|
||||
assert!(args["text"].as_str().unwrap().contains("世界"));
|
||||
}
|
||||
other => {
|
||||
panic!("Unexpected result: {:?}", other);
|
||||
}
|
||||
}
|
||||
}
|
||||
247
sgl-router/tests/tool_parser_wrapper_tokens.rs
Normal file
247
sgl-router/tests/tool_parser_wrapper_tokens.rs
Normal file
@@ -0,0 +1,247 @@
|
||||
//! Wrapper Token Tests
|
||||
//!
|
||||
//! Tests for JSON parser with custom wrapper tokens
|
||||
|
||||
use sglang_router_rs::tool_parser::{JsonParser, TokenConfig, ToolParser};
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_with_xml_style_wrapper() {
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<tool>".to_string()],
|
||||
end_tokens: vec!["</tool>".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
let input =
|
||||
r#"Some text before <tool>{"name": "test", "arguments": {"x": 1}}</tool> and after"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["x"], 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_with_multiple_wrapper_pairs() {
|
||||
// Test with multiple start/end token pairs
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<tool>".to_string(), "<<TOOL>>".to_string()],
|
||||
end_tokens: vec!["</tool>".to_string(), "<</TOOL>>".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
// Test first pair
|
||||
let input1 = r#"<tool>{"name": "tool1", "arguments": {}}</tool>"#;
|
||||
let result1 = parser.parse_complete(input1).await.unwrap();
|
||||
assert_eq!(result1.len(), 1);
|
||||
assert_eq!(result1[0].function.name, "tool1");
|
||||
|
||||
// Test second pair
|
||||
let input2 = r#"<<TOOL>>{"name": "tool2", "arguments": {}}<</TOOL>>"#;
|
||||
let result2 = parser.parse_complete(input2).await.unwrap();
|
||||
assert_eq!(result2.len(), 1);
|
||||
assert_eq!(result2[0].function.name, "tool2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_with_only_start_token() {
|
||||
// Test when only start token is provided (no end token)
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec![">>>FUNCTION:".to_string()],
|
||||
end_tokens: vec!["".to_string()], // Empty end token
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
let input = r#"Some preamble >>>FUNCTION:{"name": "execute", "arguments": {"cmd": "ls"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "execute");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_with_custom_separator() {
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["[FUNC]".to_string()],
|
||||
end_tokens: vec!["[/FUNC]".to_string()],
|
||||
separator: " | ".to_string(), // Custom separator
|
||||
});
|
||||
|
||||
// Though we're not testing multiple tools here, the separator is configured
|
||||
let input = r#"[FUNC]{"name": "test", "arguments": {}}[/FUNC]"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_with_nested_wrapper_tokens_in_content() {
|
||||
// Known limitation: When wrapper tokens appear inside JSON strings,
|
||||
// the simple regex-based extraction may fail. This would require
|
||||
// a more sophisticated parser that understands JSON string escaping.
|
||||
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<call>".to_string()],
|
||||
end_tokens: vec!["</call>".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
let input =
|
||||
r#"<call>{"name": "echo", "arguments": {"text": "Use <call> and </call> tags"}}</call>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
|
||||
// This is a known limitation - the parser may fail when end tokens appear in content
|
||||
// For now, we accept this behavior
|
||||
if result.is_empty() {
|
||||
// Parser failed due to nested tokens - this is expected
|
||||
assert_eq!(
|
||||
result.len(),
|
||||
0,
|
||||
"Known limitation: nested wrapper tokens in content"
|
||||
);
|
||||
} else {
|
||||
// If it does parse, verify it's correct
|
||||
assert_eq!(result[0].function.name, "echo");
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["text"], "Use <call> and </call> tags");
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_extraction_without_wrapper_tokens() {
|
||||
// Default parser without wrapper tokens should extract JSON from text
|
||||
let parser = JsonParser::new();
|
||||
|
||||
let input = r#"
|
||||
Here is some text before the JSON.
|
||||
{"name": "search", "arguments": {"query": "test"}}
|
||||
And here is some text after.
|
||||
"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "search");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_with_multiline_wrapper_content() {
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["```json\n".to_string()],
|
||||
end_tokens: vec!["\n```".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
let input = r#"Here's the function call:
|
||||
```json
|
||||
{
|
||||
"name": "format_code",
|
||||
"arguments": {
|
||||
"language": "rust",
|
||||
"code": "fn main() {}"
|
||||
}
|
||||
}
|
||||
```
|
||||
Done!"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "format_code");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_with_special_chars_in_tokens() {
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["{{FUNC[[".to_string()],
|
||||
end_tokens: vec!["]]FUNC}}".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
let input = r#"{{FUNC[[{"name": "test", "arguments": {"special": "[]{}"}}]]FUNC}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
|
||||
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
|
||||
assert_eq!(args["special"], "[]{}");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_multiple_tools_with_wrapper() {
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<fn>".to_string()],
|
||||
end_tokens: vec!["</fn>".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
// Multiple wrapped JSON objects
|
||||
let input = r#"
|
||||
<fn>{"name": "tool1", "arguments": {}}</fn>
|
||||
Some text between.
|
||||
<fn>{"name": "tool2", "arguments": {"x": 1}}</fn>
|
||||
"#;
|
||||
|
||||
// Current implementation might handle this as separate calls
|
||||
// Let's test that at least the first one is parsed
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert!(!result.is_empty(), "Should parse at least one tool");
|
||||
assert_eq!(result[0].function.name, "tool1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_wrapper_with_array() {
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<tools>".to_string()],
|
||||
end_tokens: vec!["</tools>".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
let input = r#"<tools>[
|
||||
{"name": "func1", "arguments": {}},
|
||||
{"name": "func2", "arguments": {"param": "value"}}
|
||||
]</tools>"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 2);
|
||||
assert_eq!(result[0].function.name, "func1");
|
||||
assert_eq!(result[1].function.name, "func2");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_incomplete_wrapper_tokens() {
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec!["<tool>".to_string()],
|
||||
end_tokens: vec!["</tool>".to_string()],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
// Missing end token
|
||||
let input = r#"<tool>{"name": "test", "arguments": {}}"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0, "Should not parse without closing token");
|
||||
|
||||
// Missing start token
|
||||
let input = r#"{"name": "test", "arguments": {}}</tool>"#;
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 0, "Should not parse without opening token");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_json_empty_wrapper_tokens() {
|
||||
// Test with empty wrapper tokens (should behave like default)
|
||||
let parser = JsonParser::with_config(TokenConfig {
|
||||
start_tokens: vec![],
|
||||
end_tokens: vec![],
|
||||
separator: ", ".to_string(),
|
||||
});
|
||||
|
||||
let input = r#"{"name": "test", "arguments": {"key": "value"}}"#;
|
||||
|
||||
let result = parser.parse_complete(input).await.unwrap();
|
||||
assert_eq!(result.len(), 1);
|
||||
assert_eq!(result[0].function.name, "test");
|
||||
}
|
||||
Reference in New Issue
Block a user