diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index b72e97a0a..605bc705b 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -2,12 +2,12 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criteri use serde_json::{from_str, to_string, to_value, to_vec}; use std::time::Instant; -use sglang_router_rs::core::{BasicWorker, WorkerType}; +use sglang_router_rs::core::{BasicWorker, Worker, WorkerType}; use sglang_router_rs::openai_api_types::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, SamplingParams, StringOrArray, UserMessageContent, }; -use sglang_router_rs::routers::bootstrap_injector::inject_bootstrap_fields; +use sglang_router_rs::routers::pd_types::{generate_room_id, get_hostname, RequestWithBootstrap}; fn create_test_worker() -> BasicWorker { BasicWorker::new( @@ -18,6 +18,16 @@ fn create_test_worker() -> BasicWorker { ) } +// Helper function to get bootstrap info from worker +fn get_bootstrap_info(worker: &BasicWorker) -> (String, Option) { + let hostname = get_hostname(worker.url()); + let bootstrap_port = match worker.worker_type() { + WorkerType::Prefill { bootstrap_port } => bootstrap_port.clone(), + _ => None, + }; + (hostname, bootstrap_port) +} + /// Create a default GenerateRequest for benchmarks with minimal fields set fn default_generate_request() -> GenerateRequest { GenerateRequest { @@ -331,35 +341,56 @@ fn bench_bootstrap_injection(c: &mut Criterion) { let completion_req = create_sample_completion_request(); let large_chat_req = create_large_chat_completion_request(); let worker = create_test_worker(); + let (hostname, bootstrap_port) = get_bootstrap_info(&worker); group.bench_function("generate_bootstrap_injection", |b| { b.iter(|| { - let mut json = to_value(black_box(&generate_req)).unwrap(); - inject_bootstrap_fields(&mut json, &worker).unwrap(); + let request_with_bootstrap = RequestWithBootstrap { + original: &generate_req, + bootstrap_host: hostname.clone(), + bootstrap_port, + bootstrap_room: generate_room_id(), + }; + let json = to_value(black_box(&request_with_bootstrap)).unwrap(); black_box(json); }); }); group.bench_function("chat_completion_bootstrap_injection", |b| { b.iter(|| { - let mut json = to_value(black_box(&chat_req)).unwrap(); - inject_bootstrap_fields(&mut json, &worker).unwrap(); + let request_with_bootstrap = RequestWithBootstrap { + original: &chat_req, + bootstrap_host: hostname.clone(), + bootstrap_port, + bootstrap_room: generate_room_id(), + }; + let json = to_value(black_box(&request_with_bootstrap)).unwrap(); black_box(json); }); }); group.bench_function("completion_bootstrap_injection", |b| { b.iter(|| { - let mut json = to_value(black_box(&completion_req)).unwrap(); - inject_bootstrap_fields(&mut json, &worker).unwrap(); + let request_with_bootstrap = RequestWithBootstrap { + original: &completion_req, + bootstrap_host: hostname.clone(), + bootstrap_port, + bootstrap_room: generate_room_id(), + }; + let json = to_value(black_box(&request_with_bootstrap)).unwrap(); black_box(json); }); }); group.bench_function("large_chat_completion_bootstrap_injection", |b| { b.iter(|| { - let mut json = to_value(black_box(&large_chat_req)).unwrap(); - inject_bootstrap_fields(&mut json, &worker).unwrap(); + let request_with_bootstrap = RequestWithBootstrap { + original: &large_chat_req, + bootstrap_host: hostname.clone(), + bootstrap_port, + bootstrap_room: generate_room_id(), + }; + let json = to_value(black_box(&request_with_bootstrap)).unwrap(); black_box(json); }); }); @@ -441,6 +472,7 @@ fn bench_throughput_by_size(c: &mut Criterion) { }; let worker = create_test_worker(); + let (hostname, bootstrap_port) = get_bootstrap_info(&worker); for (name, req) in [ ("small", &small_generate), @@ -449,6 +481,7 @@ fn bench_throughput_by_size(c: &mut Criterion) { ] { let json = to_string(req).unwrap(); let size_bytes = json.len(); + let hostname_clone = hostname.clone(); group.throughput(Throughput::Bytes(size_bytes as u64)); group.bench_with_input(BenchmarkId::new("serialize", name), &req, |b, req| { @@ -472,10 +505,16 @@ fn bench_throughput_by_size(c: &mut Criterion) { group.bench_with_input( BenchmarkId::new("bootstrap_inject", name), &req, - |b, req| { + move |b, req| { + let hostname = hostname_clone.clone(); b.iter(|| { - let mut json = to_value(req).unwrap(); - inject_bootstrap_fields(&mut json, &worker).unwrap(); + let request_with_bootstrap = RequestWithBootstrap { + original: req, + bootstrap_host: hostname.clone(), + bootstrap_port, + bootstrap_room: generate_room_id(), + }; + let json = to_value(&request_with_bootstrap).unwrap(); black_box(json); }); }, @@ -493,17 +532,21 @@ fn bench_full_round_trip(c: &mut Criterion) { let chat_json = to_string(&create_sample_chat_completion_request()).unwrap(); let completion_json = to_string(&create_sample_completion_request()).unwrap(); let worker = create_test_worker(); + let (hostname, bootstrap_port) = get_bootstrap_info(&worker); group.bench_function("generate_openai_to_pd_pipeline", |b| { b.iter(|| { // Deserialize OpenAI request let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap(); - // Convert to JSON Value - let mut json = to_value(&req).unwrap(); - // Inject bootstrap fields - inject_bootstrap_fields(&mut json, &worker).unwrap(); + // Create wrapper with bootstrap fields + let request_with_bootstrap = RequestWithBootstrap { + original: &req, + bootstrap_host: hostname.clone(), + bootstrap_port, + bootstrap_room: generate_room_id(), + }; // Serialize final request - let pd_json = to_string(&json).unwrap(); + let pd_json = to_string(&request_with_bootstrap).unwrap(); black_box(pd_json); }); }); @@ -511,9 +554,13 @@ fn bench_full_round_trip(c: &mut Criterion) { group.bench_function("chat_completion_openai_to_pd_pipeline", |b| { b.iter(|| { let req: ChatCompletionRequest = from_str(black_box(&chat_json)).unwrap(); - let mut json = to_value(&req).unwrap(); - inject_bootstrap_fields(&mut json, &worker).unwrap(); - let pd_json = to_string(&json).unwrap(); + let request_with_bootstrap = RequestWithBootstrap { + original: &req, + bootstrap_host: hostname.clone(), + bootstrap_port, + bootstrap_room: generate_room_id(), + }; + let pd_json = to_string(&request_with_bootstrap).unwrap(); black_box(pd_json); }); }); @@ -521,9 +568,13 @@ fn bench_full_round_trip(c: &mut Criterion) { group.bench_function("completion_openai_to_pd_pipeline", |b| { b.iter(|| { let req: CompletionRequest = from_str(black_box(&completion_json)).unwrap(); - let mut json = to_value(&req).unwrap(); - inject_bootstrap_fields(&mut json, &worker).unwrap(); - let pd_json = to_string(&json).unwrap(); + let request_with_bootstrap = RequestWithBootstrap { + original: &req, + bootstrap_host: hostname.clone(), + bootstrap_port, + bootstrap_room: generate_room_id(), + }; + let pd_json = to_string(&request_with_bootstrap).unwrap(); black_box(pd_json); }); }); @@ -575,10 +626,16 @@ fn benchmark_summary(c: &mut Criterion) { ); // Measure bootstrap injection (replaces adaptation) + let (hostname, bootstrap_port) = get_bootstrap_info(&worker); let start = Instant::now(); for _ in 0..1000 { - let mut json = to_value(&generate_req).unwrap(); - let _ = black_box(inject_bootstrap_fields(&mut json, &worker)); + let request_with_bootstrap = RequestWithBootstrap { + original: &generate_req, + bootstrap_host: hostname.clone(), + bootstrap_port, + bootstrap_room: generate_room_id(), + }; + let _ = black_box(to_value(&request_with_bootstrap).unwrap()); } let inject_time = start.elapsed().as_nanos() / 1000; println!(" * Bootstrap Injection (avg): {:>6} ns/req", inject_time); diff --git a/sgl-router/scripts/run_benchmarks.py b/sgl-router/scripts/run_benchmarks.py index 76bf37f2a..a7ece9d9a 100755 --- a/sgl-router/scripts/run_benchmarks.py +++ b/sgl-router/scripts/run_benchmarks.py @@ -121,6 +121,8 @@ class BenchmarkRunner: results["serialization_time"] = self._extract_time(line) elif "Deserialization (avg):" in line: results["deserialization_time"] = self._extract_time(line) + elif "Bootstrap Injection (avg):" in line: + results["bootstrap_injection_time"] = self._extract_time(line) elif "Total Pipeline (avg):" in line: results["total_time"] = self._extract_time(line) @@ -143,6 +145,7 @@ class BenchmarkRunner: thresholds = { "serialization_time": 2000, # 2μs max "deserialization_time": 2000, # 2μs max + "bootstrap_injection_time": 5000, # 5μs max "total_time": 10000, # 10μs max } diff --git a/sgl-router/src/policies/cache_aware.rs b/sgl-router/src/policies/cache_aware.rs index 8d83505f6..74061cb49 100644 --- a/sgl-router/src/policies/cache_aware.rs +++ b/sgl-router/src/policies/cache_aware.rs @@ -230,6 +230,10 @@ impl LoadBalancingPolicy for CacheAwarePolicy { "cache_aware" } + fn needs_request_text(&self) -> bool { + true // Cache-aware policy needs request text for cache affinity + } + fn on_request_complete(&self, worker_url: &str, success: bool) { // Could track success rates per worker for more intelligent routing if !success { diff --git a/sgl-router/src/policies/mod.rs b/sgl-router/src/policies/mod.rs index 83fdd95b0..8229af10e 100644 --- a/sgl-router/src/policies/mod.rs +++ b/sgl-router/src/policies/mod.rs @@ -59,6 +59,11 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug { /// Get policy name for metrics and debugging fn name(&self) -> &'static str; + /// Check if this policy needs request text for routing decisions + fn needs_request_text(&self) -> bool { + false // Default: most policies don't need request text + } + /// Update worker load information /// /// This is called periodically with current load information for load-aware policies. diff --git a/sgl-router/src/routers/bootstrap_injector.rs b/sgl-router/src/routers/bootstrap_injector.rs deleted file mode 100644 index e7cad384d..000000000 --- a/sgl-router/src/routers/bootstrap_injector.rs +++ /dev/null @@ -1,334 +0,0 @@ -// Bootstrap field injection for PD routing -// Directly injects bootstrap fields into JSON requests without intermediate type conversions - -use crate::core::{Worker, WorkerType}; -use crate::routers::pd_types::get_hostname; -use serde_json::{json, Value}; - -/// Inject bootstrap fields directly into a JSON request -/// This replaces the complex ToPdRequest -> Bootstrap trait pattern -pub fn inject_bootstrap_fields(json: &mut Value, worker: &dyn Worker) -> Result<(), String> { - let batch_size = extract_batch_size(json)?; - - // Extract bootstrap port from prefill worker if it's a prefill type - let bootstrap_port = match worker.worker_type() { - WorkerType::Prefill { bootstrap_port } => bootstrap_port, - _ => None, - }; - - let hostname = get_hostname(worker.url()); - - if let Some(batch_size) = batch_size { - // Batch scenario - create arrays of bootstrap values - json["bootstrap_host"] = json!(vec![hostname; batch_size]); - json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]); - json["bootstrap_room"] = json!((0..batch_size) - .map(|_| { - // Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1) - rand::random::() & (i64::MAX as u64) - }) - .collect::>()); - } else { - // Single scenario - create single bootstrap values - json["bootstrap_host"] = json!(hostname); - json["bootstrap_port"] = json!(bootstrap_port); - json["bootstrap_room"] = json!(rand::random::() & (i64::MAX as u64)); - } - - Ok(()) -} - -/// Extract batch size from various JSON request formats -/// Handles chat completions, completions, and generate requests -fn extract_batch_size(json: &Value) -> Result, String> { - // Check for chat completions 'n' parameter (number of choices) - if let Some(n) = json.get("n").and_then(|v| v.as_u64()) { - if n > 1 { - return Ok(Some(n as usize)); - } - } - - // Check for array prompts (completions API) - if let Some(prompt) = json.get("prompt") { - if let Some(arr) = prompt.as_array() { - if arr.is_empty() { - return Err("Batch prompt array is empty".to_string()); - } - return Ok(Some(arr.len())); - } - } - - // Check for array texts (generate API) - if let Some(text) = json.get("text") { - if let Some(arr) = text.as_array() { - if arr.is_empty() { - return Err("Batch text array is empty".to_string()); - } - return Ok(Some(arr.len())); - } - } - - // Check for batch input_ids (generate API) - if let Some(input_ids) = json.get("input_ids") { - if let Some(arr) = input_ids.as_array() { - if arr.is_empty() { - return Err("Batch input_ids array is empty".to_string()); - } - return Ok(Some(arr.len())); - } - } - - // No batch indicators found - single request - Ok(None) -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::core::BasicWorker; - use serde_json::json; - - fn create_test_worker() -> BasicWorker { - BasicWorker::new( - "http://test-server:8000".to_string(), - WorkerType::Prefill { - bootstrap_port: Some(5678), - }, - ) - } - - #[test] - fn test_inject_bootstrap_single_request() { - let worker = create_test_worker(); - let mut json = json!({ - "model": "test-model", - "prompt": "Hello world", - "max_tokens": 100 - }); - - let result = inject_bootstrap_fields(&mut json, &worker); - assert!(result.is_ok()); - - // Verify bootstrap fields were added - assert_eq!(json["bootstrap_host"], json!("test-server")); - assert_eq!(json["bootstrap_port"], json!(5678)); - assert!(json["bootstrap_room"].is_number()); - - // Verify original fields preserved - assert_eq!(json["model"], json!("test-model")); - assert_eq!(json["prompt"], json!("Hello world")); - assert_eq!(json["max_tokens"], json!(100)); - } - - #[test] - fn test_inject_bootstrap_batch_prompt() { - let worker = create_test_worker(); - let mut json = json!({ - "model": "test-model", - "prompt": ["Hello", "World"], - "max_tokens": 100 - }); - - let result = inject_bootstrap_fields(&mut json, &worker); - assert!(result.is_ok()); - - // Verify batch bootstrap fields - assert_eq!( - json["bootstrap_host"], - json!(["test-server", "test-server"]) - ); - assert_eq!(json["bootstrap_port"], json!([5678, 5678])); - - let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap(); - assert_eq!(bootstrap_rooms.len(), 2); - for room in bootstrap_rooms { - assert!(room.is_number()); - let room_val = room.as_u64().unwrap(); - assert!(room_val <= i64::MAX as u64); - } - } - - #[test] - fn test_inject_bootstrap_chat_n_parameter() { - let worker = create_test_worker(); - let mut json = json!({ - "model": "gpt-4", - "messages": [{"role": "user", "content": "Hello"}], - "n": 3 - }); - - let result = inject_bootstrap_fields(&mut json, &worker); - assert!(result.is_ok()); - - // Verify batch bootstrap fields for n=3 - let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap(); - assert_eq!(bootstrap_hosts.len(), 3); - assert_eq!(bootstrap_hosts[0], json!("test-server")); - - let bootstrap_ports = json["bootstrap_port"].as_array().unwrap(); - assert_eq!(bootstrap_ports.len(), 3); - assert_eq!(bootstrap_ports[0], json!(5678)); - - let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap(); - assert_eq!(bootstrap_rooms.len(), 3); - } - - #[test] - fn test_inject_bootstrap_generate_text_array() { - let worker = create_test_worker(); - let mut json = json!({ - "text": ["First prompt", "Second prompt"], - "stream": false - }); - - let result = inject_bootstrap_fields(&mut json, &worker); - assert!(result.is_ok()); - - // Verify batch bootstrap fields - let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap(); - assert_eq!(bootstrap_hosts.len(), 2); - - let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap(); - assert_eq!(bootstrap_rooms.len(), 2); - // Ensure room values are different (randomness) - assert_ne!(bootstrap_rooms[0], bootstrap_rooms[1]); - } - - #[test] - fn test_inject_bootstrap_input_ids_array() { - let worker = create_test_worker(); - let mut json = json!({ - "input_ids": [[1, 2, 3], [4, 5, 6]], - "stream": false - }); - - let result = inject_bootstrap_fields(&mut json, &worker); - assert!(result.is_ok()); - - // Verify batch bootstrap fields - let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap(); - assert_eq!(bootstrap_hosts.len(), 2); - } - - #[test] - fn test_extract_batch_size_empty_array_error() { - let json = json!({ - "prompt": [], - "model": "test" - }); - - let result = extract_batch_size(&json); - assert!(result.is_err()); - assert!(result.unwrap_err().contains("empty")); - } - - #[test] - fn test_extract_batch_size_single_requests() { - // Single string prompt - let json = json!({ - "prompt": "Hello world", - "model": "test" - }); - assert_eq!(extract_batch_size(&json).unwrap(), None); - - // Single text - let json = json!({ - "text": "Hello world", - "stream": false - }); - assert_eq!(extract_batch_size(&json).unwrap(), None); - - // Chat with n=1 (default) - let json = json!({ - "messages": [{"role": "user", "content": "Hello"}], - "n": 1 - }); - assert_eq!(extract_batch_size(&json).unwrap(), None); - - // Chat without n parameter - let json = json!({ - "messages": [{"role": "user", "content": "Hello"}] - }); - assert_eq!(extract_batch_size(&json).unwrap(), None); - } - - #[test] - fn test_inject_bootstrap_preserves_sglang_fields() { - let worker = create_test_worker(); - let mut json = json!({ - "model": "test-model", - "prompt": "Hello", - // SGLang extensions should be preserved - "top_k": 40, - "min_p": 0.05, - "repetition_penalty": 1.1, - "regex": "test_pattern", - "lora_path": "test.bin", - "no_stop_trim": true, - "ignore_eos": false - }); - - let result = inject_bootstrap_fields(&mut json, &worker); - assert!(result.is_ok()); - - // Verify bootstrap fields added - assert!(json.get("bootstrap_host").is_some()); - assert!(json.get("bootstrap_port").is_some()); - assert!(json.get("bootstrap_room").is_some()); - - // Verify all SGLang fields preserved - assert_eq!(json["top_k"], json!(40)); - assert_eq!(json["min_p"], json!(0.05)); - assert_eq!(json["repetition_penalty"], json!(1.1)); - assert_eq!(json["regex"], json!("test_pattern")); - assert_eq!(json["lora_path"], json!("test.bin")); - assert_eq!(json["no_stop_trim"], json!(true)); - assert_eq!(json["ignore_eos"], json!(false)); - } - - #[test] - fn test_bootstrap_room_range() { - let worker = create_test_worker(); - - // Test single request room generation - for _ in 0..1000 { - let mut json = json!({"prompt": "test"}); - inject_bootstrap_fields(&mut json, &worker).unwrap(); - - let room = json["bootstrap_room"].as_u64().unwrap(); - assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room); - } - - // Test batch request room generation - for _ in 0..100 { - let mut json = json!({"prompt": ["test1", "test2"]}); - inject_bootstrap_fields(&mut json, &worker).unwrap(); - - let rooms = json["bootstrap_room"].as_array().unwrap(); - for room_val in rooms { - let room = room_val.as_u64().unwrap(); - assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room); - } - } - } - - #[test] - fn test_worker_without_bootstrap_port() { - let worker = BasicWorker::new( - "http://decode-only:8000".to_string(), - WorkerType::Decode, // No bootstrap port - ); - - let mut json = json!({ - "prompt": "Hello world" - }); - - let result = inject_bootstrap_fields(&mut json, &worker); - assert!(result.is_ok()); - - // Verify bootstrap fields with null port - assert_eq!(json["bootstrap_host"], json!("decode-only")); - assert_eq!(json["bootstrap_port"], json!(null)); - assert!(json["bootstrap_room"].is_number()); - } -} diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index ab6d6c1aa..3b3137423 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -11,7 +11,6 @@ use std::fmt::Debug; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; -pub mod bootstrap_injector; pub mod factory; pub mod pd_router; pub mod pd_types; diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index 8b10d95db..404fe9904 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -1,7 +1,5 @@ // PD (Prefill-Decode) Router Implementation // This module handles routing for disaggregated prefill-decode systems - -use super::bootstrap_injector::inject_bootstrap_fields; use super::pd_types::{api_path, PDRouterError}; use crate::config::types::RetryConfig; use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard}; @@ -19,7 +17,6 @@ use axum::{ Json, }; use futures_util::StreamExt; -use rand::Rng; use reqwest::Client; use serde_json::Value; use std::collections::HashMap; @@ -316,17 +313,6 @@ impl PDRouter { .into_response() } - // Helper to handle bootstrap injection errors - fn handle_bootstrap_error(error: impl std::fmt::Display) -> Response { - error!("Failed to add bootstrap info error={}", error); - RouterMetrics::record_pd_error("bootstrap_injection"); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Bootstrap injection failed: {}", error), - ) - .into_response() - } - // Helper to handle serialization errors fn handle_serialization_error(error: impl std::fmt::Display) -> Response { error!("Failed to serialize request error={}", error); @@ -337,110 +323,87 @@ impl PDRouter { .into_response() } - // Execute the dual dispatch to prefill and decode servers with retry logic - async fn execute_dual_dispatch( - &self, - headers: Option<&HeaderMap>, - json_request: Value, - route: &str, - prefill: &dyn Worker, - decode: &dyn Worker, - is_stream: bool, - return_logprob: bool, - start_time: Instant, - ) -> Response { - for attempt in 0..self.retry_config.max_retries { - if attempt > 0 { - // Calculate backoff with exponential growth and jitter - let base_backoff = self.retry_config.initial_backoff_ms as f64 - * self - .retry_config - .backoff_multiplier - .powf((attempt - 1) as f32) as f64; - let backoff_ms = base_backoff.min(self.retry_config.max_backoff_ms as f64) as u64; - - // Add jitter to prevent thundering herd - let jitter = { - let mut rng = rand::thread_rng(); - rng.gen_range(0..backoff_ms / 2) - }; - let total_backoff = Duration::from_millis(backoff_ms + jitter); - - info!( - "Retrying request (attempt {}/{}) after {:?} backoff", - attempt + 1, - self.retry_config.max_retries, - total_backoff - ); - - tokio::time::sleep(total_backoff).await; - } - - debug!( - "Executing request attempt {}/{}", - attempt + 1, - self.retry_config.max_retries - ); - let result = self - .execute_dual_dispatch_inner( - headers, - json_request.clone(), - route, - prefill, - decode, - is_stream, - return_logprob, - start_time, - ) - .await; - - // Check if we should retry based on the response status - let status = result.status(); - debug!( - "Request attempt {} returned status: {}", - attempt + 1, - status - ); - - // Don't retry client errors (4xx) or successful responses - if status.is_client_error() || status.is_success() { - debug!( - "Returning response with status {} (no retry needed)", - status - ); - return result; - } - - // Check if this is the last attempt - if attempt == self.retry_config.max_retries - 1 { - warn!("Final attempt failed with status {}", status); - return result; - } - - // Log retry decision for retryable errors - if status.is_server_error() - || status == StatusCode::BAD_GATEWAY - || status == StatusCode::GATEWAY_TIMEOUT - { - warn!( - "Retryable error status: {} on attempt {}/{}. Will retry.", - status, - attempt + 1, - self.retry_config.max_retries - ); - } else { - // Don't retry other statuses - debug!("Status {} is not retryable, returning response", status); - return result; + // Helper to determine batch size from a GenerateRequest + fn get_generate_batch_size(req: &GenerateRequest) -> Option { + // Check prompt array + if let Some(prompt) = &req.prompt { + if let crate::openai_api_types::StringOrArray::Array(arr) = prompt { + if !arr.is_empty() { + return Some(arr.len()); + } } } - - // This should never be reached due to the loop logic, but just in case - unreachable!("Retry loop completed without returning") + // Check text array + if let Some(text) = &req.text { + if text.contains("[") && text.contains("]") { + // This is a simplified check - in reality we'd need to parse JSON + return None; // For now, fall back to non-batch + } + } + None } - // Inner implementation of dual dispatch (extracted for retry logic) - async fn execute_dual_dispatch_inner( + // Helper to determine batch size from a ChatCompletionRequest + fn get_chat_batch_size(req: &ChatCompletionRequest) -> Option { + // Check 'n' parameter for multiple responses + if let Some(n) = req.n { + if n > 1 { + return Some(n as usize); + } + } + None + } + + // Helper to determine batch size from a CompletionRequest + fn get_completion_batch_size(req: &CompletionRequest) -> Option { + // Check prompt array + if let crate::openai_api_types::StringOrArray::Array(arr) = &req.prompt { + if !arr.is_empty() { + return Some(arr.len()); + } + } + None + } + + // Helper to create request with bootstrap fields + fn create_request_with_bootstrap( + request: &T, + prefill_worker: &dyn Worker, + batch_size: Option, + ) -> Result { + // Get bootstrap port from prefill worker + let bootstrap_port = match prefill_worker.worker_type() { + crate::core::WorkerType::Prefill { bootstrap_port } => bootstrap_port, + _ => None, + }; + let hostname = super::pd_types::get_hostname(prefill_worker.url()); + + // Create optimized request with bootstrap fields + if let Some(batch_size) = batch_size { + // Batch request + let request_with_bootstrap = super::pd_types::BatchRequestWithBootstrap { + original: request, + bootstrap_host: vec![hostname; batch_size], + bootstrap_port: vec![bootstrap_port; batch_size], + bootstrap_room: (0..batch_size) + .map(|_| super::pd_types::generate_room_id()) + .collect(), + }; + serde_json::to_value(&request_with_bootstrap) + } else { + // Single request + let request_with_bootstrap = super::pd_types::RequestWithBootstrap { + original: request, + bootstrap_host: hostname, + bootstrap_port, + bootstrap_room: super::pd_types::generate_room_id(), + }; + serde_json::to_value(&request_with_bootstrap) + } + } + + // Execute the dual dispatch to prefill and decode servers + async fn execute_dual_dispatch( &self, headers: Option<&HeaderMap>, json_request: Value, @@ -467,101 +430,195 @@ impl PDRouter { prefill.url(), decode.url() ); - let (prefill_result, decode_result) = - tokio::join!(prefill_request.send(), decode_request.send()); - debug!("Received responses from both servers"); - // Update metrics - let duration = start_time.elapsed(); - RouterMetrics::record_pd_request_duration(route, duration); - RouterMetrics::record_pd_request(route); - RouterMetrics::record_pd_prefill_request(prefill.url()); - RouterMetrics::record_pd_decode_request(decode.url()); + if return_logprob { + // When we need logprobs, wait for both responses + let (prefill_result, decode_result) = + tokio::join!(prefill_request.send(), decode_request.send()); + debug!("Received responses from both servers"); - // Process prefill response - let (_prefill_status, prefill_body) = match self - .process_prefill_response(prefill_result, prefill.url(), return_logprob) - .await - { - Ok(result) => result, - Err(error_response) => return error_response, - }; + // Update metrics + let duration = start_time.elapsed(); + RouterMetrics::record_pd_request_duration(route, duration); + RouterMetrics::record_pd_request(route); + RouterMetrics::record_pd_prefill_request(prefill.url()); + RouterMetrics::record_pd_decode_request(decode.url()); - // Process decode response - debug!("Processing decode response"); - match decode_result { - Ok(res) => { - let status = StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - debug!("Decode response status: {}", status); + // Process decode response with prefill for logprobs + debug!("Processing decode response with logprobs"); + match decode_result { + Ok(res) => { + let status = StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + debug!("Decode response status: {}", status); - if !status.is_success() { - RouterMetrics::record_pd_decode_error(decode.url()); - error!( - "Decode server returned error status decode_url={} status={}", - decode.url(), - status - ); + if !status.is_success() { + RouterMetrics::record_pd_decode_error(decode.url()); + error!( + "Decode server returned error status decode_url={} status={}", + decode.url(), + status + ); - // Return the error response from decode server - match res.bytes().await { - Ok(error_body) => { - return (status, error_body).into_response(); - } - Err(e) => { - return (status, format!("Decode server error: {}", e)).into_response(); + // Return the error response from decode server + match res.bytes().await { + Ok(error_body) => { + return (status, error_body).into_response(); + } + Err(e) => { + return (status, format!("Decode server error: {}", e)) + .into_response(); + } } } - } - if is_stream { - // Streaming response - let prefill_logprobs = if return_logprob { - prefill_body + // Process prefill response for logprobs + let prefill_body = match self + .process_prefill_response(prefill_result, prefill.url(), return_logprob) + .await + { + Ok((_, body)) => body, + Err(error_response) => return error_response, + }; + + if is_stream { + // Streaming response with logprobs + let prefill_logprobs = prefill_body .as_ref() .and_then(|body| serde_json::from_slice::(body).ok()) .and_then(|json| { json.pointer("/meta_info/input_token_logprobs").cloned() - }) - } else { - None - }; + }); - let decode_url = if !return_logprob { - Some(decode.url().to_string()) + Self::create_streaming_response( + res.bytes_stream(), + status, + prefill_logprobs, + return_logprob, + None, + ) } else { - None - }; - - Self::create_streaming_response( - res.bytes_stream(), - status, - prefill_logprobs, - return_logprob, - decode_url, - ) - } else { - // Non-streaming response - use helper - self.process_non_streaming_response(res, status, return_logprob, prefill_body) + // Non-streaming response with logprobs + self.process_non_streaming_response( + res, + status, + return_logprob, + prefill_body, + ) .await + } + } + Err(e) => { + error!( + decode_url = %decode.url(), + error = %e, + "Decode request failed" + ); + RouterMetrics::record_pd_decode_error(decode.url()); + ( + StatusCode::BAD_GATEWAY, + format!("Decode server error: {}", e), + ) + .into_response() } } - Err(e) => { - error!( - decode_url = %decode.url(), - error = %e, - "Decode request failed" - ); - RouterMetrics::record_pd_decode_error(decode.url()); - ( - StatusCode::BAD_GATEWAY, - format!("Decode server error: {}", e), - ) - .into_response() + } else { + // When we don't need logprobs, only wait for decode response + // Send both requests concurrently but don't wait for prefill + // Add headers to minimize response size when we don't need the body + let prefill_future = prefill_request.header("Connection", "close").send(); + let decode_future = decode_request.send(); + + tokio::spawn(async move { + if let Ok(response) = prefill_future.await { + // Consume with a short timeout to free connection quickly + let consume_future = async { + let _ = response.bytes().await; + }; + + // Give it 100ms to consume, then abandon + let _ = tokio::time::timeout(Duration::from_millis(100), consume_future).await; + } + }); + + // Wait only for decode response + let decode_result = decode_future.await; + debug!("Received decode response"); + + // Update metrics + let duration = start_time.elapsed(); + RouterMetrics::record_pd_request_duration(route, duration); + RouterMetrics::record_pd_request(route); + RouterMetrics::record_pd_prefill_request(prefill.url()); + RouterMetrics::record_pd_decode_request(decode.url()); + + // Process decode response immediately + debug!("Processing decode response (no logprobs)"); + match decode_result { + Ok(res) => { + let status = StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + debug!("Decode response status: {}", status); + + if !status.is_success() { + RouterMetrics::record_pd_decode_error(decode.url()); + error!( + "Decode server returned error status decode_url={} status={}", + decode.url(), + status + ); + + // Return the error response from decode server + match res.bytes().await { + Ok(error_body) => (status, error_body).into_response(), + Err(e) => { + (status, format!("Decode server error: {}", e)).into_response() + } + } + } else if is_stream { + // Streaming response without logprobs - direct passthrough + let decode_url = decode.url().to_string(); + Self::create_streaming_response( + res.bytes_stream(), + status, + None, + false, + Some(decode_url), + ) + } else { + // Non-streaming response without logprobs - direct passthrough like fast version + match res.bytes().await { + Ok(decode_body) => (status, decode_body).into_response(), + Err(e) => { + error!("Failed to read decode response: {}", e); + (StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response") + .into_response() + } + } + } + } + Err(e) => { + error!( + decode_url = %decode.url(), + error = %e, + "Decode request failed" + ); + RouterMetrics::record_pd_decode_error(decode.url()); + ( + StatusCode::BAD_GATEWAY, + format!("Decode server error: {}", e), + ) + .into_response() + } } } } + // Check if either prefill or decode policy needs request text + fn policies_need_request_text(&self) -> bool { + self.prefill_policy.needs_request_text() || self.decode_policy.needs_request_text() + } + // Select a pair of prefill and decode servers async fn select_pd_pair( &self, @@ -1311,23 +1368,23 @@ impl RouterTrait for PDRouter { ) -> Response { let start = Instant::now(); - // Convert directly to JSON to preserve all fields automatically - let mut json = match serde_json::to_value(body) { - Ok(json) => json, - Err(e) => return Self::handle_serialization_error(e), - }; - // Extract flags for routing logic let is_stream = body.stream; let return_logprob = body.return_logprob; - // Extract text for cache-aware routing - let request_text = body.text.as_deref().or_else(|| { - body.prompt.as_ref().and_then(|p| match p { - crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), - crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()), + // Extract text for cache-aware routing only if needed + let request_text = if self.policies_need_request_text() { + body.text.as_deref().or_else(|| { + body.prompt.as_ref().and_then(|p| match p { + crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), + crate::openai_api_types::StringOrArray::Array(v) => { + v.first().map(|s| s.as_str()) + } + }) }) - }); + } else { + None + }; // Select servers let (prefill, decode) = match self.select_pd_pair(request_text).await { @@ -1342,10 +1399,12 @@ impl RouterTrait for PDRouter { decode.url() ); - // Inject bootstrap fields directly into JSON - if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) { - return Self::handle_bootstrap_error(e); - } + // Create optimized request with bootstrap fields + let batch_size = Self::get_generate_batch_size(body); + let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) { + Ok(json) => json, + Err(e) => return Self::handle_serialization_error(e), + }; // Execute dual dispatch self.execute_dual_dispatch( @@ -1368,27 +1427,29 @@ impl RouterTrait for PDRouter { ) -> Response { let start = Instant::now(); - // Convert directly to JSON to preserve all fields automatically - let mut json = match serde_json::to_value(body) { - Ok(json) => json, - Err(e) => return Self::handle_serialization_error(e), - }; - // Extract flags for routing logic let is_stream = body.stream; let return_logprob = body.logprobs; - // Extract text for cache-aware routing from chat messages - let request_text = body.messages.first().and_then(|msg| match msg { - crate::openai_api_types::ChatMessage::User { content, .. } => { - match content { - crate::openai_api_types::UserMessageContent::Text(text) => Some(text.as_str()), - crate::openai_api_types::UserMessageContent::Parts(_) => None, // Skip complex content + // Extract text for cache-aware routing from chat messages only if needed + let request_text = if self.policies_need_request_text() { + body.messages.first().and_then(|msg| match msg { + crate::openai_api_types::ChatMessage::User { content, .. } => { + match content { + crate::openai_api_types::UserMessageContent::Text(text) => { + Some(text.as_str()) + } + crate::openai_api_types::UserMessageContent::Parts(_) => None, // Skip complex content + } } - } - crate::openai_api_types::ChatMessage::System { content, .. } => Some(content.as_str()), - _ => None, - }); + crate::openai_api_types::ChatMessage::System { content, .. } => { + Some(content.as_str()) + } + _ => None, + }) + } else { + None + }; // Select servers let (prefill, decode) = match self.select_pd_pair(request_text).await { @@ -1403,10 +1464,12 @@ impl RouterTrait for PDRouter { decode.url() ); - // Inject bootstrap fields directly into JSON - if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) { - return Self::handle_bootstrap_error(e); - } + // Create optimized request with bootstrap fields + let batch_size = Self::get_chat_batch_size(body); + let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) { + Ok(json) => json, + Err(e) => return Self::handle_serialization_error(e), + }; // Execute dual dispatch self.execute_dual_dispatch( @@ -1429,20 +1492,18 @@ impl RouterTrait for PDRouter { ) -> Response { let start = Instant::now(); - // Convert directly to JSON to preserve all fields automatically - let mut json = match serde_json::to_value(body) { - Ok(json) => json, - Err(e) => return Self::handle_serialization_error(e), - }; - // Extract flags for routing logic let is_stream = body.stream; let return_logprob = body.logprobs.is_some(); - // Extract text for cache-aware routing - let request_text = match &body.prompt { - crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), - crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()), + // Extract text for cache-aware routing only if needed + let request_text = if self.policies_need_request_text() { + match &body.prompt { + crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), + crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()), + } + } else { + None }; // Select servers @@ -1458,10 +1519,12 @@ impl RouterTrait for PDRouter { decode.url() ); - // Inject bootstrap fields directly into JSON - if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) { - return Self::handle_bootstrap_error(e); - } + // Create optimized request with bootstrap fields + let batch_size = Self::get_completion_batch_size(body); + let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) { + Ok(json) => json, + Err(e) => return Self::handle_serialization_error(e), + }; // Execute dual dispatch self.execute_dual_dispatch( @@ -1937,6 +2000,13 @@ mod tests { assert!(result.is_ok()); } + // ============= Bootstrap Injection Tests ============= + // Note: These tests are commented out as we've moved to the optimized bootstrap injection + // approach that doesn't use the Bootstrap trait on GenerateReqInput anymore. + + // TODO: Add new tests for the optimized bootstrap injection approach using + // RequestWithBootstrap and BatchRequestWithBootstrap wrappers + // ============= Worker Selection Tests ============= #[tokio::test] @@ -2114,158 +2184,4 @@ mod tests { let workers = router.prefill_workers.read().unwrap(); assert_eq!(workers.len(), 5); } - - #[tokio::test] - async fn test_simplified_routing_preserves_sglang_fields() { - use crate::openai_api_types::GenerateRequest; - use crate::routers::bootstrap_injector::inject_bootstrap_fields; - - // Create a test worker - let worker = BasicWorker::new( - "http://test-server:8000".to_string(), - WorkerType::Prefill { - bootstrap_port: Some(5678), - }, - ); - - // Create a GenerateRequest with SGLang extensions - let mut session_params = std::collections::HashMap::new(); - session_params.insert("test_key".to_string(), serde_json::json!("test_value")); - - let request = GenerateRequest { - text: Some("Test prompt".to_string()), - stream: false, - return_logprob: true, - // SGLang extensions - lora_path: Some(crate::openai_api_types::LoRAPath::Single(Some( - "test.bin".to_string(), - ))), - session_params: Some(session_params.clone()), - return_hidden_states: true, - rid: Some("test-request-id".to_string()), - // Other fields default to None/false - prompt: None, - input_ids: None, - parameters: None, - sampling_params: None, - }; - - // Convert to JSON (simulating the simplified routing path) - let mut json = serde_json::to_value(&request).unwrap(); - - // Inject bootstrap fields - let result = inject_bootstrap_fields(&mut json, &worker); - assert!(result.is_ok()); - - // Verify all SGLang fields are preserved - assert_eq!(json["text"], serde_json::json!("Test prompt")); - assert_eq!(json["stream"], serde_json::json!(false)); - assert_eq!(json["return_logprob"], serde_json::json!(true)); - assert_eq!(json["lora_path"], serde_json::json!("test.bin")); // LoRAPath::Single serializes as just the inner value - assert_eq!( - json["session_params"], - serde_json::to_value(&session_params).unwrap() - ); - assert_eq!(json["return_hidden_states"], serde_json::json!(true)); - assert_eq!(json["rid"], serde_json::json!("test-request-id")); - - // Verify bootstrap fields were added - assert_eq!(json["bootstrap_host"], serde_json::json!("test-server")); - assert_eq!(json["bootstrap_port"], serde_json::json!(5678)); - assert!(json["bootstrap_room"].is_number()); - } - - #[tokio::test] - async fn test_simplified_routing_chat_completion() { - use crate::openai_api_types::{ChatCompletionRequest, ChatMessage, UserMessageContent}; - use crate::routers::bootstrap_injector::inject_bootstrap_fields; - - // Create a test worker - let worker = BasicWorker::new( - "http://chat-server:8000".to_string(), - WorkerType::Prefill { - bootstrap_port: Some(9999), - }, - ); - - // Create a ChatCompletionRequest with SGLang extensions - let request = ChatCompletionRequest { - model: "gpt-4".to_string(), - messages: vec![ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Text("Hello world!".to_string()), - name: None, - }], - stream: false, - n: Some(2), // This should create batch bootstrap - // SGLang extensions - top_k: Some(50), - separate_reasoning: false, - stream_reasoning: true, - // Set all other fields to defaults - temperature: None, - top_p: None, - stream_options: None, - stop: None, - max_tokens: None, - max_completion_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, - seed: None, - logprobs: false, - top_logprobs: None, - response_format: None, - tools: None, - tool_choice: None, - parallel_tool_calls: None, - functions: None, - function_call: None, - min_p: None, - min_tokens: None, - repetition_penalty: None, - regex: None, - ebnf: None, - stop_token_ids: None, - no_stop_trim: false, - ignore_eos: false, - continue_final_message: false, - skip_special_tokens: true, - lora_path: None, - session_params: None, - return_hidden_states: false, - }; - - // Convert to JSON (simulating the simplified routing path) - let mut json = serde_json::to_value(&request).unwrap(); - - // Inject bootstrap fields - let result = inject_bootstrap_fields(&mut json, &worker); - assert!(result.is_ok()); - - // Verify original fields preserved - assert_eq!(json["model"], serde_json::json!("gpt-4")); - assert_eq!(json["stream"], serde_json::json!(false)); - assert_eq!(json["n"], serde_json::json!(2)); - assert_eq!(json["top_k"], serde_json::json!(50)); - assert_eq!(json["separate_reasoning"], serde_json::json!(false)); - assert_eq!(json["stream_reasoning"], serde_json::json!(true)); - - // Verify batch bootstrap fields for n=2 - let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap(); - assert_eq!(bootstrap_hosts.len(), 2); - assert_eq!(bootstrap_hosts[0], serde_json::json!("chat-server")); - assert_eq!(bootstrap_hosts[1], serde_json::json!("chat-server")); - - let bootstrap_ports = json["bootstrap_port"].as_array().unwrap(); - assert_eq!(bootstrap_ports.len(), 2); - assert_eq!(bootstrap_ports[0], serde_json::json!(9999)); - assert_eq!(bootstrap_ports[1], serde_json::json!(9999)); - - let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap(); - assert_eq!(bootstrap_rooms.len(), 2); - // Rooms should be different (randomness) - assert_ne!(bootstrap_rooms[0], bootstrap_rooms[1]); - } } diff --git a/sgl-router/src/routers/pd_types.rs b/sgl-router/src/routers/pd_types.rs index 7fa52e6d7..a2b28a57d 100644 --- a/sgl-router/src/routers/pd_types.rs +++ b/sgl-router/src/routers/pd_types.rs @@ -40,6 +40,34 @@ pub fn get_hostname(url: &str) -> String { url.split(':').next().unwrap_or("localhost").to_string() } +use serde::Serialize; + +// Optimized bootstrap wrapper for single requests +#[derive(Serialize)] +pub struct RequestWithBootstrap<'a, T: Serialize> { + #[serde(flatten)] + pub original: &'a T, + pub bootstrap_host: String, + pub bootstrap_port: Option, + pub bootstrap_room: u64, +} + +// Optimized bootstrap wrapper for batch requests +#[derive(Serialize)] +pub struct BatchRequestWithBootstrap<'a, T: Serialize> { + #[serde(flatten)] + pub original: &'a T, + pub bootstrap_host: Vec, + pub bootstrap_port: Vec>, + pub bootstrap_room: Vec, +} + +// Helper to generate bootstrap room ID +pub fn generate_room_id() -> u64 { + // Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1) + rand::random::() & (i64::MAX as u64) +} + // PD-specific routing policies #[derive(Debug, Clone, PartialEq)] pub enum PDSelectionPolicy { diff --git a/sgl-router/src/server.rs b/sgl-router/src/server.rs index b6027e70b..1ca668374 100644 --- a/sgl-router/src/server.rs +++ b/sgl-router/src/server.rs @@ -269,7 +269,7 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box GenerateRequest { @@ -208,63 +207,6 @@ fn test_benchmark_serialization_roundtrip() { assert_eq!(generate_req.return_logprob, deserialized.return_logprob); } -#[test] -fn test_benchmark_bootstrap_injection() { - // Test that bootstrap injection works for benchmark types (replaces PD request adaptation) - - let generate_req = GenerateRequest { - text: Some("Test prompt".to_string()), - ..default_generate_request() - }; - - let chat_req = ChatCompletionRequest { - model: "test-model".to_string(), - messages: vec![ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Text("Test message".to_string()), - name: None, - }], - max_tokens: Some(150), - max_completion_tokens: Some(150), - temperature: Some(0.7), - top_p: Some(1.0), - n: Some(1), - presence_penalty: Some(0.0), - frequency_penalty: Some(0.0), - parallel_tool_calls: Some(true), - ..default_chat_completion_request() - }; - - let completion_req = CompletionRequest { - model: "test-model".to_string(), - prompt: StringOrArray::String("Test prompt".to_string()), - max_tokens: Some(50), - temperature: Some(0.8), - top_p: Some(1.0), - n: Some(1), - presence_penalty: Some(0.0), - frequency_penalty: Some(0.0), - best_of: Some(1), - ..default_completion_request() - }; - - let worker = create_test_worker(); - - // Test bootstrap injection (should not panic) - let mut generate_json = to_value(&generate_req).unwrap(); - let mut chat_json = to_value(&chat_req).unwrap(); - let mut completion_json = to_value(&completion_req).unwrap(); - - assert!(inject_bootstrap_fields(&mut generate_json, &worker).is_ok()); - assert!(inject_bootstrap_fields(&mut chat_json, &worker).is_ok()); - assert!(inject_bootstrap_fields(&mut completion_json, &worker).is_ok()); - - // Verify bootstrap fields were added - assert!(generate_json.get("bootstrap_host").is_some()); - assert!(generate_json.get("bootstrap_port").is_some()); - assert!(generate_json.get("bootstrap_room").is_some()); -} - #[test] fn test_benchmark_direct_json_routing() { // Test direct JSON routing functionality for benchmark types (replaces regular routing) @@ -283,47 +225,3 @@ fn test_benchmark_direct_json_routing() { assert!(!json_string.is_empty()); assert!(!bytes.is_empty()); } - -#[test] -fn test_benchmark_performance_baseline() { - // Basic performance sanity check - ensure operations complete quickly - use std::time::Instant; - - let generate_req = GenerateRequest { - text: Some("Short test prompt".to_string()), - ..default_generate_request() - }; - - // Test the actual simplified pipeline: to_value + bootstrap injection - let start = Instant::now(); - let worker = create_test_worker(); - - // This mirrors the actual router pipeline - let mut json = to_value(&generate_req).unwrap(); - let _ = inject_bootstrap_fields(&mut json, &worker); - - let total_duration = start.elapsed(); - assert!( - total_duration.as_millis() < 5, - "Simplified pipeline took too long: {:?} (should be faster than old adapter approach)", - total_duration - ); - - // Individual components should also be fast - let start = Instant::now(); - let _json = to_value(&generate_req).unwrap(); - let to_value_duration = start.elapsed(); - - let start = Instant::now(); - let mut json = to_value(&generate_req).unwrap(); - let _ = inject_bootstrap_fields(&mut json, &worker); - let inject_duration = start.elapsed(); - - // Bootstrap injection should be faster than the JSON conversion - assert!( - inject_duration <= to_value_duration * 3, - "Bootstrap injection ({:?}) should not be much slower than JSON conversion ({:?})", - inject_duration, - to_value_duration - ); -}