sglangv0.5.2 & support Qwen3-Next-80B-A3B-Instruct

This commit is contained in:
maxiao1
2025-09-13 17:00:20 +08:00
commit 118f1fc726
2037 changed files with 515371 additions and 0 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,177 @@
// tests/common/mock_mcp_server.rs - Mock MCP server for testing
use rmcp::{
handler::server::{router::tool::ToolRouter, wrapper::Parameters},
model::*,
service::RequestContext,
tool, tool_handler, tool_router,
transport::streamable_http_server::{
session::local::LocalSessionManager, StreamableHttpService,
},
ErrorData as McpError, RoleServer, ServerHandler,
};
use tokio::net::TcpListener;
/// Mock MCP server that returns hardcoded responses for testing
pub struct MockMCPServer {
pub port: u16,
pub server_handle: Option<tokio::task::JoinHandle<()>>,
}
/// Simple test server with mock search tools
#[derive(Clone)]
pub struct MockSearchServer {
tool_router: ToolRouter<MockSearchServer>,
}
#[tool_router]
impl MockSearchServer {
pub fn new() -> Self {
Self {
tool_router: Self::tool_router(),
}
}
#[tool(description = "Mock web search tool")]
fn brave_web_search(
&self,
Parameters(params): Parameters<serde_json::Map<String, serde_json::Value>>,
) -> Result<CallToolResult, McpError> {
let query = params
.get("query")
.and_then(|v| v.as_str())
.unwrap_or("test");
Ok(CallToolResult::success(vec![Content::text(format!(
"Mock search results for: {}",
query
))]))
}
#[tool(description = "Mock local search tool")]
fn brave_local_search(
&self,
Parameters(_params): Parameters<serde_json::Map<String, serde_json::Value>>,
) -> Result<CallToolResult, McpError> {
Ok(CallToolResult::success(vec![Content::text(
"Mock local search results",
)]))
}
}
#[tool_handler]
impl ServerHandler for MockSearchServer {
fn get_info(&self) -> ServerInfo {
ServerInfo {
protocol_version: ProtocolVersion::V_2024_11_05,
capabilities: ServerCapabilities::builder().enable_tools().build(),
server_info: Implementation::from_build_env(),
instructions: Some("Mock server for testing".to_string()),
}
}
async fn initialize(
&self,
_request: InitializeRequestParam,
_context: RequestContext<RoleServer>,
) -> Result<InitializeResult, McpError> {
Ok(self.get_info())
}
}
impl MockMCPServer {
/// Start a mock MCP server on an available port
pub async fn start() -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
// Find an available port
let listener = TcpListener::bind("127.0.0.1:0").await?;
let port = listener.local_addr()?.port();
// Create the MCP service using rmcp's StreamableHttpService
let service = StreamableHttpService::new(
|| Ok(MockSearchServer::new()),
LocalSessionManager::default().into(),
Default::default(),
);
let app = axum::Router::new().nest_service("/mcp", service);
let server_handle = tokio::spawn(async move {
axum::serve(listener, app)
.await
.expect("Mock MCP server failed to start");
});
// Give the server a moment to start
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
Ok(MockMCPServer {
port,
server_handle: Some(server_handle),
})
}
/// Get the full URL for this mock server
pub fn url(&self) -> String {
format!("http://127.0.0.1:{}/mcp", self.port)
}
/// Stop the mock server
pub async fn stop(&mut self) {
if let Some(handle) = self.server_handle.take() {
handle.abort();
// Wait a moment for cleanup
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
}
}
}
impl Drop for MockMCPServer {
fn drop(&mut self) {
if let Some(handle) = self.server_handle.take() {
handle.abort();
}
}
}
#[cfg(test)]
mod tests {
#[allow(unused_imports)]
use super::MockMCPServer;
#[tokio::test]
async fn test_mock_server_startup() {
let mut server = MockMCPServer::start().await.unwrap();
assert!(server.port > 0);
assert!(server.url().contains(&server.port.to_string()));
server.stop().await;
}
#[tokio::test]
async fn test_mock_server_with_rmcp_client() {
let mut server = MockMCPServer::start().await.unwrap();
// Test that we can connect with rmcp client
use rmcp::transport::StreamableHttpClientTransport;
use rmcp::ServiceExt;
let transport = StreamableHttpClientTransport::from_uri(server.url().as_str());
let client = ().serve(transport).await;
assert!(client.is_ok(), "Should be able to connect to mock server");
if let Ok(client) = client {
// Test listing tools
let tools = client.peer().list_all_tools().await;
assert!(tools.is_ok(), "Should be able to list tools");
if let Ok(tools) = tools {
assert_eq!(tools.len(), 2, "Should have 2 tools");
assert!(tools.iter().any(|t| t.name == "brave_web_search"));
assert!(tools.iter().any(|t| t.name == "brave_local_search"));
}
// Shutdown by dropping the client
drop(client);
}
server.stop().await;
}
}

View File

@@ -0,0 +1,238 @@
//! Mock servers for testing
#![allow(dead_code)]
use axum::{
body::Body,
extract::{Request, State},
http::{HeaderValue, StatusCode},
response::sse::{Event, KeepAlive},
response::{IntoResponse, Response, Sse},
routing::post,
Json, Router,
};
use futures_util::stream::{self, StreamExt};
use serde_json::json;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::TcpListener;
/// Mock OpenAI API server for testing
pub struct MockOpenAIServer {
addr: SocketAddr,
_handle: tokio::task::JoinHandle<()>,
}
#[derive(Clone)]
struct MockServerState {
require_auth: bool,
expected_auth: Option<String>,
}
impl MockOpenAIServer {
/// Create and start a new mock OpenAI server
pub async fn new() -> Self {
Self::new_with_auth(None).await
}
/// Create and start a new mock OpenAI server with optional auth requirement
pub async fn new_with_auth(expected_auth: Option<String>) -> Self {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let state = Arc::new(MockServerState {
require_auth: expected_auth.is_some(),
expected_auth,
});
let app = Router::new()
.route("/v1/chat/completions", post(mock_chat_completions))
.route("/v1/completions", post(mock_completions))
.route("/v1/models", post(mock_models).get(mock_models))
.with_state(state);
let handle = tokio::spawn(async move {
axum::serve(listener, app).await.unwrap();
});
// Give the server a moment to start
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
Self {
addr,
_handle: handle,
}
}
/// Get the base URL for this mock server
pub fn base_url(&self) -> String {
format!("http://{}", self.addr)
}
}
/// Mock chat completions endpoint
async fn mock_chat_completions(req: Request<Body>) -> Response {
let (_, body) = req.into_parts();
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes,
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
};
let request: serde_json::Value = match serde_json::from_slice(&body_bytes) {
Ok(req) => req,
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
};
// Extract model from request or use default (owned String to satisfy 'static in stream)
let model: String = request
.get("model")
.and_then(|v| v.as_str())
.unwrap_or("gpt-3.5-turbo")
.to_string();
// If stream requested, return SSE
let is_stream = request
.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);
if is_stream {
let created = 1677652288u64;
// Single chunk then [DONE]
let model_chunk = model.clone();
let event_stream = stream::once(async move {
let chunk = json!({
"id": "chatcmpl-123456789",
"object": "chat.completion.chunk",
"created": created,
"model": model_chunk,
"choices": [{
"index": 0,
"delta": {
"content": "Hello!"
},
"finish_reason": null
}]
});
Ok::<_, std::convert::Infallible>(Event::default().data(chunk.to_string()))
})
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
Sse::new(event_stream)
.keep_alive(KeepAlive::default())
.into_response()
} else {
// Create a mock non-streaming response
let response = json!({
"id": "chatcmpl-123456789",
"object": "chat.completion",
"created": 1677652288,
"model": model,
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! I'm a mock OpenAI assistant. How can I help you today?"
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 9,
"completion_tokens": 12,
"total_tokens": 21
}
});
Json(response).into_response()
}
}
/// Mock completions endpoint (legacy)
async fn mock_completions(req: Request<Body>) -> Response {
let (_, body) = req.into_parts();
let body_bytes = match axum::body::to_bytes(body, usize::MAX).await {
Ok(bytes) => bytes,
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
};
let request: serde_json::Value = match serde_json::from_slice(&body_bytes) {
Ok(req) => req,
Err(_) => return StatusCode::BAD_REQUEST.into_response(),
};
let model = request["model"].as_str().unwrap_or("text-davinci-003");
let response = json!({
"id": "cmpl-123456789",
"object": "text_completion",
"created": 1677652288,
"model": model,
"choices": [{
"text": " This is a mock completion response.",
"index": 0,
"logprobs": null,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 5,
"completion_tokens": 7,
"total_tokens": 12
}
});
Json(response).into_response()
}
/// Mock models endpoint
async fn mock_models(State(state): State<Arc<MockServerState>>, req: Request<Body>) -> Response {
// Optionally enforce Authorization header
if state.require_auth {
let auth = req
.headers()
.get("authorization")
.or_else(|| req.headers().get("Authorization"))
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
let auth_ok = match (&state.expected_auth, auth) {
(Some(expected), Some(got)) => &got == expected,
(None, Some(_)) => true,
_ => false,
};
if !auth_ok {
let mut response = Response::new(Body::from(
json!({
"error": {
"message": "Unauthorized",
"type": "invalid_request_error"
}
})
.to_string(),
));
*response.status_mut() = StatusCode::UNAUTHORIZED;
response
.headers_mut()
.insert("WWW-Authenticate", HeaderValue::from_static("Bearer"));
return response;
}
}
let response = json!({
"object": "list",
"data": [
{
"id": "gpt-4",
"object": "model",
"created": 1677610602,
"owned_by": "openai"
},
{
"id": "gpt-3.5-turbo",
"object": "model",
"created": 1677610602,
"owned_by": "openai"
}
]
});
Json(response).into_response()
}

View File

@@ -0,0 +1,614 @@
// Mock worker for testing - these functions are used by integration tests
#![allow(dead_code)]
use axum::{
extract::{Json, State},
http::StatusCode,
response::sse::{Event, KeepAlive},
response::{IntoResponse, Response, Sse},
routing::{get, post},
Router,
};
use futures_util::stream::{self, StreamExt};
use serde_json::json;
use std::convert::Infallible;
use std::sync::Arc;
use std::time::{SystemTime, UNIX_EPOCH};
use tokio::sync::RwLock;
use uuid::Uuid;
/// Configuration for mock worker behavior
#[derive(Clone)]
pub struct MockWorkerConfig {
pub port: u16,
pub worker_type: WorkerType,
pub health_status: HealthStatus,
pub response_delay_ms: u64,
pub fail_rate: f32,
}
#[derive(Clone, Debug)]
pub enum WorkerType {
Regular,
Prefill,
Decode,
}
#[derive(Clone, Debug)]
pub enum HealthStatus {
Healthy,
Unhealthy,
Degraded,
}
/// Mock worker server for testing
pub struct MockWorker {
config: Arc<RwLock<MockWorkerConfig>>,
shutdown_handle: Option<tokio::task::JoinHandle<()>>,
shutdown_tx: Option<tokio::sync::oneshot::Sender<()>>,
}
impl MockWorker {
pub fn new(config: MockWorkerConfig) -> Self {
Self {
config: Arc::new(RwLock::new(config)),
shutdown_handle: None,
shutdown_tx: None,
}
}
/// Start the mock worker server
pub async fn start(&mut self) -> Result<String, Box<dyn std::error::Error>> {
let config = self.config.clone();
let port = config.read().await.port;
// If port is 0, find an available port
let port = if port == 0 {
let listener = std::net::TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
drop(listener);
config.write().await.port = port;
port
} else {
port
};
let app = Router::new()
.route("/health", get(health_handler))
.route("/health_generate", get(health_generate_handler))
.route("/get_server_info", get(server_info_handler))
.route("/get_model_info", get(model_info_handler))
.route("/generate", post(generate_handler))
.route("/v1/chat/completions", post(chat_completions_handler))
.route("/v1/completions", post(completions_handler))
.route("/flush_cache", post(flush_cache_handler))
.route("/v1/models", get(v1_models_handler))
.with_state(config);
let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel::<()>();
self.shutdown_tx = Some(shutdown_tx);
// Spawn the server in a separate task
let handle = tokio::spawn(async move {
let listener = match tokio::net::TcpListener::bind(("127.0.0.1", port)).await {
Ok(l) => l,
Err(e) => {
eprintln!("Failed to bind to port {}: {}", port, e);
return;
}
};
let server = axum::serve(listener, app).with_graceful_shutdown(async move {
let _ = shutdown_rx.await;
});
if let Err(e) = server.await {
eprintln!("Server error: {}", e);
}
});
self.shutdown_handle = Some(handle);
// Wait for the server to start
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
let url = format!("http://127.0.0.1:{}", port);
Ok(url)
}
/// Stop the mock worker server
pub async fn stop(&mut self) {
if let Some(shutdown_tx) = self.shutdown_tx.take() {
let _ = shutdown_tx.send(());
}
if let Some(handle) = self.shutdown_handle.take() {
// Wait for the server to shut down
let _ = tokio::time::timeout(tokio::time::Duration::from_secs(5), handle).await;
}
}
}
impl Drop for MockWorker {
fn drop(&mut self) {
// Clean shutdown when dropped
if let Some(shutdown_tx) = self.shutdown_tx.take() {
let _ = shutdown_tx.send(());
}
}
}
// Handler implementations
/// Check if request should fail based on configured fail_rate
async fn should_fail(config: &MockWorkerConfig) -> bool {
rand::random::<f32>() < config.fail_rate
}
async fn health_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
let config = config.read().await;
match config.health_status {
HealthStatus::Healthy => Json(json!({
"status": "healthy",
"timestamp": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(),
"worker_type": format!("{:?}", config.worker_type),
}))
.into_response(),
HealthStatus::Unhealthy => (
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"status": "unhealthy",
"error": "Worker is not responding"
})),
)
.into_response(),
HealthStatus::Degraded => Json(json!({
"status": "degraded",
"warning": "High load detected"
}))
.into_response(),
}
}
async fn health_generate_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
let config = config.read().await;
if should_fail(&config).await {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": "Random failure for testing"
})),
)
.into_response();
}
if matches!(config.health_status, HealthStatus::Healthy) {
Json(json!({
"status": "ok",
"queue_length": 0,
"processing_time_ms": config.response_delay_ms
}))
.into_response()
} else {
(
StatusCode::SERVICE_UNAVAILABLE,
Json(json!({
"error": "Generation service unavailable"
})),
)
.into_response()
}
}
async fn server_info_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
let config = config.read().await;
if should_fail(&config).await {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": "Random failure for testing"
})),
)
.into_response();
}
Json(json!({
"model_path": "mock-model-path",
"tokenizer_path": "mock-tokenizer-path",
"port": config.port,
"host": "127.0.0.1",
"max_num_batched_tokens": 32768,
"max_prefill_tokens": 16384,
"mem_fraction_static": 0.88,
"tp_size": 1,
"dp_size": 1,
"stream_interval": 8,
"dtype": "float16",
"device": "cuda",
"enable_flashinfer": true,
"enable_p2p_check": true,
"context_length": 32768,
"chat_template": null,
"disable_radix_cache": false,
"enable_torch_compile": false,
"trust_remote_code": false,
"show_time_cost": false,
"waiting_queue_size": 0,
"running_queue_size": 0,
"req_to_token_ratio": 1.2,
"min_running_requests": 0,
"max_running_requests": 2048,
"max_req_num": 8192,
"max_batch_tokens": 32768,
"schedule_policy": "lpm",
"schedule_conservativeness": 1.0,
"version": "0.3.0",
"internal_states": [{
"waiting_queue_size": 0,
"running_queue_size": 0
}]
}))
.into_response()
}
async fn model_info_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
let config = config.read().await;
if should_fail(&config).await {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": "Random failure for testing"
})),
)
.into_response();
}
Json(json!({
"model_path": "mock-model-path",
"tokenizer_path": "mock-tokenizer-path",
"is_generation": true,
"preferred_sampling_params": {
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"max_tokens": 2048
}
}))
.into_response()
}
async fn generate_handler(
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
Json(payload): Json<serde_json::Value>,
) -> Response {
let config = config.read().await;
if should_fail(&config).await {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": "Random failure for testing"
})),
)
.into_response();
}
if config.response_delay_ms > 0 {
tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await;
}
let is_stream = payload
.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);
if is_stream {
let stream_delay = config.response_delay_ms;
// Check if it's a batch request
let is_batch = payload.get("text").and_then(|t| t.as_array()).is_some();
let batch_size = if is_batch {
payload
.get("text")
.and_then(|t| t.as_array())
.map(|arr| arr.len())
.unwrap_or(1)
} else {
1
};
let mut events = Vec::new();
// Generate events for each item in batch
for i in 0..batch_size {
let timestamp_start = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs_f64();
let data = json!({
"text": format!("Mock response {}", i + 1),
"meta_info": {
"prompt_tokens": 10,
"completion_tokens": 5,
"completion_tokens_wo_jump_forward": 5,
"input_token_logprobs": null,
"output_token_logprobs": null,
"first_token_latency": stream_delay as f64 / 1000.0,
"time_to_first_token": stream_delay as f64 / 1000.0,
"time_per_output_token": 0.01,
"end_time": timestamp_start + (stream_delay as f64 / 1000.0),
"start_time": timestamp_start,
"finish_reason": {
"type": "stop",
"reason": "length"
}
},
"stage": "mid"
});
events.push(Ok::<_, Infallible>(Event::default().data(data.to_string())));
}
// Add [DONE] event
events.push(Ok(Event::default().data("[DONE]")));
let stream = stream::iter(events);
Sse::new(stream)
.keep_alive(KeepAlive::default())
.into_response()
} else {
Json(json!({
"text": "This is a mock response.",
"meta_info": {
"prompt_tokens": 10,
"completion_tokens": 5,
"completion_tokens_wo_jump_forward": 5,
"input_token_logprobs": null,
"output_token_logprobs": null,
"first_token_latency": config.response_delay_ms as f64 / 1000.0,
"time_to_first_token": config.response_delay_ms as f64 / 1000.0,
"time_per_output_token": 0.01,
"finish_reason": {
"type": "stop",
"reason": "length"
}
}
}))
.into_response()
}
}
async fn chat_completions_handler(
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
Json(payload): Json<serde_json::Value>,
) -> Response {
let config = config.read().await;
if should_fail(&config).await {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": {
"message": "Random failure for testing",
"type": "internal_error",
"code": "internal_error"
}
})),
)
.into_response();
}
if config.response_delay_ms > 0 {
tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await;
}
let is_stream = payload
.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
if is_stream {
let request_id = format!("chatcmpl-{}", Uuid::new_v4());
let stream = stream::once(async move {
let chunk = json!({
"id": request_id,
"object": "chat.completion.chunk",
"created": timestamp,
"model": "mock-model",
"choices": [{
"index": 0,
"delta": {
"content": "This is a mock chat response."
},
"finish_reason": null
}]
});
Ok::<_, Infallible>(Event::default().data(chunk.to_string()))
})
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
Sse::new(stream)
.keep_alive(KeepAlive::default())
.into_response()
} else {
Json(json!({
"id": format!("chatcmpl-{}", Uuid::new_v4()),
"object": "chat.completion",
"created": timestamp,
"model": "mock-model",
"choices": [{
"index": 0,
"message": {
"role": "assistant",
"content": "This is a mock chat response."
},
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
}))
.into_response()
}
}
async fn completions_handler(
State(config): State<Arc<RwLock<MockWorkerConfig>>>,
Json(payload): Json<serde_json::Value>,
) -> Response {
let config = config.read().await;
if should_fail(&config).await {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": {
"message": "Random failure for testing",
"type": "internal_error",
"code": "internal_error"
}
})),
)
.into_response();
}
if config.response_delay_ms > 0 {
tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await;
}
let is_stream = payload
.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
if is_stream {
let request_id = format!("cmpl-{}", Uuid::new_v4());
let stream = stream::once(async move {
let chunk = json!({
"id": request_id,
"object": "text_completion",
"created": timestamp,
"model": "mock-model",
"choices": [{
"text": "This is a mock completion.",
"index": 0,
"logprobs": null,
"finish_reason": null
}]
});
Ok::<_, Infallible>(Event::default().data(chunk.to_string()))
})
.chain(stream::once(async { Ok(Event::default().data("[DONE]")) }));
Sse::new(stream)
.keep_alive(KeepAlive::default())
.into_response()
} else {
Json(json!({
"id": format!("cmpl-{}", Uuid::new_v4()),
"object": "text_completion",
"created": timestamp,
"model": "mock-model",
"choices": [{
"text": "This is a mock completion.",
"index": 0,
"logprobs": null,
"finish_reason": "stop"
}],
"usage": {
"prompt_tokens": 10,
"completion_tokens": 5,
"total_tokens": 15
}
}))
.into_response()
}
}
async fn flush_cache_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
let config = config.read().await;
if should_fail(&config).await {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": "Random failure for testing"
})),
)
.into_response();
}
Json(json!({
"message": "Cache flushed successfully"
}))
.into_response()
}
async fn v1_models_handler(State(config): State<Arc<RwLock<MockWorkerConfig>>>) -> Response {
let config = config.read().await;
if should_fail(&config).await {
return (
StatusCode::INTERNAL_SERVER_ERROR,
Json(json!({
"error": {
"message": "Random failure for testing",
"type": "internal_error",
"code": "internal_error"
}
})),
)
.into_response();
}
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
Json(json!({
"object": "list",
"data": [{
"id": "mock-model",
"object": "model",
"created": timestamp,
"owned_by": "organization-owner"
}]
}))
.into_response()
}
impl Default for MockWorkerConfig {
fn default() -> Self {
Self {
port: 0,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}
}
}

View File

@@ -0,0 +1,103 @@
// These modules are used by tests and benchmarks
#![allow(dead_code)]
pub mod mock_mcp_server;
pub mod mock_openai_server;
pub mod mock_worker;
pub mod test_app;
use sglang_router_rs::config::RouterConfig;
use sglang_router_rs::server::AppContext;
use std::fs;
use std::path::PathBuf;
use std::sync::{Arc, Mutex, OnceLock};
/// Helper function to create AppContext for tests
pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
Arc::new(
AppContext::new(
config.clone(),
reqwest::Client::new(),
config.max_concurrent_requests,
config.rate_limit_tokens_per_second,
)
.expect("Failed to create AppContext in test"),
)
}
// Tokenizer download configuration
const TINYLLAMA_TOKENIZER_URL: &str =
"https://huggingface.co/TinyLlama/TinyLlama-1.1B-Chat-v1.0/resolve/main/tokenizer.json";
const CACHE_DIR: &str = ".tokenizer_cache";
const TINYLLAMA_TOKENIZER_FILENAME: &str = "tinyllama_tokenizer.json";
// Global mutex to prevent concurrent downloads
static DOWNLOAD_MUTEX: OnceLock<Mutex<()>> = OnceLock::new();
/// Downloads the TinyLlama tokenizer from HuggingFace if not already cached.
/// Returns the path to the cached tokenizer file.
///
/// This function is thread-safe and will only download the tokenizer once
/// even if called from multiple threads concurrently.
pub fn ensure_tokenizer_cached() -> PathBuf {
// Get or initialize the mutex
let mutex = DOWNLOAD_MUTEX.get_or_init(|| Mutex::new(()));
// Lock to ensure only one thread downloads at a time
let _guard = mutex.lock().unwrap();
let cache_dir = PathBuf::from(CACHE_DIR);
let tokenizer_path = cache_dir.join(TINYLLAMA_TOKENIZER_FILENAME);
// Create cache directory if it doesn't exist
if !cache_dir.exists() {
fs::create_dir_all(&cache_dir).expect("Failed to create cache directory");
}
// Download tokenizer if not already cached
if !tokenizer_path.exists() {
println!("Downloading TinyLlama tokenizer from HuggingFace...");
// Use blocking reqwest client since we're in tests/benchmarks
let client = reqwest::blocking::Client::new();
let response = client
.get(TINYLLAMA_TOKENIZER_URL)
.send()
.expect("Failed to download tokenizer");
if !response.status().is_success() {
panic!("Failed to download tokenizer: HTTP {}", response.status());
}
let content = response.bytes().expect("Failed to read tokenizer content");
// Verify we got actual JSON content
if content.len() < 100 {
panic!("Downloaded content too small: {} bytes", content.len());
}
fs::write(&tokenizer_path, content).expect("Failed to write tokenizer to cache");
println!(
"Tokenizer downloaded and cached successfully ({} bytes)",
tokenizer_path.metadata().unwrap().len()
);
}
tokenizer_path
}
/// Common test prompts for consistency across tests
pub const TEST_PROMPTS: [&str; 4] = [
"deep learning is",
"Deep learning is",
"has anyone seen nemo lately",
"another prompt",
];
/// Pre-computed hashes for verification
pub const EXPECTED_HASHES: [u64; 4] = [
1209591529327510910,
4181375434596349981,
6245658446118930933,
5097285695902185237,
];

View File

@@ -0,0 +1,52 @@
use axum::Router;
use reqwest::Client;
use sglang_router_rs::{
config::RouterConfig,
routers::RouterTrait,
server::{build_app, AppContext, AppState},
};
use std::sync::Arc;
/// Create a test Axum application using the actual server's build_app function
#[allow(dead_code)]
pub fn create_test_app(
router: Arc<dyn RouterTrait>,
client: Client,
router_config: &RouterConfig,
) -> Router {
// Create AppContext
let app_context = Arc::new(
AppContext::new(
router_config.clone(),
client,
router_config.max_concurrent_requests,
router_config.rate_limit_tokens_per_second,
)
.expect("Failed to create AppContext in test"),
);
// Create AppState with the test router and context
let app_state = Arc::new(AppState {
router,
context: app_context,
concurrency_queue_tx: None, // No queue for tests
});
// Configure request ID headers (use defaults if not specified)
let request_id_headers = router_config.request_id_headers.clone().unwrap_or_else(|| {
vec![
"x-request-id".to_string(),
"x-correlation-id".to_string(),
"x-trace-id".to_string(),
"request-id".to_string(),
]
});
// Use the actual server's build_app function
build_app(
app_state,
router_config.max_payload_size,
request_id_headers,
router_config.cors_allowed_origins.clone(),
)
}

View File

@@ -0,0 +1,457 @@
// This test suite validates the complete MCP implementation against the
// functionality required for SGLang responses API integration.
//
// Test Coverage:
// - Core MCP server functionality
// - Tool session management (individual and multi-tool)
// - Tool execution and error handling
// - Schema adaptation and validation
// - Mock server integration for reliable testing
mod common;
use common::mock_mcp_server::MockMCPServer;
use serde_json::json;
use sglang_router_rs::mcp::{McpClientManager, McpConfig, McpError, McpServerConfig, McpTransport};
use std::collections::HashMap;
/// Create a new mock server for testing (each test gets its own)
async fn create_mock_server() -> MockMCPServer {
MockMCPServer::start()
.await
.expect("Failed to start mock MCP server")
}
// Core MCP Server Tests
#[tokio::test]
async fn test_mcp_server_initialization() {
// Test that we can create an empty configuration
let config = McpConfig { servers: vec![] };
// Should fail with no servers
let result = McpClientManager::new(config).await;
assert!(result.is_err(), "Should fail with no servers configured");
}
#[tokio::test]
async fn test_server_connection_with_mock() {
let mock_server = create_mock_server().await;
let config = McpConfig {
servers: vec![McpServerConfig {
name: "mock_server".to_string(),
transport: McpTransport::Streamable {
url: mock_server.url(),
token: None,
},
}],
};
let result = McpClientManager::new(config).await;
assert!(result.is_ok(), "Should connect to mock server");
let mut manager = result.unwrap();
let servers = manager.list_servers();
assert_eq!(servers.len(), 1);
assert!(servers.contains(&"mock_server".to_string()));
let tools = manager.list_tools();
assert_eq!(tools.len(), 2, "Should have 2 tools from mock server");
assert!(manager.has_tool("brave_web_search"));
assert!(manager.has_tool("brave_local_search"));
manager.shutdown().await;
}
#[tokio::test]
async fn test_tool_availability_checking() {
let mock_server = create_mock_server().await;
let config = McpConfig {
servers: vec![McpServerConfig {
name: "mock_server".to_string(),
transport: McpTransport::Streamable {
url: mock_server.url(),
token: None,
},
}],
};
let mut manager = McpClientManager::new(config).await.unwrap();
let test_tools = vec!["brave_web_search", "brave_local_search", "calculator"];
for tool in test_tools {
let available = manager.has_tool(tool);
match tool {
"brave_web_search" | "brave_local_search" => {
assert!(
available,
"Tool {} should be available from mock server",
tool
);
}
"calculator" => {
assert!(
!available,
"Tool {} should not be available from mock server",
tool
);
}
_ => {}
}
}
manager.shutdown().await;
}
#[tokio::test]
async fn test_multi_server_connection() {
let mock_server1 = create_mock_server().await;
let mock_server2 = create_mock_server().await;
let config = McpConfig {
servers: vec![
McpServerConfig {
name: "mock_server_1".to_string(),
transport: McpTransport::Streamable {
url: mock_server1.url(),
token: None,
},
},
McpServerConfig {
name: "mock_server_2".to_string(),
transport: McpTransport::Streamable {
url: mock_server2.url(),
token: None,
},
},
],
};
// Note: This will fail to connect to both servers in the current implementation
// since they return the same tools. The manager will connect to the first one.
let result = McpClientManager::new(config).await;
if let Ok(mut manager) = result {
let servers = manager.list_servers();
assert!(!servers.is_empty(), "Should have at least one server");
let tools = manager.list_tools();
assert!(tools.len() >= 2, "Should have tools from servers");
manager.shutdown().await;
}
}
#[tokio::test]
async fn test_tool_execution_with_mock() {
let mock_server = create_mock_server().await;
let config = McpConfig {
servers: vec![McpServerConfig {
name: "mock_server".to_string(),
transport: McpTransport::Streamable {
url: mock_server.url(),
token: None,
},
}],
};
let mut manager = McpClientManager::new(config).await.unwrap();
let result = manager
.call_tool(
"brave_web_search",
Some(
json!({
"query": "rust programming",
"count": 1
})
.as_object()
.unwrap()
.clone(),
),
)
.await;
assert!(
result.is_ok(),
"Tool execution should succeed with mock server"
);
let response = result.unwrap();
assert!(!response.content.is_empty(), "Should have content");
// Check the content
if let rmcp::model::RawContent::Text(text) = &response.content[0].raw {
assert!(text
.text
.contains("Mock search results for: rust programming"));
} else {
panic!("Expected text content");
}
manager.shutdown().await;
}
#[tokio::test]
async fn test_concurrent_tool_execution() {
let mock_server = create_mock_server().await;
let config = McpConfig {
servers: vec![McpServerConfig {
name: "mock_server".to_string(),
transport: McpTransport::Streamable {
url: mock_server.url(),
token: None,
},
}],
};
let mut manager = McpClientManager::new(config).await.unwrap();
// Execute tools sequentially (true concurrent execution would require Arc<Mutex>)
let tool_calls = vec![
("brave_web_search", json!({"query": "test1"})),
("brave_local_search", json!({"query": "test2"})),
];
for (tool_name, args) in tool_calls {
let result = manager
.call_tool(tool_name, Some(args.as_object().unwrap().clone()))
.await;
assert!(result.is_ok(), "Tool {} should succeed", tool_name);
let response = result.unwrap();
assert!(!response.content.is_empty(), "Should have content");
}
manager.shutdown().await;
}
// Error Handling Tests
#[tokio::test]
async fn test_tool_execution_errors() {
let mock_server = create_mock_server().await;
let config = McpConfig {
servers: vec![McpServerConfig {
name: "mock_server".to_string(),
transport: McpTransport::Streamable {
url: mock_server.url(),
token: None,
},
}],
};
let mut manager = McpClientManager::new(config).await.unwrap();
// Try to call unknown tool
let result = manager
.call_tool("unknown_tool", Some(serde_json::Map::new()))
.await;
assert!(result.is_err(), "Should fail for unknown tool");
match result.unwrap_err() {
McpError::ToolNotFound(name) => {
assert_eq!(name, "unknown_tool");
}
_ => panic!("Expected ToolNotFound error"),
}
manager.shutdown().await;
}
#[tokio::test]
async fn test_connection_without_server() {
let config = McpConfig {
servers: vec![McpServerConfig {
name: "nonexistent".to_string(),
transport: McpTransport::Streamable {
url: "http://localhost:9999/mcp".to_string(),
token: None,
},
}],
};
let result = McpClientManager::new(config).await;
assert!(result.is_err(), "Should fail when no server is running");
if let Err(e) = result {
let error_msg = e.to_string();
assert!(
error_msg.contains("Failed to connect") || error_msg.contains("Connection"),
"Error should be connection-related: {}",
error_msg
);
}
}
// Schema Validation Tests
#[tokio::test]
async fn test_tool_info_structure() {
let mock_server = create_mock_server().await;
let config = McpConfig {
servers: vec![McpServerConfig {
name: "mock_server".to_string(),
transport: McpTransport::Streamable {
url: mock_server.url(),
token: None,
},
}],
};
let manager = McpClientManager::new(config).await.unwrap();
let tools = manager.list_tools();
let brave_search = tools
.iter()
.find(|t| t.name == "brave_web_search")
.expect("Should have brave_web_search tool");
assert_eq!(brave_search.name, "brave_web_search");
assert!(brave_search.description.contains("Mock web search"));
assert_eq!(brave_search.server, "mock_server");
assert!(brave_search.parameters.is_some());
}
// SSE Parsing Tests (simplified since we don't expose parse_sse_event)
#[tokio::test]
async fn test_sse_connection() {
let mock_server = create_mock_server().await;
// Test SSE transport configuration
let config = McpConfig {
servers: vec![McpServerConfig {
name: "sse_server".to_string(),
transport: McpTransport::Sse {
// Mock server doesn't support SSE, but we can test the config
url: format!("http://127.0.0.1:{}/sse", mock_server.port),
token: Some("test_token".to_string()),
},
}],
};
// This will fail to connect but tests the configuration
let result = McpClientManager::new(config).await;
assert!(result.is_err(), "Mock server doesn't support SSE");
}
// Connection Type Tests
#[tokio::test]
async fn test_transport_types() {
// Test different transport configurations
// HTTP/Streamable transport
let http_config = McpServerConfig {
name: "http_server".to_string(),
transport: McpTransport::Streamable {
url: "http://localhost:8080/mcp".to_string(),
token: Some("auth_token".to_string()),
},
};
assert_eq!(http_config.name, "http_server");
// SSE transport
let sse_config = McpServerConfig {
name: "sse_server".to_string(),
transport: McpTransport::Sse {
url: "http://localhost:8081/sse".to_string(),
token: None,
},
};
assert_eq!(sse_config.name, "sse_server");
// STDIO transport
let stdio_config = McpServerConfig {
name: "stdio_server".to_string(),
transport: McpTransport::Stdio {
command: "mcp-server".to_string(),
args: vec!["--port".to_string(), "8082".to_string()],
envs: HashMap::new(),
},
};
assert_eq!(stdio_config.name, "stdio_server");
}
// Integration Pattern Tests
#[tokio::test]
async fn test_complete_workflow() {
let mock_server = create_mock_server().await;
// 1. Initialize configuration
let config = McpConfig {
servers: vec![McpServerConfig {
name: "integration_test".to_string(),
transport: McpTransport::Streamable {
url: mock_server.url(),
token: None,
},
}],
};
// 2. Connect to server
let mut manager = McpClientManager::new(config)
.await
.expect("Should connect to mock server");
// 3. Verify server connection
let servers = manager.list_servers();
assert_eq!(servers.len(), 1);
assert_eq!(servers[0], "integration_test");
// 4. Check available tools
let tools = manager.list_tools();
assert_eq!(tools.len(), 2);
// 5. Verify specific tools exist
assert!(manager.has_tool("brave_web_search"));
assert!(manager.has_tool("brave_local_search"));
assert!(!manager.has_tool("nonexistent_tool"));
// 6. Execute a tool
let result = manager
.call_tool(
"brave_web_search",
Some(
json!({
"query": "SGLang router MCP integration",
"count": 1
})
.as_object()
.unwrap()
.clone(),
),
)
.await;
assert!(result.is_ok(), "Tool execution should succeed");
let response = result.unwrap();
assert!(!response.content.is_empty(), "Should return content");
// 7. Clean shutdown
manager.shutdown().await;
// Verify all required capabilities for responses API integration
let capabilities = [
"MCP server initialization",
"Tool server connection and discovery",
"Tool availability checking",
"Tool execution",
"Error handling and robustness",
"Multi-server support",
"Schema adaptation",
"Mock server integration (no external dependencies)",
];
assert_eq!(capabilities.len(), 8);
}

View File

@@ -0,0 +1,392 @@
mod common;
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use reqwest::Client;
use serde_json::json;
use sglang_router_rs::config::{
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
};
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc;
/// Test context that manages mock workers
struct TestContext {
workers: Vec<MockWorker>,
router: Arc<dyn RouterTrait>,
}
impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
let mut config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec![],
},
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3003,
max_payload_size: 256 * 1024 * 1024,
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
};
let mut workers = Vec::new();
let mut worker_urls = Vec::new();
for worker_config in worker_configs {
let mut worker = MockWorker::new(worker_config);
let url = worker.start().await.unwrap();
worker_urls.push(url);
workers.push(worker);
}
if !workers.is_empty() {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
}
config.mode = RoutingMode::Regular { worker_urls };
let app_context = common::create_test_context(config);
let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router);
if !workers.is_empty() {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
}
Self { workers, router }
}
async fn shutdown(mut self) {
// Small delay to ensure any pending operations complete
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
for worker in &mut self.workers {
worker.stop().await;
}
// Another small delay to ensure cleanup completes
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
async fn make_request(
&self,
endpoint: &str,
body: serde_json::Value,
) -> Result<serde_json::Value, String> {
let client = Client::new();
// Get any worker URL for testing
let worker_urls = self.router.get_worker_urls();
if worker_urls.is_empty() {
return Err("No available workers".to_string());
}
let worker_url = &worker_urls[0];
let response = client
.post(format!("{}{}", worker_url, endpoint))
.json(&body)
.send()
.await
.map_err(|e| format!("Request failed: {}", e))?;
if !response.status().is_success() {
return Err(format!("Request failed with status: {}", response.status()));
}
response
.json::<serde_json::Value>()
.await
.map_err(|e| format!("Failed to parse response: {}", e))
}
}
#[cfg(test)]
mod request_format_tests {
use super::*;
#[tokio::test]
async fn test_generate_request_formats() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 19001,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
// Test 1: Basic text request
let payload = json!({
"text": "Hello, world!",
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
// Test 2: Request with sampling parameters
let payload = json!({
"text": "Tell me a story",
"sampling_params": {
"temperature": 0.7,
"max_new_tokens": 100,
"top_p": 0.9
},
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
// Test 3: Request with input_ids
let payload = json!({
"input_ids": [1, 2, 3, 4, 5],
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 50
},
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_v1_chat_completions_formats() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 19002,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
// Test 1: Basic chat completion
let payload = json!({
"model": "test-model",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}
],
"stream": false
});
let result = ctx.make_request("/v1/chat/completions", payload).await;
assert!(result.is_ok());
let response = result.unwrap();
assert!(response.get("choices").is_some());
assert!(response.get("id").is_some());
assert_eq!(
response.get("object").and_then(|v| v.as_str()),
Some("chat.completion")
);
// Test 2: Chat completion with parameters
let payload = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Tell me a joke"}
],
"temperature": 0.8,
"max_tokens": 150,
"top_p": 0.95,
"stream": false
});
let result = ctx.make_request("/v1/chat/completions", payload).await;
assert!(result.is_ok());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_v1_completions_formats() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 19003,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
// Test 1: Basic completion
let payload = json!({
"model": "test-model",
"prompt": "Once upon a time",
"max_tokens": 50,
"stream": false
});
let result = ctx.make_request("/v1/completions", payload).await;
assert!(result.is_ok());
let response = result.unwrap();
assert!(response.get("choices").is_some());
assert_eq!(
response.get("object").and_then(|v| v.as_str()),
Some("text_completion")
);
// Test 2: Completion with array prompt
let payload = json!({
"model": "test-model",
"prompt": ["First prompt", "Second prompt"],
"temperature": 0.5,
"stream": false
});
let result = ctx.make_request("/v1/completions", payload).await;
assert!(result.is_ok());
// Test 3: Completion with logprobs
let payload = json!({
"model": "test-model",
"prompt": "The capital of France is",
"max_tokens": 10,
"logprobs": 5,
"stream": false
});
let result = ctx.make_request("/v1/completions", payload).await;
assert!(result.is_ok());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_batch_requests() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 19004,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
// Test batch text generation
let payload = json!({
"text": ["First text", "Second text", "Third text"],
"sampling_params": {
"temperature": 0.7,
"max_new_tokens": 50
},
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
// Test batch with input_ids
let payload = json!({
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_special_parameters() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 19005,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
// Test with return_logprob
let payload = json!({
"text": "Test",
"return_logprob": true,
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
// Test with json_schema
let payload = json!({
"text": "Generate JSON",
"sampling_params": {
"temperature": 0.0,
"json_schema": "$$ANY$$"
},
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
// Test with ignore_eos
let payload = json!({
"text": "Continue forever",
"sampling_params": {
"temperature": 0.7,
"max_new_tokens": 100,
"ignore_eos": true
},
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_error_handling() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 19006,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
// Test with empty body - should still work with mock worker
let payload = json!({});
let result = ctx.make_request("/generate", payload).await;
// Mock worker accepts empty body
assert!(result.is_ok());
ctx.shutdown().await;
}
}

View File

@@ -0,0 +1,210 @@
// Integration test for Responses API
use sglang_router_rs::protocols::spec::{
GenerationRequest, ReasoningEffort, ResponseInput, ResponseReasoningParam, ResponseStatus,
ResponseTool, ResponseToolType, ResponsesRequest, ResponsesResponse, ServiceTier, ToolChoice,
ToolChoiceValue, Truncation, UsageInfo,
};
#[test]
fn test_responses_request_creation() {
let request = ResponsesRequest {
background: false,
include: None,
input: ResponseInput::Text("Hello, world!".to_string()),
instructions: Some("Be helpful".to_string()),
max_output_tokens: Some(100),
max_tool_calls: None,
metadata: None,
model: Some("test-model".to_string()),
parallel_tool_calls: true,
previous_response_id: None,
reasoning: Some(ResponseReasoningParam {
effort: Some(ReasoningEffort::Medium),
}),
service_tier: ServiceTier::Auto,
store: true,
stream: false,
temperature: Some(0.7),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![ResponseTool {
r#type: ResponseToolType::WebSearchPreview,
}],
top_logprobs: 5,
top_p: Some(0.9),
truncation: Truncation::Disabled,
user: Some("test-user".to_string()),
request_id: "resp_test123".to_string(),
priority: 0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
stop: None,
top_k: -1,
min_p: 0.0,
repetition_penalty: 1.0,
};
// Test GenerationRequest trait implementation
assert!(!request.is_stream());
assert_eq!(request.get_model(), Some("test-model"));
let routing_text = request.extract_text_for_routing();
assert_eq!(routing_text, "Hello, world!");
}
#[test]
fn test_sampling_params_conversion() {
let request = ResponsesRequest {
background: false,
include: None,
input: ResponseInput::Text("Test".to_string()),
instructions: None,
max_output_tokens: Some(50),
max_tool_calls: None,
metadata: None,
model: Some("test-model".to_string()),
parallel_tool_calls: true, // Use default true
previous_response_id: None,
reasoning: None,
service_tier: ServiceTier::Auto,
store: true, // Use default true
stream: false,
temperature: Some(0.8),
tool_choice: ToolChoice::Value(ToolChoiceValue::Auto),
tools: vec![],
top_logprobs: 0, // Use default 0
top_p: Some(0.95),
truncation: Truncation::Auto,
user: None,
request_id: "resp_test456".to_string(),
priority: 0,
frequency_penalty: 0.1,
presence_penalty: 0.2,
stop: None,
top_k: 10,
min_p: 0.05,
repetition_penalty: 1.1,
};
let params = request.to_sampling_params(1000, None);
// Check that parameters are converted correctly
assert!(params.contains_key("temperature"));
assert!(params.contains_key("top_p"));
assert!(params.contains_key("frequency_penalty"));
assert!(params.contains_key("max_new_tokens"));
}
#[test]
fn test_responses_response_creation() {
let response = ResponsesResponse::new(
"resp_test789".to_string(),
"test-model".to_string(),
ResponseStatus::Completed,
);
assert_eq!(response.id, "resp_test789");
assert_eq!(response.model, "test-model");
assert!(response.is_complete());
assert!(!response.is_in_progress());
assert!(!response.is_failed());
}
#[test]
fn test_usage_conversion() {
let usage_info = UsageInfo::new_with_cached(15, 25, Some(8), 3);
let response_usage = usage_info.to_response_usage();
assert_eq!(response_usage.input_tokens, 15);
assert_eq!(response_usage.output_tokens, 25);
assert_eq!(response_usage.total_tokens, 40);
// Check details are converted correctly
assert!(response_usage.input_tokens_details.is_some());
assert_eq!(
response_usage
.input_tokens_details
.as_ref()
.unwrap()
.cached_tokens,
3
);
assert!(response_usage.output_tokens_details.is_some());
assert_eq!(
response_usage
.output_tokens_details
.as_ref()
.unwrap()
.reasoning_tokens,
8
);
// Test reverse conversion
let back_to_usage = response_usage.to_usage_info();
assert_eq!(back_to_usage.prompt_tokens, 15);
assert_eq!(back_to_usage.completion_tokens, 25);
assert_eq!(back_to_usage.reasoning_tokens, Some(8));
}
#[test]
fn test_reasoning_param_default() {
let param = ResponseReasoningParam {
effort: Some(ReasoningEffort::Medium),
};
// Test JSON serialization/deserialization preserves default
let json = serde_json::to_string(&param).unwrap();
let parsed: ResponseReasoningParam = serde_json::from_str(&json).unwrap();
assert!(matches!(parsed.effort, Some(ReasoningEffort::Medium)));
}
#[test]
fn test_json_serialization() {
let request = ResponsesRequest {
background: true,
include: None,
input: ResponseInput::Text("Test input".to_string()),
instructions: Some("Test instructions".to_string()),
max_output_tokens: Some(200),
max_tool_calls: Some(5),
metadata: None,
model: Some("gpt-4".to_string()),
parallel_tool_calls: false,
previous_response_id: None,
reasoning: Some(ResponseReasoningParam {
effort: Some(ReasoningEffort::High),
}),
service_tier: ServiceTier::Priority,
store: false,
stream: true,
temperature: Some(0.9),
tool_choice: ToolChoice::Value(ToolChoiceValue::Required),
tools: vec![ResponseTool {
r#type: ResponseToolType::CodeInterpreter,
}],
top_logprobs: 10,
top_p: Some(0.8),
truncation: Truncation::Auto,
user: Some("test_user".to_string()),
request_id: "resp_comprehensive_test".to_string(),
priority: 1,
frequency_penalty: 0.3,
presence_penalty: 0.4,
stop: None,
top_k: 50,
min_p: 0.1,
repetition_penalty: 1.2,
};
// Test that everything can be serialized to JSON and back
let json = serde_json::to_string(&request).expect("Serialization should work");
let parsed: ResponsesRequest =
serde_json::from_str(&json).expect("Deserialization should work");
assert_eq!(parsed.request_id, "resp_comprehensive_test");
assert_eq!(parsed.model, Some("gpt-4".to_string()));
assert!(parsed.background);
assert!(parsed.stream);
assert_eq!(parsed.tools.len(), 1);
}

View File

@@ -0,0 +1,370 @@
mod common;
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use futures_util::StreamExt;
use reqwest::Client;
use serde_json::json;
use sglang_router_rs::config::{
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
};
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc;
/// Test context that manages mock workers
struct TestContext {
workers: Vec<MockWorker>,
router: Arc<dyn RouterTrait>,
}
impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
let mut config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec![],
},
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3004,
max_payload_size: 256 * 1024 * 1024,
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
rate_limit_tokens_per_second: None,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
};
let mut workers = Vec::new();
let mut worker_urls = Vec::new();
for worker_config in worker_configs {
let mut worker = MockWorker::new(worker_config);
let url = worker.start().await.unwrap();
worker_urls.push(url);
workers.push(worker);
}
if !workers.is_empty() {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
}
config.mode = RoutingMode::Regular { worker_urls };
let app_context = common::create_test_context(config);
let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router);
if !workers.is_empty() {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
}
Self { workers, router }
}
async fn shutdown(mut self) {
// Small delay to ensure any pending operations complete
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
for worker in &mut self.workers {
worker.stop().await;
}
// Another small delay to ensure cleanup completes
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
async fn make_streaming_request(
&self,
endpoint: &str,
body: serde_json::Value,
) -> Result<Vec<String>, String> {
let client = Client::new();
// Get any worker URL for testing
let worker_urls = self.router.get_worker_urls();
if worker_urls.is_empty() {
return Err("No available workers".to_string());
}
let worker_url = &worker_urls[0];
let response = client
.post(format!("{}{}", worker_url, endpoint))
.json(&body)
.send()
.await
.map_err(|e| format!("Request failed: {}", e))?;
if !response.status().is_success() {
return Err(format!("Request failed with status: {}", response.status()));
}
// Check if it's a streaming response
let content_type = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !content_type.contains("text/event-stream") {
return Err("Response is not a stream".to_string());
}
let mut stream = response.bytes_stream();
let mut events = Vec::new();
while let Some(chunk) = stream.next().await {
if let Ok(bytes) = chunk {
let text = String::from_utf8_lossy(&bytes);
for line in text.lines() {
if let Some(stripped) = line.strip_prefix("data: ") {
events.push(stripped.to_string());
}
}
}
}
Ok(events)
}
}
#[cfg(test)]
mod streaming_tests {
use super::*;
#[tokio::test]
async fn test_generate_streaming() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 20001,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 10,
fail_rate: 0.0,
}])
.await;
let payload = json!({
"text": "Stream test",
"stream": true,
"sampling_params": {
"temperature": 0.7,
"max_new_tokens": 10
}
});
let result = ctx.make_streaming_request("/generate", payload).await;
assert!(result.is_ok());
let events = result.unwrap();
// Should have at least one data chunk and [DONE]
assert!(events.len() >= 2);
assert_eq!(events.last().unwrap(), "[DONE]");
ctx.shutdown().await;
}
#[tokio::test]
async fn test_v1_chat_completions_streaming() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 20002,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 10,
fail_rate: 0.0,
}])
.await;
let payload = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Count to 3"}
],
"stream": true,
"max_tokens": 20
});
let result = ctx
.make_streaming_request("/v1/chat/completions", payload)
.await;
assert!(result.is_ok());
let events = result.unwrap();
assert!(events.len() >= 2); // At least one chunk + [DONE]
// Verify events are valid JSON (except [DONE])
for event in &events {
if event != "[DONE]" {
let parsed: Result<serde_json::Value, _> = serde_json::from_str(event);
assert!(parsed.is_ok(), "Invalid JSON in SSE event: {}", event);
let json = parsed.unwrap();
assert_eq!(
json.get("object").and_then(|v| v.as_str()),
Some("chat.completion.chunk")
);
}
}
ctx.shutdown().await;
}
#[tokio::test]
async fn test_v1_completions_streaming() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 20003,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 10,
fail_rate: 0.0,
}])
.await;
let payload = json!({
"model": "test-model",
"prompt": "Once upon a time",
"stream": true,
"max_tokens": 15
});
let result = ctx.make_streaming_request("/v1/completions", payload).await;
assert!(result.is_ok());
let events = result.unwrap();
assert!(events.len() >= 2); // At least one chunk + [DONE]
ctx.shutdown().await;
}
#[tokio::test]
async fn test_streaming_with_error() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 20004,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 1.0, // Always fail
}])
.await;
let payload = json!({
"text": "This should fail",
"stream": true
});
let result = ctx.make_streaming_request("/generate", payload).await;
// With fail_rate: 1.0, the request should fail
assert!(result.is_err());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_streaming_timeouts() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 20005,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 100, // Slow response
fail_rate: 0.0,
}])
.await;
let payload = json!({
"text": "Slow stream",
"stream": true,
"sampling_params": {
"max_new_tokens": 5
}
});
let start = std::time::Instant::now();
let result = ctx.make_streaming_request("/generate", payload).await;
let elapsed = start.elapsed();
assert!(result.is_ok());
let events = result.unwrap();
// Should have received multiple chunks over time
assert!(!events.is_empty());
assert!(elapsed.as_millis() >= 100); // At least one delay
ctx.shutdown().await;
}
#[tokio::test]
async fn test_batch_streaming() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 20006,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 10,
fail_rate: 0.0,
}])
.await;
// Batch request with streaming
let payload = json!({
"text": ["First", "Second", "Third"],
"stream": true,
"sampling_params": {
"max_new_tokens": 5
}
});
let result = ctx.make_streaming_request("/generate", payload).await;
assert!(result.is_ok());
let events = result.unwrap();
// Should have multiple events for batch
assert!(events.len() >= 4); // At least 3 responses + [DONE]
ctx.shutdown().await;
}
#[tokio::test]
async fn test_sse_format_parsing() {
// Test SSE format parsing
let parse_sse_chunk = |chunk: &[u8]| -> Vec<String> {
let text = String::from_utf8_lossy(chunk);
text.lines()
.filter(|line| line.starts_with("data: "))
.map(|line| line[6..].to_string())
.collect()
};
let sse_data =
b"data: {\"text\":\"Hello\"}\n\ndata: {\"text\":\" world\"}\n\ndata: [DONE]\n\n";
let events = parse_sse_chunk(sse_data);
assert_eq!(events.len(), 3);
assert_eq!(events[0], "{\"text\":\"Hello\"}");
assert_eq!(events[1], "{\"text\":\" world\"}");
assert_eq!(events[2], "[DONE]");
// Test with mixed content
let mixed = b"event: message\ndata: {\"test\":true}\n\n: comment\ndata: [DONE]\n\n";
let events = parse_sse_chunk(mixed);
assert_eq!(events.len(), 2);
assert_eq!(events[0], "{\"test\":true}");
assert_eq!(events[1], "[DONE]");
}
}

View File

@@ -0,0 +1,150 @@
#[cfg(test)]
mod tests {
use sglang_router_rs::tokenizer::chat_template::{ChatMessage, ChatTemplateProcessor};
#[test]
fn test_chat_message_helpers() {
let system_msg = ChatMessage::system("You are a helpful assistant");
assert_eq!(system_msg.role, "system");
assert_eq!(system_msg.content, "You are a helpful assistant");
let user_msg = ChatMessage::user("Hello!");
assert_eq!(user_msg.role, "user");
assert_eq!(user_msg.content, "Hello!");
let assistant_msg = ChatMessage::assistant("Hi there!");
assert_eq!(assistant_msg.role, "assistant");
assert_eq!(assistant_msg.content, "Hi there!");
}
#[test]
fn test_llama_style_template() {
// Test a Llama-style chat template
let template = r#"
{%- if messages[0]['role'] == 'system' -%}
{%- set system_message = messages[0]['content'] -%}
{%- set messages = messages[1:] -%}
{%- else -%}
{%- set system_message = '' -%}
{%- endif -%}
{{- bos_token }}
{%- if system_message %}
{{- '<|start_header_id|>system<|end_header_id|>\n\n' + system_message + '<|eot_id|>' }}
{%- endif %}
{%- for message in messages %}
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] + '<|eot_id|>' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|start_header_id|>assistant<|end_header_id|>\n\n' }}
{%- endif %}
"#;
let processor = ChatTemplateProcessor::new(
template.to_string(),
Some("<|begin_of_text|>".to_string()),
Some("<|end_of_text|>".to_string()),
);
let messages = vec![
ChatMessage::system("You are a helpful assistant"),
ChatMessage::user("What is 2+2?"),
];
let result = processor.apply_chat_template(&messages, true).unwrap();
// Check that the result contains expected markers
assert!(result.contains("<|begin_of_text|>"));
assert!(result.contains("<|start_header_id|>system<|end_header_id|>"));
assert!(result.contains("You are a helpful assistant"));
assert!(result.contains("<|start_header_id|>user<|end_header_id|>"));
assert!(result.contains("What is 2+2?"));
assert!(result.contains("<|start_header_id|>assistant<|end_header_id|>"));
}
#[test]
fn test_chatml_template() {
// Test a ChatML-style template
let template = r#"
{%- for message in messages %}
{{- '<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>\n' }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|im_start|>assistant\n' }}
{%- endif %}
"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there!"),
ChatMessage::user("How are you?"),
];
let result = processor.apply_chat_template(&messages, true).unwrap();
// Check ChatML format
assert!(result.contains("<|im_start|>user\nHello<|im_end|>"));
assert!(result.contains("<|im_start|>assistant\nHi there!<|im_end|>"));
assert!(result.contains("<|im_start|>user\nHow are you?<|im_end|>"));
assert!(result.ends_with("<|im_start|>assistant\n"));
}
#[test]
fn test_template_without_generation_prompt() {
let template = r#"
{%- for message in messages -%}
{{ message.role }}: {{ message.content }}
{% endfor -%}
{%- if add_generation_prompt -%}
assistant:
{%- endif -%}
"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![ChatMessage::user("Test")];
// Test without generation prompt
let result = processor.apply_chat_template(&messages, false).unwrap();
assert_eq!(result.trim(), "user: Test");
// Test with generation prompt
let result_with_prompt = processor.apply_chat_template(&messages, true).unwrap();
assert!(result_with_prompt.contains("assistant:"));
}
#[test]
fn test_template_with_special_tokens() {
let template = r#"{{ bos_token }}{% for msg in messages %}{{ msg.content }}{{ eos_token }}{% endfor %}"#;
let processor = ChatTemplateProcessor::new(
template.to_string(),
Some("<s>".to_string()),
Some("</s>".to_string()),
);
let messages = vec![ChatMessage::user("Hello")];
let result = processor.apply_chat_template(&messages, false).unwrap();
assert_eq!(result, "<s>Hello</s>");
}
#[test]
fn test_empty_messages() {
let template =
r#"{% for msg in messages %}{{ msg.role }}: {{ msg.content }}\n{% endfor %}"#;
let processor = ChatTemplateProcessor::new(template.to_string(), None, None);
let messages = vec![];
let result = processor.apply_chat_template(&messages, false).unwrap();
assert_eq!(result, "");
}
// Integration test with actual tokenizer file loading would go here
// but requires a real tokenizer_config.json file
}

View File

@@ -0,0 +1,183 @@
#[cfg(test)]
mod tests {
use std::fs;
use tempfile::TempDir;
#[test]
fn test_load_chat_template_from_file() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
// Create temporary directory
let temp_dir = TempDir::new().unwrap();
let template_path = temp_dir.path().join("template.jinja");
// Write a test template
let template_content = r#"
{%- for message in messages %}
{{- '<|' + message['role'] + '|>' + message['content'] }}
{%- endfor %}
{%- if add_generation_prompt %}
{{- '<|assistant|>' }}
{%- endif %}
"#;
fs::write(&template_path, template_content).unwrap();
// Create a mock tokenizer config
let tokenizer_config = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {
"type": "Whitespace"
},
"post_processor": null,
"decoder": null,
"model": {
"type": "BPE",
"vocab": {
"hello": 0,
"world": 1,
"<s>": 2,
"</s>": 3
},
"merges": []
}
}"#;
let tokenizer_path = temp_dir.path().join("tokenizer.json");
fs::write(&tokenizer_path, tokenizer_config).unwrap();
// Load tokenizer with custom chat template
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
tokenizer_path.to_str().unwrap(),
Some(template_path.to_str().unwrap()),
)
.unwrap();
// Test that the custom template is used
let messages = vec![
ChatMessage::user("Hello"),
ChatMessage::assistant("Hi there"),
];
let result = tokenizer.apply_chat_template(&messages, true).unwrap();
// Verify the custom template format
assert!(result.contains("<|user|>Hello"));
assert!(result.contains("<|assistant|>Hi there"));
assert!(result.ends_with("<|assistant|>"));
}
#[test]
fn test_override_existing_template() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
// Create temporary directory
let temp_dir = TempDir::new().unwrap();
// Create tokenizer config with a built-in template
let tokenizer_config_path = temp_dir.path().join("tokenizer_config.json");
let config_with_template = r#"{
"chat_template": "built-in: {% for msg in messages %}{{ msg.content }}{% endfor %}"
}"#;
fs::write(&tokenizer_config_path, config_with_template).unwrap();
// Create the actual tokenizer file
let tokenizer_json = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {
"type": "Whitespace"
},
"post_processor": null,
"decoder": null,
"model": {
"type": "BPE",
"vocab": {
"test": 0,
"<s>": 1,
"</s>": 2
},
"merges": []
}
}"#;
let tokenizer_path = temp_dir.path().join("tokenizer.json");
fs::write(&tokenizer_path, tokenizer_json).unwrap();
// Create custom template that should override
let custom_template_path = temp_dir.path().join("custom.jinja");
let custom_template =
r#"CUSTOM: {% for msg in messages %}[{{ msg.role }}]: {{ msg.content }}{% endfor %}"#;
fs::write(&custom_template_path, custom_template).unwrap();
// Load with custom template - should override the built-in one
let tokenizer = HuggingFaceTokenizer::from_file_with_chat_template(
tokenizer_path.to_str().unwrap(),
Some(custom_template_path.to_str().unwrap()),
)
.unwrap();
let messages = vec![ChatMessage::user("Test")];
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
// Should use CUSTOM template, not built-in
assert!(result.starts_with("CUSTOM:"));
assert!(result.contains("[user]: Test"));
assert!(!result.contains("built-in:"));
}
#[test]
fn test_set_chat_template_after_creation() {
use sglang_router_rs::tokenizer::chat_template::ChatMessage;
use sglang_router_rs::tokenizer::huggingface::HuggingFaceTokenizer;
// Create temporary directory and tokenizer file
let temp_dir = TempDir::new().unwrap();
let tokenizer_json = r#"{
"version": "1.0",
"truncation": null,
"padding": null,
"added_tokens": [],
"normalizer": null,
"pre_tokenizer": {
"type": "Whitespace"
},
"post_processor": null,
"decoder": null,
"model": {
"type": "BPE",
"vocab": {
"test": 0,
"<s>": 1,
"</s>": 2
},
"merges": []
}
}"#;
let tokenizer_path = temp_dir.path().join("tokenizer.json");
fs::write(&tokenizer_path, tokenizer_json).unwrap();
// Load tokenizer without custom template
let mut tokenizer =
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap()).unwrap();
// Set a template after creation (mimics Python's behavior)
let new_template =
"NEW: {% for msg in messages %}{{ msg.role }}: {{ msg.content }}; {% endfor %}";
tokenizer.set_chat_template(new_template.to_string());
let messages = vec![ChatMessage::user("Hello"), ChatMessage::assistant("World")];
let result = tokenizer.apply_chat_template(&messages, false).unwrap();
assert!(result.starts_with("NEW:"));
assert!(result.contains("user: Hello;"));
assert!(result.contains("assistant: World;"));
}
}

View File

@@ -0,0 +1,419 @@
//! Comprehensive integration tests for OpenAI backend functionality
use axum::{
body::Body,
extract::Request,
http::{Method, StatusCode},
routing::post,
Router,
};
use serde_json::json;
use sglang_router_rs::{
config::{RouterConfig, RoutingMode},
protocols::spec::{
ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateRequest, UserMessageContent,
},
routers::{openai_router::OpenAIRouter, RouterTrait},
};
use std::sync::Arc;
use tower::ServiceExt;
mod common;
use common::mock_openai_server::MockOpenAIServer;
/// Helper function to create a minimal chat completion request for testing
fn create_minimal_chat_request() -> ChatCompletionRequest {
let val = json!({
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello"}
],
"max_tokens": 100
});
serde_json::from_value(val).unwrap()
}
/// Helper function to create a minimal completion request for testing
fn create_minimal_completion_request() -> CompletionRequest {
CompletionRequest {
model: "gpt-3.5-turbo".to_string(),
prompt: sglang_router_rs::protocols::spec::StringOrArray::String("Hello".to_string()),
suffix: None,
max_tokens: Some(100),
temperature: None,
top_p: None,
n: None,
stream: false,
stream_options: None,
logprobs: None,
echo: false,
stop: None,
presence_penalty: None,
frequency_penalty: None,
best_of: None,
logit_bias: None,
user: None,
seed: None,
top_k: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
json_schema: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
skip_special_tokens: true,
lora_path: None,
session_params: None,
return_hidden_states: false,
other: serde_json::Map::new(),
}
}
// ============= Basic Unit Tests =============
/// Test basic OpenAI router creation and configuration
#[tokio::test]
async fn test_openai_router_creation() {
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None).await;
assert!(router.is_ok(), "Router creation should succeed");
let router = router.unwrap();
assert_eq!(router.router_type(), "openai");
assert!(!router.is_pd_mode());
}
/// Test health endpoints
#[tokio::test]
async fn test_openai_router_health() {
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
.await
.unwrap();
let req = Request::builder()
.method(Method::GET)
.uri("/health")
.body(Body::empty())
.unwrap();
let response = router.health(req).await;
assert_eq!(response.status(), StatusCode::OK);
}
/// Test server info endpoint
#[tokio::test]
async fn test_openai_router_server_info() {
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
.await
.unwrap();
let req = Request::builder()
.method(Method::GET)
.uri("/info")
.body(Body::empty())
.unwrap();
let response = router.get_server_info(req).await;
assert_eq!(response.status(), StatusCode::OK);
let (_, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
assert!(body_str.contains("openai"));
}
/// Test models endpoint
#[tokio::test]
async fn test_openai_router_models() {
// Use mock server for deterministic models response
let mock_server = MockOpenAIServer::new().await;
let router = OpenAIRouter::new(mock_server.base_url(), None)
.await
.unwrap();
let req = Request::builder()
.method(Method::GET)
.uri("/models")
.body(Body::empty())
.unwrap();
let response = router.get_models(req).await;
assert_eq!(response.status(), StatusCode::OK);
let (_, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
let models: serde_json::Value = serde_json::from_str(&body_str).unwrap();
assert_eq!(models["object"], "list");
assert!(models["data"].is_array());
}
/// Test router factory with OpenAI routing mode
#[tokio::test]
async fn test_router_factory_openai_mode() {
let routing_mode = RoutingMode::OpenAI {
worker_urls: vec!["https://api.openai.com".to_string()],
};
let router_config =
RouterConfig::new(routing_mode, sglang_router_rs::config::PolicyConfig::Random);
let app_context = common::create_test_context(router_config);
let router = sglang_router_rs::routers::RouterFactory::create_router(&app_context).await;
assert!(
router.is_ok(),
"Router factory should create OpenAI router successfully"
);
let router = router.unwrap();
assert_eq!(router.router_type(), "openai");
}
/// Test that unsupported endpoints return proper error codes
#[tokio::test]
async fn test_unsupported_endpoints() {
let router = OpenAIRouter::new("https://api.openai.com".to_string(), None)
.await
.unwrap();
// Test generate endpoint (SGLang-specific, should not be supported)
let generate_request = GenerateRequest {
prompt: None,
text: Some("Hello world".to_string()),
input_ids: None,
parameters: None,
sampling_params: None,
stream: false,
return_logprob: false,
lora_path: None,
session_params: None,
return_hidden_states: false,
rid: None,
};
let response = router.route_generate(None, &generate_request).await;
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
// Test completion endpoint (should also not be supported)
let completion_request = create_minimal_completion_request();
let response = router.route_completion(None, &completion_request).await;
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
}
// ============= Mock Server E2E Tests =============
/// Test chat completion with mock OpenAI server
#[tokio::test]
async fn test_openai_router_chat_completion_with_mock() {
// Start a mock OpenAI server
let mock_server = MockOpenAIServer::new().await;
let base_url = mock_server.base_url();
// Create router pointing to mock server
let router = OpenAIRouter::new(base_url, None).await.unwrap();
// Create a minimal chat completion request
let mut chat_request = create_minimal_chat_request();
chat_request.messages = vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Hello, how are you?".to_string()),
name: None,
}];
chat_request.temperature = Some(0.7);
// Route the request
let response = router.route_chat(None, &chat_request).await;
// Should get a successful response from mock server
assert_eq!(response.status(), StatusCode::OK);
let (_, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
let chat_response: serde_json::Value = serde_json::from_str(&body_str).unwrap();
// Verify it's a valid chat completion response
assert_eq!(chat_response["object"], "chat.completion");
assert_eq!(chat_response["model"], "gpt-3.5-turbo");
assert!(!chat_response["choices"].as_array().unwrap().is_empty());
}
/// Test full E2E flow with Axum server
#[tokio::test]
async fn test_openai_e2e_with_server() {
// Start mock OpenAI server
let mock_server = MockOpenAIServer::new().await;
let base_url = mock_server.base_url();
// Create router
let router = OpenAIRouter::new(base_url, None).await.unwrap();
// Create Axum app with chat completions endpoint
let app = Router::new().route(
"/v1/chat/completions",
post({
let router = Arc::new(router);
move |req: Request<Body>| {
let router = router.clone();
async move {
let (parts, body) = req.into_parts();
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
let chat_request: ChatCompletionRequest =
serde_json::from_str(&body_str).unwrap();
router.route_chat(Some(&parts.headers), &chat_request).await
}
}
}),
);
// Make a request to the server
let request = Request::builder()
.method(Method::POST)
.uri("/v1/chat/completions")
.header("content-type", "application/json")
.body(Body::from(
json!({
"model": "gpt-3.5-turbo",
"messages": [
{
"role": "user",
"content": "Hello, world!"
}
],
"max_tokens": 100
})
.to_string(),
))
.unwrap();
let response = app.oneshot(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let response_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
// Verify the response structure
assert_eq!(response_json["object"], "chat.completion");
assert_eq!(response_json["model"], "gpt-3.5-turbo");
assert!(!response_json["choices"].as_array().unwrap().is_empty());
}
/// Test streaming chat completions pass-through with mock server
#[tokio::test]
async fn test_openai_router_chat_streaming_with_mock() {
let mock_server = MockOpenAIServer::new().await;
let base_url = mock_server.base_url();
let router = OpenAIRouter::new(base_url, None).await.unwrap();
// Build a streaming chat request
let val = json!({
"model": "gpt-3.5-turbo",
"messages": [
{"role": "user", "content": "Hello"}
],
"max_tokens": 10,
"stream": true
});
let chat_request: ChatCompletionRequest = serde_json::from_value(val).unwrap();
let response = router.route_chat(None, &chat_request).await;
assert_eq!(response.status(), StatusCode::OK);
// Should be SSE
let headers = response.headers();
let ct = headers
.get("content-type")
.unwrap()
.to_str()
.unwrap()
.to_ascii_lowercase();
assert!(ct.contains("text/event-stream"));
// Read entire stream body and assert chunks + DONE
let body = axum::body::to_bytes(response.into_body(), usize::MAX)
.await
.unwrap();
let text = String::from_utf8(body.to_vec()).unwrap();
assert!(text.contains("chat.completion.chunk"));
assert!(text.contains("[DONE]"));
}
/// Test circuit breaker functionality
#[tokio::test]
async fn test_openai_router_circuit_breaker() {
// Create router with circuit breaker config
let cb_config = sglang_router_rs::config::CircuitBreakerConfig {
failure_threshold: 2,
success_threshold: 1,
timeout_duration_secs: 1,
window_duration_secs: 10,
};
let router = OpenAIRouter::new(
"http://invalid-url-that-will-fail".to_string(),
Some(cb_config),
)
.await
.unwrap();
let chat_request = create_minimal_chat_request();
// First few requests should fail and record failures
for _ in 0..3 {
let response = router.route_chat(None, &chat_request).await;
// Should get either an error or circuit breaker response
assert!(
response.status() == StatusCode::INTERNAL_SERVER_ERROR
|| response.status() == StatusCode::SERVICE_UNAVAILABLE
);
}
}
/// Test that Authorization header is forwarded in /v1/models
#[tokio::test]
async fn test_openai_router_models_auth_forwarding() {
// Start a mock server that requires Authorization
let expected_auth = "Bearer test-token".to_string();
let mock_server = MockOpenAIServer::new_with_auth(Some(expected_auth.clone())).await;
let router = OpenAIRouter::new(mock_server.base_url(), None)
.await
.unwrap();
// 1) Without auth header -> expect 401
let req = Request::builder()
.method(Method::GET)
.uri("/models")
.body(Body::empty())
.unwrap();
let response = router.get_models(req).await;
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
// 2) With auth header -> expect 200
let req = Request::builder()
.method(Method::GET)
.uri("/models")
.header("Authorization", expected_auth)
.body(Body::empty())
.unwrap();
let response = router.get_models(req).await;
assert_eq!(response.status(), StatusCode::OK);
let (_, body) = response.into_parts();
let body_bytes = axum::body::to_bytes(body, usize::MAX).await.unwrap();
let body_str = String::from_utf8(body_bytes.to_vec()).unwrap();
let models: serde_json::Value = serde_json::from_str(&body_str).unwrap();
assert_eq!(models["object"], "list");
}

View File

@@ -0,0 +1,918 @@
#[cfg(test)]
mod test_pd_routing {
use serde_json::json;
use sglang_router_rs::config::{
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
};
use sglang_router_rs::core::{WorkerFactory, WorkerType};
use sglang_router_rs::routers::http::pd_types::get_hostname;
use sglang_router_rs::routers::http::pd_types::PDSelectionPolicy;
use sglang_router_rs::routers::RouterFactory;
// Test-only struct to help validate PD request parsing
#[derive(Debug)]
struct PDRequest {
pub is_stream: bool,
pub batch_size: Option<usize>,
}
impl PDRequest {
// Extract PD-relevant info from JSON for testing
pub fn from_json(json: &serde_json::Value) -> Self {
let is_stream = json
.get("stream")
.and_then(|v| v.as_bool())
.unwrap_or(false);
// Detect batch size from text or input_ids
let batch_size = if let Some(text) = json.get("text") {
text.as_array().map(|arr| arr.len())
} else if let Some(input_ids) = json.get("input_ids") {
input_ids.as_array().map(|arr| arr.len())
} else {
None
};
PDRequest {
is_stream,
batch_size,
}
}
}
// ========================================================================
// Phase 1: Basic PD Components and Router Creation
// ========================================================================
#[test]
fn test_worker_types() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
// Test worker creation for prefill servers
let prefill_worker =
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
assert_eq!(prefill_worker.url(), "http://prefill:8080");
match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => {
assert_eq!(bootstrap_port, Some(9000));
}
_ => panic!("Expected Prefill worker type"),
}
// Test worker creation for decode servers
let decode_worker = WorkerFactory::create_decode("http://decode:8080".to_string());
assert_eq!(decode_worker.url(), "http://decode:8080");
match decode_worker.worker_type() {
WorkerType::Decode => (),
_ => panic!("Expected Decode worker type"),
}
// Test regular worker creation
let regular_worker = WorkerFactory::create_regular("http://regular:8080".to_string());
assert_eq!(regular_worker.url(), "http://regular:8080");
match regular_worker.worker_type() {
WorkerType::Regular => (),
_ => panic!("Expected Regular worker type"),
}
}
#[test]
fn test_pd_selection_policies() {
// Test all PD selection policy variants
// Note: These policies are only used when pd_disaggregation=true
let policies = vec![
PDSelectionPolicy::Random,
PDSelectionPolicy::PowerOfTwo,
PDSelectionPolicy::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
},
];
for policy in policies {
// Verify each policy can be created and matched
match &policy {
PDSelectionPolicy::Random => {
assert!(matches!(policy, PDSelectionPolicy::Random));
}
PDSelectionPolicy::PowerOfTwo => {
assert!(matches!(policy, PDSelectionPolicy::PowerOfTwo));
}
PDSelectionPolicy::CacheAware {
cache_threshold, ..
} => {
assert!(*cache_threshold >= 0.0 && *cache_threshold <= 1.0);
}
}
}
}
#[tokio::test]
async fn test_pd_router_configuration() {
// Test PD router configuration with various policies
// In the new structure, RoutingMode and PolicyConfig are separate
let test_cases = vec![
(
RoutingMode::PrefillDecode {
prefill_urls: vec![
("http://prefill1:8080".to_string(), Some(9000)),
("http://prefill2:8080".to_string(), None),
],
decode_urls: vec![
"http://decode1:8080".to_string(),
"http://decode2:8080".to_string(),
],
prefill_policy: None,
decode_policy: None,
},
PolicyConfig::Random,
),
(
RoutingMode::PrefillDecode {
prefill_urls: vec![("http://prefill:8080".to_string(), Some(9000))],
decode_urls: vec!["http://decode:8080".to_string()],
prefill_policy: None,
decode_policy: None,
},
PolicyConfig::PowerOfTwo {
load_check_interval_secs: 5,
},
),
(
RoutingMode::PrefillDecode {
prefill_urls: vec![
("http://p1:8080".to_string(), Some(9000)),
("http://p2:8080".to_string(), Some(9001)),
("http://p3:8080".to_string(), Some(9002)),
],
decode_urls: vec!["http://d1:8080".to_string(), "http://d2:8080".to_string()],
prefill_policy: None,
decode_policy: None,
},
PolicyConfig::CacheAware {
cache_threshold: 0.7,
balance_abs_threshold: 20,
balance_rel_threshold: 1.2,
eviction_interval_secs: 60,
max_tree_size: 1000000,
},
),
];
for (mode, policy) in test_cases {
let config = RouterConfig {
mode,
policy,
host: "127.0.0.1".to_string(),
port: 3001,
max_payload_size: 1024 * 1024,
request_timeout_secs: 60,
worker_startup_timeout_secs: 10,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
queue_size: 0,
queue_timeout_secs: 60,
cors_allowed_origins: vec![],
retry: RetryConfig::default(),
circuit_breaker: CircuitBreakerConfig::default(),
disable_retries: false,
disable_circuit_breaker: false,
health_check: sglang_router_rs::config::HealthCheckConfig::default(),
enable_igw: false,
rate_limit_tokens_per_second: None,
connection_mode: ConnectionMode::Http,
model_path: None,
tokenizer_path: None,
};
// Router creation will fail due to health checks, but config should be valid
let app_context =
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64, None)
.expect("Failed to create AppContext");
let app_context = std::sync::Arc::new(app_context);
let result = RouterFactory::create_router(&app_context).await;
assert!(result.is_err());
let error_msg = result.unwrap_err();
// Error should be about health/timeout, not configuration
assert!(
error_msg.contains("healthy") || error_msg.contains("timeout"),
"Unexpected error: {}",
error_msg
);
}
}
// ========================================================================
// Phase 2: Bootstrap Injection and Request Handling
// ========================================================================
#[test]
fn test_pd_request_from_json() {
// Test PDRequest parsing from single text request
let single_json = json!({
"text": "Hello world",
"stream": false,
"temperature": 0.7,
"max_tokens": 100
});
let pd_req = PDRequest::from_json(&single_json);
assert!(!pd_req.is_stream);
assert_eq!(pd_req.batch_size, None);
// Test PDRequest parsing from batch text request
let batch_json = json!({
"text": ["Hello", "World", "Test"],
"stream": true,
"temperature": 0.5
});
let pd_req = PDRequest::from_json(&batch_json);
assert!(pd_req.is_stream);
assert_eq!(pd_req.batch_size, Some(3));
// Test PDRequest parsing from input_ids request
let ids_json = json!({
"input_ids": [[1, 2, 3], [4, 5, 6]],
"stream": false
});
let pd_req = PDRequest::from_json(&ids_json);
assert!(!pd_req.is_stream);
assert_eq!(pd_req.batch_size, Some(2));
// Test PDRequest parsing from chat request
let chat_json = json!({
"messages": [
{"role": "system", "content": "You are a helpful assistant"},
{"role": "user", "content": "Hello"}
],
"stream": true
});
let pd_req = PDRequest::from_json(&chat_json);
assert!(pd_req.is_stream);
assert_eq!(pd_req.batch_size, None);
}
#[test]
fn test_bootstrap_injection_simulation() {
// Since we can't test the actual inject_bootstrap_fields function here
// (it's private in the router module), we'll test the expected behavior
// Simulate bootstrap injection for single request
let mut single_json = json!({
"text": "Hello world",
"stream": false,
"temperature": 0.7
});
// Create a prefill worker to simulate injection
let prefill_worker =
WorkerFactory::create_prefill("http://prefill1:8080".to_string(), Some(9000));
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
// Simulate what inject_bootstrap_fields would do
single_json["bootstrap_host"] = json!(get_hostname(prefill_worker.url()));
single_json["bootstrap_port"] = json!(bootstrap_port);
single_json["bootstrap_room"] = json!(12345u64); // Random room ID
// Verify bootstrap fields are added correctly
assert_eq!(single_json["bootstrap_host"], "prefill1");
assert_eq!(single_json["bootstrap_port"], json!(Some(9000)));
assert!(single_json["bootstrap_room"].is_u64());
assert_eq!(single_json["temperature"], 0.7); // Original field preserved
// Simulate bootstrap injection for batch request
let mut batch_json = json!({
"text": ["Hello", "World", "Test"],
"stream": true
});
let batch_size = 3;
let hostname = get_hostname(prefill_worker.url());
batch_json["bootstrap_host"] = json!(vec![hostname; batch_size]);
batch_json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
batch_json["bootstrap_room"] = json!(vec![111u64, 222u64, 333u64]);
// Verify batch bootstrap fields
assert!(batch_json["bootstrap_host"].is_array());
assert_eq!(
batch_json["bootstrap_host"].as_array().unwrap().len(),
batch_size
);
assert!(batch_json["bootstrap_port"].is_array());
assert!(batch_json["bootstrap_room"].is_array());
assert_eq!(batch_json["stream"], true); // Original field preserved
}
#[test]
fn test_request_serialization() {
// Test that requests can be properly serialized and deserialized
let request = json!({
"text": "Test prompt",
"stream": false,
"temperature": 0.7,
"max_tokens": 100,
"top_p": 0.9,
"frequency_penalty": 0.5,
"bootstrap_host": "prefill1",
"bootstrap_port": 9000,
"bootstrap_room": 12345u64
});
// Convert to bytes (as would happen in the router)
let bytes = serde_json::to_vec(&request).unwrap();
// Parse back from bytes
let parsed: serde_json::Value = serde_json::from_slice(&bytes).unwrap();
// Verify all fields are preserved
assert_eq!(parsed["text"], "Test prompt");
assert_eq!(parsed["stream"], false);
assert_eq!(parsed["temperature"], 0.7);
assert_eq!(parsed["max_tokens"], 100);
assert_eq!(parsed["bootstrap_host"], "prefill1");
assert_eq!(parsed["bootstrap_port"], 9000);
assert_eq!(parsed["bootstrap_room"], 12345);
}
#[test]
fn test_hostname_extraction() {
// Test various URL formats
let test_cases = vec![
("http://localhost:8080", "localhost"),
("http://10.0.0.1:8080", "10.0.0.1"),
("https://api.example.com:443", "api.example.com"),
("http://prefill-server", "prefill-server"),
("http://[::1]:8080", "["), // IPv6 edge case
("prefill:8080", "prefill"), // No protocol
];
for (url, expected_hostname) in test_cases {
assert_eq!(get_hostname(url), expected_hostname);
}
}
#[test]
fn test_pd_request_edge_cases() {
// Test empty request
let empty_json = json!({});
let pd_req = PDRequest::from_json(&empty_json);
assert!(!pd_req.is_stream);
assert_eq!(pd_req.batch_size, None);
// Test request with only stream field
let stream_only = json!({
"stream": true
});
let pd_req = PDRequest::from_json(&stream_only);
assert!(pd_req.is_stream);
assert_eq!(pd_req.batch_size, None);
// Test request with empty text array
let empty_batch = json!({
"text": []
});
let pd_req = PDRequest::from_json(&empty_batch);
assert_eq!(pd_req.batch_size, Some(0));
// Test request with non-array text (should be None)
let non_array_text = json!({
"text": "single string"
});
let pd_req = PDRequest::from_json(&non_array_text);
assert_eq!(pd_req.batch_size, None);
}
// ========================================================================
// Phase 2: Background Load Monitoring Tests
// ========================================================================
#[tokio::test]
async fn test_background_load_monitoring() {
use std::collections::HashMap;
use tokio::sync::watch;
// Create a watch channel for testing
let (tx, rx) = watch::channel(HashMap::new());
// Simulate load updates
let mut loads = HashMap::new();
loads.insert("http://prefill1:8080".to_string(), 10);
loads.insert("http://prefill2:8080".to_string(), 20);
loads.insert("http://decode1:8080".to_string(), 5);
loads.insert("http://decode2:8080".to_string(), 15);
// Send the loads
tx.send(loads.clone()).unwrap();
// Verify receiver gets the update
let received_loads = rx.borrow();
assert_eq!(received_loads.get("http://prefill1:8080"), Some(&10));
assert_eq!(received_loads.get("http://prefill2:8080"), Some(&20));
assert_eq!(received_loads.get("http://decode1:8080"), Some(&5));
assert_eq!(received_loads.get("http://decode2:8080"), Some(&15));
}
#[test]
fn test_load_monitoring_configuration() {
// Test that load monitoring is only enabled for PowerOfTwo policy
let policies = vec![
(PDSelectionPolicy::Random, false),
(PDSelectionPolicy::PowerOfTwo, true),
(
PDSelectionPolicy::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
},
false,
),
];
for (policy, should_monitor) in policies {
match policy {
PDSelectionPolicy::PowerOfTwo => assert!(should_monitor),
_ => assert!(!should_monitor),
}
}
}
#[tokio::test]
async fn test_watch_channel_behavior() {
use std::collections::HashMap;
use tokio::sync::watch;
// Test watch channel's broadcast behavior
let (tx, rx1) = watch::channel(HashMap::new());
let rx2 = rx1.clone();
// Initial state - empty map
assert!(rx1.borrow().is_empty());
assert!(rx2.borrow().is_empty());
// Update 1
let mut loads = HashMap::new();
loads.insert("worker1".to_string(), 10);
tx.send(loads.clone()).unwrap();
// Both receivers see the update
assert_eq!(rx1.borrow().get("worker1"), Some(&10));
assert_eq!(rx2.borrow().get("worker1"), Some(&10));
// Update 2 - overwrites previous
loads.insert("worker1".to_string(), 20);
loads.insert("worker2".to_string(), 30);
tx.send(loads).unwrap();
// Both receivers see the latest state
assert_eq!(rx1.borrow().get("worker1"), Some(&20));
assert_eq!(rx2.borrow().get("worker2"), Some(&30));
}
// ========================================================================
// Tests based on bench_one_batch_server.py patterns
// ========================================================================
#[test]
fn test_generate_request_formats() {
// Based on bench_one_batch_server.py request patterns
// Test 1: Batch request with input_ids (most common in benchmarks)
let batch_request = json!({
"input_ids": [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]],
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 16,
"ignore_eos": true,
},
"return_logprob": false,
"stream": true
});
let pd_req = PDRequest::from_json(&batch_request);
assert!(pd_req.is_stream);
assert_eq!(pd_req.batch_size, Some(3));
// Test 2: Request with return_logprob (critical for PD)
let logprob_request = json!({
"input_ids": [[1, 2, 3]],
"sampling_params": {
"temperature": 0.7,
"max_new_tokens": 8,
},
"return_logprob": true,
"stream": false
});
assert_eq!(logprob_request["return_logprob"], true);
assert_eq!(logprob_request["stream"], false);
// Test 3: Large batch sizes from benchmark
let batch_sizes = vec![1, 16, 64]; // From bench_one_batch_server.py
for bs in batch_sizes {
let request = json!({
"input_ids": vec![vec![1, 2, 3]; bs],
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 16,
},
"stream": true
});
let pd_req = PDRequest::from_json(&request);
assert_eq!(pd_req.batch_size, Some(bs));
}
}
#[test]
fn test_sampling_params_handling() {
// Test various sampling parameters from bench_one_batch_server.py
let sampling_params_variations = vec![
json!({
"temperature": 0.0,
"max_new_tokens": 8,
"ignore_eos": true
}),
json!({
"temperature": 0.7,
"max_new_tokens": 16,
"ignore_eos": false,
"top_p": 0.9,
"frequency_penalty": 0.5
}),
json!({
"temperature": 1.0,
"max_new_tokens": 64,
"json_schema": "$$ANY$$" // Structured output
}),
];
for params in sampling_params_variations {
let request = json!({
"input_ids": [[1, 2, 3]],
"sampling_params": params.clone(),
"stream": false
});
// Verify params are preserved
assert_eq!(request["sampling_params"], params);
}
}
#[test]
fn test_streaming_response_parsing() {
// Test SSE format parsing from streaming responses
let sse_chunks = ["data: {\"text\":\"Hello\",\"meta_info\":{\"completion_tokens\":1,\"finish_reason\":null}}",
"data: {\"text\":\" world\",\"meta_info\":{\"completion_tokens\":2,\"finish_reason\":null}}",
"data: {\"text\":\"!\",\"meta_info\":{\"completion_tokens\":3,\"finish_reason\":{\"type\":\"length\"}}}",
"data: [DONE]"];
for chunk in &sse_chunks[..3] {
assert!(chunk.starts_with("data: "));
let json_str = &chunk[6..]; // Skip "data: "
let parsed: serde_json::Value = serde_json::from_str(json_str).unwrap();
assert!(parsed["meta_info"]["completion_tokens"].is_u64());
}
// Test [DONE] detection
assert_eq!(sse_chunks[3], "data: [DONE]");
}
#[test]
fn test_ttft_calculation() {
// Test Time To First Token calculation pattern
let first_token_response = json!({
"text": "Hello",
"meta_info": {
"completion_tokens": 1,
"finish_reason": null
}
});
// TTFT is calculated when completion_tokens == 1
assert_eq!(first_token_response["meta_info"]["completion_tokens"], 1);
assert!(first_token_response["meta_info"]["finish_reason"].is_null());
}
#[test]
fn test_throughput_metrics() {
// Test throughput calculation patterns from bench_one_batch_server.py
let batch_size = 16;
let input_len = 1024;
let output_len = 16;
let ttft = 0.5; // seconds
let total_latency = 2.0; // seconds
// Input throughput = batch_size * input_len / ttft
let input_throughput = (batch_size as f64) * (input_len as f64) / ttft;
assert!((input_throughput - 32768.0).abs() < 0.01);
// Output throughput = batch_size * output_len / (latency - ttft)
let output_throughput = (batch_size as f64) * (output_len as f64) / (total_latency - ttft);
assert!((output_throughput - 170.67).abs() < 0.01);
}
#[test]
fn test_error_response_handling() {
// Test error response format from bench_one_batch_server.py
let error_response = json!({
"error": "Request has failed. Invalid input format."
});
assert!(error_response.get("error").is_some());
assert!(error_response["error"].as_str().unwrap().contains("failed"));
}
#[test]
fn test_structured_output_request() {
// Test structured output format (json_schema)
let structured_request = json!({
"text": "What is the capital of France? Answer in JSON.",
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 64,
"json_schema": "$$ANY$$"
},
"stream": false
});
assert_eq!(
structured_request["sampling_params"]["json_schema"],
"$$ANY$$"
);
}
#[test]
fn test_bootstrap_injection_with_benchmark_requests() {
use sglang_router_rs::core::{WorkerFactory, WorkerType};
// Test bootstrap injection with actual benchmark request patterns
let mut benchmark_request = json!({
"input_ids": vec![vec![1, 2, 3, 4]; 16], // Batch size 16
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 8,
"ignore_eos": true
},
"return_logprob": true,
"stream": true
});
// Create a prefill worker to simulate injection
let prefill_worker =
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let batch_size = 16;
let hostname = get_hostname(prefill_worker.url());
benchmark_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
benchmark_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
benchmark_request["bootstrap_room"] =
json!((0..batch_size).map(|_| 12345u64).collect::<Vec<_>>());
// Verify bootstrap fields match batch size
assert_eq!(
benchmark_request["bootstrap_host"]
.as_array()
.unwrap()
.len(),
batch_size
);
assert_eq!(
benchmark_request["bootstrap_port"]
.as_array()
.unwrap()
.len(),
batch_size
);
assert_eq!(
benchmark_request["bootstrap_room"]
.as_array()
.unwrap()
.len(),
batch_size
);
// Verify original fields are preserved
assert_eq!(benchmark_request["return_logprob"], true);
assert_eq!(benchmark_request["stream"], true);
}
#[test]
fn test_server_info_response_format() {
// Test server info format expected by bench_one_batch_server.py
let server_info = json!({
"internal_states": [{
"avg_spec_accept_length": 3.5,
"last_gen_throughput": 2048.5,
"load": 16
}],
"prefill": [
{"url": "http://prefill1:8080", "load": 10},
{"url": "http://prefill2:8080", "load": 20}
],
"decode": [
{"url": "http://decode1:8080", "load": 5},
{"url": "http://decode2:8080", "load": 15}
]
});
// Verify structure matches what benchmark expects
assert!(server_info["internal_states"][0]["avg_spec_accept_length"].is_f64());
assert!(server_info["internal_states"][0]["last_gen_throughput"].is_f64());
assert!(server_info["prefill"].is_array());
assert!(server_info["decode"].is_array());
}
// ========================================================================
// Comprehensive Endpoint Coverage Test
// ========================================================================
#[test]
fn test_pd_endpoints_coverage() {
// Document all endpoints from Python mini_lb.py and verify implementation status
let implemented_endpoints = vec![
("/health", "GET", true),
("/health_generate", "GET", true), // Note: Python uses POST, we use GET
("/get_server_info", "GET", true),
("/v1/models", "GET", true),
("/get_model_info", "GET", true),
("/generate", "POST", true),
("/v1/chat/completions", "POST", true),
("/v1/completions", "POST", true),
("/flush_cache", "POST", true),
("/get_loads", "GET", true),
("/register", "POST", false), // NOT IMPLEMENTED - needs dynamic worker management
];
let implemented_count = implemented_endpoints
.iter()
.filter(|(_, _, impl_status)| *impl_status)
.count();
let total_count = implemented_endpoints.len();
// We've implemented 10 out of 11 endpoints (register is not needed for Phase 1/2)
assert_eq!(implemented_count, 10);
assert_eq!(total_count, 11);
// Document the missing endpoint
let missing: Vec<_> = implemented_endpoints
.iter()
.filter(|(_, _, impl_status)| !impl_status)
.map(|(endpoint, method, _)| format!("{} {}", method, endpoint))
.collect();
assert_eq!(missing, vec!["POST /register"]);
}
#[test]
fn test_large_batch_bootstrap_injection() {
// Test bootstrap injection performance with very large batches
// This simulates the bench_one_batch_server.py scenario
let large_batch_sizes = vec![1024, 4096, 8192];
for batch_size in large_batch_sizes {
let start = std::time::Instant::now();
// Simulate a large batch request
let mut large_batch_request = json!({
"input_ids": vec![vec![1, 2, 3, 4]; batch_size],
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 16,
},
"stream": true
});
// Create a prefill worker to simulate injection
let prefill_worker =
WorkerFactory::create_prefill("http://prefill:8080".to_string(), Some(9000));
// Extract bootstrap port from worker type
let bootstrap_port = match prefill_worker.worker_type() {
WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = get_hostname(prefill_worker.url());
large_batch_request["bootstrap_host"] = json!(vec![hostname; batch_size]);
large_batch_request["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]);
large_batch_request["bootstrap_room"] = json!((0..batch_size)
.map(|_| rand::random::<u64>())
.collect::<Vec<_>>());
let elapsed = start.elapsed();
// Verify bootstrap fields are correctly sized
assert_eq!(
large_batch_request["bootstrap_host"]
.as_array()
.unwrap()
.len(),
batch_size
);
assert_eq!(
large_batch_request["bootstrap_port"]
.as_array()
.unwrap()
.len(),
batch_size
);
assert_eq!(
large_batch_request["bootstrap_room"]
.as_array()
.unwrap()
.len(),
batch_size
);
// Bootstrap injection should be reasonably fast even for large batches
println!(
"Bootstrap injection for batch_size {} took {:?}",
batch_size, elapsed
);
assert!(
elapsed.as_millis() < 1000,
"Bootstrap injection took too long for batch size {}",
batch_size
);
}
}
#[test]
fn test_payload_size_calculation() {
// Test payload size estimation for bench_one_batch_server.py scenarios
let test_cases = vec![
(1, 1024, 16), // Small batch
(16, 1024, 16), // Medium batch
(64, 1024, 16), // Large batch
(8192, 4096, 5), // Benchmark scenario
];
for (batch_size, input_len, _output_len) in test_cases {
// Estimate payload size (rough calculation)
// Each token is ~4 bytes (i32), plus JSON overhead
let tokens_size = batch_size * input_len * 4; // 4 bytes per token
let json_overhead = batch_size * 100; // ~100 bytes overhead per request
let total_size = tokens_size + json_overhead;
println!(
"Batch size: {}, Input len: {}, Estimated payload: {} MB",
batch_size,
input_len,
total_size / (1024 * 1024)
);
// For the benchmark case (8192, 4096), this should be ~134 MB
if batch_size == 8192 && input_len == 4096 {
assert!(
total_size > 100 * 1024 * 1024,
"Benchmark payload should be > 100MB"
);
assert!(
total_size < 200 * 1024 * 1024,
"Benchmark payload should be < 200MB"
);
}
}
}
#[test]
fn test_policy_type_to_pd_selection_policy_mapping() {
// Test that PDSelectionPolicy doesn't include RoundRobin
let pd_policy_count = 3; // Random, PowerOfTwo, CacheAware
assert_eq!(
pd_policy_count, 3,
"PDSelectionPolicy should have exactly 3 variants"
);
// Verify that each PDSelectionPolicy variant can be created
let _random = PDSelectionPolicy::Random;
let _po2 = PDSelectionPolicy::PowerOfTwo;
let _cache_aware = PDSelectionPolicy::CacheAware {
cache_threshold: 0.5,
balance_abs_threshold: 32,
balance_rel_threshold: 1.1,
};
}
}

View File

@@ -0,0 +1,320 @@
//! Integration tests for tokenizers using real tokenizer data
//!
//! These tests download the TinyLlama tokenizer from HuggingFace to verify our tokenizer
//! implementation works correctly with real-world tokenizer files.
mod common;
use common::{ensure_tokenizer_cached, EXPECTED_HASHES, TEST_PROMPTS};
use sglang_router_rs::tokenizer::{
factory, huggingface::HuggingFaceTokenizer, sequence::Sequence, stop::*, stream::DecodeStream,
traits::*,
};
use std::sync::Arc;
const LONG_TEST_PROMPTS: [(&str, &str); 6] = [
("Tell me about the following text.", "Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris nisi ut aliquip ex ea commodo consequat."),
("Tell me about the following text.", "Duis aute irure dolor in reprehenderit in voluptate velit esse cillum dolore eu fugiat nulla pariatur. Excepteur sint occaecat cupidatat non proident, sunt in culpa qui officia deserunt mollit anim id est laborum."),
("Tell me about the following text.", "Sed ut perspiciatis unde omnis iste natus error sit voluptatem accusantium doloremque laudantium, totam rem aperiam, eaque ipsa quae ab illo inventore veritatis et quasi architecto beatae vitae dicta sunt explicabo. Nemo enim ipsam voluptatem quia voluptas sit aspernatur aut odit aut fugit, sed quia consequuntur magni dolores eos qui ratione voluptatem sequi nesciunt."),
("Tell me about the following text.", "Neque porro quisquam est, qui dolorem ipsum quia dolor sit amet, consectetur, adipisci velit, sed quia non numquam eius modi tempora incidunt ut labore et dolore magnam aliquam quaerat voluptatem."),
// Tennis-themed prompt for variety
("Tell me about the following text.", "In the ancient realm of Tennisia, the very magic of the land is drawn from the sport itself. Forehands light the skies, backhands carve the earth, and serves rumble like thunder across kingdoms. At the center of this balance lie four sacred Grand Slam relics: the Sapphire Trophy of Melbourne, the Emerald Chalice of Paris, the Ruby Crown of London, and the Diamond Orb of New York. Together, they keep the game's spirit alive.
But the relics are scattered, guarded by champions of legendary skill. The first is the Fire King of Clay, ruler of the crimson courts, whose topspin arcs blaze high and heavy, scorching all who dare stand across from him. The second is the Tempest Trickster, master of the baseline fortress, whose footwork and precision can turn back any storm, and whose returns arrive as if pulled by invisible strings. The third is the Shadow-Dancer of the Highlands, a tactician who thrives in the long rallies of twilight, changing pace and spin until opponents lose their rhythm. The fourth and final guardian is a towering Diamond Titan, a net-charging colossus whose volleys shatter the air itself.
Into this arena of gods steps the Silver-Wristed Knight — a player of impossible grace, whose game is an art form. His quest: to claim each relic not for glory, but to restore harmony to the rankings of the realm.
He travels across the Kingdom of Clay, where the points stretch like marathons and the air tastes of iron; through the Grasslands of London, where the ball skids low and the margins are razor-thin; over the Hard Courts of the East, where rallies turn into duels of endurance; and finally to the Cathedral of Lights in New York, where night matches burn with fevered energy.
Each battle is played under enchanted floodlights, the lines patrolled by spectral line judges whose calls are final. The crowd's roar swells with every break point, and the Silver-Wristed Knight's racket glows brightest when the match teeters at deuce. There are moments when doubt grips him — when his serve falters or his touch deserts him — but each challenge teaches a new stroke, culminating in the legendary Forehand of Dawn.
When the last relic is claimed, he stands not as a conqueror but as a custodian of the game, knowing that rivalries forge the very magic he protects. The balance is restored — until the next season begins."),
// Emoji stress test
("Tell me about the following text.", "😀😃😄😁😆🥹😅😂🤣🥲☺️😊😇🙂🙃😉🤩😎 🤪🥳🤓🙄🤪😵👻")
];
fn compute_hashes_for_tokenizer<E: Encoder>(tokenizer: &E, prompts: &[&str]) -> Vec<u64> {
prompts
.iter()
.map(|&prompt| {
tokenizer
.encode(prompt)
.expect("Failed to encode prompt")
.get_hash()
})
.collect()
}
#[test]
fn test_huggingface_tokenizer_hashes() {
let tokenizer_path = ensure_tokenizer_cached();
let tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
.expect("Failed to load HuggingFace tokenizer");
let prompt_hashes = compute_hashes_for_tokenizer(&tokenizer, &TEST_PROMPTS);
println!(
"HF Tokenizer: {:?}\nComputed Hashes: {:?}\nExpected Hashes: {:?}",
tokenizer_path, prompt_hashes, EXPECTED_HASHES
);
assert_eq!(prompt_hashes, EXPECTED_HASHES);
}
#[test]
fn test_tokenizer_encode_decode_lifecycle() {
let tokenizer_path = ensure_tokenizer_cached();
let tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
.expect("Failed to load HuggingFace tokenizer");
for prompt in TEST_PROMPTS.iter() {
let encoding = tokenizer.encode(prompt).expect("Failed to encode prompt");
let decoded = tokenizer
.decode(encoding.token_ids(), false)
.expect("Failed to decode token_ids");
assert_eq!(decoded, *prompt, "Encode-decode mismatch for: {}", prompt);
}
}
#[test]
fn test_sequence_operations() {
let tokenizer_path = ensure_tokenizer_cached();
let tokenizer = Arc::new(
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
.expect("Failed to load tokenizer"),
);
for prompt in TEST_PROMPTS.iter() {
let encoding = tokenizer.encode(prompt).expect("Failed to encode prompt");
// Test Sequence with append_text
let mut sequence = Sequence::new(tokenizer.clone());
sequence.append_text(prompt).expect("Failed to append text");
assert_eq!(
sequence.len(),
encoding.token_ids().len(),
"Sequence length mismatch"
);
assert_eq!(sequence.text().unwrap(), *prompt, "Sequence text mismatch");
// Test incremental decoding with append_token
let mut decoder = Sequence::new(tokenizer.clone());
let mut output = String::new();
for token_id in encoding.token_ids() {
let text = decoder
.append_token(*token_id)
.expect("Failed to append token");
output.push_str(&text);
}
assert_eq!(decoder.len(), sequence.len(), "Decoder length mismatch");
assert_eq!(
decoder.token_ids(),
sequence.token_ids(),
"Token IDs mismatch"
);
assert_eq!(output, *prompt, "Incremental decode mismatch");
}
}
#[test]
fn test_decode_stream() {
let tokenizer_path = ensure_tokenizer_cached();
let tokenizer = Arc::new(
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
.expect("Failed to load tokenizer"),
);
for prompt in TEST_PROMPTS.iter() {
let encoding = tokenizer.encode(prompt).expect("Failed to encode prompt");
let mut decoder = DecodeStream::new(tokenizer.clone(), &[], false);
let mut output = String::new();
for token_id in encoding.token_ids() {
if let Some(text) = decoder.step(*token_id).expect("Failed to decode token") {
output.push_str(&text);
}
}
assert_eq!(output, *prompt, "DecodeStream output mismatch");
}
}
#[test]
fn test_long_sequence_incremental_decode_with_prefill() {
let tokenizer_path = ensure_tokenizer_cached();
let tokenizer = Arc::new(
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
.expect("Failed to load tokenizer"),
);
for (input_text, output_text) in LONG_TEST_PROMPTS.iter() {
let input_encoding = tokenizer
.encode(input_text)
.expect("Failed to encode input");
let output_encoding = tokenizer
.encode(output_text)
.expect("Failed to encode output");
let mut decoder = DecodeStream::new(tokenizer.clone(), input_encoding.token_ids(), false);
let mut output = String::new();
for token_id in output_encoding.token_ids() {
if let Some(text) = decoder.step(*token_id).expect("Failed to decode token") {
output.push_str(&text);
}
}
assert_eq!(output.trim(), *output_text, "Long sequence decode mismatch");
}
}
#[test]
fn test_stop_sequence_decoder() {
let tokenizer_path = ensure_tokenizer_cached();
let tokenizer = Arc::new(
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
.expect("Failed to load tokenizer"),
);
// Test with various stop sequences
let test_cases = vec![
(
"Hello world! Stop here. Continue after.",
"Stop",
"Hello world! ",
),
("Testing stop sequences.", ".", "Testing stop sequences"),
("No stop sequence here", "xyz", "No stop sequence here"),
];
for (input, stop_seq, expected) in test_cases {
let config = StopSequenceConfig::default().with_stop_sequence(stop_seq);
let mut decoder = StopSequenceDecoder::new(tokenizer.clone(), config, false);
let encoding = tokenizer.encode(input).expect("Failed to encode");
let mut output = String::new();
let mut stopped = false;
for token_id in encoding.token_ids() {
match decoder.process_token(*token_id).unwrap() {
SequenceDecoderOutput::Text(text) => output.push_str(&text),
SequenceDecoderOutput::StoppedWithText(text) => {
output.push_str(&text);
stopped = true;
break;
}
SequenceDecoderOutput::Stopped => {
stopped = true;
break;
}
SequenceDecoderOutput::Held => {}
}
}
if !stopped {
// Flush any remaining text
if let SequenceDecoderOutput::Text(text) = decoder.flush() {
output.push_str(&text);
}
}
println!(
"Input: '{}', Stop: '{}', Output: '{}', Expected: '{}'",
input, stop_seq, output, expected
);
// The test should check if output starts with expected
// since stop sequences might not be perfectly aligned with token boundaries
assert!(
output.starts_with(expected) || output == input,
"Stop sequence test failed"
);
}
}
#[test]
fn test_factory_creation() {
// Test factory creation method
let tokenizer_path = ensure_tokenizer_cached();
let tokenizer = factory::create_tokenizer(tokenizer_path.to_str().unwrap())
.expect("Failed to create tokenizer via factory");
let encoding = tokenizer.encode(TEST_PROMPTS[0]).expect("Failed to encode");
let decoded = tokenizer
.decode(encoding.token_ids(), false)
.expect("Failed to decode");
assert_eq!(decoded, TEST_PROMPTS[0]);
}
#[test]
fn test_batch_encoding() {
let tokenizer_path = ensure_tokenizer_cached();
let tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
.expect("Failed to load tokenizer");
let encodings = tokenizer
.encode_batch(&TEST_PROMPTS)
.expect("Failed to batch encode");
assert_eq!(encodings.len(), TEST_PROMPTS.len());
for (i, encoding) in encodings.iter().enumerate() {
let decoded = tokenizer
.decode(encoding.token_ids(), false)
.expect("Failed to decode");
assert_eq!(decoded, TEST_PROMPTS[i]);
}
}
#[test]
fn test_special_tokens() {
use sglang_router_rs::tokenizer::traits::Tokenizer as TokenizerTrait;
let tokenizer_path = ensure_tokenizer_cached();
let tokenizer = HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
.expect("Failed to load tokenizer");
let special_tokens = tokenizer.get_special_tokens();
// TinyLlama should have at least BOS and EOS tokens
assert!(special_tokens.bos_token.is_some());
assert!(special_tokens.eos_token.is_some());
println!("Special tokens: {:?}", special_tokens);
}
#[test]
fn test_thread_safety() {
use std::thread;
let tokenizer_path = ensure_tokenizer_cached();
let tokenizer = Arc::new(
HuggingFaceTokenizer::from_file(tokenizer_path.to_str().unwrap())
.expect("Failed to load tokenizer"),
);
let handles: Vec<_> = TEST_PROMPTS
.iter()
.map(|&prompt| {
let tokenizer_clone = tokenizer.clone();
thread::spawn(move || {
let encoding = tokenizer_clone
.encode(prompt)
.expect("Failed to encode in thread");
let decoded = tokenizer_clone
.decode(encoding.token_ids(), false)
.expect("Failed to decode in thread");
assert_eq!(decoded, prompt);
})
})
.collect();
for handle in handles {
handle.join().expect("Thread panicked");
}
}

View File

@@ -0,0 +1,183 @@
//! DeepSeek V3 Parser Integration Tests
use sglang_router_rs::tool_parser::{DeepSeekParser, ParseState, StreamResult, ToolParser};
#[tokio::test]
async fn test_deepseek_complete_parsing() {
let parser = DeepSeekParser::new();
// Test single tool call
let input = r#"Let me help you with that.
<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
```json
{"location": "Tokyo", "units": "celsius"}
```<tool▁call▁end><tool▁calls▁end>
The weather in Tokyo is..."#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
// Verify arguments
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["location"], "Tokyo");
assert_eq!(args["units"], "celsius");
}
#[tokio::test]
async fn test_deepseek_multiple_tools() {
let parser = DeepSeekParser::new();
let input = r#"<tool▁calls▁begin>
<tool▁call▁begin>function<tool▁sep>search
```json
{"query": "rust programming"}
```<tool▁call▁end>
<tool▁call▁begin>function<tool▁sep>translate
```json
{"text": "Hello World", "to": "ja"}
```<tool▁call▁end>
<tool▁calls▁end>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "search");
assert_eq!(result[1].function.name, "translate");
}
#[tokio::test]
async fn test_deepseek_streaming() {
let parser = DeepSeekParser::new();
let mut state = ParseState::new();
// Simulate streaming chunks
let chunks = vec![
"<tool▁calls▁begin><tool▁call▁begin>",
"function<tool▁sep>get_weather\n",
"```json\n",
r#"{"location": "#,
r#""Beijing", "#,
r#""units": "metric"}"#,
"\n```<tool▁call▁end><tool▁calls▁end>",
];
let mut found_name = false;
let mut found_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
match result {
StreamResult::ToolName { name, .. } => {
assert_eq!(name, "get_weather");
found_name = true;
}
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
found_complete = true;
}
_ => {}
}
}
assert!(found_name || found_complete);
}
#[tokio::test]
async fn test_deepseek_nested_json() {
let parser = DeepSeekParser::new();
let input = r#"<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>process
```json
{
"data": {
"nested": {
"deep": [1, 2, 3]
}
}
}
```<tool▁call▁end><tool▁calls▁end>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "process");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert!(args["data"]["nested"]["deep"].is_array());
}
#[test]
fn test_deepseek_format_detection() {
let parser = DeepSeekParser::new();
// Should detect DeepSeek format
assert!(parser.detect_format("<tool▁calls▁begin>"));
assert!(parser.detect_format("text with <tool▁calls▁begin> marker"));
// Should not detect other formats
assert!(!parser.detect_format("[TOOL_CALLS]"));
assert!(!parser.detect_format("<tool_call>"));
assert!(!parser.detect_format("plain text"));
}
#[tokio::test]
async fn test_deepseek_malformed_json_handling() {
let parser = DeepSeekParser::new();
// Malformed JSON should be skipped
let input = r#"<tool▁calls▁begin>
<tool▁call▁begin>function<tool▁sep>broken
```json
{invalid json}
```<tool▁call▁end>
<tool▁call▁begin>function<tool▁sep>valid
```json
{"key": "value"}
```<tool▁call▁end>
<tool▁calls▁end>"#;
let result = parser.parse_complete(input).await.unwrap();
// Only the valid tool call should be parsed
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "valid");
}
#[tokio::test]
async fn test_normal_text_extraction() {
let parser = DeepSeekParser::new();
// Python extracts text before tool calls as normal_text
let input = r#"Let me help you with that.
<tool▁calls▁begin><tool▁call▁begin>function<tool▁sep>get_weather
```json
{"location": "Tokyo"}
```<tool▁call▁end><tool▁calls▁end>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
// TODO: Verify normal text extraction when parser returns it
// In Python: normal_text = "Let me help you with that."
}
#[tokio::test]
async fn test_multiple_tool_calls() {
let parser = DeepSeekParser::new();
let input = r#"<tool▁calls▁begin>
<tool▁call▁begin>function<tool▁sep>get_weather
```json
{"location": "Tokyo"}
```<tool▁call▁end>
<tool▁call▁begin>function<tool▁sep>get_weather
```json
{"location": "Paris"}
```<tool▁call▁end>
<tool▁calls▁end><end▁of▁sentence>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "get_weather");
assert_eq!(result[1].function.name, "get_weather");
}

View File

@@ -0,0 +1,330 @@
//! Edge Cases and Error Handling Tests
//!
//! Tests for malformed input, edge cases, and error recovery
use sglang_router_rs::tool_parser::{
JsonParser, MistralParser, ParseState, ParserRegistry, PythonicParser, QwenParser,
StreamResult, ToolParser,
};
#[tokio::test]
async fn test_empty_input() {
let registry = ParserRegistry::new();
let parsers = vec!["json", "mistral", "qwen", "pythonic", "llama"];
for parser_name in parsers {
let parser = registry
.get_parser(&format!("test-{}", parser_name))
.unwrap();
let result = parser.parse_complete("").await.unwrap();
assert_eq!(
result.len(),
0,
"Parser {} should return empty for empty input",
parser_name
);
}
}
#[tokio::test]
async fn test_plain_text_no_tools() {
let plain_text = "This is just a regular response with no tool calls whatsoever.";
let json_parser = JsonParser::new();
assert_eq!(
json_parser.parse_complete(plain_text).await.unwrap().len(),
0
);
let mistral_parser = MistralParser::new();
assert_eq!(
mistral_parser
.parse_complete(plain_text)
.await
.unwrap()
.len(),
0
);
let qwen_parser = QwenParser::new();
assert_eq!(
qwen_parser.parse_complete(plain_text).await.unwrap().len(),
0
);
let pythonic_parser = PythonicParser::new();
assert_eq!(
pythonic_parser
.parse_complete(plain_text)
.await
.unwrap()
.len(),
0
);
}
#[tokio::test]
async fn test_incomplete_json() {
let json_parser = JsonParser::new();
let incomplete_cases = vec![
r#"{"name": "test""#, // Missing closing brace
r#"{"name": "test", "arguments":"#, // Incomplete arguments
r#"{"name": "test", "arguments": {"#, // Incomplete nested object
];
for input in incomplete_cases {
let result = json_parser.parse_complete(input).await.unwrap();
assert_eq!(
result.len(),
0,
"Should not parse incomplete JSON: {}",
input
);
}
// This case might actually parse because [{"name": "test"}] is complete
// The trailing comma suggests more items but the first item is valid
let _result = json_parser
.parse_complete(r#"[{"name": "test"},"#)
.await
.unwrap();
// This could parse the first element or return empty - implementation dependent
}
#[tokio::test]
async fn test_malformed_mistral() {
let parser = MistralParser::new();
let malformed_cases = vec![
"[TOOL_CALLS]", // Missing array
"[TOOL_CALLS] {", // Not an array
"[TOOL_CALLS] [", // Incomplete array
"[TOOL_CALLS] [{]", // Invalid JSON in array
"[TOOL_CALLS] [{\"name\": }]", // Invalid value
];
for input in malformed_cases {
// Parser might return error or empty vec for malformed input
if let Ok(result) = parser.parse_complete(input).await {
assert_eq!(
result.len(),
0,
"Should not parse malformed Mistral: {}",
input
);
}
// Error is also acceptable for malformed input
}
}
#[tokio::test]
async fn test_missing_required_fields() {
let json_parser = JsonParser::new();
// Missing name field
let input = r#"{"arguments": {"x": 1}}"#;
let result = json_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0, "Should not parse without name field");
// Name is not a string
let input = r#"{"name": 123, "arguments": {}}"#;
let result = json_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0, "Should not parse with non-string name");
}
#[tokio::test]
async fn test_very_long_strings() {
let json_parser = JsonParser::new();
let long_string = "x".repeat(10000);
let input = format!(
r#"{{"name": "test", "arguments": {{"data": "{}"}}}}"#,
long_string
);
let result = json_parser.parse_complete(&input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["data"].as_str().unwrap().len(), 10000);
}
#[tokio::test]
async fn test_unicode_edge_cases() {
let json_parser = JsonParser::new();
// Various Unicode characters including emojis, CJK, RTL text
let input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍 مرحبا עולם"}}"#;
let result = json_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["text"], "Hello 世界 🌍 مرحبا עולם");
}
#[tokio::test]
async fn test_nested_brackets_in_strings() {
// Test that parsers correctly handle brackets within string literals
let mistral_parser = MistralParser::new();
let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"text": "Array: [1, 2, 3]"}}]"#;
let result = mistral_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["text"], "Array: [1, 2, 3]");
let pythonic_parser = PythonicParser::new();
let input = r#"[echo(text="List: [a, b, c]")]"#;
let result = pythonic_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["text"], "List: [a, b, c]");
}
#[tokio::test]
async fn test_multiple_formats_in_text() {
// Test that parsers don't get confused by other formats in the text
let json_parser = JsonParser::new();
let input = r#"
Here's some text with [TOOL_CALLS] that shouldn't trigger.
{"name": "actual_tool", "arguments": {}}
And some more text with <tool_call> tags.
"#;
let result = json_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "actual_tool");
}
#[tokio::test]
async fn test_escaped_characters() {
let json_parser = JsonParser::new();
let input = r#"{"name": "write", "arguments": {"content": "Line 1\nLine 2\r\nLine 3\tTabbed\\Backslash\"Quote"}}"#;
let result = json_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
let content = args["content"].as_str().unwrap();
assert!(content.contains('\n'));
assert!(content.contains('\t'));
assert!(content.contains('\\'));
assert!(content.contains('"'));
}
#[tokio::test]
async fn test_numeric_edge_cases() {
let json_parser = JsonParser::new();
let input = r#"{
"name": "calculate",
"arguments": {
"int": 42,
"float": 123.456,
"scientific": 1.23e-4,
"negative": -999,
"zero": 0,
"large": 9007199254740991
}
}"#;
let result = json_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["int"], 42);
assert_eq!(args["float"], 123.456);
assert_eq!(args["scientific"], 0.000123);
assert_eq!(args["negative"], -999);
assert_eq!(args["zero"], 0);
assert_eq!(args["large"], 9007199254740991i64);
}
#[tokio::test]
async fn test_null_and_boolean_values() {
let json_parser = JsonParser::new();
let input = r#"{
"name": "configure",
"arguments": {
"enabled": true,
"disabled": false,
"optional": null
}
}"#;
let result = json_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["enabled"], true);
assert_eq!(args["disabled"], false);
assert_eq!(args["optional"], serde_json::Value::Null);
}
#[tokio::test]
async fn test_partial_token_at_buffer_boundary() {
let parser = QwenParser::new();
let mut state = ParseState::new();
// Test case that would fail with the bug:
// Send exactly "<tool" which is a 5-character prefix of "<tool_call>\n"
let result = parser.parse_incremental("<tool", &mut state).await.unwrap();
assert!(matches!(result, StreamResult::Incomplete));
assert_eq!(state.buffer, "<tool");
// Complete the token
let result = parser
.parse_incremental(
"_call>\n{\"name\": \"test\", \"arguments\": {}}\n</tool_call>",
&mut state,
)
.await
.unwrap();
// Should successfully parse after completing
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "test");
}
_ => {
// In Phase 2 simplified streaming, might get Incomplete
// The important thing is it didn't fail to recognize the partial token
}
}
}
#[tokio::test]
async fn test_exact_prefix_lengths() {
let parser = QwenParser::new();
// Test various exact prefix lengths that would be missed by exclusive range
let test_cases = vec![
("<", 1), // 1-char prefix
("<t", 2), // 2-char prefix
("<tool", 5), // 5-char prefix (the main bug case)
("<tool_call", 10), // 10-char prefix
("<tool_call>", 11), // 11-char prefix (full start without \n)
];
for (prefix, expected_len) in test_cases {
let mut state = ParseState::new();
let result = parser.parse_incremental(prefix, &mut state).await.unwrap();
assert!(
matches!(result, StreamResult::Incomplete),
"Prefix '{}' (len {}) should be incomplete",
prefix,
expected_len
);
assert_eq!(
state.buffer, prefix,
"Buffer should contain the prefix '{}'",
prefix
);
}
}

View File

@@ -0,0 +1,194 @@
//! GLM-4 MoE Parser Integration Tests
use sglang_router_rs::tool_parser::{Glm4MoeParser, ParseState, StreamResult, ToolParser};
#[tokio::test]
async fn test_glm4_complete_parsing() {
let parser = Glm4MoeParser::new();
// Test single tool call
let input = r#"Let me search for that.
<tool_call>get_weather
<arg_key>city</arg_key>
<arg_value>Beijing</arg_value>
<arg_key>date</arg_key>
<arg_value>2024-12-25</arg_value>
</tool_call>
The weather will be..."#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
// Verify arguments
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["city"], "Beijing");
assert_eq!(args["date"], "2024-12-25");
}
#[tokio::test]
async fn test_glm4_multiple_tools() {
let parser = Glm4MoeParser::new();
let input = r#"<tool_call>search
<arg_key>query</arg_key>
<arg_value>rust tutorials</arg_value>
</tool_call>
<tool_call>translate
<arg_key>text</arg_key>
<arg_value>Hello World</arg_value>
<arg_key>target_lang</arg_key>
<arg_value>zh</arg_value>
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "search");
assert_eq!(result[1].function.name, "translate");
}
#[tokio::test]
async fn test_glm4_type_conversion() {
let parser = Glm4MoeParser::new();
// Test various value types
let input = r#"<tool_call>process
<arg_key>count</arg_key>
<arg_value>42</arg_value>
<arg_key>rate</arg_key>
<arg_value>1.5</arg_value>
<arg_key>enabled</arg_key>
<arg_value>true</arg_value>
<arg_key>data</arg_key>
<arg_value>null</arg_value>
<arg_key>text</arg_key>
<arg_value>string value</arg_value>
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["count"], 42);
assert_eq!(args["rate"], 1.5);
assert_eq!(args["enabled"], true);
assert_eq!(args["data"], serde_json::Value::Null);
assert_eq!(args["text"], "string value");
}
#[tokio::test]
async fn test_glm4_streaming() {
let parser = Glm4MoeParser::new();
let mut state = ParseState::new();
// Simulate streaming chunks
let chunks = vec![
"<tool_call>",
"get_weather\n",
"<arg_key>city</arg_key>\n",
"<arg_value>Shanghai</arg_value>\n",
"<arg_key>units</arg_key>\n",
"<arg_value>celsius</arg_value>\n",
"</tool_call>",
];
let mut found_name = false;
let mut found_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
match result {
StreamResult::ToolName { name, .. } => {
assert_eq!(name, "get_weather");
found_name = true;
}
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
found_complete = true;
}
_ => {}
}
}
assert!(found_name || found_complete);
}
#[test]
fn test_glm4_format_detection() {
let parser = Glm4MoeParser::new();
// Should detect GLM-4 format
assert!(parser.detect_format("<tool_call>"));
assert!(parser.detect_format("text with <tool_call> marker"));
// Should not detect other formats
assert!(!parser.detect_format("[TOOL_CALLS]"));
assert!(!parser.detect_format("<tool▁calls▁begin>"));
assert!(!parser.detect_format("plain text"));
}
#[tokio::test]
async fn test_glm4_python_literal_values() {
let parser = Glm4MoeParser::new();
// Test Python-style boolean values
let input = r#"<tool_call>config
<arg_key>debug</arg_key>
<arg_value>True</arg_value>
<arg_key>verbose</arg_key>
<arg_value>False</arg_value>
<arg_key>optional</arg_key>
<arg_value>None</arg_value>
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["debug"], true);
assert_eq!(args["verbose"], false);
assert_eq!(args["optional"], serde_json::Value::Null);
}
#[tokio::test]
async fn test_python_literals() {
let parser = Glm4MoeParser::new();
let input = r#"<tool_call>test_func
<arg_key>bool_true</arg_key>
<arg_value>True</arg_value>
<arg_key>bool_false</arg_key>
<arg_value>False</arg_value>
<arg_key>none_val</arg_key>
<arg_value>None</arg_value>
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test_func");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["bool_true"], true);
assert_eq!(args["bool_false"], false);
assert_eq!(args["none_val"], serde_json::Value::Null);
}
#[tokio::test]
async fn test_nested_values() {
let parser = Glm4MoeParser::new();
let input = r#"<tool_call>process
<arg_key>data</arg_key>
<arg_value>{"nested": {"key": "value"}}</arg_value>
<arg_key>list</arg_key>
<arg_value>[1, 2, 3]</arg_value>
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert!(args["data"].is_object());
assert!(args["list"].is_array());
}

View File

@@ -0,0 +1,201 @@
//! GPT-OSS Parser Integration Tests
use sglang_router_rs::tool_parser::{GptOssParser, ParseState, StreamResult, ToolParser};
#[tokio::test]
async fn test_gpt_oss_complete_parsing() {
let parser = GptOssParser::new();
// Test single tool call
let input = r#"Let me search for that information.
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "rust programming", "limit": 10}<|call|>
Here are the results..."#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "search");
// Verify arguments
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["query"], "rust programming");
assert_eq!(args["limit"], 10);
}
#[tokio::test]
async fn test_gpt_oss_multiple_tools() {
let parser = GptOssParser::new();
let input = r#"<|channel|>commentary to=functions.get_weather<|constrain|>json<|message|>{"location": "Paris"}<|call|>commentary
<|channel|>commentary to=functions.search<|constrain|>json<|message|>{"query": "Paris tourism"}<|call|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "get_weather");
assert_eq!(result[1].function.name, "search");
}
#[tokio::test]
async fn test_gpt_oss_with_namespace() {
let parser = GptOssParser::new();
// Test with different namespace patterns
let input = r#"<|channel|>commentary to=api.users.create<|constrain|>json<|message|>{"name": "John", "email": "john@example.com"}<|call|>
<|channel|>commentary to=tools.calculator.add<|constrain|>json<|message|>{"x": 10, "y": 20}<|call|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "create"); // Should extract last part
assert_eq!(result[1].function.name, "add");
}
#[tokio::test]
async fn test_gpt_oss_with_assistant_prefix() {
let parser = GptOssParser::new();
// Test with <|start|>assistant prefix
let input = r#"<|start|>assistant<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"key": "value"}<|call|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
}
#[tokio::test]
async fn test_gpt_oss_empty_args() {
let parser = GptOssParser::new();
// Test with empty arguments
let input =
r#"<|channel|>commentary to=functions.get_time<|constrain|>json<|message|>{}<|call|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_time");
assert_eq!(result[0].function.arguments, "{}");
}
#[tokio::test]
async fn test_gpt_oss_streaming() {
let parser = GptOssParser::new();
let mut state = ParseState::new();
// Simulate streaming chunks
let chunks = vec![
"<|channel|>commentary to=",
"functions.calculate",
"<|constrain|>json<|message|>",
r#"{"x": 10"#,
r#", "y": 20}"#,
"<|call|>",
];
let mut found_name = false;
let mut found_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
match result {
StreamResult::ToolName { name, .. } => {
assert_eq!(name, "calculate");
found_name = true;
}
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "calculate");
found_complete = true;
}
_ => {}
}
}
assert!(found_name || found_complete);
}
#[test]
fn test_gpt_oss_format_detection() {
let parser = GptOssParser::new();
// Should detect GPT-OSS format
assert!(parser.detect_format("<|channel|>commentary to="));
assert!(parser.detect_format("<|channel|>commentary"));
assert!(parser.detect_format("text with <|channel|>commentary to= marker"));
// Should not detect other formats
assert!(!parser.detect_format("[TOOL_CALLS]"));
assert!(!parser.detect_format("<tool_call>"));
assert!(!parser.detect_format("plain text"));
}
#[tokio::test]
async fn test_gpt_oss_with_whitespace() {
let parser = GptOssParser::new();
// Test with whitespace after function name
let input = r#"<|channel|>commentary to=functions.test <|constrain|>json<|message|>{"key": "value"}<|call|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
}
#[tokio::test]
async fn test_gpt_oss_complex_json() {
let parser = GptOssParser::new();
// Test with complex nested JSON
let input = r#"<|channel|>commentary to=functions.process<|constrain|>json<|message|>{
"nested": {
"data": [1, 2, 3],
"config": {
"enabled": true
}
}
}<|call|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "process");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert!(args["nested"]["data"].is_array());
assert_eq!(args["nested"]["config"]["enabled"], true);
}
#[tokio::test]
async fn test_commentary_without_function() {
let parser = GptOssParser::new();
// Python should extract commentary as normal text
let input = r#"<|channel|>commentary<|message|>**Action plan**: 1. Do X 2. Do Y<|end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0); // No tool calls
// TODO: Verify normal text = "**Action plan**: 1. Do X 2. Do Y"
}
#[tokio::test]
async fn test_final_channel() {
let parser = GptOssParser::new();
let input = r#"<|channel|>commentary to=functions.test<|constrain|>json<|message|>{"x": 1}<|call|>
<|channel|>final<|message|>The result is calculated.<|return|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
// TODO: Verify normal text = "The result is calculated."
}
#[tokio::test]
async fn test_mixed_commentary_and_calls() {
let parser = GptOssParser::new();
let input = r#"<|channel|>commentary<|message|>Let me think<|end|>
<|channel|>commentary to=functions.calc<|constrain|>json<|message|>{"x": 5}<|call|>
<|channel|>commentary<|message|>Processing...<|end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "calc");
// TODO: Verify normal text = "Let me think Processing..."
}

View File

@@ -0,0 +1,147 @@
//! JSON Parser Integration Tests
//!
//! Tests for the JSON parser which handles OpenAI, Claude, and generic JSON formats
use serde_json::json;
use sglang_router_rs::tool_parser::{JsonParser, ToolParser};
#[tokio::test]
async fn test_simple_json_tool_call() {
let parser = JsonParser::new();
let input = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["location"], "San Francisco");
}
#[tokio::test]
async fn test_json_array_of_tools() {
let parser = JsonParser::new();
let input = r#"[
{"name": "get_weather", "arguments": {"location": "SF"}},
{"name": "search", "arguments": {"query": "news"}}
]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "get_weather");
assert_eq!(result[1].function.name, "search");
}
#[tokio::test]
async fn test_json_with_parameters_key() {
let parser = JsonParser::new();
let input = r#"{"name": "calculate", "parameters": {"x": 10, "y": 20}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "calculate");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["x"], 10);
assert_eq!(args["y"], 20);
}
#[tokio::test]
async fn test_json_extraction_from_text() {
let parser = JsonParser::new();
let input = r#"I'll help you with that. {"name": "search", "arguments": {"query": "rust"}} Let me search for that."#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "search");
}
#[tokio::test]
async fn test_json_with_nested_objects() {
let parser = JsonParser::new();
let input = r#"{
"name": "update_config",
"arguments": {
"settings": {
"theme": "dark",
"language": "en",
"notifications": {
"email": true,
"push": false
}
}
}
}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "update_config");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["settings"]["theme"], "dark");
assert_eq!(args["settings"]["notifications"]["email"], true);
}
#[tokio::test]
async fn test_json_with_special_characters() {
let parser = JsonParser::new();
let input = r#"{"name": "echo", "arguments": {"text": "Line 1\nLine 2\tTabbed", "path": "C:\\Users\\test"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["text"], "Line 1\nLine 2\tTabbed");
assert_eq!(args["path"], "C:\\Users\\test");
}
#[tokio::test]
async fn test_json_with_unicode() {
let parser = JsonParser::new();
let input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍", "emoji": "😊"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["text"], "Hello 世界 🌍");
assert_eq!(args["emoji"], "😊");
}
#[tokio::test]
async fn test_json_empty_arguments() {
let parser = JsonParser::new();
let input = r#"{"name": "ping", "arguments": {}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "ping");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args, json!({}));
}
#[tokio::test]
async fn test_json_invalid_format() {
let parser = JsonParser::new();
// Missing closing brace
let input = r#"{"name": "test", "arguments": {"key": "value""#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0);
// Not JSON at all
let input = "This is just plain text";
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0);
}
#[tokio::test]
async fn test_json_format_detection() {
let parser = JsonParser::new();
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
assert!(parser.detect_format(r#"[{"name": "test"}]"#));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
}

View File

@@ -0,0 +1,160 @@
//! Kimi K2 Parser Integration Tests
use sglang_router_rs::tool_parser::{KimiK2Parser, ParseState, StreamResult, ToolParser};
#[tokio::test]
async fn test_kimik2_complete_parsing() {
let parser = KimiK2Parser::new();
// Test single tool call
let input = r#"Let me help you with that.
<|tool_calls_section_begin|>
<|tool_call_begin|>functions.get_weather:0<|tool_call_argument_begin|>{"location": "Tokyo", "units": "celsius"}<|tool_call_end|>
<|tool_calls_section_end|>
The weather in Tokyo is..."#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
// Verify arguments
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["location"], "Tokyo");
assert_eq!(args["units"], "celsius");
}
#[tokio::test]
async fn test_kimik2_multiple_tools() {
let parser = KimiK2Parser::new();
let input = r#"<|tool_calls_section_begin|>
<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{"query": "rust tutorials"}<|tool_call_end|>
<|tool_call_begin|>functions.translate:1<|tool_call_argument_begin|>{"text": "Hello", "to": "ja"}<|tool_call_end|>
<|tool_calls_section_end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "search");
assert_eq!(result[1].function.name, "translate");
}
#[tokio::test]
async fn test_kimik2_with_whitespace() {
let parser = KimiK2Parser::new();
// Test with extra whitespace
let input = r#"<|tool_calls_section_begin|>
<|tool_call_begin|> functions.test:0 <|tool_call_argument_begin|> {"key": "value", "num": 42} <|tool_call_end|>
<|tool_calls_section_end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["key"], "value");
assert_eq!(args["num"], 42);
}
#[tokio::test]
async fn test_kimik2_streaming() {
let parser = KimiK2Parser::new();
let mut state = ParseState::new();
// Simulate streaming chunks
let chunks = vec![
"<|tool_calls_section_begin|>\n",
"<|tool_call_begin|>functions.",
"calculate:0",
"<|tool_call_argument_begin|>",
r#"{"x": 10, "#,
r#""y": 20}"#,
"<|tool_call_end|>\n",
"<|tool_calls_section_end|>",
];
let mut found_name = false;
let mut found_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
match result {
StreamResult::ToolName { name, .. } => {
assert_eq!(name, "calculate");
found_name = true;
}
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "calculate");
found_complete = true;
}
_ => {}
}
}
assert!(found_name || found_complete);
}
#[test]
fn test_kimik2_format_detection() {
let parser = KimiK2Parser::new();
// Should detect Kimi K2 format
assert!(parser.detect_format("<|tool_calls_section_begin|>"));
assert!(parser.detect_format("<|tool_call_begin|>"));
assert!(parser.detect_format("text with <|tool_calls_section_begin|> marker"));
// Should not detect other formats
assert!(!parser.detect_format("[TOOL_CALLS]"));
assert!(!parser.detect_format("<tool_call>"));
assert!(!parser.detect_format("plain text"));
}
#[tokio::test]
async fn test_kimik2_sequential_indices() {
let parser = KimiK2Parser::new();
// Test with proper sequential indexing
let input = r#"<|tool_calls_section_begin|>
<|tool_call_begin|>functions.first:0<|tool_call_argument_begin|>{"param": "a"}<|tool_call_end|>
<|tool_call_begin|>functions.second:1<|tool_call_argument_begin|>{"param": "b"}<|tool_call_end|>
<|tool_call_begin|>functions.third:2<|tool_call_argument_begin|>{"param": "c"}<|tool_call_end|>
<|tool_calls_section_end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0].function.name, "first");
assert_eq!(result[1].function.name, "second");
assert_eq!(result[2].function.name, "third");
}
#[tokio::test]
async fn test_function_index_extraction() {
let parser = KimiK2Parser::new();
let input = r#"Text before tool calls.
<|tool_calls_section_begin|>
<|tool_call_begin|>functions.search:0<|tool_call_argument_begin|>{"query": "rust"}<|tool_call_end|>
<|tool_call_begin|>functions.calc:1<|tool_call_argument_begin|>{"x": 10}<|tool_call_end|>
<|tool_calls_section_end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "search");
assert_eq!(result[1].function.name, "calc");
// TODO: Verify indices are preserved: 0 and 1
// TODO: Verify normal text = "Text before tool calls."
}
#[tokio::test]
async fn test_namespace_extraction() {
let parser = KimiK2Parser::new();
let input = r#"<|tool_calls_section_begin|>
<|tool_call_begin|>api.tools.search:0<|tool_call_argument_begin|>{"q": "test"}<|tool_call_end|>
<|tool_calls_section_end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "search"); // Should extract after last dot
}

View File

@@ -0,0 +1,424 @@
//! Llama Parser Integration Tests
//!
//! Tests for the Llama parser which handles <|python_tag|> format and plain JSON
use sglang_router_rs::tool_parser::{LlamaParser, ToolParser};
#[tokio::test]
async fn test_llama_python_tag_format() {
let parser = LlamaParser::new();
let input = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "search");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["query"], "weather");
}
#[tokio::test]
async fn test_llama_plain_json_fallback() {
let parser = LlamaParser::new();
let input = r#"{"name": "calculate", "arguments": {"x": 5, "y": 10}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "calculate");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["x"], 5);
assert_eq!(args["y"], 10);
}
#[tokio::test]
async fn test_llama_with_text_before() {
let parser = LlamaParser::new();
let input = r#"Let me help you with that. <|python_tag|>{"name": "get_time", "arguments": {"timezone": "UTC"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_time");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["timezone"], "UTC");
}
#[tokio::test]
async fn test_llama_with_nested_json() {
let parser = LlamaParser::new();
let input = r#"<|python_tag|>{
"name": "update_settings",
"arguments": {
"preferences": {
"theme": "dark",
"language": "en"
},
"notifications": true
}
}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "update_settings");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["preferences"]["theme"], "dark");
assert_eq!(args["notifications"], true);
}
#[tokio::test]
async fn test_llama_empty_arguments() {
let parser = LlamaParser::new();
// With python_tag
let input = r#"<|python_tag|>{"name": "ping", "arguments": {}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "ping");
// Plain JSON
let input = r#"{"name": "ping", "arguments": {}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "ping");
}
#[tokio::test]
async fn test_llama_format_detection() {
let parser = LlamaParser::new();
assert!(parser.detect_format(r#"<|python_tag|>{"name": "test"}"#));
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#));
assert!(!parser.detect_format("plain text"));
assert!(!parser.detect_format(r#"{"key": "value"}"#)); // No name field
}
#[tokio::test]
async fn test_llama_invalid_json_after_tag() {
let parser = LlamaParser::new();
let input = r#"<|python_tag|>{"name": invalid}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0);
}
#[tokio::test]
async fn test_llama_real_world_output() {
let parser = LlamaParser::new();
// Actual output from Llama 3.2 model - simplified for testing
let input = r#"I'll search for that information for you.
<|python_tag|>{"name": "web_search", "arguments": {"query": "Llama 3.2 model capabilities", "num_results": 5, "search_type": "recent"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "web_search");
// Test with nicely formatted JSON
let formatted_input = r#"<|python_tag|>{
"name": "get_current_time",
"arguments": {
"timezone": "America/New_York",
"format": "ISO8601"
}
}"#;
let result2 = parser.parse_complete(formatted_input).await.unwrap();
assert_eq!(result2.len(), 1);
assert_eq!(result2[0].function.name, "get_current_time");
}
#[tokio::test]
async fn test_llama_json_array_format() {
let parser = LlamaParser::new();
// Plain JSON array (should work as fallback)
let input = r#"[{"name": "func1", "arguments": {}}, {"name": "func2", "arguments": {}}]"#;
let result = parser.parse_complete(input).await.unwrap();
// Current implementation might handle this through JSON fallback
assert!(!result.is_empty());
}
#[tokio::test]
async fn test_single_json() {
// Test parsing plain JSON without python_tag
let parser = LlamaParser::new();
let text = r#"{"name": "get_weather", "arguments": {"city": "Paris"}}"#;
let result = parser.parse_complete(text).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["city"], "Paris");
}
#[tokio::test]
async fn test_multiple_json_with_separator() {
// Test multiple JSON objects with semicolon separator
let parser = LlamaParser::new();
let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {"city": "Paris"}};{"name": "get_tourist_attractions", "arguments": {"city": "Paris"}}"#;
let result = parser.parse_complete(text).await.unwrap();
// Note: Current implementation may only parse the first one due to semicolon handling
assert!(!result.is_empty());
assert_eq!(result[0].function.name, "get_weather");
}
#[tokio::test]
async fn test_multiple_json_with_separator_customized() {
// Test multiple JSON objects with python_tag repeated
let parser = LlamaParser::new();
let text = r#"<|python_tag|>{"name": "get_weather", "arguments": {}}<|python_tag|>{"name": "get_tourist_attractions", "arguments": {}}"#;
let result = parser.parse_complete(text).await.unwrap();
// Current implementation may handle this differently
assert!(!result.is_empty());
assert_eq!(result[0].function.name, "get_weather");
}
#[tokio::test]
async fn test_json_with_trailing_text() {
// Test JSON with trailing text after
let parser = LlamaParser::new();
let text = r#"{"name": "get_weather", "arguments": {}} Some follow-up text"#;
let result = parser.parse_complete(text).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
}
#[tokio::test]
async fn test_invalid_then_valid_json() {
// Test error recovery - invalid JSON followed by valid JSON
let parser = LlamaParser::new();
let text = r#"{"name": "get_weather", "arguments": {{"name": "get_weather", "arguments": {}}"#;
let result = parser.parse_complete(text).await.unwrap();
// Should parse at least one valid JSON
if !result.is_empty() {
assert_eq!(result[0].function.name, "get_weather");
}
}
#[tokio::test]
async fn test_plain_text_only() {
// Test plain text with no tool calls
let parser = LlamaParser::new();
let text = "This is just plain explanation text.";
let result = parser.parse_complete(text).await.unwrap();
assert_eq!(result.len(), 0);
}
#[tokio::test]
async fn test_with_python_tag_prefix() {
// Test text before python_tag
let parser = LlamaParser::new();
let text = r#"Some intro. <|python_tag|>{"name": "get_weather", "arguments": {}}"#;
let result = parser.parse_complete(text).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
}
// ============================================================================
// STREAMING TESTS
// ============================================================================
#[tokio::test]
async fn test_llama_streaming_simple() {
let parser = LlamaParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
// Send complete JSON at once
let full_json = r#"<|python_tag|>{"name": "search", "arguments": {"query": "weather"}}"#;
let result = parser
.parse_incremental(full_json, &mut state)
.await
.unwrap();
match result {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "search");
}
_ => panic!("Expected ToolComplete for complete JSON input"),
}
}
#[tokio::test]
async fn test_llama_streaming_partial() {
let parser = LlamaParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
// Stream in chunks
let chunks = vec![
r#"<|python"#,
r#"_tag|>{"name": "#,
r#""calculate", "#,
r#""arguments": {"x": 10}"#,
r#"}"#,
];
let mut got_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
assert_eq!(tool.function.name, "calculate");
got_complete = true;
}
}
assert!(got_complete, "Should have completed parsing");
}
#[tokio::test]
async fn test_llama_streaming_plain_json() {
let parser = LlamaParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
// Stream plain JSON without python_tag
let chunks = vec![
r#"{"name": "#,
r#""search", "#,
r#""arguments": "#,
r#"{"query": "#,
r#""test"}}"#,
];
let mut got_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
assert_eq!(tool.function.name, "search");
got_complete = true;
}
}
assert!(got_complete, "Should have completed parsing");
}
#[tokio::test]
async fn test_llama_streaming_with_text_before() {
let parser = LlamaParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let chunks = vec![
r#"Let me help you. "#,
r#"<|python_tag|>"#,
r#"{"name": "get_time","#,
r#" "arguments": {"#,
r#""timezone": "UTC"}}"#,
];
let mut got_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
assert_eq!(tool.function.name, "get_time");
got_complete = true;
}
}
assert!(got_complete, "Should have completed parsing");
}
#[tokio::test]
async fn test_llama_streaming_multiple_tools() {
// Test streaming multiple tool calls with semicolon separator
let parser = LlamaParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let text =
r#"<|python_tag|>{"name": "func1", "arguments": {}};{"name": "func2", "arguments": {}}"#;
let result = parser.parse_incremental(text, &mut state).await.unwrap();
// Should get first tool complete
match result {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "func1");
}
_ => panic!("Expected first tool to be complete"),
}
// Process remaining buffer to get second tool
let result2 = parser.parse_incremental("", &mut state).await.unwrap();
match result2 {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "func2");
}
_ => panic!("Expected second tool to be complete"),
}
}
#[tokio::test]
async fn test_llama_streaming_multiple_tools_chunked() {
// Test streaming multiple tool calls arriving in chunks
let parser = LlamaParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
// First chunk - incomplete first JSON
let chunk1 = r#"<|python_tag|>{"name": "get_weather", "arguments""#;
let result1 = parser.parse_incremental(chunk1, &mut state).await.unwrap();
// Should be incomplete or have tool name
match result1 {
sglang_router_rs::tool_parser::StreamResult::Incomplete
| sglang_router_rs::tool_parser::StreamResult::ToolName { .. }
| sglang_router_rs::tool_parser::StreamResult::ToolArguments { .. } => {
// Expected - could get tool name or be incomplete or even partial args
}
_ => panic!(
"Expected incomplete or tool name for partial JSON, got: {:?}",
result1
),
}
// Second chunk - complete first JSON and separator
let chunk2 = r#": {"city": "Paris"}};{"name": "#;
let result2 = parser.parse_incremental(chunk2, &mut state).await.unwrap();
// Should get first tool complete
match result2 {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["city"], "Paris");
}
_ => panic!("Expected first tool to be complete after separator"),
}
// Third chunk - complete second JSON
let chunk3 = r#""get_time", "arguments": {"timezone": "UTC"}}"#;
let result3 = parser.parse_incremental(chunk3, &mut state).await.unwrap();
// Should get second tool complete
match result3 {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_time");
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["timezone"], "UTC");
}
_ => {
// If not complete yet, try one more empty chunk
let result4 = parser.parse_incremental("", &mut state).await.unwrap();
match result4 {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_time");
let args: serde_json::Value =
serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["timezone"], "UTC");
}
_ => panic!("Expected second tool to be complete"),
}
}
}
}

View File

@@ -0,0 +1,153 @@
//! Mistral Parser Integration Tests
//!
//! Tests for the Mistral parser which handles [TOOL_CALLS] format
use serde_json::json;
use sglang_router_rs::tool_parser::{MistralParser, ToolParser};
#[tokio::test]
async fn test_mistral_single_tool() {
let parser = MistralParser::new();
let input = r#"Let me search for that.
[TOOL_CALLS] [{"name": "search_web", "arguments": {"query": "latest news", "max_results": 5}}]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "search_web");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["query"], "latest news");
assert_eq!(args["max_results"], 5);
}
#[tokio::test]
async fn test_mistral_multiple_tools() {
let parser = MistralParser::new();
let input = r#"I'll help you with both tasks.
[TOOL_CALLS] [
{"name": "get_weather", "arguments": {"city": "Tokyo", "units": "celsius"}},
{"name": "search_news", "arguments": {"query": "AI developments", "limit": 10}}
]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "get_weather");
let args0: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args0["city"], "Tokyo");
assert_eq!(result[1].function.name, "search_news");
let args1: serde_json::Value = serde_json::from_str(&result[1].function.arguments).unwrap();
assert_eq!(args1["query"], "AI developments");
}
#[tokio::test]
async fn test_mistral_nested_json() {
let parser = MistralParser::new();
let input = r#"Processing complex data.
[TOOL_CALLS] [{"name": "process_data", "arguments": {"config": {"nested": {"value": [1, 2, 3]}}, "enabled": true}}]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["config"]["nested"]["value"], json!([1, 2, 3]));
assert_eq!(args["enabled"], true);
}
#[tokio::test]
async fn test_mistral_with_text_after() {
let parser = MistralParser::new();
let input = r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]
And here's some text after the tool call that should be ignored."#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
}
#[tokio::test]
async fn test_mistral_empty_arguments() {
let parser = MistralParser::new();
let input = r#"[TOOL_CALLS] [{"name": "ping", "arguments": {}}]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "ping");
}
#[tokio::test]
async fn test_mistral_with_brackets_in_strings() {
let parser = MistralParser::new();
let input = r#"[TOOL_CALLS] [{"name": "echo", "arguments": {"text": "Array notation: arr[0] = value[1]"}}]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["text"], "Array notation: arr[0] = value[1]");
}
#[tokio::test]
async fn test_mistral_format_detection() {
let parser = MistralParser::new();
assert!(parser.detect_format("[TOOL_CALLS] ["));
assert!(parser.detect_format("Some text [TOOL_CALLS] ["));
assert!(!parser.detect_format("Just plain text"));
assert!(!parser.detect_format("[{\"name\": \"test\"}]")); // JSON array without TOOL_CALLS
}
#[tokio::test]
async fn test_mistral_malformed_json() {
let parser = MistralParser::new();
// Missing closing bracket
let input = r#"[TOOL_CALLS] [{"name": "test", "arguments": {}"#;
if let Ok(result) = parser.parse_complete(input).await {
assert_eq!(result.len(), 0);
}
// Error is also acceptable for malformed input
// Invalid JSON inside
let input = r#"[TOOL_CALLS] [{"name": invalid}]"#;
if let Ok(result) = parser.parse_complete(input).await {
assert_eq!(result.len(), 0);
}
// Error is also acceptable for malformed input
}
#[tokio::test]
async fn test_mistral_real_world_output() {
let parser = MistralParser::new();
// Actual output from Mistral model
let input = r#"I'll search for information about Rust programming and check the weather in San Francisco.
[TOOL_CALLS] [
{
"name": "web_search",
"arguments": {
"query": "Rust programming language features 2024",
"max_results": 3,
"include_snippets": true
}
},
{
"name": "get_weather",
"arguments": {
"location": "San Francisco, CA",
"units": "fahrenheit",
"include_forecast": false
}
}
]
Let me execute these searches for you."#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "web_search");
assert_eq!(result[1].function.name, "get_weather");
}

View File

@@ -0,0 +1,301 @@
//! Mixed Format and Additional Edge Case Tests
//!
//! Tests for edge cases across parsers and mixed format scenarios
use serde_json::json;
use sglang_router_rs::tool_parser::{
JsonParser, LlamaParser, MistralParser, ParseState, PythonicParser, QwenParser, StreamResult,
ToolParser,
};
#[tokio::test]
async fn test_mixed_formats_in_text() {
// Test that parsers correctly ignore other formats' markers
let json_parser = JsonParser::new();
let input = r#"
Some text with [TOOL_CALLS] marker that shouldn't trigger.
Also has <tool_call> tags and [function()] syntax.
But here's the actual JSON: {"name": "test", "arguments": {}}
"#;
let result = json_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
// Mistral parser should ignore JSON and other formats
let mistral_parser = MistralParser::new();
let input = r#"
{"name": "fake"} [function()] <tool_call>
[TOOL_CALLS] [{"name": "real", "arguments": {}}]
"#;
let result = mistral_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "real");
}
#[tokio::test]
async fn test_format_markers_in_string_content() {
// Test that format markers inside string content don't interfere
let pythonic_parser = PythonicParser::new();
let input = r#"[echo(text="Use [TOOL_CALLS] and <tool_call> in text")]"#;
let result = pythonic_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["text"], "Use [TOOL_CALLS] and <tool_call> in text");
let qwen_parser = QwenParser::new();
let input = r#"<tool_call>
{"name": "log", "arguments": {"msg": "Found [function()] pattern"}}
</tool_call>"#;
let result = qwen_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["msg"], "Found [function()] pattern");
}
#[tokio::test]
async fn test_deeply_nested_json_structures() {
let json_parser = JsonParser::new();
let input = r#"{
"name": "deep_process",
"arguments": {
"level1": {
"level2": {
"level3": {
"level4": {
"level5": {
"data": [1, 2, [3, [4, 5]]]
}
}
}
}
}
}
}"#;
let result = json_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "deep_process");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert!(args["level1"]["level2"]["level3"]["level4"]["level5"]["data"].is_array());
}
#[tokio::test]
async fn test_multiple_sequential_calls_different_formats() {
// Simulate a scenario where different parts of text have different formats
// (though each parser will only recognize its own format)
let llama_parser = LlamaParser::new();
// Llama parser currently only returns the first tool found
let input = r#"First call: <|python_tag|>{"name": "call1", "arguments": {}}"#;
let result = llama_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "call1");
// Test plain JSON separately
let input2 = r#"{"name": "call2", "arguments": {"x": 1}}"#;
let result2 = llama_parser.parse_complete(input2).await.unwrap();
assert_eq!(result2.len(), 1);
assert_eq!(result2[0].function.name, "call2");
}
#[tokio::test]
async fn test_empty_and_whitespace_variations() {
let json_parser = JsonParser::new();
// Various whitespace scenarios
let cases = vec![
r#" {"name":"compact","arguments":{}} "#,
r#"
{"name": "spaced", "arguments": {}}
"#,
r#" {"name": "tabbed", "arguments": {}} "#, // tabs
];
for input in cases {
let result = json_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1, "Should parse regardless of whitespace");
}
}
#[tokio::test]
async fn test_special_json_values() {
let json_parser = JsonParser::new();
// Test various special JSON values
let input = r#"{
"name": "test_special",
"arguments": {
"float_e": 1.23e10,
"float_neg_e": 1.23e-10,
"hex_like": "0x1234",
"very_long_num": 99999999999999999999,
"special_strings": ["", " ", "\u0000", "\u001f"],
"escaped": "\\n\\r\\t\\\"\\\\",
"unicode": "\u4e2d\u6587"
}
}"#;
let result = json_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test_special");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert!(args["special_strings"].is_array());
assert!(args["escaped"].is_string());
}
#[tokio::test]
async fn test_parser_recovery_after_invalid_input() {
let mut state = ParseState::new();
let parser = JsonParser::new();
// Send invalid JSON first
let _ = parser.parse_incremental(r#"{"broken": "#, &mut state).await;
// Clear state and try valid JSON
state.buffer.clear();
let result = parser
.parse_incremental(r#"{"name": "valid", "arguments": {}}"#, &mut state)
.await
.unwrap();
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "valid");
}
_ => {
// Might be incomplete depending on implementation
}
}
}
#[tokio::test]
async fn test_boundary_cases_for_extraction() {
// Test edge cases in JSON extraction from text
let json_parser = JsonParser::new();
// JSON at the very beginning
let input = r#"{"name": "start", "arguments": {}} and then text"#;
let result = json_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "start");
// JSON at the very end
let input = r#"Some text first {"name": "end", "arguments": {}}"#;
let result = json_parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "end");
// Multiple JSON objects in text (should find first valid one)
let input =
r#"Text {"name": "first", "arguments": {}} more {"name": "second", "arguments": {}}"#;
let result = json_parser.parse_complete(input).await.unwrap();
assert!(!result.is_empty());
assert_eq!(result[0].function.name, "first");
}
#[tokio::test]
async fn test_pythonic_edge_cases() {
let parser = PythonicParser::new();
// Function name with underscores and numbers
let input = r#"[func_name_2(param_1="value")]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "func_name_2");
// Empty string argument
let input = r#"[process(text="")]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["text"], "");
}
#[tokio::test]
async fn test_mistral_with_pretty_json() {
let parser = MistralParser::new();
// Pretty-printed JSON in Mistral format
let input = r#"[TOOL_CALLS] [
{
"name": "formatted",
"arguments": {
"nested": {
"key": "value"
},
"array": [
1,
2,
3
]
}
}
]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "formatted");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["nested"]["key"], "value");
assert_eq!(args["array"], json!([1, 2, 3]));
}
#[tokio::test]
async fn test_qwen_with_cdata_like_content() {
let parser = QwenParser::new();
// Test with content that looks like CDATA but isn't
// Note: QwenParser expects exactly "<tool_call>\n" with the newline
let input = r#"<tool_call>
{"name": "process", "arguments": {"xml": "<![CDATA[some data]]>"}}
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "process");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["xml"], "<![CDATA[some data]]>");
}
#[tokio::test]
async fn test_extremely_long_function_names() {
let parser = PythonicParser::new();
let long_name = "very_long_function_name_that_might_appear_in_generated_code_somewhere";
let input = format!(r#"[{}(param="value")]"#, long_name);
let result = parser.parse_complete(&input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, long_name);
}
#[tokio::test]
async fn test_json_with_duplicate_keys() {
let parser = JsonParser::new();
// JSON with duplicate keys (last one should win per JSON spec)
let input = r#"{"name": "test", "arguments": {"key": "first", "key": "second"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
// JSON parsers typically keep the last value for duplicate keys
assert_eq!(args["key"], "second");
}

View File

@@ -0,0 +1,559 @@
//! Pythonic Parser Integration Tests
//!
//! Tests for the Pythonic parser which handles Python function call syntax
use serde_json::json;
use sglang_router_rs::tool_parser::{PythonicParser, ToolParser};
#[tokio::test]
async fn test_pythonic_single_function() {
let parser = PythonicParser::new();
let input = r#"[get_weather(city="London", units="celsius")]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["city"], "London");
assert_eq!(args["units"], "celsius");
}
#[tokio::test]
async fn test_pythonic_multiple_functions() {
let parser = PythonicParser::new();
let input =
r#"[search_web(query="Rust programming", max_results=5), get_time(timezone="UTC")]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "search_web");
assert_eq!(result[1].function.name, "get_time");
let args0: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args0["query"], "Rust programming");
assert_eq!(args0["max_results"], 5);
}
#[tokio::test]
async fn test_pythonic_with_python_literals() {
let parser = PythonicParser::new();
let input = r#"[configure(enabled=True, disabled=False, optional=None)]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["enabled"], true);
assert_eq!(args["disabled"], false);
assert_eq!(args["optional"], json!(null));
}
#[tokio::test]
async fn test_pythonic_with_lists_and_dicts() {
let parser = PythonicParser::new();
let input =
r#"[process_data(items=[1, 2, 3], config={"key": "value", "nested": {"deep": True}})]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["items"], json!([1, 2, 3]));
assert_eq!(args["config"]["key"], "value");
assert_eq!(args["config"]["nested"]["deep"], true);
}
#[tokio::test]
async fn test_pythonic_with_special_tokens() {
let parser = PythonicParser::new();
// Llama 4 sometimes outputs these tokens
let input = r#"<|python_start|>[calculate(x=10, y=20)]<|python_end|>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "calculate");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["x"], 10);
assert_eq!(args["y"], 20);
}
#[tokio::test]
async fn test_pythonic_with_nested_parentheses() {
let parser = PythonicParser::new();
let input = r#"[math_eval(expression="(2 + 3) * (4 - 1)", round_to=2)]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["expression"], "(2 + 3) * (4 - 1)");
assert_eq!(args["round_to"], 2);
}
#[tokio::test]
async fn test_pythonic_with_escaped_quotes() {
let parser = PythonicParser::new();
let input = r#"[echo(text="She said \"Hello\" to him")]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["text"], "She said \"Hello\" to him");
}
#[tokio::test]
async fn test_pythonic_empty_arguments() {
let parser = PythonicParser::new();
let input = r#"[ping()]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "ping");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args, json!({}));
}
#[tokio::test]
async fn test_pythonic_format_detection() {
let parser = PythonicParser::new();
assert!(parser.detect_format("[function_name("));
assert!(parser.detect_format("[get_weather(city=\"NYC\")]"));
assert!(!parser.detect_format("Just plain text"));
assert!(!parser.detect_format("[1, 2, 3]")); // Plain list
assert!(!parser.detect_format("{\"name\": \"test\"}")); // JSON
}
#[tokio::test]
async fn test_pythonic_invalid_syntax() {
let parser = PythonicParser::new();
// Missing closing bracket
let input = r#"[function(arg=value"#;
if let Ok(result) = parser.parse_complete(input).await {
assert_eq!(result.len(), 0);
}
// Error is also acceptable for invalid syntax
// Invalid Python syntax - empty parameter name
// Note: The parser currently accepts this invalid syntax and returns a result
// This is a known limitation of the current implementation
let input = r#"[function(=value)]"#;
if let Ok(result) = parser.parse_complete(input).await {
// The parser incorrectly accepts this, returning 1 result
// We'll accept this behavior for now but note it's not ideal
assert!(result.len() <= 1, "Should parse at most one function");
}
// Error would be the correct behavior
}
#[tokio::test]
async fn test_pythonic_real_world_llama4() {
let parser = PythonicParser::new();
// Actual output from Llama 4 model
let input = r#"I'll help you with multiple tasks. Let me search for information and perform calculations.
[web_search(query="latest Rust features", max_results=3, safe_search=True),
calculate(expression="42 * 3.14159", precision=2),
get_weather(city="San Francisco", units="fahrenheit", include_forecast=False)]
These functions will provide the information you need."#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0].function.name, "web_search");
assert_eq!(result[1].function.name, "calculate");
assert_eq!(result[2].function.name, "get_weather");
let args0: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args0["query"], "latest Rust features");
assert_eq!(args0["safe_search"], true);
}
#[tokio::test]
async fn test_pythonic_nested_brackets_in_lists() {
let parser = PythonicParser::new();
// Test nested brackets within list arguments
let input = r#"[process_matrix(data=[[1, 2], [3, 4]], labels=["row[0]", "row[1]"])]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "process_matrix");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["data"], json!([[1, 2], [3, 4]]));
assert_eq!(args["labels"], json!(["row[0]", "row[1]"]));
}
#[tokio::test]
async fn test_pythonic_nested_brackets_in_dicts() {
let parser = PythonicParser::new();
// Test nested brackets within dictionary arguments
let input =
r#"[analyze(config={"patterns": ["[a-z]+", "[0-9]+"], "nested": {"list": [1, [2, 3]]}})]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "analyze");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["config"]["patterns"], json!(["[a-z]+", "[0-9]+"]));
assert_eq!(args["config"]["nested"]["list"], json!([1, [2, 3]]));
}
#[tokio::test]
async fn test_pythonic_mixed_quotes() {
let parser = PythonicParser::new();
// Test mixed quote types in arguments
let input = r#"[format_text(single='Hello', double="World", mixed="It's \"quoted\"")]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "format_text");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["single"], "Hello");
assert_eq!(args["double"], "World");
assert_eq!(args["mixed"], "It's \"quoted\"");
}
#[tokio::test]
async fn test_pythonic_complex_nesting() {
let parser = PythonicParser::new();
// Test complex nested structures
let input = r#"[transform(
matrix=[[1, [2, 3]], [4, [5, [6, 7]]]],
operations=[{"type": "scale", "factor": [2, 3]}, {"type": "rotate", "angle": 90}],
metadata={"tags": ["nested[0]", "nested[1]"], "config": {"depth": [1, 2, 3]}}
)]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "transform");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert!(args["matrix"].is_array());
assert!(args["operations"].is_array());
assert_eq!(args["operations"][0]["type"], "scale");
assert_eq!(args["metadata"]["config"]["depth"], json!([1, 2, 3]));
}
#[tokio::test]
async fn test_parse_streaming_no_brackets() {
// Test parsing text with no brackets (no tool calls)
let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let text = "This is just normal text without any tool calls.";
let result = parser.parse_incremental(text, &mut state).await.unwrap();
match result {
sglang_router_rs::tool_parser::StreamResult::Incomplete => {
// Expected - no tool calls found
assert_eq!(state.buffer, text);
}
_ => panic!("Should return Incomplete for text without tool calls"),
}
}
#[tokio::test]
async fn test_parse_streaming_complete_tool_call() {
// Test parsing a complete tool call
let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let text = "Here's a tool call: [get_weather(location='New York', unit='celsius')]";
let result = parser.parse_incremental(text, &mut state).await.unwrap();
match result {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["location"], "New York");
assert_eq!(args["unit"], "celsius");
assert_eq!(state.buffer, "");
}
_ => panic!("Should return ToolComplete for complete tool call"),
}
}
#[tokio::test]
async fn test_parse_streaming_text_before_tool_call() {
// Test parsing text that appears before a tool call
let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let text = "This is some text before [get_weather(location='London')]";
let result = parser.parse_incremental(text, &mut state).await.unwrap();
match result {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["location"], "London");
}
_ => panic!("Should return ToolComplete"),
}
}
#[tokio::test]
async fn test_parse_streaming_partial_tool_call() {
// Test parsing a partial tool call that spans multiple chunks
let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
// First chunk with opening bracket but no closing bracket
let text1 = "Let me check the weather: [get_weather(location=";
let result1 = parser.parse_incremental(text1, &mut state).await.unwrap();
match result1 {
sglang_router_rs::tool_parser::StreamResult::Incomplete => {
assert!(state.buffer.contains("[get_weather(location="));
}
_ => panic!("First chunk should return Incomplete"),
}
// Second chunk completing the tool call
let text2 = "'Paris')]";
let result2 = parser.parse_incremental(text2, &mut state).await.unwrap();
match result2 {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["location"], "Paris");
assert_eq!(state.buffer, "");
}
_ => panic!("Second chunk should return ToolComplete"),
}
}
#[tokio::test]
async fn test_parse_streaming_bracket_without_text_before() {
// Test parsing a tool call that starts at the beginning of the text
let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let text = "[search(query='python programming')]";
let result = parser.parse_incremental(text, &mut state).await.unwrap();
match result {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "search");
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["query"], "python programming");
}
_ => panic!("Should return ToolComplete"),
}
}
#[tokio::test]
async fn test_parse_streaming_text_after_tool_call() {
// Test parsing text that appears after a tool call
let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
// First chunk with complete tool call and some text after
let text = "[get_weather(location='Tokyo')] Here's the forecast:";
let result = parser.parse_incremental(text, &mut state).await.unwrap();
match result {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
// Text after tool call should remain in buffer
// Note: Current implementation may clear buffer, this behavior needs verification
}
_ => panic!("Should return ToolComplete"),
}
}
#[tokio::test]
async fn test_parse_streaming_multiple_tool_calls() {
// Test parsing multiple tool calls in sequence
let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let text = "[get_weather(location='Berlin'), search(query='restaurants')]";
// Current implementation may handle this as a single parse
let result = parser.parse_incremental(text, &mut state).await.unwrap();
// The parser should handle multiple tools in one bracket pair
match result {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(_) => {
// Expected behavior - parses first tool
}
_ => {
// Also acceptable if it returns Incomplete waiting for more
}
}
}
#[tokio::test]
async fn test_parse_streaming_opening_bracket_only() {
// Test parsing text with only an opening bracket but no closing bracket
let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let text = "Let's try this: [";
let result = parser.parse_incremental(text, &mut state).await.unwrap();
match result {
sglang_router_rs::tool_parser::StreamResult::Incomplete => {
assert!(state.buffer.ends_with("["));
}
_ => panic!("Should return Incomplete for partial bracket"),
}
}
#[tokio::test]
async fn test_parse_streaming_nested_brackets() {
// Test parsing tool calls with nested brackets in arguments
let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let text = "[get_weather(location='New York', unit='celsius', data=[1, 2, 3])]";
let result = parser.parse_incremental(text, &mut state).await.unwrap();
match result {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["location"], "New York");
assert_eq!(args["unit"], "celsius");
assert_eq!(args["data"], json!([1, 2, 3]));
}
_ => panic!("Should return ToolComplete"),
}
}
#[tokio::test]
async fn test_parse_streaming_nested_brackets_dict() {
// Test parsing tool calls with nested dictionaries and lists
let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let text = r#"[search(query='test', config={'options': [1, 2], 'nested': {'key': 'value'}})]"#;
let result = parser.parse_incremental(text, &mut state).await.unwrap();
match result {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "search");
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["query"], "test");
assert_eq!(args["config"]["options"], json!([1, 2]));
assert_eq!(args["config"]["nested"]["key"], "value");
}
_ => panic!("Should return ToolComplete"),
}
}
#[tokio::test]
async fn test_parse_streaming_multiple_tools_with_nested_brackets() {
// Test parsing multiple tool calls with nested brackets
let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let text =
"[get_weather(location='Paris', data=[10, 20]), search(query='test', filters=['a', 'b'])]";
let result = parser.parse_incremental(text, &mut state).await.unwrap();
// Should parse both tools successfully
match result {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
// At least gets the first tool
assert_eq!(tool.function.name, "get_weather");
}
_ => panic!("Should return ToolComplete"),
}
}
#[tokio::test]
async fn test_parse_streaming_partial_nested_brackets() {
// Test parsing partial tool calls with nested brackets across chunks
let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
// First chunk with nested brackets but incomplete
let text1 = "Here's a call: [get_weather(location='Tokyo', data=[1, 2";
let result1 = parser.parse_incremental(text1, &mut state).await.unwrap();
match result1 {
sglang_router_rs::tool_parser::StreamResult::Incomplete => {
assert!(state
.buffer
.contains("[get_weather(location='Tokyo', data=[1, 2"));
}
_ => panic!("First chunk should return Incomplete"),
}
// Second chunk completing the nested brackets
let text2 = ", 3])]";
let result2 = parser.parse_incremental(text2, &mut state).await.unwrap();
match result2 {
sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["location"], "Tokyo");
assert_eq!(args["data"], json!([1, 2, 3]));
}
_ => panic!("Second chunk should return ToolComplete"),
}
}
#[tokio::test]
async fn test_parse_streaming_with_python_start_and_end_token() {
// Test parsing a message that starts with <|python_start|> and <|python_end|> across chunks
let parser = PythonicParser::new();
let mut state = sglang_router_rs::tool_parser::ParseState::new();
let chunks = vec![
"Here's a call: ",
"<|python_",
"start|>[get_weather(location=",
"'Tokyo', data=[1, 2",
", 3])]<|python_end|>",
];
let mut got_tool = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
if let sglang_router_rs::tool_parser::StreamResult::ToolComplete(tool) = result {
assert_eq!(tool.function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["location"], "Tokyo");
assert_eq!(args["data"], json!([1, 2, 3]));
got_tool = true;
}
}
assert!(got_tool, "Should have parsed the tool call");
}
#[tokio::test]
async fn test_detect_and_parse_with_python_start_and_end_token() {
// Test parsing a message that starts with <|python_start|> and contains a valid tool call
let parser = PythonicParser::new();
let text = "User wants to get the weather in Mars. <|python_start|>[get_weather(location='Mars', unit='celsius')]<|python_end|> In this way we will get the weather in Mars.";
let result = parser.parse_complete(text).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["location"], "Mars");
assert_eq!(args["unit"], "celsius");
}

View File

@@ -0,0 +1,259 @@
//! Qwen Parser Integration Tests
//!
//! Tests for the Qwen parser which handles <tool_call>...</tool_call> format
use serde_json::json;
use sglang_router_rs::tool_parser::{ParseState, QwenParser, StreamResult, ToolParser};
#[tokio::test]
async fn test_qwen_single_tool() {
let parser = QwenParser::new();
let input = r#"<tool_call>
{"name": "get_weather", "arguments": {"city": "Beijing", "units": "celsius"}}
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["city"], "Beijing");
assert_eq!(args["units"], "celsius");
}
#[tokio::test]
async fn test_qwen_multiple_sequential_tools() {
let parser = QwenParser::new();
let input = r#"Let me help you with that.
<tool_call>
{"name": "search", "arguments": {"query": "Qwen model"}}
</tool_call>
<tool_call>
{"name": "translate", "arguments": {"text": "Hello", "to": "zh"}}
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "search");
assert_eq!(result[1].function.name, "translate");
}
#[tokio::test]
async fn test_qwen_pretty_printed_json() {
let parser = QwenParser::new();
let input = r#"<tool_call>
{
"name": "create_document",
"arguments": {
"title": "Test Document",
"content": "This is a test",
"metadata": {
"author": "Qwen",
"tags": ["test", "example"]
}
}
}
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "create_document");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["metadata"]["author"], "Qwen");
assert_eq!(args["metadata"]["tags"], json!(["test", "example"]));
}
#[tokio::test]
async fn test_qwen_with_text_between() {
let parser = QwenParser::new();
let input = r#"First, let me search for information.
<tool_call>
{"name": "search", "arguments": {"query": "test"}}
</tool_call>
Now I'll translate something.
<tool_call>
{"name": "translate", "arguments": {"text": "world", "to": "es"}}
</tool_call>
Done!"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "search");
assert_eq!(result[1].function.name, "translate");
}
#[tokio::test]
async fn test_qwen_empty_arguments() {
let parser = QwenParser::new();
let input = r#"<tool_call>
{"name": "get_time", "arguments": {}}
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_time");
}
#[tokio::test]
async fn test_qwen_with_newlines_in_strings() {
let parser = QwenParser::new();
let input = r#"<tool_call>
{"name": "write_file", "arguments": {"content": "Line 1\nLine 2\nLine 3", "path": "/tmp/test.txt"}}
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["content"], "Line 1\nLine 2\nLine 3");
}
#[tokio::test]
async fn test_qwen_format_detection() {
let parser = QwenParser::new();
assert!(parser.detect_format("<tool_call>"));
assert!(parser.detect_format("Some text <tool_call>\n{"));
assert!(!parser.detect_format("Just plain text"));
assert!(!parser.detect_format("{\"name\": \"test\"}")); // Plain JSON
}
#[tokio::test]
async fn test_qwen_incomplete_tags() {
let parser = QwenParser::new();
// Missing closing tag
let input = r#"<tool_call>
{"name": "test", "arguments": {}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0);
// Missing opening tag
let input = r#"{"name": "test", "arguments": {}}
</tool_call>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0);
}
#[tokio::test]
async fn test_qwen_real_world_output() {
let parser = QwenParser::new();
// Actual output from Qwen model
let input = r#"I'll help you search for information and perform calculations.
<tool_call>
{
"name": "web_search",
"arguments": {
"query": "quantum computing breakthroughs 2024",
"language": "en",
"region": "us",
"safe_search": true
}
}
</tool_call>
Let me also calculate something for you:
<tool_call>
{
"name": "calculator",
"arguments": {
"expression": "sqrt(144) + 3^2",
"precision": 2
}
}
</tool_call>
These tools will provide the information you need."#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "web_search");
assert_eq!(result[1].function.name, "calculator");
let args0: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args0["query"], "quantum computing breakthroughs 2024");
assert_eq!(args0["safe_search"], true);
}
#[tokio::test]
async fn test_buffer_drain_optimization() {
let parser = QwenParser::new();
let mut state = ParseState::new();
// First chunk - incomplete tool call
let chunk1 = "<tool_call>\n{\"name\": \"test1\", ";
let _result = parser.parse_incremental(chunk1, &mut state).await.unwrap();
// Phase 2 simplified streaming might not handle partial JSON correctly
// The important thing is buffer accumulation works
assert!(!state.buffer.is_empty());
// Complete first tool and start second
let chunk2 = "\"arguments\": {}}\n</tool_call><tool_call>\n{\"name\": \"test2\", ";
let result = parser.parse_incremental(chunk2, &mut state).await.unwrap();
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "test1");
// After consuming the first tool, buffer should contain only the second tool start
assert!(state.buffer.starts_with("<tool_call>"));
assert!(state.buffer.contains("test2"));
}
_ => {
// Phase 2 simplified streaming might return Incomplete
// The important thing is the buffer is managed correctly
}
}
// Complete the second tool
let chunk3 = "\"arguments\": {\"x\": 1}}\n</tool_call>";
let result = parser.parse_incremental(chunk3, &mut state).await.unwrap();
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "test2");
// Buffer should be empty after consuming all tools
assert!(state.buffer.is_empty() || !state.buffer.contains("</tool_call>"));
}
_ => {
// Phase 2 simplified streaming might handle this differently
}
}
}
#[tokio::test]
async fn test_buffer_efficiency_with_multiple_tools() {
let parser = QwenParser::new();
let mut state = ParseState::new();
// Send multiple complete tools at once
let input = r#"<tool_call>
{"name": "tool1", "arguments": {"a": 1}}
</tool_call><tool_call>
{"name": "tool2", "arguments": {"b": 2}}
</tool_call><tool_call>
{"name": "tool3", "arguments": {"c": 3}}
</tool_call>"#;
// This should efficiently process tools using drain() without creating new strings
let result = parser.parse_incremental(input, &mut state).await.unwrap();
// In Phase 2, this will likely parse only the first tool
// The important thing is that drain() doesn't cause any issues
match result {
StreamResult::ToolComplete(tool) => {
assert!(["tool1", "tool2", "tool3"].contains(&tool.function.name.as_str()));
}
_ => {
// Simplified streaming might return Incomplete
}
}
// Verify no memory issues or panics occurred with drain()
// Test passes if we reach this point without panic
}

View File

@@ -0,0 +1,194 @@
//! Parser Registry Integration Tests
//!
//! Tests for model-to-parser mappings and registry functionality
use sglang_router_rs::tool_parser::ParserRegistry;
#[tokio::test]
async fn test_registry_has_all_parsers() {
let registry = ParserRegistry::new();
let parsers = registry.list_parsers();
assert!(parsers.contains(&"json"));
assert!(parsers.contains(&"mistral"));
assert!(parsers.contains(&"qwen"));
assert!(parsers.contains(&"pythonic"));
assert!(parsers.contains(&"llama"));
}
#[tokio::test]
async fn test_openai_models_use_json() {
let registry = ParserRegistry::new();
let models = vec!["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo", "gpt-4o"];
for model in models {
let parser = registry.get_parser(model).unwrap();
let test_input = r#"{"name": "test", "arguments": {}}"#;
let result = parser.parse_complete(test_input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
}
}
#[tokio::test]
async fn test_anthropic_models_use_json() {
let registry = ParserRegistry::new();
let models = vec!["claude-3-opus", "claude-3-sonnet", "claude-2.1"];
for model in models {
let parser = registry.get_parser(model).unwrap();
let test_input = r#"{"name": "test", "arguments": {}}"#;
let result = parser.parse_complete(test_input).await.unwrap();
assert_eq!(result.len(), 1);
}
}
#[tokio::test]
async fn test_mistral_models() {
let registry = ParserRegistry::new();
let models = vec!["mistral-large", "mistral-medium", "mixtral-8x7b"];
for model in models {
let parser = registry.get_parser(model).unwrap();
let test_input = r#"[TOOL_CALLS] [{"name": "test", "arguments": {}}]"#;
let result = parser.parse_complete(test_input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
}
}
#[tokio::test]
async fn test_qwen_models() {
let registry = ParserRegistry::new();
let models = vec!["qwen2.5-72b", "Qwen2-7B", "qwen-max"];
for model in models {
let parser = registry.get_parser(model).unwrap();
let test_input = r#"<tool_call>
{"name": "test", "arguments": {}}
</tool_call>"#;
let result = parser.parse_complete(test_input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
}
}
#[tokio::test]
async fn test_llama_model_variants() {
let registry = ParserRegistry::new();
// Llama 4 uses pythonic
let parser = registry.get_parser("llama-4-70b").unwrap();
let test_input = r#"[get_weather(city="NYC")]"#;
let result = parser.parse_complete(test_input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "get_weather");
// Llama 3.2 uses python_tag
let parser = registry.get_parser("llama-3.2-8b").unwrap();
let test_input = r#"<|python_tag|>{"name": "test", "arguments": {}}"#;
let result = parser.parse_complete(test_input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
// Other Llama models use JSON
let parser = registry.get_parser("llama-2-70b").unwrap();
let test_input = r#"{"name": "test", "arguments": {}}"#;
let result = parser.parse_complete(test_input).await.unwrap();
assert_eq!(result.len(), 1);
}
#[tokio::test]
async fn test_deepseek_models() {
let registry = ParserRegistry::new();
// DeepSeek uses pythonic format (simplified, v3 would need custom parser)
let parser = registry.get_parser("deepseek-coder").unwrap();
let test_input = r#"[function(arg="value")]"#;
let result = parser.parse_complete(test_input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "function");
}
#[tokio::test]
async fn test_unknown_model_fallback() {
let registry = ParserRegistry::new();
// Unknown models should fall back to JSON parser
let parser = registry.get_parser("unknown-model-xyz").unwrap();
let test_input = r#"{"name": "fallback", "arguments": {}}"#;
let result = parser.parse_complete(test_input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "fallback");
}
#[tokio::test]
async fn test_pattern_specificity() {
let registry = ParserRegistry::new();
// Test that more specific patterns take precedence
// llama-4* should match before llama-*
let parser = registry.get_parser("llama-4-70b").unwrap();
assert!(parser.detect_format(r#"[test_function(x=1)]"#)); // Pythonic format
let parser = registry.get_parser("llama-3-70b").unwrap();
assert!(parser.detect_format(r#"{"name": "test", "arguments": {}}"#)); // JSON format
}
#[tokio::test]
async fn test_real_world_model_outputs() {
let registry = ParserRegistry::new();
// Test with realistic outputs from different models
let test_cases = vec![
(
"gpt-4",
r#"I'll help you with that.
{"name": "search_web", "arguments": {"query": "latest AI news", "max_results": 5}}
Let me search for that information."#,
"search_web",
),
(
"mistral-large",
r#"Let me search for information about Rust.
[TOOL_CALLS] [
{"name": "search", "arguments": {"query": "Rust programming"}},
{"name": "get_weather", "arguments": {"city": "San Francisco"}}
]
I've initiated the search."#,
"search",
),
(
"qwen2.5",
r#"I'll check the weather for you.
<tool_call>
{
"name": "get_weather",
"arguments": {
"location": "Tokyo",
"units": "celsius"
}
}
</tool_call>
The weather information has been requested."#,
"get_weather",
),
];
for (model, output, expected_name) in test_cases {
let parser = registry.get_parser(model).unwrap();
let result = parser.parse_complete(output).await.unwrap();
assert!(!result.is_empty(), "No tools parsed for model {}", model);
assert_eq!(
result[0].function.name, expected_name,
"Wrong function name for model {}",
model
);
}
}

View File

@@ -0,0 +1,245 @@
//! Step3 Parser Integration Tests
use sglang_router_rs::tool_parser::{ParseState, Step3Parser, StreamResult, ToolParser};
#[tokio::test]
async fn test_step3_complete_parsing() {
let parser = Step3Parser::new();
// Test single tool call
let input = r#"Let me help you.
<tool_calls_begin>
<tool_call_begin>function<tool_sep><steptml:invoke name="search">
<steptml:parameter name="query">rust programming</steptml:parameter>
<steptml:parameter name="limit">10</steptml:parameter>
</steptml:invoke><tool_call_end>
<tool_calls_end>
Here are the results..."#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "search");
// Verify arguments
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["query"], "rust programming");
assert_eq!(args["limit"], 10);
}
#[tokio::test]
async fn test_step3_multiple_tools() {
let parser = Step3Parser::new();
let input = r#"<tool_calls_begin>
<tool_call_begin>function<tool_sep><steptml:invoke name="get_weather">
<steptml:parameter name="location">Tokyo</steptml:parameter>
</steptml:invoke><tool_call_end>
<tool_call_begin>function<tool_sep><steptml:invoke name="get_news">
<steptml:parameter name="category">tech</steptml:parameter>
<steptml:parameter name="limit">5</steptml:parameter>
</steptml:invoke><tool_call_end>
<tool_calls_end>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "get_weather");
assert_eq!(result[1].function.name, "get_news");
}
#[tokio::test]
async fn test_step3_type_conversion() {
let parser = Step3Parser::new();
let input = r#"<tool_calls_begin>
<tool_call_begin>function<tool_sep><steptml:invoke name="process">
<steptml:parameter name="count">100</steptml:parameter>
<steptml:parameter name="rate">2.5</steptml:parameter>
<steptml:parameter name="active">true</steptml:parameter>
<steptml:parameter name="optional">null</steptml:parameter>
<steptml:parameter name="text">hello world</steptml:parameter>
</steptml:invoke><tool_call_end>
<tool_calls_end>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["count"], 100);
assert_eq!(args["rate"], 2.5);
assert_eq!(args["active"], true);
assert_eq!(args["optional"], serde_json::Value::Null);
assert_eq!(args["text"], "hello world");
}
#[tokio::test]
async fn test_step3_streaming() {
let parser = Step3Parser::new();
let mut state = ParseState::new();
// Simulate streaming chunks
let chunks = vec![
"<tool_calls_begin>\n",
"<tool_call_begin>function",
"<tool_sep><steptml:invoke name=\"calc\">",
"\n<steptml:parameter name=\"x\">10</steptml:parameter>",
"\n<steptml:parameter name=\"y\">20</steptml:parameter>",
"\n</steptml:invoke><tool_call_end>",
"\n<tool_calls_end>",
];
let mut found_name = false;
let mut found_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
match result {
StreamResult::ToolName { name, .. } => {
assert_eq!(name, "calc");
found_name = true;
}
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "calc");
found_complete = true;
}
_ => {}
}
}
assert!(found_name || found_complete);
}
#[test]
fn test_step3_format_detection() {
let parser = Step3Parser::new();
// Should detect Step3 format
assert!(parser.detect_format("<tool_calls_begin>"));
assert!(parser.detect_format("text with <tool_calls_begin> marker"));
// Should not detect other formats
assert!(!parser.detect_format("[TOOL_CALLS]"));
assert!(!parser.detect_format("<tool_call>"));
assert!(!parser.detect_format("plain text"));
}
#[tokio::test]
async fn test_step3_nested_steptml() {
let parser = Step3Parser::new();
// Test with complex parameter values
let input = r#"<tool_calls_begin>
<tool_call_begin>function<tool_sep><steptml:invoke name="config">
<steptml:parameter name="settings">{"nested": {"key": "value"}}</steptml:parameter>
<steptml:parameter name="array">[1, 2, 3]</steptml:parameter>
</steptml:invoke><tool_call_end>
<tool_calls_end>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "config");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert!(args["settings"].is_object());
assert!(args["array"].is_array());
}
#[tokio::test]
async fn test_step3_python_literals() {
let parser = Step3Parser::new();
// Test Python-style literals
let input = r#"<tool_calls_begin>
<tool_call_begin>function<tool_sep><steptml:invoke name="test">
<steptml:parameter name="bool_true">True</steptml:parameter>
<steptml:parameter name="bool_false">False</steptml:parameter>
<steptml:parameter name="none_value">None</steptml:parameter>
</steptml:invoke><tool_call_end>
<tool_calls_end>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["bool_true"], true);
assert_eq!(args["bool_false"], false);
assert_eq!(args["none_value"], serde_json::Value::Null);
}
#[tokio::test]
async fn test_steptml_format() {
let parser = Step3Parser::new();
let input = r#"Text before.
<tool_calls_begin>
<tool_call_begin>function<tool_sep><steptml:invoke name="search">
<steptml:parameter name="query">rust lang</steptml:parameter>
<steptml:parameter name="limit">10</steptml:parameter>
</steptml:invoke><tool_call_end>
<tool_calls_end>Text after."#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "search");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["query"], "rust lang");
assert_eq!(args["limit"], 10);
// TODO: Verify normal text extraction
}
#[tokio::test]
async fn test_json_parameter_values() {
let parser = Step3Parser::new();
let input = r#"<tool_calls_begin>
<tool_call_begin>function<tool_sep><steptml:invoke name="config">
<steptml:parameter name="settings">{"nested": {"value": true}}</steptml:parameter>
<steptml:parameter name="items">[1, 2, 3]</steptml:parameter>
</steptml:invoke><tool_call_end>
<tool_calls_end>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert!(args["settings"].is_object());
assert!(args["items"].is_array());
}
#[tokio::test]
async fn test_step3_parameter_with_angle_brackets() {
let parser = Step3Parser::new();
// Test parameter value containing < character
let input = r#"<tool_calls_begin>
<tool_call_begin>function<tool_sep><steptml:invoke name="compare">
<steptml:parameter name="expression">a < b && b > c</steptml:parameter>
<steptml:parameter name="context">comparison test</steptml:parameter>
</steptml:invoke><tool_call_end>
<tool_calls_end>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "compare");
// Verify the parameter value was parsed correctly
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["expression"], "a < b && b > c");
assert_eq!(args["context"], "comparison test");
}
#[tokio::test]
async fn test_step3_empty_function_name() {
let parser = Step3Parser::new();
// Test empty function name
let input = r#"<tool_calls_begin>
<tool_call_begin>function<tool_sep><steptml:invoke name="">
<steptml:parameter name="param">value</steptml:parameter>
</steptml:invoke><tool_call_end>
<tool_calls_end>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0); // Should reject empty function name
}

View File

@@ -0,0 +1,341 @@
//! Streaming Parser Tests
//!
//! Tests for incremental/streaming parsing capabilities across all parsers
use sglang_router_rs::tool_parser::{
JsonParser, LlamaParser, MistralParser, ParseState, PythonicParser, QwenParser, StreamResult,
ToolParser,
};
#[tokio::test]
async fn test_json_streaming_simple() {
let parser = JsonParser::new();
let mut state = ParseState::new();
// Phase 2 note: This test sends the full JSON at once in the last chunk
// In real streaming, chunks would be smaller
let full_json = r#"{"name": "get_weather", "arguments": {"location": "San Francisco"}}"#;
let result = parser
.parse_incremental(full_json, &mut state)
.await
.unwrap();
// With complete JSON sent at once, we should get ToolComplete
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
}
_ => {
panic!("Expected ToolComplete for complete JSON input");
}
}
}
#[tokio::test]
async fn test_json_streaming_array() {
let parser = JsonParser::new();
let mut state = ParseState::new();
// Stream a JSON array of tools
let chunks = vec![
r#"["#,
r#"{"name": "tool1", "#,
r#""arguments": {}}, "#,
r#"{"name": "tool2", "#,
r#""arguments": {"x": 1"#,
r#"}}]"#,
];
let mut tool_count = 0;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
if let StreamResult::ToolComplete(_) = result {
tool_count += 1;
}
}
// Current implementation may handle this differently
// We're mainly testing that it doesn't crash
assert!(tool_count <= 2, "Should parse at most 2 tools");
}
#[tokio::test]
async fn test_mistral_streaming() {
let parser = MistralParser::new();
let mut state = ParseState::new();
let chunks = vec![
r#"Here is the result: "#,
r#"[TOOL_CALLS] ["#,
r#"{"name": "#,
r#""search", "#,
r#""arguments": "#,
r#"{"query": "#,
r#""rust lang""#,
r#"}}]"#,
];
let mut got_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
if let StreamResult::ToolComplete(tool) = result {
assert_eq!(tool.function.name, "search");
got_complete = true;
}
}
assert!(got_complete, "Should have completed parsing");
}
#[tokio::test]
async fn test_pythonic_streaming() {
let parser = PythonicParser::new();
let mut state = ParseState::new();
// Send complete pythonic format at once
let full_input = r#"[get_weather(city="London", units="celsius")]"#;
let result = parser
.parse_incremental(full_input, &mut state)
.await
.unwrap();
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "get_weather");
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
assert_eq!(args["city"], "London");
}
_ => {
panic!("Expected ToolComplete for complete pythonic input");
}
}
}
#[tokio::test]
async fn test_llama_streaming_with_python_tag() {
let parser = LlamaParser::new();
let mut state = ParseState::new();
let chunks = vec![
r#"Let me help. "#,
r#"<|python"#,
r#"_tag|>"#,
r#"{"name": "#,
r#""calculate", "#,
r#""arguments": "#,
r#"{"x": 10}"#,
r#"}"#,
];
let mut got_complete = false;
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
if let StreamResult::ToolComplete(tool) = result {
assert_eq!(tool.function.name, "calculate");
got_complete = true;
}
}
assert!(got_complete, "Should have completed parsing");
}
#[tokio::test]
async fn test_qwen_streaming() {
let parser = QwenParser::new();
let mut state = ParseState::new();
// Send complete Qwen format at once (with exact format expected by parser)
// Note: Parser expects newline after both tags
let full_input = "<tool_call>\n{\"name\": \"translate\", \"arguments\": {\"text\": \"hello\", \"to\": \"zh\"}}\n</tool_call>";
let result = parser
.parse_incremental(full_input, &mut state)
.await
.unwrap();
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "translate");
}
other => {
panic!(
"Expected ToolComplete for complete Qwen input, got: {:?}",
other
);
}
}
}
#[tokio::test]
async fn test_streaming_incomplete_stays_incomplete() {
let parser = JsonParser::new();
let mut state = ParseState::new();
// Send truly incomplete JSON that can't be auto-completed
let chunks = vec![r#"{"na"#, r#"me": "#];
for chunk in chunks {
let result = parser.parse_incremental(chunk, &mut state).await.unwrap();
// Should return Incomplete for partial JSON that can't be auto-completed
assert!(
matches!(result, StreamResult::Incomplete),
"Should return Incomplete for partial JSON, got: {:?}",
result
);
}
// Buffer should contain the accumulated incomplete JSON
assert!(!state.buffer.is_empty());
}
#[tokio::test]
async fn test_streaming_with_text_before_tool() {
let parser = JsonParser::new();
let mut state = ParseState::new();
// For streaming, the parser expects clean JSON
// Mixed text extraction only works in parse_complete, not parse_incremental
let full_input = r#"{"name": "test", "arguments": {}}"#;
let result = parser
.parse_incremental(full_input, &mut state)
.await
.unwrap();
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "test");
}
other => {
panic!("Expected ToolComplete, got: {:?}", other);
}
}
}
#[tokio::test]
async fn test_streaming_buffer_accumulation() {
let parser = JsonParser::new();
// Test: Complete JSON should clear buffer after parsing
let mut state = ParseState::new();
// Send partial JSON that can't be interpreted as complete
let result1 = parser
.parse_incremental(r#"{"na"#, &mut state)
.await
.unwrap();
assert!(matches!(result1, StreamResult::Incomplete));
assert!(
!state.buffer.is_empty(),
"Buffer should accumulate incomplete JSON"
);
// Send rest of JSON
let result2 = parser
.parse_incremental(r#"me": "test", "arguments": {}}"#, &mut state)
.await
.unwrap();
match result2 {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "test");
assert!(
state.buffer.is_empty(),
"Buffer should be cleared after complete parse"
);
}
_ => panic!(
"Expected ToolComplete for complete JSON, got: {:?}",
result2
),
}
}
#[tokio::test]
async fn test_streaming_multiple_tools_sequential() {
let parser = QwenParser::new();
let mut state = ParseState::new();
// Send complete Qwen format with newlines
let full_input = r#"<tool_call>
{"name": "tool1", "arguments": {}}
</tool_call>"#;
let result = parser
.parse_incremental(full_input, &mut state)
.await
.unwrap();
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "tool1");
}
_ => {
panic!("Expected ToolComplete for first tool");
}
}
}
#[tokio::test]
async fn test_streaming_reset_after_error() {
let parser = JsonParser::new();
// First attempt with invalid JSON
let mut state1 = ParseState::new();
let _ = parser
.parse_incremental(r#"{"name": invalid}"#, &mut state1)
.await;
// Second attempt with valid JSON should work with fresh state
let mut state2 = ParseState::new();
let result = parser
.parse_incremental(r#"{"name": "test", "arguments": {}}"#, &mut state2)
.await
.unwrap();
if let StreamResult::ToolComplete(tool) = result {
assert_eq!(tool.function.name, "test");
}
}
#[tokio::test]
async fn test_streaming_with_unicode_chunks() {
let parser = JsonParser::new();
let mut state = ParseState::new();
// Send complete JSON with unicode
let full_input = r#"{"name": "translate", "arguments": {"text": "Hello 世界 🌍"}}"#;
let result = parser
.parse_incremental(full_input, &mut state)
.await
.unwrap();
// Phase 2 may return partial results even with complete JSON
// The important thing is that unicode is handled without crashes
match result {
StreamResult::ToolComplete(tool) => {
assert_eq!(tool.function.name, "translate");
let args: serde_json::Value = serde_json::from_str(&tool.function.arguments).unwrap();
assert!(args["text"].as_str().unwrap().contains("世界"));
}
StreamResult::ToolName { name, .. } => {
assert_eq!(name, "translate");
// Phase 2 partial streaming behavior - acceptable
}
StreamResult::ToolArguments { arguments, .. } => {
// Verify unicode was preserved
let args: serde_json::Value = serde_json::from_str(&arguments).unwrap();
assert!(args["text"].as_str().unwrap().contains("世界"));
}
other => {
panic!("Unexpected result: {:?}", other);
}
}
}

View File

@@ -0,0 +1,247 @@
//! Wrapper Token Tests
//!
//! Tests for JSON parser with custom wrapper tokens
use sglang_router_rs::tool_parser::{JsonParser, TokenConfig, ToolParser};
#[tokio::test]
async fn test_json_with_xml_style_wrapper() {
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<tool>".to_string()],
end_tokens: vec!["</tool>".to_string()],
separator: ", ".to_string(),
});
let input =
r#"Some text before <tool>{"name": "test", "arguments": {"x": 1}}</tool> and after"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["x"], 1);
}
#[tokio::test]
async fn test_json_with_multiple_wrapper_pairs() {
// Test with multiple start/end token pairs
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<tool>".to_string(), "<<TOOL>>".to_string()],
end_tokens: vec!["</tool>".to_string(), "<</TOOL>>".to_string()],
separator: ", ".to_string(),
});
// Test first pair
let input1 = r#"<tool>{"name": "tool1", "arguments": {}}</tool>"#;
let result1 = parser.parse_complete(input1).await.unwrap();
assert_eq!(result1.len(), 1);
assert_eq!(result1[0].function.name, "tool1");
// Test second pair
let input2 = r#"<<TOOL>>{"name": "tool2", "arguments": {}}<</TOOL>>"#;
let result2 = parser.parse_complete(input2).await.unwrap();
assert_eq!(result2.len(), 1);
assert_eq!(result2[0].function.name, "tool2");
}
#[tokio::test]
async fn test_json_with_only_start_token() {
// Test when only start token is provided (no end token)
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec![">>>FUNCTION:".to_string()],
end_tokens: vec!["".to_string()], // Empty end token
separator: ", ".to_string(),
});
let input = r#"Some preamble >>>FUNCTION:{"name": "execute", "arguments": {"cmd": "ls"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "execute");
}
#[tokio::test]
async fn test_json_with_custom_separator() {
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["[FUNC]".to_string()],
end_tokens: vec!["[/FUNC]".to_string()],
separator: " | ".to_string(), // Custom separator
});
// Though we're not testing multiple tools here, the separator is configured
let input = r#"[FUNC]{"name": "test", "arguments": {}}[/FUNC]"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
}
#[tokio::test]
async fn test_json_with_nested_wrapper_tokens_in_content() {
// Known limitation: When wrapper tokens appear inside JSON strings,
// the simple regex-based extraction may fail. This would require
// a more sophisticated parser that understands JSON string escaping.
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<call>".to_string()],
end_tokens: vec!["</call>".to_string()],
separator: ", ".to_string(),
});
let input =
r#"<call>{"name": "echo", "arguments": {"text": "Use <call> and </call> tags"}}</call>"#;
let result = parser.parse_complete(input).await.unwrap();
// This is a known limitation - the parser may fail when end tokens appear in content
// For now, we accept this behavior
if result.is_empty() {
// Parser failed due to nested tokens - this is expected
assert_eq!(
result.len(),
0,
"Known limitation: nested wrapper tokens in content"
);
} else {
// If it does parse, verify it's correct
assert_eq!(result[0].function.name, "echo");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["text"], "Use <call> and </call> tags");
}
}
#[tokio::test]
async fn test_json_extraction_without_wrapper_tokens() {
// Default parser without wrapper tokens should extract JSON from text
let parser = JsonParser::new();
let input = r#"
Here is some text before the JSON.
{"name": "search", "arguments": {"query": "test"}}
And here is some text after.
"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "search");
}
#[tokio::test]
async fn test_json_with_multiline_wrapper_content() {
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["```json\n".to_string()],
end_tokens: vec!["\n```".to_string()],
separator: ", ".to_string(),
});
let input = r#"Here's the function call:
```json
{
"name": "format_code",
"arguments": {
"language": "rust",
"code": "fn main() {}"
}
}
```
Done!"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "format_code");
}
#[tokio::test]
async fn test_json_with_special_chars_in_tokens() {
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["{{FUNC[[".to_string()],
end_tokens: vec!["]]FUNC}}".to_string()],
separator: ", ".to_string(),
});
let input = r#"{{FUNC[[{"name": "test", "arguments": {"special": "[]{}"}}]]FUNC}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
let args: serde_json::Value = serde_json::from_str(&result[0].function.arguments).unwrap();
assert_eq!(args["special"], "[]{}");
}
#[tokio::test]
async fn test_json_multiple_tools_with_wrapper() {
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<fn>".to_string()],
end_tokens: vec!["</fn>".to_string()],
separator: ", ".to_string(),
});
// Multiple wrapped JSON objects
let input = r#"
<fn>{"name": "tool1", "arguments": {}}</fn>
Some text between.
<fn>{"name": "tool2", "arguments": {"x": 1}}</fn>
"#;
// Current implementation might handle this as separate calls
// Let's test that at least the first one is parsed
let result = parser.parse_complete(input).await.unwrap();
assert!(!result.is_empty(), "Should parse at least one tool");
assert_eq!(result[0].function.name, "tool1");
}
#[tokio::test]
async fn test_json_wrapper_with_array() {
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<tools>".to_string()],
end_tokens: vec!["</tools>".to_string()],
separator: ", ".to_string(),
});
let input = r#"<tools>[
{"name": "func1", "arguments": {}},
{"name": "func2", "arguments": {"param": "value"}}
]</tools>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 2);
assert_eq!(result[0].function.name, "func1");
assert_eq!(result[1].function.name, "func2");
}
#[tokio::test]
async fn test_json_incomplete_wrapper_tokens() {
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec!["<tool>".to_string()],
end_tokens: vec!["</tool>".to_string()],
separator: ", ".to_string(),
});
// Missing end token
let input = r#"<tool>{"name": "test", "arguments": {}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0, "Should not parse without closing token");
// Missing start token
let input = r#"{"name": "test", "arguments": {}}</tool>"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 0, "Should not parse without opening token");
}
#[tokio::test]
async fn test_json_empty_wrapper_tokens() {
// Test with empty wrapper tokens (should behave like default)
let parser = JsonParser::with_config(TokenConfig {
start_tokens: vec![],
end_tokens: vec![],
separator: ", ".to_string(),
});
let input = r#"{"name": "test", "arguments": {"key": "value"}}"#;
let result = parser.parse_complete(input).await.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(result[0].function.name, "test");
}