diff --git a/sgl-router/Cargo.toml b/sgl-router/Cargo.toml index b23b6d7ac..74b1ed129 100644 --- a/sgl-router/Cargo.toml +++ b/sgl-router/Cargo.toml @@ -42,6 +42,9 @@ url = "2.5.4" [dev-dependencies] criterion = { version = "0.5", features = ["html_reports"] } +tokio-stream = "0.1" +actix-http = "3.0" +futures = "0.3" [[bench]] name = "request_processing" diff --git a/sgl-router/tests/common/mock_worker.rs b/sgl-router/tests/common/mock_worker.rs new file mode 100644 index 000000000..c5129febc --- /dev/null +++ b/sgl-router/tests/common/mock_worker.rs @@ -0,0 +1,650 @@ +use actix_web::{middleware, web, App, HttpRequest, HttpResponse, HttpServer}; +use futures_util::StreamExt; +use serde_json::json; +use std::sync::Arc; +use std::time::{SystemTime, UNIX_EPOCH}; +use tokio::sync::RwLock; +use uuid; + +/// Configuration for mock worker behavior +#[derive(Clone)] +pub struct MockWorkerConfig { + pub port: u16, + pub worker_type: WorkerType, + pub health_status: HealthStatus, + pub response_delay_ms: u64, + pub fail_rate: f32, +} + +#[derive(Clone, Debug)] +pub enum WorkerType { + Regular, + Prefill, + Decode, +} + +#[derive(Clone, Debug)] +pub enum HealthStatus { + Healthy, + Unhealthy, + Degraded, +} + +/// Mock worker server for testing +pub struct MockWorker { + config: Arc>, + server_handle: Option, +} + +impl MockWorker { + pub fn new(config: MockWorkerConfig) -> Self { + Self { + config: Arc::new(RwLock::new(config)), + server_handle: None, + } + } + + /// Start the mock worker server + pub async fn start(&mut self) -> Result> { + let config = self.config.clone(); + let port = config.read().await.port; + + let server = HttpServer::new(move || { + App::new() + .app_data(web::Data::new(config.clone())) + .wrap(middleware::Logger::default()) + .route("/health", web::get().to(health_handler)) + .route("/health_generate", web::get().to(health_generate_handler)) + .route("/get_server_info", web::get().to(server_info_handler)) + .route("/get_model_info", web::get().to(model_info_handler)) + .route("/generate", web::post().to(generate_handler)) + .route( + "/v1/chat/completions", + web::post().to(chat_completions_handler), + ) + .route("/v1/completions", web::post().to(completions_handler)) + .route("/flush_cache", web::post().to(flush_cache_handler)) + .route("/v1/models", web::get().to(v1_models_handler)) + }) + .bind(("127.0.0.1", port))? + .run(); + + let handle = server.handle(); + self.server_handle = Some(handle); + + tokio::spawn(server); + + Ok(format!("http://127.0.0.1:{}", port)) + } + + /// Stop the mock worker server + pub async fn stop(&mut self) { + if let Some(handle) = self.server_handle.take() { + // First try graceful stop with short timeout + handle.stop(false); + // Give it a moment to stop gracefully + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + } + } + + /// Update the mock worker configuration + pub async fn update_config(&self, updater: F) + where + F: FnOnce(&mut MockWorkerConfig), + { + let mut config = self.config.write().await; + updater(&mut *config); + } +} + +// Handler implementations + +async fn health_handler(config: web::Data>>) -> HttpResponse { + let config = config.read().await; + + match config.health_status { + HealthStatus::Healthy => HttpResponse::Ok().json(json!({ + "status": "healthy", + "timestamp": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + "worker_type": format!("{:?}", config.worker_type), + })), + HealthStatus::Unhealthy => HttpResponse::ServiceUnavailable().json(json!({ + "status": "unhealthy", + "error": "Worker is not responding" + })), + HealthStatus::Degraded => HttpResponse::Ok().json(json!({ + "status": "degraded", + "warning": "High load detected" + })), + } +} + +async fn health_generate_handler(config: web::Data>>) -> HttpResponse { + let config = config.read().await; + + if matches!(config.health_status, HealthStatus::Healthy) { + HttpResponse::Ok().json(json!({ + "status": "ok", + "queue_length": 0, + "processing_time_ms": config.response_delay_ms + })) + } else { + HttpResponse::ServiceUnavailable().json(json!({ + "error": "Generation service unavailable" + })) + } +} + +async fn server_info_handler(config: web::Data>>) -> HttpResponse { + let config = config.read().await; + + // Return response matching actual sglang server implementation + HttpResponse::Ok().json(json!({ + // Server args fields + "model_path": "mock-model-path", + "tokenizer_path": "mock-tokenizer-path", + "port": config.port, + "host": "127.0.0.1", + "max_num_batched_tokens": 32768, + "max_prefill_tokens": 16384, + "mem_fraction_static": 0.88, + "tp_size": 1, + "dp_size": 1, + "stream_interval": 8, + "dtype": "float16", + "device": "cuda", + "enable_flashinfer": true, + "enable_p2p_check": true, + "context_length": 32768, + "chat_template": null, + "disable_radix_cache": false, + "enable_torch_compile": false, + "trust_remote_code": false, + "show_time_cost": false, + + // Scheduler info fields + "waiting_queue_size": 0, + "running_queue_size": 0, + "req_to_token_ratio": 1.2, + "min_running_requests": 0, + "max_running_requests": 2048, + "max_req_num": 8192, + "max_batch_tokens": 32768, + "schedule_policy": "lpm", + "schedule_conservativeness": 1.0, + + // Additional fields + "version": "0.3.0", + "internal_states": [{ + "waiting_queue_size": 0, + "running_queue_size": 0 + }] + })) +} + +async fn model_info_handler(_config: web::Data>>) -> HttpResponse { + // Return response matching actual sglang server implementation + HttpResponse::Ok().json(json!({ + "model_path": "mock-model-path", + "tokenizer_path": "mock-tokenizer-path", + "is_generation": true, + "preferred_sampling_params": { + "temperature": 0.7, + "top_p": 0.9, + "top_k": 40, + "max_tokens": 2048 + } + })) +} + +async fn generate_handler( + config: web::Data>>, + _req: HttpRequest, + payload: web::Json, +) -> HttpResponse { + let config = config.read().await; + + // Simulate failure based on fail_rate + if rand::random::() < config.fail_rate { + return HttpResponse::InternalServerError().json(json!({ + "error": "Random failure for testing" + })); + } + + // Simulate processing delay + if config.response_delay_ms > 0 { + tokio::time::sleep(tokio::time::Duration::from_millis(config.response_delay_ms)).await; + } + + let is_stream = payload + .get("stream") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + if is_stream { + // Return streaming response matching sglang format + let (tx, rx) = tokio::sync::mpsc::channel(10); + let stream_delay = config.response_delay_ms; + let request_id = format!("mock-req-{}", rand::random::()); + + tokio::spawn(async move { + let tokens = vec!["This ", "is ", "a ", "mock ", "response."]; + let timestamp_start = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs_f64(); + + for (i, token) in tokens.iter().enumerate() { + let chunk = json!({ + "text": token, + "meta_info": { + "id": &request_id, + "finish_reason": if i == tokens.len() - 1 { + json!({"type": "stop", "matched_stop": null}) + } else { + json!(null) + }, + "prompt_tokens": 10, + "completion_tokens": i + 1, + "cached_tokens": 0, + "e2e_latency": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs_f64() - timestamp_start + } + }); + + if tx.send(format!("data: {}\n\n", serde_json::to_string(&chunk).unwrap())).await.is_err() { + break; + } + + if stream_delay > 0 { + tokio::time::sleep(tokio::time::Duration::from_millis(stream_delay)).await; + } + } + + let _ = tx.send("data: [DONE]\n\n".to_string()).await; + }); + + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + + HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .streaming(stream.map(|chunk| Ok::<_, actix_web::Error>(bytes::Bytes::from(chunk)))) + } else { + // Return non-streaming response matching sglang format + let request_id = format!("mock-req-{}", rand::random::()); + let timestamp_start = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs_f64(); + + HttpResponse::Ok().json(json!({ + "text": "Mock generated response for the input", + "meta_info": { + "id": request_id, + "finish_reason": { + "type": "stop", + "matched_stop": null + }, + "prompt_tokens": 10, + "completion_tokens": 7, + "cached_tokens": 0, + "e2e_latency": 0.042 + } + })) + } +} + +async fn chat_completions_handler( + config: web::Data>>, + payload: web::Json, +) -> HttpResponse { + let config = config.read().await; + + // Simulate failure + if rand::random::() < config.fail_rate { + return HttpResponse::InternalServerError().json(json!({ + "error": "Chat completion failed" + })); + } + + let is_stream = payload + .get("stream") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + if is_stream { + // Return proper streaming response for chat completions + let (tx, rx) = tokio::sync::mpsc::channel(10); + let stream_delay = config.response_delay_ms; + let model = payload + .get("model") + .and_then(|m| m.as_str()) + .unwrap_or("mock-model") + .to_string(); + + tokio::spawn(async move { + let chat_id = format!("chatcmpl-mock{}", rand::random::()); + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + // Send initial chunk with role + let initial_chunk = json!({ + "id": &chat_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": &model, + "choices": [{ + "index": 0, + "delta": { + "role": "assistant" + }, + "finish_reason": null + }] + }); + + let _ = tx + .send(format!( + "data: {}\n\n", + serde_json::to_string(&initial_chunk).unwrap() + )) + .await; + + // Send content chunks + let content_chunks = [ + "This ", + "is ", + "a ", + "mock ", + "streaming ", + "chat ", + "response.", + ]; + for chunk in content_chunks.iter() { + let data = json!({ + "id": &chat_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": &model, + "choices": [{ + "index": 0, + "delta": { + "content": chunk + }, + "finish_reason": null + }] + }); + + if tx + .send(format!( + "data: {}\n\n", + serde_json::to_string(&data).unwrap() + )) + .await + .is_err() + { + break; + } + + if stream_delay > 0 { + tokio::time::sleep(tokio::time::Duration::from_millis(stream_delay)).await; + } + } + + // Send final chunk with finish_reason + let final_chunk = json!({ + "id": &chat_id, + "object": "chat.completion.chunk", + "created": timestamp, + "model": &model, + "choices": [{ + "index": 0, + "delta": {}, + "finish_reason": "stop" + }] + }); + + let _ = tx + .send(format!( + "data: {}\n\n", + serde_json::to_string(&final_chunk).unwrap() + )) + .await; + let _ = tx.send("data: [DONE]\n\n".to_string()).await; + }); + + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + + HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .streaming(stream.map(|chunk| Ok::<_, actix_web::Error>(bytes::Bytes::from(chunk)))) + } else { + // Non-streaming response matching OpenAI format + let model = payload + .get("model") + .and_then(|m| m.as_str()) + .unwrap_or("mock-model") + .to_string(); + + HttpResponse::Ok().json(json!({ + "id": format!("chatcmpl-{}", uuid::Uuid::new_v4()), + "object": "chat.completion", + "created": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + "model": model, + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "This is a mock chat completion response." + }, + "logprobs": null, + "finish_reason": "stop", + "matched_stop": null + }], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 8, + "total_tokens": 18, + "prompt_tokens_details": { + "cached_tokens": 0 + } + } + })) + } +} + +async fn completions_handler( + config: web::Data>>, + payload: web::Json, +) -> HttpResponse { + let config = config.read().await; + + if rand::random::() < config.fail_rate { + return HttpResponse::InternalServerError().json(json!({ + "error": "Completion failed" + })); + } + + // Check if streaming is requested + let is_stream = payload + .get("stream") + .and_then(|v| v.as_bool()) + .unwrap_or(false); + + let prompts = payload + .get("prompt") + .map(|p| { + if p.is_array() { + p.as_array().unwrap().len() + } else { + 1 + } + }) + .unwrap_or(1); + + if is_stream { + // Return streaming response for completions + let (tx, rx) = tokio::sync::mpsc::channel(10); + let stream_delay = config.response_delay_ms; + let model = payload + .get("model") + .and_then(|m| m.as_str()) + .unwrap_or("mock-model") + .to_string(); + + tokio::spawn(async move { + let completion_id = format!("cmpl-mock{}", rand::random::()); + let timestamp = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .as_secs(); + + // Stream completions for each prompt + for prompt_idx in 0..prompts { + let prompt_suffix = format!("{} ", prompt_idx); + let tokens = vec!["This ", "is ", "mock ", "completion ", &prompt_suffix]; + + for (token_idx, token) in tokens.iter().enumerate() { + let data = json!({ + "id": &completion_id, + "object": "text_completion", + "created": timestamp, + "model": &model, + "choices": [{ + "text": token, + "index": prompt_idx, + "logprobs": null, + "finish_reason": if token_idx == tokens.len() - 1 { Some("stop") } else { None } + }] + }); + + if tx + .send(format!( + "data: {}\n\n", + serde_json::to_string(&data).unwrap() + )) + .await + .is_err() + { + return; + } + + if stream_delay > 0 { + tokio::time::sleep(tokio::time::Duration::from_millis(stream_delay)).await; + } + } + } + + let _ = tx.send("data: [DONE]\n\n".to_string()).await; + }); + + let stream = tokio_stream::wrappers::ReceiverStream::new(rx); + + HttpResponse::Ok() + .content_type("text/event-stream") + .insert_header(("Cache-Control", "no-cache")) + .streaming(stream.map(|chunk| Ok::<_, actix_web::Error>(bytes::Bytes::from(chunk)))) + } else { + // Return non-streaming response + let mut choices = vec![]; + for i in 0..prompts { + choices.push(json!({ + "text": format!("Mock completion {}", i), + "index": i, + "logprobs": null, + "finish_reason": "stop" + })); + } + + HttpResponse::Ok().json(json!({ + "id": format!("cmpl-mock{}", rand::random::()), + "object": "text_completion", + "created": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + "model": payload.get("model").and_then(|m| m.as_str()).unwrap_or("mock-model"), + "choices": choices, + "usage": { + "prompt_tokens": 5 * prompts, + "completion_tokens": 10 * prompts, + "total_tokens": 15 * prompts + } + })) + } +} + +async fn flush_cache_handler(_config: web::Data>>) -> HttpResponse { + HttpResponse::Ok().json(json!({ + "status": "success", + "message": "Cache flushed", + "freed_entries": 42 + })) +} + +async fn v1_models_handler(_config: web::Data>>) -> HttpResponse { + HttpResponse::Ok().json(json!({ + "object": "list", + "data": [{ + "id": "mock-model-v1", + "object": "model", + "created": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + "owned_by": "sglang", + "permission": [{ + "id": "modelperm-mock", + "object": "model_permission", + "created": SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(), + "allow_create_engine": false, + "allow_sampling": true, + "allow_logprobs": true, + "allow_search_indices": false, + "allow_view": true, + "allow_fine_tuning": false, + "organization": "*", + "group": null, + "is_blocking": false + }], + "root": "mock-model-v1", + "parent": null + }] + })) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_mock_worker_lifecycle() { + let config = MockWorkerConfig { + port: 18080, + worker_type: WorkerType::Regular, + health_status: HealthStatus::Healthy, + response_delay_ms: 0, + fail_rate: 0.0, + }; + + let mut worker = MockWorker::new(config); + + // Start the worker + let url = worker.start().await.unwrap(); + assert_eq!(url, "http://127.0.0.1:18080"); + + // Give server time to start + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + + // Test health endpoint + let client = reqwest::Client::new(); + let resp = client.get(&format!("{}/health", url)).send().await.unwrap(); + + assert_eq!(resp.status(), 200); + let body: serde_json::Value = resp.json().await.unwrap(); + assert_eq!(body["status"], "healthy"); + + // Update config to unhealthy + worker + .update_config(|c| c.health_status = HealthStatus::Unhealthy) + .await; + + // Test health again + let resp = client.get(&format!("{}/health", url)).send().await.unwrap(); + + assert_eq!(resp.status(), 503); + + // Stop the worker + worker.stop().await; + } +} diff --git a/sgl-router/tests/common/mod.rs b/sgl-router/tests/common/mod.rs new file mode 100644 index 000000000..34467cd08 --- /dev/null +++ b/sgl-router/tests/common/mod.rs @@ -0,0 +1,56 @@ +pub mod mock_worker; + +use actix_web::web; +use reqwest::Client; +use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode}; +use sglang_router_rs::server::AppState; + +/// Helper function to create test router configuration +pub fn create_test_config(worker_urls: Vec) -> RouterConfig { + RouterConfig { + mode: RoutingMode::Regular { worker_urls }, + policy: PolicyConfig::Random, + host: "127.0.0.1".to_string(), + port: 3001, + max_payload_size: 256 * 1024 * 1024, // 256MB + request_timeout_secs: 600, + worker_startup_timeout_secs: 300, + worker_startup_check_interval_secs: 10, + discovery: None, + metrics: None, + log_dir: None, + log_level: None, + } +} + +/// Helper function to create test router configuration with no health check +pub fn create_test_config_no_workers() -> RouterConfig { + RouterConfig { + mode: RoutingMode::Regular { + worker_urls: vec![], + }, // Empty to skip health check + policy: PolicyConfig::Random, + host: "127.0.0.1".to_string(), + port: 3001, + max_payload_size: 256 * 1024 * 1024, // 256MB + request_timeout_secs: 600, + worker_startup_timeout_secs: 0, // No wait + worker_startup_check_interval_secs: 10, + discovery: None, + metrics: None, + log_dir: None, + log_level: None, + } +} + +/// Helper function to create test app state +pub async fn create_test_app_state(config: RouterConfig) -> Result, String> { + // Create a non-blocking client + let client = Client::builder() + .timeout(std::time::Duration::from_secs(config.request_timeout_secs)) + .build() + .map_err(|e| e.to_string())?; + + let app_state = AppState::new(config, client)?; + Ok(web::Data::new(app_state)) +}