[router] migrate router from actix to axum (#8479)

This commit is contained in:
Simo Lin
2025-07-30 17:47:19 -07:00
committed by GitHub
parent 299803343d
commit 66a398f49d
18 changed files with 3626 additions and 3549 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -1,62 +1,2 @@
pub mod mock_worker;
use actix_web::web;
use reqwest::Client;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
use sglang_router_rs::server::AppState;
/// Helper function to create test router configuration
pub fn create_test_config(worker_urls: Vec<String>) -> RouterConfig {
RouterConfig {
mode: RoutingMode::Regular { worker_urls },
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3001,
max_payload_size: 256 * 1024 * 1024, // 256MB
request_timeout_secs: 600,
worker_startup_timeout_secs: 300,
worker_startup_check_interval_secs: 10,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
}
}
/// Helper function to create test router configuration with no health check
pub fn create_test_config_no_workers() -> RouterConfig {
RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec![],
}, // Empty to skip health check
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3001,
max_payload_size: 256 * 1024 * 1024, // 256MB
request_timeout_secs: 600,
worker_startup_timeout_secs: 0, // No wait
worker_startup_check_interval_secs: 10,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
log_level: None,
request_id_headers: None,
}
}
/// Helper function to create test app state
pub async fn create_test_app_state(config: RouterConfig) -> Result<web::Data<AppState>, String> {
// Create a non-blocking client
let client = Client::builder()
.timeout(std::time::Duration::from_secs(config.request_timeout_secs))
.build()
.map_err(|e| e.to_string())?;
let app_state = AppState::new(config, client)?;
Ok(web::Data::new(app_state))
}
pub mod test_app;

View File

@@ -0,0 +1,42 @@
use axum::Router;
use reqwest::Client;
use sglang_router_rs::{
config::RouterConfig,
routers::RouterTrait,
server::{build_app, AppState},
};
use std::sync::Arc;
/// Create a test Axum application using the actual server's build_app function
pub fn create_test_app(
router: Arc<dyn RouterTrait>,
client: Client,
router_config: &RouterConfig,
) -> Router {
// Create AppState with the test router
let app_state = Arc::new(AppState {
router,
client,
_concurrency_limiter: Arc::new(tokio::sync::Semaphore::new(
router_config.max_concurrent_requests,
)),
});
// Configure request ID headers (use defaults if not specified)
let request_id_headers = router_config.request_id_headers.clone().unwrap_or_else(|| {
vec![
"x-request-id".to_string(),
"x-correlation-id".to_string(),
"x-trace-id".to_string(),
"request-id".to_string(),
]
});
// Use the actual server's build_app function
build_app(
app_state,
router_config.max_payload_size,
request_id_headers,
router_config.cors_allowed_origins.clone(),
)
}

View File

@@ -1,43 +1,27 @@
mod common;
use actix_web::{http::StatusCode, rt::System, test as actix_test, web, App};
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use reqwest::Client;
use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
use sglang_router_rs::server::{
add_worker, generate, v1_chat_completions, v1_completions, AppState,
};
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc;
/// Test context for request type testing
struct RequestTestContext {
/// Test context that manages mock workers
struct TestContext {
workers: Vec<MockWorker>,
app_state: web::Data<AppState>,
router: Arc<dyn RouterTrait>,
}
impl RequestTestContext {
impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
let mut workers = Vec::new();
let mut worker_urls = Vec::new();
// Start mock workers
for config in worker_configs {
let mut worker = MockWorker::new(config);
let url = worker.start().await.unwrap();
worker_urls.push(url);
workers.push(worker);
}
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
// Create router config
let config = RouterConfig {
let mut config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec![],
},
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3006,
port: 3003,
max_payload_size: 256 * 1024 * 1024,
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
@@ -49,528 +33,348 @@ impl RequestTestContext {
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
};
let client = Client::builder()
.timeout(std::time::Duration::from_secs(config.request_timeout_secs))
.build()
.unwrap();
let mut workers = Vec::new();
let mut worker_urls = Vec::new();
let app_state = AppState::new(config, client).unwrap();
let app_state = web::Data::new(app_state);
// Add workers via HTTP API
let app =
actix_test::init_service(App::new().app_data(app_state.clone()).service(add_worker))
.await;
for url in &worker_urls {
let req = actix_test::TestRequest::post()
.uri(&format!("/add_worker?url={}", url))
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert!(resp.status().is_success());
for worker_config in worker_configs {
let mut worker = MockWorker::new(worker_config);
let url = worker.start().await.unwrap();
worker_urls.push(url);
workers.push(worker);
}
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
if !workers.is_empty() {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
}
Self { workers, app_state }
}
config.mode = RoutingMode::Regular { worker_urls };
async fn create_app(
&self,
) -> impl actix_web::dev::Service<
actix_http::Request,
Response = actix_web::dev::ServiceResponse,
Error = actix_web::Error,
> {
actix_test::init_service(
App::new()
.app_data(self.app_state.clone())
.service(generate)
.service(v1_chat_completions)
.service(v1_completions),
)
.await
let router = tokio::task::spawn_blocking(move || RouterFactory::create_router(&config))
.await
.unwrap()
.unwrap();
let router = Arc::from(router);
if !workers.is_empty() {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
}
Self { workers, router }
}
async fn shutdown(mut self) {
// Small delay to ensure any pending operations complete
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
for worker in &mut self.workers {
worker.stop().await;
}
// Another small delay to ensure cleanup completes
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
async fn make_request(
&self,
endpoint: &str,
body: serde_json::Value,
) -> Result<serde_json::Value, String> {
let client = Client::new();
// Get any worker URL for testing
let worker_urls = self.router.get_worker_urls();
if worker_urls.is_empty() {
return Err("No available workers".to_string());
}
let worker_url = &worker_urls[0];
let response = client
.post(&format!("{}{}", worker_url, endpoint))
.json(&body)
.send()
.await
.map_err(|e| format!("Request failed: {}", e))?;
if !response.status().is_success() {
return Err(format!("Request failed with status: {}", response.status()));
}
response
.json::<serde_json::Value>()
.await
.map_err(|e| format!("Failed to parse response: {}", e))
}
}
#[cfg(test)]
mod generate_input_format_tests {
mod request_format_tests {
use super::*;
#[test]
fn test_generate_with_text_input() {
System::new().block_on(async {
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
port: 21001,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
#[tokio::test]
async fn test_generate_request_formats() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 19001,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
// Standard text input
let payload = json!({
"text": "Hello world",
"stream": false
});
let req = actix_test::TestRequest::post()
.uri("/generate")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
let body: serde_json::Value = actix_test::read_body_json(resp).await;
assert!(body.get("text").is_some());
ctx.shutdown().await;
// Test 1: Basic text request
let payload = json!({
"text": "Hello, world!",
"stream": false
});
}
#[test]
fn test_generate_with_prompt_input() {
System::new().block_on(async {
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
port: 21002,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
let app = ctx.create_app().await;
// Prompt input (alternative to text)
let payload = json!({
"prompt": "Once upon a time",
"stream": false
});
let req = actix_test::TestRequest::post()
.uri("/generate")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
ctx.shutdown().await;
});
}
#[test]
fn test_generate_with_input_ids() {
System::new().block_on(async {
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
port: 21003,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
// Input IDs (tokenized input)
let payload = json!({
"input_ids": [1, 2, 3, 4, 5],
"stream": false
});
let req = actix_test::TestRequest::post()
.uri("/generate")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
ctx.shutdown().await;
});
}
#[test]
fn test_generate_with_all_parameters() {
System::new().block_on(async {
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
port: 21004,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
// All generation parameters
let payload = json!({
"text": "Complete this",
// Test 2: Request with sampling parameters
let payload = json!({
"text": "Tell me a story",
"sampling_params": {
"temperature": 0.7,
"top_p": 0.9,
"top_k": 50,
"max_new_tokens": 100,
"min_new_tokens": 10,
"frequency_penalty": 0.5,
"presence_penalty": 0.3,
"repetition_penalty": 1.1,
"stop": [".", "!", "?"],
"stream": false
});
let req = actix_test::TestRequest::post()
.uri("/generate")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
ctx.shutdown().await;
});
}
}
#[cfg(test)]
mod chat_completion_format_tests {
use super::*;
#[test]
fn test_chat_with_system_message() {
System::new().block_on(async {
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
port: 21010,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"model": "test-model",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}
]
});
let req = actix_test::TestRequest::post()
.uri("/v1/chat/completions")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
ctx.shutdown().await;
});
}
// Note: Function calling and tools tests are commented out because
// they require special handling in the mock worker that's not implemented yet.
// In production, these would be forwarded to the actual model.
// #[test]
// fn test_chat_with_function_calling() {
// // Test would go here when mock worker supports function calling
// }
// #[test]
// fn test_chat_with_tools() {
// // Test would go here when mock worker supports tools
// }
#[test]
fn test_chat_with_response_format() {
System::new().block_on(async {
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
port: 21013,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Return JSON"}
],
"response_format": {
"type": "json_object"
}
});
let req = actix_test::TestRequest::post()
.uri("/v1/chat/completions")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
ctx.shutdown().await;
});
}
}
#[cfg(test)]
mod completion_format_tests {
use super::*;
#[test]
fn test_completion_with_single_prompt() {
System::new().block_on(async {
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
port: 21020,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"model": "test-model",
"prompt": "Once upon a time",
"max_tokens": 50
});
let req = actix_test::TestRequest::post()
.uri("/v1/completions")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
let body: serde_json::Value = actix_test::read_body_json(resp).await;
assert!(body.get("choices").is_some());
ctx.shutdown().await;
});
}
#[test]
fn test_completion_with_batch_prompts() {
System::new().block_on(async {
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
port: 21021,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"model": "test-model",
"prompt": ["First prompt", "Second prompt", "Third prompt"],
"max_tokens": 30
});
let req = actix_test::TestRequest::post()
.uri("/v1/completions")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
ctx.shutdown().await;
});
}
#[test]
fn test_completion_with_echo() {
System::new().block_on(async {
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
port: 21022,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"model": "test-model",
"prompt": "Echo this prompt",
"echo": true,
"max_tokens": 20
});
let req = actix_test::TestRequest::post()
.uri("/v1/completions")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
ctx.shutdown().await;
});
}
#[test]
fn test_completion_with_logprobs() {
System::new().block_on(async {
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
port: 21023,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"model": "test-model",
"prompt": "Calculate probability",
"logprobs": 5,
"max_tokens": 10
});
let req = actix_test::TestRequest::post()
.uri("/v1/completions")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
ctx.shutdown().await;
});
}
#[test]
fn test_completion_with_suffix() {
System::new().block_on(async {
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
port: 21024,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"model": "test-model",
"prompt": "Insert text here: ",
"suffix": " and continue from here.",
"max_tokens": 20
});
let req = actix_test::TestRequest::post()
.uri("/v1/completions")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
ctx.shutdown().await;
});
}
}
#[cfg(test)]
mod stop_sequence_tests {
use super::*;
#[test]
fn test_stop_sequences_array() {
System::new().block_on(async {
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
port: 21030,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"text": "Generate until stop",
"stop": [".", "!", "?", "\n"],
"stream": false
});
let req = actix_test::TestRequest::post()
.uri("/generate")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
ctx.shutdown().await;
});
}
#[test]
fn test_stop_sequences_string() {
System::new().block_on(async {
let ctx = RequestTestContext::new(vec![MockWorkerConfig {
port: 21031,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"text": "Generate until stop",
"stop": "\n\n",
"stream": false
});
let req = actix_test::TestRequest::post()
.uri("/generate")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
ctx.shutdown().await;
"top_p": 0.9
},
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
// Test 3: Request with input_ids
let payload = json!({
"input_ids": [1, 2, 3, 4, 5],
"sampling_params": {
"temperature": 0.0,
"max_new_tokens": 50
},
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_v1_chat_completions_formats() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 19002,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
// Test 1: Basic chat completion
let payload = json!({
"model": "test-model",
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"}
],
"stream": false
});
let result = ctx.make_request("/v1/chat/completions", payload).await;
assert!(result.is_ok());
let response = result.unwrap();
assert!(response.get("choices").is_some());
assert!(response.get("id").is_some());
assert_eq!(
response.get("object").and_then(|v| v.as_str()),
Some("chat.completion")
);
// Test 2: Chat completion with parameters
let payload = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Tell me a joke"}
],
"temperature": 0.8,
"max_tokens": 150,
"top_p": 0.95,
"stream": false
});
let result = ctx.make_request("/v1/chat/completions", payload).await;
assert!(result.is_ok());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_v1_completions_formats() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 19003,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
// Test 1: Basic completion
let payload = json!({
"model": "test-model",
"prompt": "Once upon a time",
"max_tokens": 50,
"stream": false
});
let result = ctx.make_request("/v1/completions", payload).await;
assert!(result.is_ok());
let response = result.unwrap();
assert!(response.get("choices").is_some());
assert_eq!(
response.get("object").and_then(|v| v.as_str()),
Some("text_completion")
);
// Test 2: Completion with array prompt
let payload = json!({
"model": "test-model",
"prompt": ["First prompt", "Second prompt"],
"temperature": 0.5,
"stream": false
});
let result = ctx.make_request("/v1/completions", payload).await;
assert!(result.is_ok());
// Test 3: Completion with logprobs
let payload = json!({
"model": "test-model",
"prompt": "The capital of France is",
"max_tokens": 10,
"logprobs": 5,
"stream": false
});
let result = ctx.make_request("/v1/completions", payload).await;
assert!(result.is_ok());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_batch_requests() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 19004,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
// Test batch text generation
let payload = json!({
"text": ["First text", "Second text", "Third text"],
"sampling_params": {
"temperature": 0.7,
"max_new_tokens": 50
},
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
// Test batch with input_ids
let payload = json!({
"input_ids": [[1, 2, 3], [4, 5, 6], [7, 8, 9]],
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_special_parameters() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 19005,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
// Test with return_logprob
let payload = json!({
"text": "Test",
"return_logprob": true,
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
// Test with json_schema
let payload = json!({
"text": "Generate JSON",
"sampling_params": {
"temperature": 0.0,
"json_schema": "$$ANY$$"
},
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
// Test with ignore_eos
let payload = json!({
"text": "Continue forever",
"sampling_params": {
"temperature": 0.7,
"max_new_tokens": 100,
"ignore_eos": true
},
"stream": false
});
let result = ctx.make_request("/generate", payload).await;
assert!(result.is_ok());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_error_handling() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 19006,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
// Test with empty body - should still work with mock worker
let payload = json!({});
let result = ctx.make_request("/generate", payload).await;
// Mock worker accepts empty body
assert!(result.is_ok());
ctx.shutdown().await;
}
}

View File

@@ -1,47 +1,28 @@
mod common;
use actix_web::{http::StatusCode, rt::System, test as actix_test, web, App};
use bytes::Bytes;
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
use futures_util::StreamExt;
use reqwest::Client;
use serde_json::json;
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
use sglang_router_rs::server::{
add_worker, generate, list_workers, v1_chat_completions, v1_completions, AppState,
};
use std::time::Instant;
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
use std::sync::Arc;
/// Test context for streaming tests
struct StreamingTestContext {
/// Test context that manages mock workers
struct TestContext {
workers: Vec<MockWorker>,
app_state: web::Data<AppState>,
router: Arc<dyn RouterTrait>,
}
impl StreamingTestContext {
impl TestContext {
async fn new(worker_configs: Vec<MockWorkerConfig>) -> Self {
let mut workers = Vec::new();
let mut worker_urls = Vec::new();
// Start mock workers
for config in worker_configs {
let mut worker = MockWorker::new(config);
let url = worker.start().await.unwrap();
worker_urls.push(url);
workers.push(worker);
}
// Give workers time to start
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
// Create router config with empty worker URLs initially
// We'll add workers via the /add_worker endpoint
let config = RouterConfig {
let mut config = RouterConfig {
mode: RoutingMode::Regular {
worker_urls: vec![],
},
policy: PolicyConfig::Random,
host: "127.0.0.1".to_string(),
port: 3003,
port: 3004,
max_payload_size: 256 * 1024 * 1024,
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
@@ -53,530 +34,325 @@ impl StreamingTestContext {
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
};
let client = Client::builder()
.timeout(std::time::Duration::from_secs(config.request_timeout_secs))
.build()
.unwrap();
let mut workers = Vec::new();
let mut worker_urls = Vec::new();
let app_state = AppState::new(config, client).unwrap();
let app_state = web::Data::new(app_state);
// Add workers via HTTP API
let app =
actix_test::init_service(App::new().app_data(app_state.clone()).service(add_worker))
.await;
for url in &worker_urls {
let req = actix_test::TestRequest::post()
.uri(&format!("/add_worker?url={}", url))
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert!(resp.status().is_success());
for worker_config in worker_configs {
let mut worker = MockWorker::new(worker_config);
let url = worker.start().await.unwrap();
worker_urls.push(url);
workers.push(worker);
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
if !workers.is_empty() {
tokio::time::sleep(tokio::time::Duration::from_millis(200)).await;
}
Self { workers, app_state }
}
config.mode = RoutingMode::Regular { worker_urls };
async fn create_app(
&self,
) -> impl actix_web::dev::Service<
actix_http::Request,
Response = actix_web::dev::ServiceResponse,
Error = actix_web::Error,
> {
actix_test::init_service(
App::new()
.app_data(self.app_state.clone())
.service(generate)
.service(v1_chat_completions)
.service(v1_completions)
.service(list_workers),
)
.await
let router = tokio::task::spawn_blocking(move || RouterFactory::create_router(&config))
.await
.unwrap()
.unwrap();
let router = Arc::from(router);
if !workers.is_empty() {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
}
Self { workers, router }
}
async fn shutdown(mut self) {
// Small delay to ensure any pending operations complete
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
for worker in &mut self.workers {
worker.stop().await;
}
// Another small delay to ensure cleanup completes
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
}
/// Parse SSE (Server-Sent Events) from response body
async fn parse_sse_stream(body: Bytes) -> Vec<serde_json::Value> {
let text = String::from_utf8_lossy(&body);
let mut events = Vec::new();
async fn make_streaming_request(
&self,
endpoint: &str,
body: serde_json::Value,
) -> Result<Vec<String>, String> {
let client = Client::new();
for line in text.lines() {
if line.starts_with("data: ") {
let data = &line[6..];
if data == "[DONE]" {
continue;
}
if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
events.push(json);
}
// Get any worker URL for testing
let worker_urls = self.router.get_worker_urls();
if worker_urls.is_empty() {
return Err("No available workers".to_string());
}
}
events
}
#[cfg(test)]
mod basic_streaming_tests {
use super::*;
#[test]
fn test_router_uses_mock_workers() {
System::new().block_on(async {
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
port: 19000,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
// Verify workers are registered with the router
let req = actix_test::TestRequest::get()
.uri("/list_workers")
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
let body: serde_json::Value = actix_test::read_body_json(resp).await;
let urls = body["urls"].as_array().unwrap();
assert_eq!(urls.len(), 1);
assert!(urls[0].as_str().unwrap().contains("19000"));
ctx.shutdown().await;
});
}
#[test]
fn test_generate_streaming() {
System::new().block_on(async {
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
port: 19001,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"text": "Hello, streaming world!",
"stream": true,
"max_new_tokens": 50
});
let req = actix_test::TestRequest::post()
.uri("/generate")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
// Check content type
let content_type = resp.headers().get("content-type").unwrap();
assert_eq!(content_type, "text/event-stream");
// Read streaming body
let body = actix_test::read_body(resp).await;
let events = parse_sse_stream(body).await;
// Verify we got multiple chunks
assert!(events.len() > 1);
// Verify first chunk has text
assert!(events[0].get("text").is_some());
// Verify last chunk has finish_reason in meta_info
let last_event = events.last().unwrap();
assert!(last_event.get("meta_info").is_some());
let meta_info = &last_event["meta_info"];
assert!(meta_info.get("finish_reason").is_some());
ctx.shutdown().await;
});
}
#[test]
fn test_chat_completion_streaming() {
System::new().block_on(async {
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
port: 19002,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Hello, streaming!"}
],
"stream": true
});
let req = actix_test::TestRequest::post()
.uri("/v1/chat/completions")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get("content-type").unwrap(),
"text/event-stream"
);
let body = actix_test::read_body(resp).await;
let events = parse_sse_stream(body).await;
// Verify we got streaming events
// Note: Mock doesn't provide full OpenAI format, just verify we got chunks
assert!(!events.is_empty(), "Should have received streaming events");
ctx.shutdown().await;
});
}
#[test]
fn test_completion_streaming() {
System::new().block_on(async {
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
port: 19003,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"model": "test-model",
"prompt": "Once upon a time",
"stream": true,
"max_tokens": 30
});
let req = actix_test::TestRequest::post()
.uri("/v1/completions")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
assert_eq!(
resp.headers().get("content-type").unwrap(),
"text/event-stream"
);
let _body = actix_test::read_body(resp).await;
ctx.shutdown().await;
});
}
}
#[cfg(test)]
mod streaming_performance_tests {
use super::*;
#[test]
fn test_streaming_first_token_latency() {
System::new().block_on(async {
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
port: 19010,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 10, // Small delay to simulate processing
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"text": "Measure latency",
"stream": true
});
let req = actix_test::TestRequest::post()
.uri("/generate")
.set_json(&payload)
.to_request();
let start = Instant::now();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
// Note: actix_test framework doesn't provide easy access to streaming chunks.
// The ideal solution would be to:
// 1. Start the router as a real HTTP server
// 2. Use reqwest::Client to make streaming requests
// 3. Measure time to first chunk properly
//
// For now, we verify that streaming responses work correctly,
// but cannot accurately measure TTFT with actix_test.
let body = actix_test::read_body(resp).await;
let total_time = start.elapsed();
// Verify we got streaming data
let events = parse_sse_stream(body).await;
assert!(!events.is_empty(), "Should receive streaming events");
// With mock worker delay of 10ms, total time should still be reasonable
assert!(
total_time.as_millis() < 1000,
"Total response took {}ms",
total_time.as_millis()
);
ctx.shutdown().await;
});
}
#[test]
fn test_concurrent_streaming_requests() {
System::new().block_on(async {
// Test basic concurrent streaming functionality
let ctx = StreamingTestContext::new(vec![
MockWorkerConfig {
port: 19050,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
},
MockWorkerConfig {
port: 19051,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
},
])
.await;
let app = ctx.create_app().await;
// Send a moderate number of concurrent requests for unit testing
use futures::future::join_all;
let mut futures = Vec::new();
for i in 0..20 {
let app_ref = &app;
let future = async move {
let payload = json!({
"text": format!("Concurrent request {}", i),
"stream": true,
"max_new_tokens": 5
});
let req = actix_test::TestRequest::post()
.uri("/generate")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(app_ref, req).await;
resp.status() == StatusCode::OK
};
futures.push(future);
}
let results = join_all(futures).await;
let successful = results.iter().filter(|&&r| r).count();
// All requests should succeed in a unit test environment
assert_eq!(
successful, 20,
"Expected all 20 requests to succeed, got {}",
successful
);
ctx.shutdown().await;
});
}
// Note: Extreme load testing has been moved to benches/streaming_load_test.rs
// Run with: cargo run --release --bin streaming_load_test 10000 10
// Or: cargo bench streaming_load_test
}
#[cfg(test)]
mod streaming_error_tests {
use super::*;
#[test]
fn test_streaming_with_worker_failure() {
System::new().block_on(async {
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
port: 19020,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 1.0, // Always fail
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"text": "This should fail",
"stream": true
});
let req = actix_test::TestRequest::post()
.uri("/generate")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::INTERNAL_SERVER_ERROR);
ctx.shutdown().await;
});
}
#[test]
fn test_streaming_with_invalid_payload() {
System::new().block_on(async {
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
port: 19021,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
// Missing required fields
"stream": true
});
let req = actix_test::TestRequest::post()
.uri("/generate")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
// TODO: Router should validate payload and reject requests with missing content fields
// Currently, the router accepts requests with no prompt/text/input_ids which is a bug
// This should return StatusCode::BAD_REQUEST once proper validation is implemented
assert_eq!(resp.status(), StatusCode::OK);
ctx.shutdown().await;
});
}
}
#[cfg(test)]
mod streaming_content_tests {
use super::*;
#[test]
fn test_unicode_streaming() {
System::new().block_on(async {
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
port: 19030,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"text": "Test Unicode: 你好世界 🌍 émojis",
"stream": true
});
let req = actix_test::TestRequest::post()
.uri("/generate")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
let body = actix_test::read_body(resp).await;
let events = parse_sse_stream(body).await;
// Verify events were parsed correctly (Unicode didn't break parsing)
assert!(!events.is_empty());
ctx.shutdown().await;
});
}
#[test]
fn test_incremental_text_building() {
System::new().block_on(async {
let ctx = StreamingTestContext::new(vec![MockWorkerConfig {
port: 19031,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
let payload = json!({
"text": "Build text incrementally",
"stream": true
});
let req = actix_test::TestRequest::post()
.uri("/generate")
.set_json(&payload)
.to_request();
let resp = actix_test::call_service(&app, req).await;
assert_eq!(resp.status(), StatusCode::OK);
let body = actix_test::read_body(resp).await;
let events = parse_sse_stream(body).await;
// Build complete text from chunks
let mut complete_text = String::new();
for event in &events {
if let Some(text) = event.get("text").and_then(|t| t.as_str()) {
complete_text.push_str(text);
let worker_url = &worker_urls[0];
let response = client
.post(&format!("{}{}", worker_url, endpoint))
.json(&body)
.send()
.await
.map_err(|e| format!("Request failed: {}", e))?;
if !response.status().is_success() {
return Err(format!("Request failed with status: {}", response.status()));
}
// Check if it's a streaming response
let content_type = response
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
if !content_type.contains("text/event-stream") {
return Err("Response is not a stream".to_string());
}
let mut stream = response.bytes_stream();
let mut events = Vec::new();
while let Some(chunk) = stream.next().await {
if let Ok(bytes) = chunk {
let text = String::from_utf8_lossy(&bytes);
for line in text.lines() {
if line.starts_with("data: ") {
events.push(line[6..].to_string());
}
}
}
}
// Verify we got some text
assert!(!complete_text.is_empty());
ctx.shutdown().await;
});
Ok(events)
}
}
#[cfg(test)]
mod streaming_tests {
use super::*;
#[tokio::test]
async fn test_generate_streaming() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 20001,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 10,
fail_rate: 0.0,
}])
.await;
let payload = json!({
"text": "Stream test",
"stream": true,
"sampling_params": {
"temperature": 0.7,
"max_new_tokens": 10
}
});
let result = ctx.make_streaming_request("/generate", payload).await;
assert!(result.is_ok());
let events = result.unwrap();
// Should have at least one data chunk and [DONE]
assert!(events.len() >= 2);
assert_eq!(events.last().unwrap(), "[DONE]");
ctx.shutdown().await;
}
#[tokio::test]
async fn test_v1_chat_completions_streaming() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 20002,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 10,
fail_rate: 0.0,
}])
.await;
let payload = json!({
"model": "test-model",
"messages": [
{"role": "user", "content": "Count to 3"}
],
"stream": true,
"max_tokens": 20
});
let result = ctx
.make_streaming_request("/v1/chat/completions", payload)
.await;
assert!(result.is_ok());
let events = result.unwrap();
assert!(events.len() >= 2); // At least one chunk + [DONE]
// Verify events are valid JSON (except [DONE])
for event in &events {
if event != "[DONE]" {
let parsed: Result<serde_json::Value, _> = serde_json::from_str(event);
assert!(parsed.is_ok(), "Invalid JSON in SSE event: {}", event);
let json = parsed.unwrap();
assert_eq!(
json.get("object").and_then(|v| v.as_str()),
Some("chat.completion.chunk")
);
}
}
ctx.shutdown().await;
}
#[tokio::test]
async fn test_v1_completions_streaming() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 20003,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 10,
fail_rate: 0.0,
}])
.await;
let payload = json!({
"model": "test-model",
"prompt": "Once upon a time",
"stream": true,
"max_tokens": 15
});
let result = ctx.make_streaming_request("/v1/completions", payload).await;
assert!(result.is_ok());
let events = result.unwrap();
assert!(events.len() >= 2); // At least one chunk + [DONE]
ctx.shutdown().await;
}
#[tokio::test]
async fn test_streaming_with_error() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 20004,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 1.0, // Always fail
}])
.await;
let payload = json!({
"text": "This should fail",
"stream": true
});
let result = ctx.make_streaming_request("/generate", payload).await;
// With fail_rate: 1.0, the request should fail
assert!(result.is_err());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_streaming_timeouts() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 20005,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 100, // Slow response
fail_rate: 0.0,
}])
.await;
let payload = json!({
"text": "Slow stream",
"stream": true,
"sampling_params": {
"max_new_tokens": 5
}
});
let start = std::time::Instant::now();
let result = ctx.make_streaming_request("/generate", payload).await;
let elapsed = start.elapsed();
assert!(result.is_ok());
let events = result.unwrap();
// Should have received multiple chunks over time
assert!(!events.is_empty());
assert!(elapsed.as_millis() >= 100); // At least one delay
ctx.shutdown().await;
}
#[tokio::test]
async fn test_batch_streaming() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 20006,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 10,
fail_rate: 0.0,
}])
.await;
// Batch request with streaming
let payload = json!({
"text": ["First", "Second", "Third"],
"stream": true,
"sampling_params": {
"max_new_tokens": 5
}
});
let result = ctx.make_streaming_request("/generate", payload).await;
assert!(result.is_ok());
let events = result.unwrap();
// Should have multiple events for batch
assert!(events.len() >= 4); // At least 3 responses + [DONE]
ctx.shutdown().await;
}
#[tokio::test]
async fn test_sse_format_parsing() {
// Test SSE format parsing
let parse_sse_chunk = |chunk: &[u8]| -> Vec<String> {
let text = String::from_utf8_lossy(chunk);
text.lines()
.filter(|line| line.starts_with("data: "))
.map(|line| line[6..].to_string())
.collect()
};
let sse_data =
b"data: {\"text\":\"Hello\"}\n\ndata: {\"text\":\" world\"}\n\ndata: [DONE]\n\n";
let events = parse_sse_chunk(sse_data);
assert_eq!(events.len(), 3);
assert_eq!(events[0], "{\"text\":\"Hello\"}");
assert_eq!(events[1], "{\"text\":\" world\"}");
assert_eq!(events[2], "[DONE]");
// Test with mixed content
let mixed = b"event: message\ndata: {\"test\":true}\n\n: comment\ndata: [DONE]\n\n";
let events = parse_sse_chunk(mixed);
assert_eq!(events.len(), 2);
assert_eq!(events[0], "{\"test\":true}");
assert_eq!(events[1], "[DONE]");
}
}

View File

@@ -176,6 +176,8 @@ mod test_pd_routing {
log_dir: None,
log_level: None,
request_id_headers: None,
max_concurrent_requests: 64,
cors_allowed_origins: vec![],
};
// Router creation will fail due to health checks, but config should be valid