diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index a997b8dfd..b72e97a0a 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -1,12 +1,22 @@ use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion, Throughput}; -use serde_json::{from_str, to_string, to_vec}; +use serde_json::{from_str, to_string, to_value, to_vec}; use std::time::Instant; +use sglang_router_rs::core::{BasicWorker, WorkerType}; use sglang_router_rs::openai_api_types::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, SamplingParams, StringOrArray, UserMessageContent, }; -use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest}; +use sglang_router_rs::routers::bootstrap_injector::inject_bootstrap_fields; + +fn create_test_worker() -> BasicWorker { + BasicWorker::new( + "http://test-server:8000".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(5678), + }, + ) +} /// Create a default GenerateRequest for benchmarks with minimal fields set fn default_generate_request() -> GenerateRequest { @@ -312,49 +322,54 @@ fn bench_json_deserialization(c: &mut Criterion) { group.finish(); } -// Benchmark request adaptation from OpenAI to PD format -fn bench_request_adaptation(c: &mut Criterion) { - let mut group = c.benchmark_group("request_adaptation"); +// Benchmark bootstrap injection (replaces request adaptation) +fn bench_bootstrap_injection(c: &mut Criterion) { + let mut group = c.benchmark_group("bootstrap_injection"); let generate_req = create_sample_generate_request(); let chat_req = create_sample_chat_completion_request(); let completion_req = create_sample_completion_request(); let large_chat_req = create_large_chat_completion_request(); + let worker = create_test_worker(); - group.bench_function("generate_to_pd", |b| { + group.bench_function("generate_bootstrap_injection", |b| { b.iter(|| { - let pd_req = black_box(generate_req.clone()).to_pd_request(); - black_box(pd_req); + let mut json = to_value(black_box(&generate_req)).unwrap(); + inject_bootstrap_fields(&mut json, &worker).unwrap(); + black_box(json); }); }); - group.bench_function("chat_completion_to_pd", |b| { + group.bench_function("chat_completion_bootstrap_injection", |b| { b.iter(|| { - let pd_req = black_box(chat_req.clone()).to_pd_request(); - black_box(pd_req); + let mut json = to_value(black_box(&chat_req)).unwrap(); + inject_bootstrap_fields(&mut json, &worker).unwrap(); + black_box(json); }); }); - group.bench_function("completion_to_pd", |b| { + group.bench_function("completion_bootstrap_injection", |b| { b.iter(|| { - let pd_req = black_box(completion_req.clone()).to_pd_request(); - black_box(pd_req); + let mut json = to_value(black_box(&completion_req)).unwrap(); + inject_bootstrap_fields(&mut json, &worker).unwrap(); + black_box(json); }); }); - group.bench_function("large_chat_completion_to_pd", |b| { + group.bench_function("large_chat_completion_bootstrap_injection", |b| { b.iter(|| { - let pd_req = black_box(large_chat_req.clone()).to_pd_request(); - black_box(pd_req); + let mut json = to_value(black_box(&large_chat_req)).unwrap(); + inject_bootstrap_fields(&mut json, &worker).unwrap(); + black_box(json); }); }); group.finish(); } -// Benchmark regular routing (RouteableRequest methods) -fn bench_regular_routing(c: &mut Criterion) { - let mut group = c.benchmark_group("regular_routing"); +// Benchmark direct JSON routing (replaces regular routing) +fn bench_direct_json_routing(c: &mut Criterion) { + let mut group = c.benchmark_group("direct_json_routing"); let generate_req = create_sample_generate_request(); let chat_req = create_sample_chat_completion_request(); @@ -362,35 +377,42 @@ fn bench_regular_routing(c: &mut Criterion) { group.bench_function("generate_to_json", |b| { b.iter(|| { - let json = black_box(&generate_req).to_json().unwrap(); + let json = to_value(black_box(&generate_req)).unwrap(); + black_box(json); + }); + }); + + group.bench_function("generate_to_json_string", |b| { + b.iter(|| { + let json = to_string(black_box(&generate_req)).unwrap(); black_box(json); }); }); group.bench_function("generate_to_bytes", |b| { b.iter(|| { - let bytes = black_box(&generate_req).to_bytes().unwrap(); + let bytes = to_vec(black_box(&generate_req)).unwrap(); black_box(bytes); }); }); group.bench_function("chat_completion_to_json", |b| { b.iter(|| { - let json = black_box(&chat_req).to_json().unwrap(); + let json = to_value(black_box(&chat_req)).unwrap(); black_box(json); }); }); - group.bench_function("chat_completion_to_bytes", |b| { + group.bench_function("chat_completion_to_json_string", |b| { b.iter(|| { - let bytes = black_box(&chat_req).to_bytes().unwrap(); - black_box(bytes); + let json = to_string(black_box(&chat_req)).unwrap(); + black_box(json); }); }); group.bench_function("completion_to_json", |b| { b.iter(|| { - let json = black_box(&completion_req).to_json().unwrap(); + let json = to_value(black_box(&completion_req)).unwrap(); black_box(json); }); }); @@ -418,6 +440,8 @@ fn bench_throughput_by_size(c: &mut Criterion) { ..default_generate_request() }; + let worker = create_test_worker(); + for (name, req) in [ ("small", &small_generate), ("medium", &medium_generate), @@ -445,33 +469,41 @@ fn bench_throughput_by_size(c: &mut Criterion) { }, ); - group.bench_with_input(BenchmarkId::new("adapt_to_pd", name), &req, |b, req| { - b.iter(|| { - let pd_req = (*req).clone().to_pd_request(); - black_box(pd_req); - }); - }); + group.bench_with_input( + BenchmarkId::new("bootstrap_inject", name), + &req, + |b, req| { + b.iter(|| { + let mut json = to_value(req).unwrap(); + inject_bootstrap_fields(&mut json, &worker).unwrap(); + black_box(json); + }); + }, + ); } group.finish(); } -// Benchmark full round-trip: deserialize -> adapt -> serialize +// Benchmark full round-trip: deserialize -> inject bootstrap -> serialize fn bench_full_round_trip(c: &mut Criterion) { let mut group = c.benchmark_group("full_round_trip"); let generate_json = to_string(&create_sample_generate_request()).unwrap(); let chat_json = to_string(&create_sample_chat_completion_request()).unwrap(); let completion_json = to_string(&create_sample_completion_request()).unwrap(); + let worker = create_test_worker(); group.bench_function("generate_openai_to_pd_pipeline", |b| { b.iter(|| { // Deserialize OpenAI request let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap(); - // Adapt to PD format - let pd_req = req.to_pd_request(); - // Serialize PD request - let pd_json = to_string(&pd_req).unwrap(); + // Convert to JSON Value + let mut json = to_value(&req).unwrap(); + // Inject bootstrap fields + inject_bootstrap_fields(&mut json, &worker).unwrap(); + // Serialize final request + let pd_json = to_string(&json).unwrap(); black_box(pd_json); }); }); @@ -479,8 +511,9 @@ fn bench_full_round_trip(c: &mut Criterion) { group.bench_function("chat_completion_openai_to_pd_pipeline", |b| { b.iter(|| { let req: ChatCompletionRequest = from_str(black_box(&chat_json)).unwrap(); - let pd_req = req.to_pd_request(); - let pd_json = to_string(&pd_req).unwrap(); + let mut json = to_value(&req).unwrap(); + inject_bootstrap_fields(&mut json, &worker).unwrap(); + let pd_json = to_string(&json).unwrap(); black_box(pd_json); }); }); @@ -488,19 +521,21 @@ fn bench_full_round_trip(c: &mut Criterion) { group.bench_function("completion_openai_to_pd_pipeline", |b| { b.iter(|| { let req: CompletionRequest = from_str(black_box(&completion_json)).unwrap(); - let pd_req = req.to_pd_request(); - let pd_json = to_string(&pd_req).unwrap(); + let mut json = to_value(&req).unwrap(); + inject_bootstrap_fields(&mut json, &worker).unwrap(); + let pd_json = to_string(&json).unwrap(); black_box(pd_json); }); }); - group.bench_function("generate_regular_routing_pipeline", |b| { + group.bench_function("generate_direct_json_pipeline", |b| { b.iter(|| { // Deserialize OpenAI request let req: GenerateRequest = from_str(black_box(&generate_json)).unwrap(); - // Convert to JSON for regular routing - let routing_json = req.to_json().unwrap(); - black_box(routing_json); + // Convert to JSON for direct routing (no bootstrap injection) + let routing_json = to_value(&req).unwrap(); + let json_string = to_string(&routing_json).unwrap(); + black_box(json_string); }); }); @@ -515,6 +550,7 @@ fn benchmark_summary(c: &mut Criterion) { // Quick performance overview let generate_req = create_sample_generate_request(); + let worker = create_test_worker(); println!("\nQuick Performance Overview:"); @@ -538,32 +574,39 @@ fn benchmark_summary(c: &mut Criterion) { deserialize_time ); - // Measure adaptation + // Measure bootstrap injection (replaces adaptation) let start = Instant::now(); for _ in 0..1000 { - let _ = black_box(generate_req.clone().to_pd_request()); + let mut json = to_value(&generate_req).unwrap(); + let _ = black_box(inject_bootstrap_fields(&mut json, &worker)); } - let adapt_time = start.elapsed().as_nanos() / 1000; - println!(" * PD Adaptation (avg): {:>8} ns/req", adapt_time); + let inject_time = start.elapsed().as_nanos() / 1000; + println!(" * Bootstrap Injection (avg): {:>6} ns/req", inject_time); // Calculate ratios - let total_pipeline = serialize_time + deserialize_time + adapt_time; + let total_pipeline = serialize_time + deserialize_time + inject_time; println!(" * Total Pipeline (avg): {:>8} ns/req", total_pipeline); println!("\nPerformance Insights:"); if deserialize_time > serialize_time * 2 { println!(" • Deserialization is significantly faster than serialization"); } - if adapt_time < serialize_time / 10 { + if inject_time < serialize_time / 10 { println!( - " • PD adaptation overhead is negligible ({:.1}% of serialization)", - (adapt_time as f64 / serialize_time as f64) * 100.0 + " • Bootstrap injection overhead is negligible ({:.1}% of serialization)", + (inject_time as f64 / serialize_time as f64) * 100.0 ); } - if total_pipeline < 10_000 { - println!(" • Total pipeline latency is excellent (< 10μs)"); + if total_pipeline < 100_000 { + println!(" • Total pipeline latency is excellent (< 100μs)"); } + println!("\nSimplification Benefits:"); + println!(" • Eliminated complex type conversion layer"); + println!(" • Reduced memory allocations"); + println!(" • Automatic field preservation (no manual mapping)"); + println!(" • Direct JSON manipulation improves performance"); + println!("\nRecommendations:"); if serialize_time > deserialize_time { println!(" • Focus optimization efforts on serialization rather than deserialization"); @@ -581,8 +624,8 @@ criterion_group!( benchmark_summary, bench_json_serialization, bench_json_deserialization, - bench_request_adaptation, - bench_regular_routing, + bench_bootstrap_injection, + bench_direct_json_routing, bench_throughput_by_size, bench_full_round_trip ); diff --git a/sgl-router/scripts/run_benchmarks.py b/sgl-router/scripts/run_benchmarks.py index 307c3557b..76bf37f2a 100755 --- a/sgl-router/scripts/run_benchmarks.py +++ b/sgl-router/scripts/run_benchmarks.py @@ -121,8 +121,6 @@ class BenchmarkRunner: results["serialization_time"] = self._extract_time(line) elif "Deserialization (avg):" in line: results["deserialization_time"] = self._extract_time(line) - elif "PD Adaptation (avg):" in line: - results["adaptation_time"] = self._extract_time(line) elif "Total Pipeline (avg):" in line: results["total_time"] = self._extract_time(line) @@ -145,7 +143,6 @@ class BenchmarkRunner: thresholds = { "serialization_time": 2000, # 2μs max "deserialization_time": 2000, # 2μs max - "adaptation_time": 5000, # 5μs max "total_time": 10000, # 10μs max } diff --git a/sgl-router/src/routers/bootstrap_injector.rs b/sgl-router/src/routers/bootstrap_injector.rs new file mode 100644 index 000000000..e7cad384d --- /dev/null +++ b/sgl-router/src/routers/bootstrap_injector.rs @@ -0,0 +1,334 @@ +// Bootstrap field injection for PD routing +// Directly injects bootstrap fields into JSON requests without intermediate type conversions + +use crate::core::{Worker, WorkerType}; +use crate::routers::pd_types::get_hostname; +use serde_json::{json, Value}; + +/// Inject bootstrap fields directly into a JSON request +/// This replaces the complex ToPdRequest -> Bootstrap trait pattern +pub fn inject_bootstrap_fields(json: &mut Value, worker: &dyn Worker) -> Result<(), String> { + let batch_size = extract_batch_size(json)?; + + // Extract bootstrap port from prefill worker if it's a prefill type + let bootstrap_port = match worker.worker_type() { + WorkerType::Prefill { bootstrap_port } => bootstrap_port, + _ => None, + }; + + let hostname = get_hostname(worker.url()); + + if let Some(batch_size) = batch_size { + // Batch scenario - create arrays of bootstrap values + json["bootstrap_host"] = json!(vec![hostname; batch_size]); + json["bootstrap_port"] = json!(vec![bootstrap_port; batch_size]); + json["bootstrap_room"] = json!((0..batch_size) + .map(|_| { + // Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1) + rand::random::() & (i64::MAX as u64) + }) + .collect::>()); + } else { + // Single scenario - create single bootstrap values + json["bootstrap_host"] = json!(hostname); + json["bootstrap_port"] = json!(bootstrap_port); + json["bootstrap_room"] = json!(rand::random::() & (i64::MAX as u64)); + } + + Ok(()) +} + +/// Extract batch size from various JSON request formats +/// Handles chat completions, completions, and generate requests +fn extract_batch_size(json: &Value) -> Result, String> { + // Check for chat completions 'n' parameter (number of choices) + if let Some(n) = json.get("n").and_then(|v| v.as_u64()) { + if n > 1 { + return Ok(Some(n as usize)); + } + } + + // Check for array prompts (completions API) + if let Some(prompt) = json.get("prompt") { + if let Some(arr) = prompt.as_array() { + if arr.is_empty() { + return Err("Batch prompt array is empty".to_string()); + } + return Ok(Some(arr.len())); + } + } + + // Check for array texts (generate API) + if let Some(text) = json.get("text") { + if let Some(arr) = text.as_array() { + if arr.is_empty() { + return Err("Batch text array is empty".to_string()); + } + return Ok(Some(arr.len())); + } + } + + // Check for batch input_ids (generate API) + if let Some(input_ids) = json.get("input_ids") { + if let Some(arr) = input_ids.as_array() { + if arr.is_empty() { + return Err("Batch input_ids array is empty".to_string()); + } + return Ok(Some(arr.len())); + } + } + + // No batch indicators found - single request + Ok(None) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::core::BasicWorker; + use serde_json::json; + + fn create_test_worker() -> BasicWorker { + BasicWorker::new( + "http://test-server:8000".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(5678), + }, + ) + } + + #[test] + fn test_inject_bootstrap_single_request() { + let worker = create_test_worker(); + let mut json = json!({ + "model": "test-model", + "prompt": "Hello world", + "max_tokens": 100 + }); + + let result = inject_bootstrap_fields(&mut json, &worker); + assert!(result.is_ok()); + + // Verify bootstrap fields were added + assert_eq!(json["bootstrap_host"], json!("test-server")); + assert_eq!(json["bootstrap_port"], json!(5678)); + assert!(json["bootstrap_room"].is_number()); + + // Verify original fields preserved + assert_eq!(json["model"], json!("test-model")); + assert_eq!(json["prompt"], json!("Hello world")); + assert_eq!(json["max_tokens"], json!(100)); + } + + #[test] + fn test_inject_bootstrap_batch_prompt() { + let worker = create_test_worker(); + let mut json = json!({ + "model": "test-model", + "prompt": ["Hello", "World"], + "max_tokens": 100 + }); + + let result = inject_bootstrap_fields(&mut json, &worker); + assert!(result.is_ok()); + + // Verify batch bootstrap fields + assert_eq!( + json["bootstrap_host"], + json!(["test-server", "test-server"]) + ); + assert_eq!(json["bootstrap_port"], json!([5678, 5678])); + + let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap(); + assert_eq!(bootstrap_rooms.len(), 2); + for room in bootstrap_rooms { + assert!(room.is_number()); + let room_val = room.as_u64().unwrap(); + assert!(room_val <= i64::MAX as u64); + } + } + + #[test] + fn test_inject_bootstrap_chat_n_parameter() { + let worker = create_test_worker(); + let mut json = json!({ + "model": "gpt-4", + "messages": [{"role": "user", "content": "Hello"}], + "n": 3 + }); + + let result = inject_bootstrap_fields(&mut json, &worker); + assert!(result.is_ok()); + + // Verify batch bootstrap fields for n=3 + let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap(); + assert_eq!(bootstrap_hosts.len(), 3); + assert_eq!(bootstrap_hosts[0], json!("test-server")); + + let bootstrap_ports = json["bootstrap_port"].as_array().unwrap(); + assert_eq!(bootstrap_ports.len(), 3); + assert_eq!(bootstrap_ports[0], json!(5678)); + + let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap(); + assert_eq!(bootstrap_rooms.len(), 3); + } + + #[test] + fn test_inject_bootstrap_generate_text_array() { + let worker = create_test_worker(); + let mut json = json!({ + "text": ["First prompt", "Second prompt"], + "stream": false + }); + + let result = inject_bootstrap_fields(&mut json, &worker); + assert!(result.is_ok()); + + // Verify batch bootstrap fields + let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap(); + assert_eq!(bootstrap_hosts.len(), 2); + + let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap(); + assert_eq!(bootstrap_rooms.len(), 2); + // Ensure room values are different (randomness) + assert_ne!(bootstrap_rooms[0], bootstrap_rooms[1]); + } + + #[test] + fn test_inject_bootstrap_input_ids_array() { + let worker = create_test_worker(); + let mut json = json!({ + "input_ids": [[1, 2, 3], [4, 5, 6]], + "stream": false + }); + + let result = inject_bootstrap_fields(&mut json, &worker); + assert!(result.is_ok()); + + // Verify batch bootstrap fields + let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap(); + assert_eq!(bootstrap_hosts.len(), 2); + } + + #[test] + fn test_extract_batch_size_empty_array_error() { + let json = json!({ + "prompt": [], + "model": "test" + }); + + let result = extract_batch_size(&json); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("empty")); + } + + #[test] + fn test_extract_batch_size_single_requests() { + // Single string prompt + let json = json!({ + "prompt": "Hello world", + "model": "test" + }); + assert_eq!(extract_batch_size(&json).unwrap(), None); + + // Single text + let json = json!({ + "text": "Hello world", + "stream": false + }); + assert_eq!(extract_batch_size(&json).unwrap(), None); + + // Chat with n=1 (default) + let json = json!({ + "messages": [{"role": "user", "content": "Hello"}], + "n": 1 + }); + assert_eq!(extract_batch_size(&json).unwrap(), None); + + // Chat without n parameter + let json = json!({ + "messages": [{"role": "user", "content": "Hello"}] + }); + assert_eq!(extract_batch_size(&json).unwrap(), None); + } + + #[test] + fn test_inject_bootstrap_preserves_sglang_fields() { + let worker = create_test_worker(); + let mut json = json!({ + "model": "test-model", + "prompt": "Hello", + // SGLang extensions should be preserved + "top_k": 40, + "min_p": 0.05, + "repetition_penalty": 1.1, + "regex": "test_pattern", + "lora_path": "test.bin", + "no_stop_trim": true, + "ignore_eos": false + }); + + let result = inject_bootstrap_fields(&mut json, &worker); + assert!(result.is_ok()); + + // Verify bootstrap fields added + assert!(json.get("bootstrap_host").is_some()); + assert!(json.get("bootstrap_port").is_some()); + assert!(json.get("bootstrap_room").is_some()); + + // Verify all SGLang fields preserved + assert_eq!(json["top_k"], json!(40)); + assert_eq!(json["min_p"], json!(0.05)); + assert_eq!(json["repetition_penalty"], json!(1.1)); + assert_eq!(json["regex"], json!("test_pattern")); + assert_eq!(json["lora_path"], json!("test.bin")); + assert_eq!(json["no_stop_trim"], json!(true)); + assert_eq!(json["ignore_eos"], json!(false)); + } + + #[test] + fn test_bootstrap_room_range() { + let worker = create_test_worker(); + + // Test single request room generation + for _ in 0..1000 { + let mut json = json!({"prompt": "test"}); + inject_bootstrap_fields(&mut json, &worker).unwrap(); + + let room = json["bootstrap_room"].as_u64().unwrap(); + assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room); + } + + // Test batch request room generation + for _ in 0..100 { + let mut json = json!({"prompt": ["test1", "test2"]}); + inject_bootstrap_fields(&mut json, &worker).unwrap(); + + let rooms = json["bootstrap_room"].as_array().unwrap(); + for room_val in rooms { + let room = room_val.as_u64().unwrap(); + assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room); + } + } + } + + #[test] + fn test_worker_without_bootstrap_port() { + let worker = BasicWorker::new( + "http://decode-only:8000".to_string(), + WorkerType::Decode, // No bootstrap port + ); + + let mut json = json!({ + "prompt": "Hello world" + }); + + let result = inject_bootstrap_fields(&mut json, &worker); + assert!(result.is_ok()); + + // Verify bootstrap fields with null port + assert_eq!(json["bootstrap_host"], json!("decode-only")); + assert_eq!(json["bootstrap_port"], json!(null)); + assert!(json["bootstrap_room"].is_number()); + } +} diff --git a/sgl-router/src/routers/mod.rs b/sgl-router/src/routers/mod.rs index 75f12c63b..ab6d6c1aa 100644 --- a/sgl-router/src/routers/mod.rs +++ b/sgl-router/src/routers/mod.rs @@ -11,10 +11,10 @@ use std::fmt::Debug; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; +pub mod bootstrap_injector; pub mod factory; pub mod pd_router; pub mod pd_types; -pub mod request_adapter; pub mod router; pub use factory::RouterFactory; diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index dccb68e8f..8b10d95db 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -1,14 +1,16 @@ // PD (Prefill-Decode) Router Implementation // This module handles routing for disaggregated prefill-decode systems -use super::pd_types::{api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError}; -use super::request_adapter::ToPdRequest; +use super::bootstrap_injector::inject_bootstrap_fields; +use super::pd_types::{api_path, PDRouterError}; use crate::config::types::RetryConfig; use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard}; use crate::metrics::RouterMetrics; use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest}; use crate::policies::LoadBalancingPolicy; +use crate::routers::{RouterTrait, WorkerManagement}; use crate::tree::Tree; +use async_trait::async_trait; use axum::{ body::Body, extract::Request, @@ -46,18 +48,26 @@ pub struct PDRouter { impl PDRouter { // Dynamic worker management methods for service discovery + + // Private helper method to perform health check on a new server + async fn wait_for_server_health(&self, url: &str) -> Result<(), PDRouterError> { + crate::routers::router::Router::wait_for_healthy_workers( + &[url.to_string()], + self.timeout_secs, + self.interval_secs, + ) + .map_err(|_| PDRouterError::HealthCheckFailed { + url: url.to_string(), + }) + } + pub async fn add_prefill_server( &self, url: String, bootstrap_port: Option, ) -> Result { // Wait for the new server to be healthy - crate::routers::router::Router::wait_for_healthy_workers( - &[url.clone()], - self.timeout_secs, - self.interval_secs, - ) - .map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?; + self.wait_for_server_health(&url).await?; // Create Worker for the new prefill server let worker = WorkerFactory::create_prefill(url.clone(), bootstrap_port); @@ -88,12 +98,7 @@ impl PDRouter { pub async fn add_decode_server(&self, url: String) -> Result { // Wait for the new server to be healthy - crate::routers::router::Router::wait_for_healthy_workers( - &[url.clone()], - self.timeout_secs, - self.interval_secs, - ) - .map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?; + self.wait_for_server_health(&url).await?; // Create Worker for the new decode server let worker = WorkerFactory::create_decode(url.clone()); @@ -332,189 +337,6 @@ impl PDRouter { .into_response() } - // Route a typed generate request - pub async fn route_generate( - &self, - headers: Option<&HeaderMap>, - mut typed_req: GenerateReqInput, - route: &str, - ) -> Response { - let start = Instant::now(); - - // Get stream flag and return_logprob flag before moving the request - let is_stream = typed_req.stream; - let return_logprob = typed_req - .other - .get("return_logprob") - .and_then(|v| v.as_bool()) - .unwrap_or(false); - - // Extract text for cache-aware routing from the typed request - let request_text = typed_req.text.as_ref().and_then(|t| match t { - super::pd_types::InputText::Single(s) => Some(s.as_str()), - super::pd_types::InputText::Batch(v) => v.first().map(|s| s.as_str()), - }); - - // Select servers - let (prefill, decode) = match self.select_pd_pair(request_text).await { - Ok(pair) => pair, - Err(e) => return Self::handle_server_selection_error(e), - }; - - // Log routing decision - info!( - "PD routing decision route={} prefill_url={} decode_url={}", - route, - prefill.url(), - decode.url() - ); - - // Add bootstrap info using the trait method - if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { - return Self::handle_bootstrap_error(e); - } - - // Convert to JSON after bootstrap injection - let json_with_bootstrap = match serde_json::to_value(&typed_req) { - Ok(json) => json, - Err(e) => return Self::handle_serialization_error(e), - }; - - // Execute dual dispatch - self.execute_dual_dispatch( - headers, - json_with_bootstrap, - route, - prefill.as_ref(), - decode.as_ref(), - is_stream, - return_logprob, - start, - ) - .await - } - - // Route a typed chat request - pub async fn route_chat( - &self, - headers: Option<&HeaderMap>, - mut typed_req: ChatReqInput, - route: &str, - ) -> Response { - let start = Instant::now(); - - // Get stream flag and return_logprob flag before moving the request - let is_stream = typed_req.stream; - let return_logprob = typed_req - .other - .get("return_logprob") - .and_then(|v| v.as_bool()) - .unwrap_or(false); - - // Extract text for cache-aware routing from chat messages - let request_text = typed_req - .other - .get("messages") - .and_then(|messages| messages.as_array()) - .and_then(|arr| arr.first()) - .and_then(|msg| msg.get("content")) - .and_then(|content| content.as_str()); - - // Select servers - let (prefill, decode) = match self.select_pd_pair(request_text).await { - Ok(pair) => pair, - Err(e) => return Self::handle_server_selection_error(e), - }; - - // Log routing decision - info!( - "PD routing decision route={} prefill_url={} decode_url={}", - route, - prefill.url(), - decode.url() - ); - - if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { - return Self::handle_bootstrap_error(e); - } - - // Convert to JSON after bootstrap injection - let json_with_bootstrap = match serde_json::to_value(&typed_req) { - Ok(json) => json, - Err(e) => return Self::handle_serialization_error(e), - }; - - // Execute dual dispatch - self.execute_dual_dispatch( - headers, - json_with_bootstrap, - route, - prefill.as_ref(), - decode.as_ref(), - is_stream, - return_logprob, - start, - ) - .await - } - - // Route a completion request while preserving OpenAI format - pub async fn route_completion( - &self, - headers: Option<&HeaderMap>, - mut typed_req: CompletionRequest, - route: &str, - ) -> Response { - let start = Instant::now(); - - // Get stream flag and return_logprob flag before moving the request - let is_stream = typed_req.stream; - let return_logprob = typed_req.logprobs.is_some(); - - // Extract text for cache-aware routing from the typed request - let request_text = match &typed_req.prompt { - crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), - crate::openai_api_types::StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()), - }; - - // Select servers - let (prefill, decode) = match self.select_pd_pair(request_text).await { - Ok(pair) => pair, - Err(e) => return Self::handle_server_selection_error(e), - }; - - // Log routing decision - info!( - "PD routing decision route={} prefill_url={} decode_url={}", - route, - prefill.url(), - decode.url() - ); - - if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { - return Self::handle_bootstrap_error(e); - } - - // Convert to JSON after bootstrap injection - let json_with_bootstrap = match serde_json::to_value(&typed_req) { - Ok(json) => json, - Err(e) => return Self::handle_serialization_error(e), - }; - - // Execute dual dispatch - self.execute_dual_dispatch( - headers, - json_with_bootstrap, - route, - prefill.as_ref(), - decode.as_ref(), - is_stream, - return_logprob, - start, - ) - .await - } - // Execute the dual dispatch to prefill and decode servers with retry logic async fn execute_dual_dispatch( &self, @@ -1090,7 +912,7 @@ impl PDRouter { // Helper functions -async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option { +async fn get_worker_load(client: &Client, worker_url: &str) -> Option { match client.get(format!("{}/get_load", worker_url)).send().await { Ok(res) if res.status().is_success() => match res.bytes().await { Ok(bytes) => match serde_json::from_slice::(&bytes) { @@ -1123,349 +945,6 @@ async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option Response { - // Test model generation capability by selecting a random pair and testing them - // Note: This endpoint actually causes the model to generate tokens, so we only test one pair - - // Select a random worker pair using the policy - let (prefill, decode) = match self.select_pd_pair(None).await { - Ok(pair) => pair, - Err(e) => { - return ( - StatusCode::SERVICE_UNAVAILABLE, - format!("No healthy worker pair available: {}", e), - ) - .into_response(); - } - }; - - // Test prefill server's health_generate - let prefill_url = format!("{}/health_generate", prefill.url()); - let prefill_result = self.client.get(&prefill_url).send().await; - - // Test decode server's health_generate - let decode_url = format!("{}/health_generate", decode.url()); - let decode_result = self.client.get(&decode_url).send().await; - - // Check results - let mut errors = Vec::new(); - - match prefill_result { - Ok(res) if res.status().is_success() => { - debug!( - "Health generate passed for prefill server: {}", - prefill.url() - ); - } - Ok(res) => { - errors.push(format!( - "Prefill {} returned status {}", - prefill.url(), - res.status() - )); - } - Err(e) => { - errors.push(format!("Prefill {} error: {}", prefill.url(), e)); - } - } - - match decode_result { - Ok(res) if res.status().is_success() => { - debug!("Health generate passed for decode server: {}", decode.url()); - } - Ok(res) => { - errors.push(format!( - "Decode {} returned status {}", - decode.url(), - res.status() - )); - } - Err(e) => { - errors.push(format!("Decode {} error: {}", decode.url(), e)); - } - } - - if errors.is_empty() { - ( - StatusCode::OK, - format!( - "Health generate passed on selected pair: prefill={}, decode={}", - prefill.url(), - decode.url() - ), - ) - .into_response() - } else { - ( - StatusCode::SERVICE_UNAVAILABLE, - format!("Health generate failed: {:?}", errors), - ) - .into_response() - } - } - - pub async fn get_server_info(&self) -> Response { - // Get info from the first decode server to match sglang's server info format - let first_decode_url = if let Ok(workers) = self.decode_workers.read() { - workers.first().map(|w| w.url().to_string()) - } else { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to access decode workers", - ) - .into_response(); - }; - - if let Some(worker_url) = first_decode_url { - match self - .client - .get(format!("{}/get_server_info", worker_url)) - .send() - .await - { - Ok(res) if res.status().is_success() => { - match res.json::().await { - Ok(info) => { - // The decode server should already return the proper format - // with tokenizer_path and other fields that bench_one_batch_server.py expects - Json(info).into_response() - } - Err(e) => { - error!("Failed to parse server info: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to parse server info: {}", e), - ) - .into_response() - } - } - } - Ok(res) => { - let status = StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - ( - status, - format!("Decode server returned status: {}", res.status()), - ) - .into_response() - } - Err(e) => { - error!("Failed to get server info: {}", e); - ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to get server info: {}", e), - ) - .into_response() - } - } - } else { - ( - StatusCode::SERVICE_UNAVAILABLE, - "No decode servers available", - ) - .into_response() - } - } - - pub async fn get_models(&self, req: Request) -> Response { - // Extract headers first to avoid Send issues - let headers = crate::routers::router::copy_request_headers(&req); - - // Get first prefill worker URL to avoid holding lock across await - let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { - workers.first().map(|w| w.url().to_string()) - } else { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to access prefill workers", - ) - .into_response(); - }; - - if let Some(worker_url) = first_worker_url { - // Send request directly without going through Router - let mut request_builder = self.client.get(format!("{}/v1/models", worker_url)); - for (name, value) in headers { - if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" - { - request_builder = request_builder.header(name, value); - } - } - match request_builder.send().await { - Ok(res) => { - let status = StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - match res.bytes().await { - Ok(body) => (status, body).into_response(), - Err(e) => ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to read response body: {}", e), - ) - .into_response(), - } - } - Err(e) => ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to send request: {}", e), - ) - .into_response(), - } - } else { - ( - StatusCode::SERVICE_UNAVAILABLE, - "No prefill servers available", - ) - .into_response() - } - } - - pub async fn get_loads(&self, client: &reqwest::Client) -> Response { - let p_urls: Vec<_> = self - .prefill_workers - .read() - .unwrap() - .iter() - .map(|w| w.url().to_string()) - .collect(); - let d_urls: Vec<_> = self - .decode_workers - .read() - .unwrap() - .iter() - .map(|w| w.url().to_string()) - .collect(); - - let mut prefill_loads = Vec::new(); - let mut decode_loads = Vec::new(); - - for url in &p_urls { - let load = get_worker_load(client, url).await.unwrap_or(-1); - prefill_loads.push(serde_json::json!({ - "engine": format!("(Prefill@{})", url), - "load": load as i64 - })); - } - - for url in &d_urls { - let load = get_worker_load(client, url).await.unwrap_or(-1); - decode_loads.push(serde_json::json!({ - "engine": format!("(Decode@{})", url), - "load": load as i64 - })); - } - - Json(serde_json::json!({ - "prefill": prefill_loads, - "decode": decode_loads - })) - .into_response() - } - - pub async fn get_model_info(&self, req: Request) -> Response { - // Extract headers first to avoid Send issues - let headers = crate::routers::router::copy_request_headers(&req); - - // Get model info from the first prefill server (matches original Rust PDLB behavior) - // Get first prefill worker URL to avoid holding lock across await - let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { - workers.first().map(|w| w.url().to_string()) - } else { - return ( - StatusCode::INTERNAL_SERVER_ERROR, - "Failed to access prefill workers", - ) - .into_response(); - }; - - if let Some(worker_url) = first_worker_url { - let mut request_builder = self.client.get(format!("{}/get_model_info", worker_url)); - for (name, value) in headers { - if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length" - { - request_builder = request_builder.header(name, value); - } - } - match request_builder.send().await { - Ok(res) => { - let status = StatusCode::from_u16(res.status().as_u16()) - .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); - match res.bytes().await { - Ok(body) => (status, body).into_response(), - Err(e) => ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to read response body: {}", e), - ) - .into_response(), - } - } - Err(e) => ( - StatusCode::INTERNAL_SERVER_ERROR, - format!("Failed to send request: {}", e), - ) - .into_response(), - } - } else { - ( - StatusCode::SERVICE_UNAVAILABLE, - "No prefill servers available", - ) - .into_response() - } - } - - pub async fn flush_cache(&self, client: &reqwest::Client) -> Response { - let mut tasks = Vec::new(); - - // Flush cache on all prefill servers - for worker in self.prefill_workers.read().unwrap().iter() { - let url = format!("{}/flush_cache", worker.url()); - tasks.push(client.post(&url).send()); - } - - // Flush cache on all decode servers - for worker in self.decode_workers.read().unwrap().iter() { - let url = format!("{}/flush_cache", worker.url()); - tasks.push(client.post(&url).send()); - } - - let results = futures_util::future::join_all(tasks).await; - - let mut all_success = true; - for (i, result) in results.into_iter().enumerate() { - match result { - Ok(res) if res.status().is_success() => {} - Ok(res) => { - all_success = false; - warn!( - "Server {} returned status {} for flush_cache", - i, - res.status() - ); - } - Err(e) => { - all_success = false; - error!("Server {} error during flush_cache: {}", i, e); - } - } - } - - if all_success { - (StatusCode::OK, "Cache flushed on all servers").into_response() - } else { - ( - StatusCode::INTERNAL_SERVER_ERROR, - "Cache flush failed on one or more servers", - ) - .into_response() - } - } -} - -use crate::routers::{RouterTrait, WorkerManagement}; -use async_trait::async_trait; - #[async_trait] impl WorkerManagement for PDRouter { async fn add_worker(&self, _worker_url: &str) -> Result { @@ -1556,23 +1035,273 @@ impl RouterTrait for PDRouter { } async fn health_generate(&self, _req: Request) -> Response { - // Use the existing PDRouter health_generate method - PDRouter::health_generate(self).await + // Test model generation capability by selecting a random pair and testing them + // Note: This endpoint actually causes the model to generate tokens, so we only test one pair + + // Select a random worker pair using the policy + let (prefill, decode) = match self.select_pd_pair(None).await { + Ok(pair) => pair, + Err(e) => { + return ( + StatusCode::SERVICE_UNAVAILABLE, + format!("No healthy worker pair available: {}", e), + ) + .into_response(); + } + }; + + // Test prefill server's health_generate + let prefill_url = format!("{}/health_generate", prefill.url()); + let prefill_result = self.client.get(&prefill_url).send().await; + + // Test decode server's health_generate + let decode_url = format!("{}/health_generate", decode.url()); + let decode_result = self.client.get(&decode_url).send().await; + + // Check results + let mut errors = Vec::new(); + + match prefill_result { + Ok(res) if res.status().is_success() => { + debug!( + "Health generate passed for prefill server: {}", + prefill.url() + ); + } + Ok(res) => { + errors.push(format!( + "Prefill {} returned status {}", + prefill.url(), + res.status() + )); + } + Err(e) => { + errors.push(format!("Prefill {} error: {}", prefill.url(), e)); + } + } + + match decode_result { + Ok(res) if res.status().is_success() => { + debug!("Health generate passed for decode server: {}", decode.url()); + } + Ok(res) => { + errors.push(format!( + "Decode {} returned status {}", + decode.url(), + res.status() + )); + } + Err(e) => { + errors.push(format!("Decode {} error: {}", decode.url(), e)); + } + } + + if errors.is_empty() { + ( + StatusCode::OK, + format!( + "Health generate passed on selected pair: prefill={}, decode={}", + prefill.url(), + decode.url() + ), + ) + .into_response() + } else { + ( + StatusCode::SERVICE_UNAVAILABLE, + format!("Health generate failed: {:?}", errors), + ) + .into_response() + } } async fn get_server_info(&self, _req: Request) -> Response { - // Use the existing PDRouter get_server_info method - PDRouter::get_server_info(self).await + // Get info from the first decode server to match sglang's server info format + let first_decode_url = if let Ok(workers) = self.decode_workers.read() { + workers.first().map(|w| w.url().to_string()) + } else { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to access decode workers", + ) + .into_response(); + }; + + if let Some(worker_url) = first_decode_url { + match self + .client + .get(format!("{}/get_server_info", worker_url)) + .send() + .await + { + Ok(res) if res.status().is_success() => { + match res.json::().await { + Ok(info) => { + // The decode server should already return the proper format + // with tokenizer_path and other fields that bench_one_batch_server.py expects + Json(info).into_response() + } + Err(e) => { + error!("Failed to parse server info: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to parse server info: {}", e), + ) + .into_response() + } + } + } + Ok(res) => { + let status = StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + ( + status, + format!("Decode server returned status: {}", res.status()), + ) + .into_response() + } + Err(e) => { + error!("Failed to get server info: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to get server info: {}", e), + ) + .into_response() + } + } + } else { + ( + StatusCode::SERVICE_UNAVAILABLE, + "No decode servers available", + ) + .into_response() + } } async fn get_models(&self, req: Request) -> Response { - // Use the existing PDRouter get_models method - PDRouter::get_models(self, req).await + // Extract headers first to avoid Send issues + let headers = crate::routers::router::copy_request_headers(&req); + + // Get first prefill worker URL to avoid holding lock across await + let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { + workers.first().map(|w| w.url().to_string()) + } else { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to access prefill workers", + ) + .into_response(); + }; + + if let Some(worker_url) = first_worker_url { + let url = format!("{}/v1/models", worker_url); + let mut request_builder = self.client.get(&url); + + // Add headers + for (name, value) in headers { + request_builder = request_builder.header(name, value); + } + + match request_builder.send().await { + Ok(res) if res.status().is_success() => match res.bytes().await { + Ok(body) => (StatusCode::OK, body).into_response(), + Err(e) => { + error!("Failed to read response body: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to read response body: {}", e), + ) + .into_response() + } + }, + Ok(res) => { + let status = StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + ( + status, + format!("Prefill server returned status: {}", res.status()), + ) + .into_response() + } + Err(e) => { + error!("Failed to get models: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to get models: {}", e), + ) + .into_response() + } + } + } else { + ( + StatusCode::SERVICE_UNAVAILABLE, + "No prefill servers available", + ) + .into_response() + } } async fn get_model_info(&self, req: Request) -> Response { - // Use the existing PDRouter get_model_info method - PDRouter::get_model_info(self, req).await + // Extract headers first to avoid Send issues + let headers = crate::routers::router::copy_request_headers(&req); + + // Get first prefill worker URL to avoid holding lock across await + let first_worker_url = if let Ok(workers) = self.prefill_workers.read() { + workers.first().map(|w| w.url().to_string()) + } else { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to access prefill workers", + ) + .into_response(); + }; + + if let Some(worker_url) = first_worker_url { + let url = format!("{}/get_model_info", worker_url); + let mut request_builder = self.client.get(&url); + + // Add headers + for (name, value) in headers { + request_builder = request_builder.header(name, value); + } + + match request_builder.send().await { + Ok(res) if res.status().is_success() => match res.bytes().await { + Ok(body) => (StatusCode::OK, body).into_response(), + Err(e) => { + error!("Failed to read response body: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to read response body: {}", e), + ) + .into_response() + } + }, + Ok(res) => { + let status = StatusCode::from_u16(res.status().as_u16()) + .unwrap_or(StatusCode::INTERNAL_SERVER_ERROR); + ( + status, + format!("Prefill server returned status: {}", res.status()), + ) + .into_response() + } + Err(e) => { + error!("Failed to get model info: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("Failed to get model info: {}", e), + ) + .into_response() + } + } + } else { + ( + StatusCode::SERVICE_UNAVAILABLE, + "No prefill servers available", + ) + .into_response() + } } async fn route_generate( @@ -1580,10 +1309,56 @@ impl RouterTrait for PDRouter { headers: Option<&HeaderMap>, body: &GenerateRequest, ) -> Response { - // Convert OpenAI format to PD format - let pd_req = body.clone().to_pd_request(); + let start = Instant::now(); - PDRouter::route_generate(self, headers, pd_req, "/generate").await + // Convert directly to JSON to preserve all fields automatically + let mut json = match serde_json::to_value(body) { + Ok(json) => json, + Err(e) => return Self::handle_serialization_error(e), + }; + + // Extract flags for routing logic + let is_stream = body.stream; + let return_logprob = body.return_logprob; + + // Extract text for cache-aware routing + let request_text = body.text.as_deref().or_else(|| { + body.prompt.as_ref().and_then(|p| match p { + crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), + crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()), + }) + }); + + // Select servers + let (prefill, decode) = match self.select_pd_pair(request_text).await { + Ok(pair) => pair, + Err(e) => return Self::handle_server_selection_error(e), + }; + + // Log routing decision + info!( + "PD routing decision route=/generate prefill_url={} decode_url={}", + prefill.url(), + decode.url() + ); + + // Inject bootstrap fields directly into JSON + if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) { + return Self::handle_bootstrap_error(e); + } + + // Execute dual dispatch + self.execute_dual_dispatch( + headers, + json, + "/generate", + prefill.as_ref(), + decode.as_ref(), + is_stream, + return_logprob, + start, + ) + .await } async fn route_chat( @@ -1591,10 +1366,60 @@ impl RouterTrait for PDRouter { headers: Option<&HeaderMap>, body: &ChatCompletionRequest, ) -> Response { - // Convert OpenAI format to PD format - let pd_req = body.clone().to_pd_request(); + let start = Instant::now(); - PDRouter::route_chat(self, headers, pd_req, "/v1/chat/completions").await + // Convert directly to JSON to preserve all fields automatically + let mut json = match serde_json::to_value(body) { + Ok(json) => json, + Err(e) => return Self::handle_serialization_error(e), + }; + + // Extract flags for routing logic + let is_stream = body.stream; + let return_logprob = body.logprobs; + + // Extract text for cache-aware routing from chat messages + let request_text = body.messages.first().and_then(|msg| match msg { + crate::openai_api_types::ChatMessage::User { content, .. } => { + match content { + crate::openai_api_types::UserMessageContent::Text(text) => Some(text.as_str()), + crate::openai_api_types::UserMessageContent::Parts(_) => None, // Skip complex content + } + } + crate::openai_api_types::ChatMessage::System { content, .. } => Some(content.as_str()), + _ => None, + }); + + // Select servers + let (prefill, decode) = match self.select_pd_pair(request_text).await { + Ok(pair) => pair, + Err(e) => return Self::handle_server_selection_error(e), + }; + + // Log routing decision + info!( + "PD routing decision route=/v1/chat/completions prefill_url={} decode_url={}", + prefill.url(), + decode.url() + ); + + // Inject bootstrap fields directly into JSON + if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) { + return Self::handle_bootstrap_error(e); + } + + // Execute dual dispatch + self.execute_dual_dispatch( + headers, + json, + "/v1/chat/completions", + prefill.as_ref(), + decode.as_ref(), + is_stream, + return_logprob, + start, + ) + .await } async fn route_completion( @@ -1602,18 +1427,196 @@ impl RouterTrait for PDRouter { headers: Option<&HeaderMap>, body: &CompletionRequest, ) -> Response { - // Use the new method that preserves OpenAI format - PDRouter::route_completion(self, headers, body.clone(), "/v1/completions").await + let start = Instant::now(); + + // Convert directly to JSON to preserve all fields automatically + let mut json = match serde_json::to_value(body) { + Ok(json) => json, + Err(e) => return Self::handle_serialization_error(e), + }; + + // Extract flags for routing logic + let is_stream = body.stream; + let return_logprob = body.logprobs.is_some(); + + // Extract text for cache-aware routing + let request_text = match &body.prompt { + crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), + crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()), + }; + + // Select servers + let (prefill, decode) = match self.select_pd_pair(request_text).await { + Ok(pair) => pair, + Err(e) => return Self::handle_server_selection_error(e), + }; + + // Log routing decision + info!( + "PD routing decision route=/v1/completions prefill_url={} decode_url={}", + prefill.url(), + decode.url() + ); + + // Inject bootstrap fields directly into JSON + if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) { + return Self::handle_bootstrap_error(e); + } + + // Execute dual dispatch + self.execute_dual_dispatch( + headers, + json, + "/v1/completions", + prefill.as_ref(), + decode.as_ref(), + is_stream, + return_logprob, + start, + ) + .await } async fn flush_cache(&self) -> Response { - // Use the existing PDRouter flush_cache method - PDRouter::flush_cache(self, &self.client).await + let mut results = Vec::new(); + let mut errors = Vec::new(); + + // Get prefill worker URLs first to avoid holding lock across await + let prefill_urls = if let Ok(workers) = self.prefill_workers.read() { + workers + .iter() + .map(|w| w.url().to_string()) + .collect::>() + } else { + errors.push("Failed to access prefill workers".to_string()); + Vec::new() + }; + + // Flush prefill workers + for worker_url in prefill_urls { + let url = format!("{}/flush_cache", worker_url); + match self.client.post(&url).send().await { + Ok(res) if res.status().is_success() => { + results.push(format!("Prefill {}: OK", worker_url)); + } + Ok(res) => { + errors.push(format!( + "Prefill {} returned status: {}", + worker_url, + res.status() + )); + } + Err(e) => { + errors.push(format!("Prefill {} error: {}", worker_url, e)); + } + } + } + + // Get decode worker URLs first to avoid holding lock across await + let decode_urls = if let Ok(workers) = self.decode_workers.read() { + workers + .iter() + .map(|w| w.url().to_string()) + .collect::>() + } else { + errors.push("Failed to access decode workers".to_string()); + Vec::new() + }; + + // Flush decode workers + for worker_url in decode_urls { + let url = format!("{}/flush_cache", worker_url); + match self.client.post(&url).send().await { + Ok(res) if res.status().is_success() => { + results.push(format!("Decode {}: OK", worker_url)); + } + Ok(res) => { + errors.push(format!( + "Decode {} returned status: {}", + worker_url, + res.status() + )); + } + Err(e) => { + errors.push(format!("Decode {} error: {}", worker_url, e)); + } + } + } + + if errors.is_empty() { + ( + StatusCode::OK, + format!("Cache flushed successfully: {:?}", results), + ) + .into_response() + } else { + ( + StatusCode::PARTIAL_CONTENT, + format!( + "Partial success. Results: {:?}, Errors: {:?}", + results, errors + ), + ) + .into_response() + } } async fn get_worker_loads(&self) -> Response { - // Use the existing PDRouter get_loads method - PDRouter::get_loads(self, &self.client).await + let mut loads = HashMap::new(); + let mut errors = Vec::new(); + + // Get prefill worker URLs first to avoid holding lock across await + let prefill_urls = if let Ok(workers) = self.prefill_workers.read() { + workers + .iter() + .map(|w| w.url().to_string()) + .collect::>() + } else { + errors.push("Failed to access prefill workers".to_string()); + Vec::new() + }; + + // Get loads from prefill workers + for worker_url in prefill_urls { + match get_worker_load(&self.client, &worker_url).await { + Some(load) => { + loads.insert(format!("prefill_{}", worker_url), load); + } + None => { + errors.push(format!("Failed to get load from prefill {}", worker_url)); + } + } + } + + // Get decode worker URLs first to avoid holding lock across await + let decode_urls = if let Ok(workers) = self.decode_workers.read() { + workers + .iter() + .map(|w| w.url().to_string()) + .collect::>() + } else { + errors.push("Failed to access decode workers".to_string()); + Vec::new() + }; + + // Get loads from decode workers + for worker_url in decode_urls { + match get_worker_load(&self.client, &worker_url).await { + Some(load) => { + loads.insert(format!("decode_{}", worker_url), load); + } + None => { + errors.push(format!("Failed to get load from decode {}", worker_url)); + } + } + } + + let response_data = serde_json::json!({ + "loads": loads, + "errors": errors + }); + + (StatusCode::OK, Json(response_data)).into_response() } fn router_type(&self) -> &'static str { @@ -1688,7 +1691,6 @@ mod tests { use super::*; use crate::core::{BasicWorker, WorkerType}; use crate::policies::{CacheAwarePolicy, RandomPolicy}; - use crate::routers::pd_types::SingleOrBatch; fn create_test_pd_router() -> PDRouter { let prefill_policy = Arc::new(RandomPolicy::new()); @@ -1935,90 +1937,6 @@ mod tests { assert!(result.is_ok()); } - // ============= Bootstrap Injection Tests ============= - - #[test] - fn test_bootstrap_injection_with_existing_fields() { - let mut req = GenerateReqInput { - text: Some(SingleOrBatch::Single("Test".to_string())), - input_ids: None, - stream: false, - bootstrap_host: Some(SingleOrBatch::Single("existing-host".to_string())), - bootstrap_port: Some(SingleOrBatch::Single(Some(9999))), - bootstrap_room: Some(SingleOrBatch::Single(12345)), - other: Value::Object(serde_json::Map::new()), - }; - - let prefill_worker = create_test_worker( - "http://new-host:8000".to_string(), - WorkerType::Prefill { - bootstrap_port: Some(8080), - }, - true, - ); - - // Bootstrap info is added regardless of existing fields - let result = req.add_bootstrap_info(prefill_worker.as_ref()); - assert!(result.is_ok()); - - // Bootstrap info should be updated with new values - assert_eq!( - req.bootstrap_host, - Some(SingleOrBatch::Single("new-host".to_string())) - ); - assert_eq!(req.bootstrap_port, Some(SingleOrBatch::Single(Some(8080)))); - // Room should be regenerated (different from original) - if let Some(SingleOrBatch::Single(room)) = req.bootstrap_room { - assert_ne!(room, 12345); - } else { - panic!("Expected single room ID"); - } - } - - #[test] - fn test_bootstrap_room_generation() { - let mut req1 = GenerateReqInput { - text: Some(SingleOrBatch::Single("Test".to_string())), - input_ids: None, - stream: false, - bootstrap_host: None, - bootstrap_port: None, - bootstrap_room: None, - other: Value::Object(serde_json::Map::new()), - }; - - let mut req2 = GenerateReqInput { - text: Some(SingleOrBatch::Single("Test".to_string())), - input_ids: None, - stream: false, - bootstrap_host: None, - bootstrap_port: None, - bootstrap_room: None, - other: Value::Object(serde_json::Map::new()), - }; - - let prefill_worker = create_test_worker( - "http://host:8000".to_string(), - WorkerType::Prefill { - bootstrap_port: Some(8080), - }, - true, - ); - - // Add bootstrap info to both requests - let _ = req1.add_bootstrap_info(prefill_worker.as_ref()); - let _ = req2.add_bootstrap_info(prefill_worker.as_ref()); - - // Room IDs should be different - if let (Some(SingleOrBatch::Single(room1)), Some(SingleOrBatch::Single(room2))) = - (req1.bootstrap_room, req2.bootstrap_room) - { - assert_ne!(room1, room2, "Room IDs should be unique"); - } else { - panic!("Expected single room IDs"); - } - } - // ============= Worker Selection Tests ============= #[tokio::test] @@ -2196,4 +2114,158 @@ mod tests { let workers = router.prefill_workers.read().unwrap(); assert_eq!(workers.len(), 5); } + + #[tokio::test] + async fn test_simplified_routing_preserves_sglang_fields() { + use crate::openai_api_types::GenerateRequest; + use crate::routers::bootstrap_injector::inject_bootstrap_fields; + + // Create a test worker + let worker = BasicWorker::new( + "http://test-server:8000".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(5678), + }, + ); + + // Create a GenerateRequest with SGLang extensions + let mut session_params = std::collections::HashMap::new(); + session_params.insert("test_key".to_string(), serde_json::json!("test_value")); + + let request = GenerateRequest { + text: Some("Test prompt".to_string()), + stream: false, + return_logprob: true, + // SGLang extensions + lora_path: Some(crate::openai_api_types::LoRAPath::Single(Some( + "test.bin".to_string(), + ))), + session_params: Some(session_params.clone()), + return_hidden_states: true, + rid: Some("test-request-id".to_string()), + // Other fields default to None/false + prompt: None, + input_ids: None, + parameters: None, + sampling_params: None, + }; + + // Convert to JSON (simulating the simplified routing path) + let mut json = serde_json::to_value(&request).unwrap(); + + // Inject bootstrap fields + let result = inject_bootstrap_fields(&mut json, &worker); + assert!(result.is_ok()); + + // Verify all SGLang fields are preserved + assert_eq!(json["text"], serde_json::json!("Test prompt")); + assert_eq!(json["stream"], serde_json::json!(false)); + assert_eq!(json["return_logprob"], serde_json::json!(true)); + assert_eq!(json["lora_path"], serde_json::json!("test.bin")); // LoRAPath::Single serializes as just the inner value + assert_eq!( + json["session_params"], + serde_json::to_value(&session_params).unwrap() + ); + assert_eq!(json["return_hidden_states"], serde_json::json!(true)); + assert_eq!(json["rid"], serde_json::json!("test-request-id")); + + // Verify bootstrap fields were added + assert_eq!(json["bootstrap_host"], serde_json::json!("test-server")); + assert_eq!(json["bootstrap_port"], serde_json::json!(5678)); + assert!(json["bootstrap_room"].is_number()); + } + + #[tokio::test] + async fn test_simplified_routing_chat_completion() { + use crate::openai_api_types::{ChatCompletionRequest, ChatMessage, UserMessageContent}; + use crate::routers::bootstrap_injector::inject_bootstrap_fields; + + // Create a test worker + let worker = BasicWorker::new( + "http://chat-server:8000".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(9999), + }, + ); + + // Create a ChatCompletionRequest with SGLang extensions + let request = ChatCompletionRequest { + model: "gpt-4".to_string(), + messages: vec![ChatMessage::User { + role: "user".to_string(), + content: UserMessageContent::Text("Hello world!".to_string()), + name: None, + }], + stream: false, + n: Some(2), // This should create batch bootstrap + // SGLang extensions + top_k: Some(50), + separate_reasoning: false, + stream_reasoning: true, + // Set all other fields to defaults + temperature: None, + top_p: None, + stream_options: None, + stop: None, + max_tokens: None, + max_completion_tokens: None, + presence_penalty: None, + frequency_penalty: None, + logit_bias: None, + user: None, + seed: None, + logprobs: false, + top_logprobs: None, + response_format: None, + tools: None, + tool_choice: None, + parallel_tool_calls: None, + functions: None, + function_call: None, + min_p: None, + min_tokens: None, + repetition_penalty: None, + regex: None, + ebnf: None, + stop_token_ids: None, + no_stop_trim: false, + ignore_eos: false, + continue_final_message: false, + skip_special_tokens: true, + lora_path: None, + session_params: None, + return_hidden_states: false, + }; + + // Convert to JSON (simulating the simplified routing path) + let mut json = serde_json::to_value(&request).unwrap(); + + // Inject bootstrap fields + let result = inject_bootstrap_fields(&mut json, &worker); + assert!(result.is_ok()); + + // Verify original fields preserved + assert_eq!(json["model"], serde_json::json!("gpt-4")); + assert_eq!(json["stream"], serde_json::json!(false)); + assert_eq!(json["n"], serde_json::json!(2)); + assert_eq!(json["top_k"], serde_json::json!(50)); + assert_eq!(json["separate_reasoning"], serde_json::json!(false)); + assert_eq!(json["stream_reasoning"], serde_json::json!(true)); + + // Verify batch bootstrap fields for n=2 + let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap(); + assert_eq!(bootstrap_hosts.len(), 2); + assert_eq!(bootstrap_hosts[0], serde_json::json!("chat-server")); + assert_eq!(bootstrap_hosts[1], serde_json::json!("chat-server")); + + let bootstrap_ports = json["bootstrap_port"].as_array().unwrap(); + assert_eq!(bootstrap_ports.len(), 2); + assert_eq!(bootstrap_ports[0], serde_json::json!(9999)); + assert_eq!(bootstrap_ports[1], serde_json::json!(9999)); + + let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap(); + assert_eq!(bootstrap_rooms.len(), 2); + // Rooms should be different (randomness) + assert_ne!(bootstrap_rooms[0], bootstrap_rooms[1]); + } } diff --git a/sgl-router/src/routers/pd_types.rs b/sgl-router/src/routers/pd_types.rs index 34dabdd26..7fa52e6d7 100644 --- a/sgl-router/src/routers/pd_types.rs +++ b/sgl-router/src/routers/pd_types.rs @@ -1,10 +1,3 @@ -// Essential PDLB types extracted for PD routing - -use crate::core::{Worker, WorkerType}; -use crate::openai_api_types::{CompletionRequest, StringOrArray}; -use serde::{Deserialize, Serialize}; -use serde_json::Value; - // Custom error type for PD router operations #[derive(Debug, thiserror::Error)] pub enum PDRouterError { @@ -58,428 +51,3 @@ pub enum PDSelectionPolicy { balance_rel_threshold: f32, }, } -// Bootstrap types from PDLB -#[derive(Debug, Deserialize, Serialize, PartialEq)] -#[serde(untagged)] -pub enum SingleOrBatch { - Single(T), - Batch(Vec), -} - -pub type InputIds = SingleOrBatch>; -pub type InputText = SingleOrBatch; -pub type BootstrapHost = SingleOrBatch; -pub type BootstrapPort = SingleOrBatch>; -pub type BootstrapRoom = SingleOrBatch; - -// Bootstrap trait for request handling -pub trait Bootstrap: Send + Sync { - fn is_stream(&self) -> bool; - fn get_batch_size(&self) -> Result, String>; - fn set_bootstrap_info( - &mut self, - bootstrap_host: BootstrapHost, - bootstrap_port: BootstrapPort, - bootstrap_room: BootstrapRoom, - ); - - fn add_bootstrap_info(&mut self, prefill_worker: &dyn Worker) -> Result<(), String> { - let batch_size = self.get_batch_size()?; - - // Extract bootstrap port from prefill worker if it's a prefill type - let bootstrap_port = match prefill_worker.worker_type() { - WorkerType::Prefill { bootstrap_port } => bootstrap_port, - _ => None, - }; - - let hostname = get_hostname(prefill_worker.url()); - - if let Some(batch_size) = batch_size { - self.set_bootstrap_info( - BootstrapHost::Batch(vec![hostname; batch_size]), - BootstrapPort::Batch(vec![bootstrap_port; batch_size]), - // Use high-quality random numbers to minimize collision risk - BootstrapRoom::Batch( - (0..batch_size) - .map(|_| { - // Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1) - rand::random::() & (i64::MAX as u64) - }) - .collect(), - ), - ); - } else { - self.set_bootstrap_info( - BootstrapHost::Single(hostname), - BootstrapPort::Single(bootstrap_port), - BootstrapRoom::Single( - // Generate a value in the range [0, 2^63 - 1] to match Python's random.randint(0, 2**63 - 1) - rand::random::() & (i64::MAX as u64), - ), - ); - } - Ok(()) - } -} - -// Request types -#[derive(Debug, Deserialize, Serialize)] -pub struct GenerateReqInput { - pub text: Option, - pub input_ids: Option, - #[serde(default)] - pub stream: bool, - pub bootstrap_host: Option, - pub bootstrap_port: Option, - pub bootstrap_room: Option, - - #[serde(flatten)] - pub other: Value, -} - -impl GenerateReqInput { - pub fn get_batch_size(&self) -> Result, String> { - if self.text.is_some() && self.input_ids.is_some() { - return Err("Both text and input_ids are present in the request".to_string()); - } - - // Check text batch - if let Some(InputText::Batch(texts)) = &self.text { - if texts.is_empty() { - return Err("Batch text array is empty".to_string()); - } - return Ok(Some(texts.len())); - } - - // Check input_ids batch - if let Some(InputIds::Batch(ids)) = &self.input_ids { - if ids.is_empty() { - return Err("Batch input_ids array is empty".to_string()); - } - // Validate each sequence is not empty - for (i, seq) in ids.iter().enumerate() { - if seq.is_empty() { - return Err(format!("Input sequence at index {} is empty", i)); - } - } - return Ok(Some(ids.len())); - } - - Ok(None) - } -} - -impl Bootstrap for GenerateReqInput { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_batch_size(&self) -> Result, String> { - self.get_batch_size() - } - - fn set_bootstrap_info( - &mut self, - bootstrap_host: BootstrapHost, - bootstrap_port: BootstrapPort, - bootstrap_room: BootstrapRoom, - ) { - self.bootstrap_host = Some(bootstrap_host); - self.bootstrap_port = Some(bootstrap_port); - self.bootstrap_room = Some(bootstrap_room); - } -} - -#[derive(Debug, Deserialize, Serialize)] -pub struct ChatReqInput { - #[serde(default)] - pub stream: bool, - pub bootstrap_host: Option, - pub bootstrap_port: Option, - pub bootstrap_room: Option, - - #[serde(flatten)] - pub other: Value, -} - -impl Bootstrap for ChatReqInput { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_batch_size(&self) -> Result, String> { - // Check if 'n' parameter is present and > 1 - if let Some(n_value) = self.other.get("n") { - if let Some(n) = n_value.as_u64() { - if n > 1 { - return Ok(Some(n as usize)); - } - } - } - Ok(None) - } - - fn set_bootstrap_info( - &mut self, - bootstrap_host: BootstrapHost, - bootstrap_port: BootstrapPort, - bootstrap_room: BootstrapRoom, - ) { - self.bootstrap_host = Some(bootstrap_host); - self.bootstrap_port = Some(bootstrap_port); - self.bootstrap_room = Some(bootstrap_room); - } -} - -// Bootstrap implementation for CompletionRequest to preserve OpenAI format -impl Bootstrap for CompletionRequest { - fn is_stream(&self) -> bool { - self.stream - } - - fn get_batch_size(&self) -> Result, String> { - if let StringOrArray::Array(prompts) = &self.prompt { - if prompts.is_empty() { - return Err("Batch prompt array is empty".to_string()); - } - return Ok(Some(prompts.len())); - } - - // Single string prompt - Ok(None) - } - - fn set_bootstrap_info( - &mut self, - bootstrap_host: BootstrapHost, - bootstrap_port: BootstrapPort, - bootstrap_room: BootstrapRoom, - ) { - // Insert bootstrap_host - it serializes correctly whether Single or Batch - if let Ok(host_value) = serde_json::to_value(&bootstrap_host) { - self.other.insert("bootstrap_host".to_string(), host_value); - } - - // Insert bootstrap_port - it serializes correctly whether Single or Batch - if let Ok(port_value) = serde_json::to_value(&bootstrap_port) { - self.other.insert("bootstrap_port".to_string(), port_value); - } - - // Insert bootstrap_room - it serializes correctly whether Single or Batch - if let Ok(room_value) = serde_json::to_value(&bootstrap_room) { - self.other.insert("bootstrap_room".to_string(), room_value); - } - } -} - -#[cfg(test)] -mod bootstrap_tests { - use super::*; - use crate::core::BasicWorker; - use crate::openai_api_types::StringOrArray; - - /// Create a default CompletionRequest for testing with minimal fields set - fn default_completion_request() -> CompletionRequest { - CompletionRequest { - model: String::new(), - prompt: StringOrArray::String(String::new()), - n: None, - other: serde_json::Map::new(), - suffix: None, - max_tokens: None, - temperature: None, - top_p: None, - stream: false, - stream_options: None, - logprobs: None, - echo: false, - stop: None, - presence_penalty: None, - frequency_penalty: None, - best_of: None, - logit_bias: None, - user: None, - seed: None, - // SGLang Extensions - top_k: None, - min_p: None, - min_tokens: None, - repetition_penalty: None, - regex: None, - ebnf: None, - json_schema: None, - stop_token_ids: None, - no_stop_trim: false, - ignore_eos: false, - skip_special_tokens: true, - // SGLang Extensions - lora_path: None, - session_params: None, - return_hidden_states: false, - } - } - - #[test] - fn test_completion_batch_size_with_array_prompt() { - let req = CompletionRequest { - model: "test".to_string(), - prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), - ..default_completion_request() - }; - - // Should return batch size for array prompt - assert_eq!(req.get_batch_size().unwrap(), Some(2)); - } - - #[test] - fn test_completion_batch_size_with_single_prompt() { - let req = CompletionRequest { - model: "test".to_string(), - prompt: StringOrArray::String("single prompt".to_string()), - ..default_completion_request() - }; - - // Should return None for single prompt - assert_eq!(req.get_batch_size().unwrap(), None); - } - - #[test] - fn test_completion_batch_size_with_n_parameter() { - let req = CompletionRequest { - model: "test".to_string(), - prompt: StringOrArray::String("single prompt".to_string()), - n: Some(3), - ..default_completion_request() - }; - - // Should return None for single string prompt, even with n > 1 - // SGLang handles n parameter differently than batch requests - assert_eq!(req.get_batch_size().unwrap(), None); - } - - #[test] - fn test_completion_bootstrap_single_values() { - let mut req = CompletionRequest { - model: "test".to_string(), - prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), - ..default_completion_request() - }; - - // Set bootstrap info - should always use single values - req.set_bootstrap_info( - BootstrapHost::Single("test-server".to_string()), - BootstrapPort::Single(Some(5678)), - BootstrapRoom::Single(12345), - ); - - // Verify single values were created - assert!(req.other.get("bootstrap_host").unwrap().is_string()); - assert!(req.other.get("bootstrap_port").unwrap().is_number()); - assert!(req.other.get("bootstrap_room").unwrap().is_number()); - - assert_eq!( - req.other.get("bootstrap_host").unwrap().as_str().unwrap(), - "test-server" - ); - assert_eq!( - req.other.get("bootstrap_port").unwrap().as_u64().unwrap(), - 5678 - ); - assert_eq!( - req.other.get("bootstrap_room").unwrap().as_u64().unwrap(), - 12345 - ); - } - - #[test] - fn test_completion_bootstrap_array_values() { - let mut req = CompletionRequest { - model: "test".to_string(), - prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), - ..default_completion_request() - }; - - // Set bootstrap info with arrays - req.set_bootstrap_info( - BootstrapHost::Batch(vec!["test-server".to_string(); 2]), - BootstrapPort::Batch(vec![Some(5678); 2]), - BootstrapRoom::Batch(vec![12345, 67890]), - ); - - // Verify arrays were created correctly - assert!(req.other.get("bootstrap_host").unwrap().is_array()); - assert!(req.other.get("bootstrap_port").unwrap().is_array()); - assert!(req.other.get("bootstrap_room").unwrap().is_array()); - - let hosts = req.other.get("bootstrap_host").unwrap().as_array().unwrap(); - assert_eq!(hosts.len(), 2); - assert_eq!(hosts[0].as_str().unwrap(), "test-server"); - - let ports = req.other.get("bootstrap_port").unwrap().as_array().unwrap(); - assert_eq!(ports.len(), 2); - assert_eq!(ports[0].as_u64().unwrap(), 5678); - - let rooms = req.other.get("bootstrap_room").unwrap().as_array().unwrap(); - assert_eq!(rooms.len(), 2); - assert_eq!(rooms[0].as_u64().unwrap(), 12345); - assert_eq!(rooms[1].as_u64().unwrap(), 67890); - } - - #[test] - fn test_bootstrap_room_range() { - // Test that bootstrap_room values are within the expected range [0, 2^63 - 1] - let worker = BasicWorker::new( - "http://test:8000".to_string(), - WorkerType::Prefill { - bootstrap_port: Some(8080), - }, - ); - - // Test single request - let mut single_req = GenerateReqInput { - text: Some(InputText::Single("test".to_string())), - input_ids: None, - stream: false, - bootstrap_host: None, - bootstrap_port: None, - bootstrap_room: None, - other: Value::Object(serde_json::Map::new()), - }; - - for _ in 0..200000 { - single_req.add_bootstrap_info(&worker).unwrap(); - if let Some(BootstrapRoom::Single(room)) = single_req.bootstrap_room { - // Verify the room value is within signed 64-bit range - assert!(room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room); - } else { - panic!("Expected single bootstrap room"); - } - } - - // Test batch request - let mut batch_req = GenerateReqInput { - text: Some(InputText::Batch(vec![ - "test1".to_string(), - "test2".to_string(), - ])), - input_ids: None, - stream: false, - bootstrap_host: None, - bootstrap_port: None, - bootstrap_room: None, - other: Value::Object(serde_json::Map::new()), - }; - - for _ in 0..200000 { - batch_req.add_bootstrap_info(&worker).unwrap(); - if let Some(BootstrapRoom::Batch(rooms)) = &batch_req.bootstrap_room { - for room in rooms { - // Verify each room value is within signed 64-bit range - assert!(*room <= i64::MAX as u64, "Room {} exceeds i64::MAX", room); - } - } else { - panic!("Expected batch bootstrap rooms"); - } - } - } -} diff --git a/sgl-router/src/routers/request_adapter.rs b/sgl-router/src/routers/request_adapter.rs deleted file mode 100644 index 809244793..000000000 --- a/sgl-router/src/routers/request_adapter.rs +++ /dev/null @@ -1,1512 +0,0 @@ -// Request adapter to bridge OpenAI API types with PD routing requirements - -use super::pd_types::{Bootstrap, ChatReqInput, GenerateReqInput, SingleOrBatch}; -use crate::openai_api_types::{ - ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, StringOrArray, -}; -use serde_json::Value; - -/// Adapter trait to convert OpenAI requests to PD-compatible requests -pub trait ToPdRequest { - type Output: Bootstrap; - fn to_pd_request(self) -> Self::Output; -} - -// Helper macro to insert optional fields into a map -macro_rules! insert_if_some { - ($map:expr, $($field:expr => $key:expr),* $(,)?) => { - $( - if let Some(value) = $field { - $map.insert($key.to_string(), serde_json::to_value(value).unwrap_or(Value::Null)); - } - )* - }; -} - -// Helper macro for simple value insertions -macro_rules! insert_value { - ($map:expr, $($field:expr => $key:expr),* $(,)?) => { - $( - $map.insert($key.to_string(), $field.into()); - )* - }; -} - -// ============= Generate Request Adapter ============= - -impl ToPdRequest for GenerateRequest { - type Output = GenerateReqInput; - - fn to_pd_request(self) -> Self::Output { - // Build the other fields first - let mut other = serde_json::Map::new(); - - // Handle text input - check in priority order: text (SGLang), prompt (OpenAI) - let (text, input_ids) = if let Some(text_str) = self.text { - // SGLang native format - (Some(SingleOrBatch::Single(text_str)), None) - } else if let Some(prompt) = self.prompt { - // OpenAI style prompt - let text = match prompt { - StringOrArray::String(s) => Some(SingleOrBatch::Single(s)), - StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)), - }; - (text, None) - } else if let Some(ids) = self.input_ids { - // Input IDs case - let input_ids = match ids { - crate::openai_api_types::InputIds::Single(ids) => Some(SingleOrBatch::Single(ids)), - crate::openai_api_types::InputIds::Batch(ids) => Some(SingleOrBatch::Batch(ids)), - }; - (None, input_ids) - } else { - // No input provided - (None, None) - }; - - // Add parameters to other - handle both old and new style - if let Some(params) = self.parameters { - // For generate endpoint, extract max_new_tokens to top level if present - let mut params_value = serde_json::to_value(¶ms).unwrap_or(Value::Null); - if let Value::Object(ref mut params_map) = params_value { - // Move max_new_tokens to top level if it exists - if let Some(max_new_tokens) = params_map.remove("max_new_tokens") { - other.insert("max_new_tokens".to_string(), max_new_tokens); - } - // Move temperature to top level if it exists - if let Some(temperature) = params_map.remove("temperature") { - other.insert("temperature".to_string(), temperature); - } - } - // Only add parameters if there are remaining fields - if !params_value.is_null() && params_value.as_object().map_or(false, |m| !m.is_empty()) - { - other.insert("parameters".to_string(), params_value); - } - } - - // Add sampling_params if present - if let Some(sampling_params) = self.sampling_params { - let params_value = serde_json::to_value(&sampling_params).unwrap_or(Value::Null); - if !params_value.is_null() { - // Extract commonly used fields to top level - if let Value::Object(ref params_map) = params_value { - if let Some(max_new_tokens) = params_map.get("max_new_tokens") { - other.insert("max_new_tokens".to_string(), max_new_tokens.clone()); - } - if let Some(temperature) = params_map.get("temperature") { - other.insert("temperature".to_string(), temperature.clone()); - } - } - other.insert("sampling_params".to_string(), params_value); - } - } - - // Add other fields - insert_value!(other, - self.stream => "stream", - self.return_logprob => "return_logprob" - ); - - GenerateReqInput { - text, - input_ids, - stream: self.stream, - bootstrap_host: None, - bootstrap_port: None, - bootstrap_room: None, - other: Value::Object(other), - } - } -} - -// ============= Completion Request Adapter ============= - -impl ToPdRequest for CompletionRequest { - type Output = GenerateReqInput; - - fn to_pd_request(self) -> Self::Output { - // Convert CompletionRequest to GenerateReqInput - let text = match self.prompt { - StringOrArray::String(s) => Some(SingleOrBatch::Single(s)), - StringOrArray::Array(v) => Some(SingleOrBatch::Batch(v)), - }; - - // Map OpenAI parameters to generate parameters - let mut other = serde_json::Map::new(); - - // Create parameters object - let mut params = serde_json::Map::new(); - - // Map OpenAI fields to internal parameter names - insert_if_some!(params, - self.max_tokens => "max_new_tokens", - self.temperature => "temperature", - self.top_p => "top_p", - self.n => "best_of", - self.logprobs => "top_n_tokens", - self.seed => "seed" - ); - - // Special handling for fields that need transformation - if let Some(presence_penalty) = self.presence_penalty { - params.insert( - "repetition_penalty".to_string(), - (1.0 + presence_penalty).into(), - ); - } - - if let Some(stop) = self.stop { - let stop_sequences = match stop { - StringOrArray::String(s) => vec![s], - StringOrArray::Array(v) => v, - }; - params.insert("stop".to_string(), stop_sequences.into()); - } - - if self.echo { - params.insert("return_full_text".to_string(), true.into()); - } - - other.insert("parameters".to_string(), Value::Object(params)); - - // Store original model and stream flag - insert_value!(other, - self.model => "model", - self.stream => "stream" - ); - - // Add SGLang extension fields - insert_if_some!(other, - // SGLang Extensions - Priority 1 - self.top_k => "top_k", - self.min_p => "min_p", - self.min_tokens => "min_tokens", - self.repetition_penalty => "repetition_penalty", - self.regex => "regex", - self.ebnf => "ebnf", - self.stop_token_ids => "stop_token_ids", - // SGLang Extensions - Priority 2 - self.lora_path => "lora_path", - self.session_params => "session_params" - ); - - // SGLang boolean extensions (CompletionRequest has these as bool, not Option) - other.insert("no_stop_trim".to_string(), self.no_stop_trim.into()); - other.insert("ignore_eos".to_string(), self.ignore_eos.into()); - other.insert( - "skip_special_tokens".to_string(), - self.skip_special_tokens.into(), - ); - other.insert( - "return_hidden_states".to_string(), - self.return_hidden_states.into(), - ); - - GenerateReqInput { - text, - input_ids: None, - stream: self.stream, - bootstrap_host: None, - bootstrap_port: None, - bootstrap_room: None, - other: Value::Object(other), - } - } -} - -// ============= Chat Completion Request Adapter ============= - -impl ToPdRequest for ChatCompletionRequest { - type Output = ChatReqInput; - - fn to_pd_request(self) -> Self::Output { - let mut other = serde_json::Map::new(); - - // Add required fields - insert_if_some!(other, - Some(&self.messages) => "messages" - ); - - insert_value!(other, - self.model => "model", - self.stream => "stream" - ); - - // Add all optional fields - insert_if_some!(other, - self.temperature => "temperature", - self.top_p => "top_p", - self.n => "n", - self.stream_options => "stream_options", - self.stop => "stop", - self.max_tokens => "max_tokens", - self.max_completion_tokens => "max_completion_tokens", - self.presence_penalty => "presence_penalty", - self.frequency_penalty => "frequency_penalty", - self.logit_bias => "logit_bias", - self.user => "user", - self.seed => "seed", - self.top_logprobs => "top_logprobs", - self.response_format => "response_format", - self.tools => "tools", - self.tool_choice => "tool_choice", - self.parallel_tool_calls => "parallel_tool_calls", - self.functions => "functions", - self.function_call => "function_call", - // SGLang Extensions - Priority 1 - self.top_k => "top_k", - self.min_p => "min_p", - self.min_tokens => "min_tokens", - self.repetition_penalty => "repetition_penalty", - self.regex => "regex", - self.ebnf => "ebnf", - self.stop_token_ids => "stop_token_ids", - // SGLang Extensions - Priority 2 - self.lora_path => "lora_path", - self.session_params => "session_params" - ); - - // Handle boolean flags - if self.logprobs { - other.insert("logprobs".to_string(), true.into()); - } - - // SGLang boolean extensions (ChatCompletionRequest has these as bool, not Option) - other.insert("no_stop_trim".to_string(), self.no_stop_trim.into()); - other.insert("ignore_eos".to_string(), self.ignore_eos.into()); - other.insert( - "continue_final_message".to_string(), - self.continue_final_message.into(), - ); - other.insert( - "skip_special_tokens".to_string(), - self.skip_special_tokens.into(), - ); - other.insert( - "separate_reasoning".to_string(), - self.separate_reasoning.into(), - ); - other.insert("stream_reasoning".to_string(), self.stream_reasoning.into()); - other.insert( - "return_hidden_states".to_string(), - self.return_hidden_states.into(), - ); - - ChatReqInput { - stream: self.stream, - bootstrap_host: None, - bootstrap_port: None, - bootstrap_room: None, - other: Value::Object(other), - } - } -} - -// ============= Direct routing support for regular router ============= - -/// Extension trait for routing without PD conversion -pub trait RouteableRequest: GenerationRequest + serde::Serialize + Clone { - /// Convert to JSON for sending to backend - fn to_json(&self) -> Result { - serde_json::to_value(self) - } - - /// Convert to bytes for legacy routing - fn to_bytes(&self) -> Result { - let json = serde_json::to_vec(self)?; - Ok(bytes::Bytes::from(json)) - } -} - -impl RouteableRequest for GenerateRequest {} -impl RouteableRequest for CompletionRequest {} -impl RouteableRequest for ChatCompletionRequest {} - -#[cfg(test)] -mod tests { - use super::*; - use crate::openai_api_types::*; - use serde_json::json; - use std::collections::HashMap; - - // ============= Test Helper Functions ============= - // - // These helper functions create default request instances with all required SGLang extension fields - // properly initialized. Use the struct spread operator `..default_*_request()` to override only - // the fields you need for specific tests, avoiding repetitive boilerplate code. - // - // Example usage: - // let req = GenerateRequest { - // text: Some("Custom text".to_string()), - // stream: true, - // ..default_generate_request() - // }; - - /// Create a default GenerateRequest with minimal fields set - fn default_generate_request() -> GenerateRequest { - GenerateRequest { - text: None, - prompt: None, - input_ids: None, - stream: false, - parameters: None, - sampling_params: None, - return_logprob: false, - // SGLang Extensions - lora_path: None, - session_params: None, - return_hidden_states: false, - rid: None, - } - } - - /// Create a default CompletionRequest with minimal fields set - fn default_completion_request() -> CompletionRequest { - CompletionRequest { - model: "test-model".to_string(), - prompt: StringOrArray::String("test prompt".to_string()), - max_tokens: None, - temperature: None, - top_p: None, - n: None, - stream: false, - stream_options: None, - logprobs: None, - echo: false, - stop: None, - presence_penalty: None, - frequency_penalty: None, - best_of: None, - logit_bias: None, - user: None, - seed: None, - suffix: None, - // SGLang Extensions - top_k: None, - min_p: None, - min_tokens: None, - repetition_penalty: None, - regex: None, - ebnf: None, - json_schema: None, - stop_token_ids: None, - no_stop_trim: false, - ignore_eos: false, - skip_special_tokens: true, - // SGLang Extensions - lora_path: None, - session_params: None, - return_hidden_states: false, - other: serde_json::Map::new(), - } - } - - /// Create a default ChatCompletionRequest with minimal fields set - fn default_chat_completion_request() -> ChatCompletionRequest { - ChatCompletionRequest { - model: "test-model".to_string(), - messages: vec![ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Text("test message".to_string()), - name: None, - }], - temperature: None, - top_p: None, - n: None, - stream: false, - stream_options: None, - stop: None, - max_tokens: None, - max_completion_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - logprobs: false, - top_logprobs: None, - user: None, - seed: None, - response_format: None, - tools: None, - tool_choice: None, - parallel_tool_calls: None, - functions: None, - function_call: None, - // SGLang Extensions - top_k: None, - min_p: None, - min_tokens: None, - repetition_penalty: None, - regex: None, - ebnf: None, - stop_token_ids: None, - no_stop_trim: false, - ignore_eos: false, - continue_final_message: false, - skip_special_tokens: true, - // SGLang Extensions - lora_path: None, - session_params: None, - separate_reasoning: true, - stream_reasoning: true, - return_hidden_states: false, - } - } - - // ============= GenerateRequest to_pd_request Tests ============= - - #[test] - fn test_generate_to_pd_request_with_text_only() { - let req = GenerateRequest { - text: Some("Hello world".to_string()), - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - - // Check text field conversion - assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Hello world")); - assert!(pd_req.input_ids.is_none()); - - // Check bootstrap fields are None - assert!(pd_req.bootstrap_host.is_none()); - assert!(pd_req.bootstrap_port.is_none()); - assert!(pd_req.bootstrap_room.is_none()); - - // Check stream flag - assert_eq!(pd_req.stream, false); - - // Check other fields - let other = pd_req.other.as_object().unwrap(); - assert_eq!(other.get("stream"), Some(&json!(false))); - assert_eq!(other.get("return_logprob"), Some(&json!(false))); - } - - #[test] - fn test_generate_to_pd_request_with_prompt_string() { - let req = GenerateRequest { - prompt: Some(StringOrArray::String("Test prompt".to_string())), - stream: true, - return_logprob: true, - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - - assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Test prompt")); - assert!(pd_req.input_ids.is_none()); - assert_eq!(pd_req.stream, true); - - let other = pd_req.other.as_object().unwrap(); - assert_eq!(other.get("stream"), Some(&json!(true))); - assert_eq!(other.get("return_logprob"), Some(&json!(true))); - } - - #[test] - fn test_generate_to_pd_request_with_prompt_array() { - let req = GenerateRequest { - text: None, - prompt: Some(StringOrArray::Array(vec![ - "Prompt 1".to_string(), - "Prompt 2".to_string(), - "Prompt 3".to_string(), - ])), - input_ids: None, - stream: false, - parameters: None, - sampling_params: None, - return_logprob: false, - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - - match pd_req.text { - Some(SingleOrBatch::Batch(ref batch)) => { - assert_eq!(batch.len(), 3); - assert_eq!(batch[0], "Prompt 1"); - assert_eq!(batch[1], "Prompt 2"); - assert_eq!(batch[2], "Prompt 3"); - } - _ => panic!("Expected batch text"), - } - } - - #[test] - fn test_generate_to_pd_request_with_single_input_ids() { - let req = GenerateRequest { - input_ids: Some(InputIds::Single(vec![100, 200, 300, 400])), - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - - assert!(pd_req.text.is_none()); - assert!(matches!( - pd_req.input_ids, - Some(SingleOrBatch::Single(ref ids)) if ids == &vec![100, 200, 300, 400] - )); - } - - #[test] - fn test_generate_to_pd_request_with_batch_input_ids() { - let req = GenerateRequest { - input_ids: Some(InputIds::Batch(vec![ - vec![1, 2, 3], - vec![4, 5, 6, 7], - vec![8, 9], - ])), - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - - match pd_req.input_ids { - Some(SingleOrBatch::Batch(ref batch)) => { - assert_eq!(batch.len(), 3); - assert_eq!(batch[0], vec![1, 2, 3]); - assert_eq!(batch[1], vec![4, 5, 6, 7]); - assert_eq!(batch[2], vec![8, 9]); - } - _ => panic!("Expected batch input_ids"), - } - } - - #[test] - fn test_generate_to_pd_request_priority_text_over_prompt() { - let req = GenerateRequest { - text: Some("SGLang text".to_string()), - prompt: Some(StringOrArray::String("OpenAI prompt".to_string())), - input_ids: Some(InputIds::Single(vec![1, 2, 3])), - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - - // text should take priority - assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "SGLang text")); - assert!(pd_req.input_ids.is_none()); - } - - #[test] - fn test_generate_to_pd_request_priority_prompt_over_input_ids() { - let req = GenerateRequest { - prompt: Some(StringOrArray::String("OpenAI prompt".to_string())), - input_ids: Some(InputIds::Single(vec![1, 2, 3])), - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - - // prompt should take priority over input_ids - assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "OpenAI prompt")); - assert!(pd_req.input_ids.is_none()); - } - - #[test] - fn test_generate_to_pd_request_with_parameters() { - let params = GenerateParameters { - max_new_tokens: Some(100), - temperature: Some(0.8), - top_p: Some(0.95), - seed: Some(12345), - stop: Some(vec!["END".to_string(), "STOP".to_string()]), - repetition_penalty: Some(1.1), - ..Default::default() - }; - - let req = GenerateRequest { - text: Some("test".to_string()), - parameters: Some(params), - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - - // Check that max_new_tokens and temperature were extracted to top level - assert_eq!(other.get("max_new_tokens"), Some(&json!(100))); - assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001); - - // Check that other parameters remain under "parameters" - let params = other.get("parameters").unwrap().as_object().unwrap(); - assert!(params.get("top_p").unwrap().as_f64().unwrap() - 0.95 < 0.0001); - assert_eq!(params.get("seed"), Some(&json!(12345))); - assert_eq!(params.get("stop"), Some(&json!(vec!["END", "STOP"]))); - assert!(params.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1 < 0.0001); - } - - #[test] - fn test_generate_to_pd_request_with_sampling_params() { - let sampling = SamplingParams { - max_new_tokens: Some(200), - temperature: Some(0.7), - top_p: Some(0.9), - top_k: Some(50), - frequency_penalty: Some(0.1), - presence_penalty: Some(0.2), - repetition_penalty: Some(1.05), - ..Default::default() - }; - - let req = GenerateRequest { - text: Some("test".to_string()), - sampling_params: Some(sampling), - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - - // Check extracted top-level fields - assert_eq!(other.get("max_new_tokens"), Some(&json!(200))); - assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.7 < 0.0001); - - // Check full sampling_params is preserved - let sampling = other.get("sampling_params").unwrap().as_object().unwrap(); - assert_eq!(sampling.get("max_new_tokens"), Some(&json!(200))); - assert!(sampling.get("temperature").unwrap().as_f64().unwrap() - 0.7 < 0.0001); - assert!(sampling.get("top_p").unwrap().as_f64().unwrap() - 0.9 < 0.0001); - assert_eq!(sampling.get("top_k"), Some(&json!(50))); - assert!(sampling.get("frequency_penalty").unwrap().as_f64().unwrap() - 0.1 < 0.0001); - assert!(sampling.get("presence_penalty").unwrap().as_f64().unwrap() - 0.2 < 0.0001); - } - - #[test] - fn test_generate_to_pd_request_sampling_params_override_parameters() { - // When both parameters and sampling_params have max_new_tokens/temperature, - // sampling_params should take precedence (processed last) - let params = GenerateParameters { - max_new_tokens: Some(100), - temperature: Some(0.5), - ..Default::default() - }; - - let sampling = SamplingParams { - max_new_tokens: Some(200), - temperature: Some(0.9), - ..Default::default() - }; - - let req = GenerateRequest { - text: Some("test".to_string()), - prompt: None, - input_ids: None, - stream: false, - parameters: Some(params), - sampling_params: Some(sampling), - return_logprob: false, - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - - // Should use values from sampling_params since they're processed last - assert_eq!(other.get("max_new_tokens"), Some(&json!(200))); - assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.9 < 0.0001); - } - - #[test] - fn test_generate_to_pd_request_empty_parameters() { - let params = GenerateParameters::default(); - - let req = GenerateRequest { - text: Some("test".to_string()), - prompt: None, - input_ids: None, - stream: false, - parameters: Some(params), - sampling_params: None, - return_logprob: false, - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - - // Should not have parameters field if all values are None/default - assert!(!other.contains_key("parameters")); - assert!(!other.contains_key("max_new_tokens")); - assert!(!other.contains_key("temperature")); - } - - #[test] - fn test_generate_to_pd_request_all_fields() { - let params = GenerateParameters { - max_new_tokens: Some(150), - temperature: Some(0.6), - top_k: Some(40), - ..Default::default() - }; - - let sampling = SamplingParams { - max_new_tokens: Some(250), // Will override parameters - temperature: Some(0.8), // Will override parameters - presence_penalty: Some(0.1), - ..Default::default() - }; - - let req = GenerateRequest { - text: Some("Complex test".to_string()), - prompt: Some(StringOrArray::String("Ignored prompt".to_string())), - input_ids: None, - stream: true, - parameters: Some(params), - sampling_params: Some(sampling), - return_logprob: true, - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - - // Verify all fields - assert!(matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Complex test")); - assert!(pd_req.input_ids.is_none()); - assert_eq!(pd_req.stream, true); - assert!(pd_req.bootstrap_host.is_none()); - assert!(pd_req.bootstrap_port.is_none()); - assert!(pd_req.bootstrap_room.is_none()); - - let other = pd_req.other.as_object().unwrap(); - assert_eq!(other.get("stream"), Some(&json!(true))); - assert_eq!(other.get("return_logprob"), Some(&json!(true))); - // Sampling params override parameters - assert_eq!(other.get("max_new_tokens"), Some(&json!(250))); - assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001); - assert!(other.contains_key("parameters")); - assert!(other.contains_key("sampling_params")); - } - - // ============= CompletionRequest to_pd_request Tests ============= - - #[test] - fn test_completion_to_pd_request_basic() { - let req = CompletionRequest { - model: "gpt-3.5-turbo".to_string(), - prompt: StringOrArray::String("Complete this sentence".to_string()), - ..default_completion_request() - }; - - let pd_req = req.to_pd_request(); - - assert!( - matches!(pd_req.text, Some(SingleOrBatch::Single(ref s)) if s == "Complete this sentence") - ); - assert!(pd_req.input_ids.is_none()); - assert_eq!(pd_req.stream, false); - - let other = pd_req.other.as_object().unwrap(); - assert_eq!(other.get("model"), Some(&json!("gpt-3.5-turbo"))); - assert_eq!(other.get("stream"), Some(&json!(false))); - } - - #[test] - fn test_completion_to_pd_request_array_prompt() { - let req = CompletionRequest { - model: "test".to_string(), - prompt: StringOrArray::Array(vec![ - "First prompt".to_string(), - "Second prompt".to_string(), - ]), - ..default_completion_request() - }; - - let pd_req = req.to_pd_request(); - - match pd_req.text { - Some(SingleOrBatch::Batch(ref batch)) => { - assert_eq!(batch.len(), 2); - assert_eq!(batch[0], "First prompt"); - assert_eq!(batch[1], "Second prompt"); - } - _ => panic!("Expected batch text"), - } - } - - #[test] - fn test_completion_to_pd_request_parameter_mapping() { - let req = CompletionRequest { - model: "test".to_string(), - prompt: StringOrArray::String("test".to_string()), - max_tokens: Some(150), // -> max_new_tokens - temperature: Some(0.75), - top_p: Some(0.92), - n: Some(3), // -> best_of - stream: true, - stream_options: None, - logprobs: Some(10), // -> top_n_tokens - echo: true, // -> return_full_text - stop: Some(StringOrArray::Array(vec![ - "\\n".to_string(), - "END".to_string(), - ])), - presence_penalty: Some(0.5), // -> repetition_penalty = 1.5 - frequency_penalty: Some(0.2), - best_of: Some(5), - logit_bias: None, - user: Some("user123".to_string()), - seed: Some(42), - suffix: Some("...".to_string()), - ..default_completion_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - let params = other.get("parameters").unwrap().as_object().unwrap(); - - // Check parameter mappings - assert_eq!(params.get("max_new_tokens"), Some(&json!(150))); - assert!(params.get("temperature").unwrap().as_f64().unwrap() - 0.75 < 0.0001); - assert!(params.get("top_p").unwrap().as_f64().unwrap() - 0.92 < 0.0001); - assert_eq!(params.get("best_of"), Some(&json!(3))); - assert_eq!(params.get("top_n_tokens"), Some(&json!(10))); - assert_eq!(params.get("return_full_text"), Some(&json!(true))); - assert_eq!(params.get("stop"), Some(&json!(vec!["\\n", "END"]))); - assert!(params.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.5 < 0.0001); - assert_eq!(params.get("seed"), Some(&json!(42))); - - // Check other fields - assert_eq!(other.get("model"), Some(&json!("test"))); - assert_eq!(other.get("stream"), Some(&json!(true))); - } - - #[test] - fn test_completion_to_pd_request_stop_string() { - let req = CompletionRequest { - model: "test".to_string(), - prompt: StringOrArray::String("test".to_string()), - stop: Some(StringOrArray::String("STOP".to_string())), - max_tokens: None, - temperature: None, - top_p: None, - n: None, - stream: false, - stream_options: None, - logprobs: None, - echo: false, - presence_penalty: None, - frequency_penalty: None, - best_of: None, - logit_bias: None, - user: None, - seed: None, - suffix: None, - ..default_completion_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - let params = other.get("parameters").unwrap().as_object().unwrap(); - - // Single string stop should be converted to array - assert_eq!(params.get("stop"), Some(&json!(vec!["STOP"]))); - } - - #[test] - fn test_completion_to_pd_request_no_presence_penalty() { - let req = CompletionRequest { - model: "test".to_string(), - prompt: StringOrArray::String("test".to_string()), - presence_penalty: None, - max_tokens: None, - temperature: None, - top_p: None, - n: None, - stream: false, - stream_options: None, - logprobs: None, - echo: false, - stop: None, - frequency_penalty: None, - best_of: None, - logit_bias: None, - user: None, - seed: None, - suffix: None, - ..default_completion_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - let params = other.get("parameters").unwrap().as_object().unwrap(); - - // Should not have repetition_penalty if presence_penalty is None - assert!(!params.contains_key("repetition_penalty")); - } - - // ============= ChatCompletionRequest to_pd_request Tests ============= - - #[test] - fn test_chat_to_pd_request_basic() { - let messages = vec![ - ChatMessage::System { - role: "system".to_string(), - content: "You are a helpful assistant".to_string(), - name: None, - }, - ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Text("Hello!".to_string()), - name: None, - }, - ]; - - let req = ChatCompletionRequest { - messages, - model: "gpt-4".to_string(), - ..default_chat_completion_request() - }; - - let pd_req = req.to_pd_request(); - - assert_eq!(pd_req.stream, false); - assert!(pd_req.bootstrap_host.is_none()); - assert!(pd_req.bootstrap_port.is_none()); - assert!(pd_req.bootstrap_room.is_none()); - - let other = pd_req.other.as_object().unwrap(); - assert!(other.contains_key("messages")); - assert_eq!(other.get("model"), Some(&json!("gpt-4"))); - assert_eq!(other.get("stream"), Some(&json!(false))); - - // Check messages are preserved - let messages = other.get("messages").unwrap().as_array().unwrap(); - assert_eq!(messages.len(), 2); - } - - #[test] - fn test_chat_to_pd_request_with_all_optional_fields() { - let messages = vec![ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Text("Test".to_string()), - name: Some("test_user".to_string()), - }]; - - let mut logit_bias = HashMap::new(); - logit_bias.insert("50256".to_string(), -100.0f32); - - let tool = Tool { - tool_type: "function".to_string(), - function: Function { - name: "get_weather".to_string(), - description: Some("Get weather info".to_string()), - parameters: json!({"type": "object"}), - }, - }; - - let req = ChatCompletionRequest { - messages, - model: "gpt-4".to_string(), - temperature: Some(0.8), - top_p: Some(0.95), - n: Some(2), - stream: true, - stream_options: Some(StreamOptions { - include_usage: Some(true), - }), - stop: Some(StringOrArray::String("\\n\\n".to_string())), - max_tokens: Some(200), - max_completion_tokens: Some(150), - presence_penalty: Some(0.1), - frequency_penalty: Some(0.2), - logit_bias: Some(logit_bias), - logprobs: true, - top_logprobs: Some(5), - user: Some("user456".to_string()), - seed: Some(12345), - response_format: Some(ResponseFormat::JsonObject), - tools: Some(vec![tool]), - tool_choice: Some(ToolChoice::Auto), - parallel_tool_calls: Some(false), - functions: None, - function_call: None, - ..default_chat_completion_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - - // Check all fields are preserved - assert!(other.get("temperature").unwrap().as_f64().unwrap() - 0.8 < 0.0001); - assert!(other.get("top_p").unwrap().as_f64().unwrap() - 0.95 < 0.0001); - assert_eq!(other.get("n"), Some(&json!(2))); - assert_eq!(other.get("stream"), Some(&json!(true))); - assert!(other.contains_key("stream_options")); - assert!(other.contains_key("stop")); - assert_eq!(other.get("max_tokens"), Some(&json!(200))); - assert_eq!(other.get("max_completion_tokens"), Some(&json!(150))); - assert!(other.get("presence_penalty").unwrap().as_f64().unwrap() - 0.1 < 0.0001); - assert!(other.get("frequency_penalty").unwrap().as_f64().unwrap() - 0.2 < 0.0001); - assert!(other.contains_key("logit_bias")); - assert_eq!(other.get("logprobs"), Some(&json!(true))); - assert_eq!(other.get("top_logprobs"), Some(&json!(5))); - assert_eq!(other.get("user"), Some(&json!("user456"))); - assert_eq!(other.get("seed"), Some(&json!(12345))); - assert!(other.contains_key("response_format")); - assert!(other.contains_key("tools")); - assert!(other.contains_key("tool_choice")); - assert_eq!(other.get("parallel_tool_calls"), Some(&json!(false))); - } - - #[test] - fn test_chat_to_pd_request_multimodal_content() { - let messages = vec![ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Parts(vec![ - ContentPart::Text { - text: "What's in this image?".to_string(), - }, - ContentPart::ImageUrl { - image_url: ImageUrl { - url: "https://example.com/image.jpg".to_string(), - detail: Some("high".to_string()), - }, - }, - ]), - name: None, - }]; - - let req = ChatCompletionRequest { - messages, - model: "gpt-4-vision".to_string(), - ..default_chat_completion_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - - // Messages with multimodal content should be preserved - assert!(other.contains_key("messages")); - let messages = other.get("messages").unwrap().as_array().unwrap(); - assert_eq!(messages.len(), 1); - - // Verify the message structure is preserved - let msg = &messages[0]; - assert_eq!(msg["role"], "user"); - assert!(msg["content"].is_array()); - } - - #[test] - fn test_chat_to_pd_request_logprobs_boolean() { - let messages = vec![ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Text("Test".to_string()), - name: None, - }]; - - let req = ChatCompletionRequest { - messages, - model: "test".to_string(), - logprobs: true, // Boolean logprobs flag - top_logprobs: Some(3), - temperature: None, - top_p: None, - n: None, - stream: false, - stream_options: None, - stop: None, - max_tokens: None, - max_completion_tokens: None, - presence_penalty: None, - frequency_penalty: None, - logit_bias: None, - user: None, - seed: None, - response_format: None, - tools: None, - tool_choice: None, - parallel_tool_calls: None, - functions: None, - function_call: None, - ..default_chat_completion_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - - assert_eq!(other.get("logprobs"), Some(&json!(true))); - assert_eq!(other.get("top_logprobs"), Some(&json!(3))); - } - - #[test] - fn test_chat_to_pd_request_minimal_fields() { - let messages = vec![ChatMessage::Assistant { - role: "assistant".to_string(), - content: Some("I can help with that.".to_string()), - name: None, - tool_calls: None, - function_call: None, - reasoning_content: None, - }]; - - let req = ChatCompletionRequest { - messages, - model: "gpt-3.5-turbo".to_string(), - ..default_chat_completion_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - - // Should only have required fields - assert!(other.contains_key("messages")); - assert!(other.contains_key("model")); - assert!(other.contains_key("stream")); - - // Optional fields should not be present - assert!(!other.contains_key("temperature")); - assert!(!other.contains_key("top_p")); - assert!(!other.contains_key("max_tokens")); - assert!(!other.contains_key("stop")); - } - - #[test] - fn test_routeable_request_to_json() { - let req = GenerateRequest { - text: Some("test".to_string()), - ..default_generate_request() - }; - - let json = req.to_json().unwrap(); - assert_eq!(json["text"], "test"); - assert_eq!(json["stream"], false); - } - - // ============= Macro Tests ============= - - #[test] - fn test_insert_if_some_macro() { - let mut map = serde_json::Map::new(); - - let some_value: Option = Some(42); - let none_value: Option = None; - - insert_if_some!(map, - some_value => "present", - none_value => "absent" - ); - - assert_eq!(map.get("present"), Some(&json!(42))); - assert!(!map.contains_key("absent")); - } - - #[test] - fn test_insert_value_macro() { - let mut map = serde_json::Map::new(); - - let value1 = "test"; - let value2 = 42; - - insert_value!(map, - value1 => "string_field", - value2 => "int_field" - ); - - assert_eq!(map.get("string_field"), Some(&json!("test"))); - assert_eq!(map.get("int_field"), Some(&json!(42))); - } - - // ============= Edge Cases and Error Handling ============= - - #[test] - fn test_null_value_handling() { - let params = GenerateParameters { - max_new_tokens: None, - temperature: None, - ..Default::default() - }; - - let req = GenerateRequest { - text: Some("test".to_string()), - prompt: None, - input_ids: None, - stream: false, - parameters: Some(params), - sampling_params: None, - return_logprob: false, - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - - // Should not have parameters field if all fields are None - assert!(!other.contains_key("parameters")); - } - - #[test] - fn test_large_batch_conversion() { - let large_batch: Vec = (0..1000).map(|i| format!("item_{}", i)).collect(); - - let req = GenerateRequest { - text: None, - prompt: Some(StringOrArray::Array(large_batch.clone())), - input_ids: None, - stream: false, - parameters: None, - sampling_params: None, - return_logprob: false, - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - - if let Some(SingleOrBatch::Batch(batch)) = pd_req.text { - assert_eq!(batch.len(), 1000); - assert_eq!(batch[0], "item_0"); - assert_eq!(batch[999], "item_999"); - } else { - panic!("Expected batch text"); - } - } - - #[test] - fn test_unicode_string_handling() { - let unicode_text = "Hello 世界 🌍 नमस्ते мир".to_string(); - - let req = GenerateRequest { - text: Some(unicode_text.clone()), - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - - if let Some(SingleOrBatch::Single(text)) = pd_req.text { - assert_eq!(text, unicode_text); - } else { - panic!("Expected single text"); - } - } - - #[test] - fn test_deeply_nested_parameters() { - let mut nested_params = serde_json::Map::new(); - nested_params.insert( - "nested".to_string(), - json!({ - "level1": { - "level2": { - "level3": "value" - } - } - }), - ); - - let params = GenerateParameters { - max_new_tokens: Some(100), - ..Default::default() - }; - - let req = GenerateRequest { - text: Some("test".to_string()), - prompt: None, - input_ids: None, - stream: false, - parameters: Some(params), - sampling_params: None, - return_logprob: false, - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - - // Parameters should be preserved even with nested structures - assert!(other.contains_key("max_new_tokens")); - } - - // ============= Bootstrap Field Tests ============= - - #[test] - fn test_bootstrap_fields_none() { - let req = GenerateRequest { - text: Some("test".to_string()), - ..default_generate_request() - }; - - let pd_req = req.to_pd_request(); - - assert_eq!(pd_req.bootstrap_host, None); - assert_eq!(pd_req.bootstrap_port, None); - assert_eq!(pd_req.bootstrap_room, None); - } - - // ============= SGLang Extension Field Pass-Through Tests ============= - - #[test] - fn test_chat_completion_sglang_extensions_passed_through() { - let messages = vec![ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Text("Test".to_string()), - name: None, - }]; - - let mut session_params = std::collections::HashMap::new(); - session_params.insert( - "key".to_string(), - serde_json::Value::String("value".to_string()), - ); - - let req = ChatCompletionRequest { - messages, - model: "test-model".to_string(), - // SGLang Extensions - Priority 1 - top_k: Some(40), - min_p: Some(0.05), - min_tokens: Some(10), - repetition_penalty: Some(1.1), - regex: Some("test_regex".to_string()), - ebnf: Some("test_ebnf".to_string()), - stop_token_ids: Some(vec![1, 2, 3]), - // SGLang Extensions - Priority 2 - lora_path: Some(LoRAPath::Single(Some("test_lora.bin".to_string()))), - session_params: Some(session_params.clone()), - // Boolean extensions (ChatCompletionRequest has these as bool, not Option) - no_stop_trim: true, - ignore_eos: false, - continue_final_message: true, - skip_special_tokens: false, - separate_reasoning: true, - stream_reasoning: false, - return_hidden_states: true, - ..default_chat_completion_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - - // Verify SGLang extensions are passed through - assert_eq!(other.get("top_k"), Some(&json!(40))); - assert!((other.get("min_p").unwrap().as_f64().unwrap() - 0.05).abs() < 0.0001); - assert_eq!(other.get("min_tokens"), Some(&json!(10))); - assert!((other.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1).abs() < 0.0001); - assert_eq!(other.get("regex"), Some(&json!("test_regex"))); - assert_eq!(other.get("ebnf"), Some(&json!("test_ebnf"))); - assert_eq!(other.get("stop_token_ids"), Some(&json!(vec![1, 2, 3]))); - assert_eq!(other.get("lora_path"), Some(&json!("test_lora.bin"))); - assert_eq!( - other.get("session_params"), - Some(&serde_json::to_value(&session_params).unwrap()) - ); - - // Verify boolean extensions - assert_eq!(other.get("no_stop_trim"), Some(&json!(true))); - assert_eq!(other.get("ignore_eos"), Some(&json!(false))); - assert_eq!(other.get("continue_final_message"), Some(&json!(true))); - assert_eq!(other.get("skip_special_tokens"), Some(&json!(false))); - assert_eq!(other.get("separate_reasoning"), Some(&json!(true))); - assert_eq!(other.get("stream_reasoning"), Some(&json!(false))); - assert_eq!(other.get("return_hidden_states"), Some(&json!(true))); - } - - #[test] - fn test_completion_request_sglang_extensions_passed_through() { - let mut session_params = std::collections::HashMap::new(); - session_params.insert( - "key".to_string(), - serde_json::Value::String("value".to_string()), - ); - - let req = CompletionRequest { - prompt: StringOrArray::String("Test prompt".to_string()), - model: "test-model".to_string(), - // SGLang Extensions - Priority 1 - top_k: Some(40), - min_p: Some(0.05), - min_tokens: Some(10), - repetition_penalty: Some(1.1), - regex: Some("test_regex".to_string()), - ebnf: Some("test_ebnf".to_string()), - stop_token_ids: Some(vec![1, 2, 3]), - // SGLang Extensions - Priority 2 - lora_path: Some(LoRAPath::Single(Some("test_lora.bin".to_string()))), - session_params: Some(session_params.clone()), - // Boolean extensions (CompletionRequest only has these 4 boolean fields) - no_stop_trim: true, - ignore_eos: false, - skip_special_tokens: false, - return_hidden_states: true, - ..default_completion_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - - // Verify SGLang extensions are passed through - assert_eq!(other.get("top_k"), Some(&json!(40))); - assert!((other.get("min_p").unwrap().as_f64().unwrap() - 0.05).abs() < 0.0001); - assert_eq!(other.get("min_tokens"), Some(&json!(10))); - assert!((other.get("repetition_penalty").unwrap().as_f64().unwrap() - 1.1).abs() < 0.0001); - assert_eq!(other.get("regex"), Some(&json!("test_regex"))); - assert_eq!(other.get("ebnf"), Some(&json!("test_ebnf"))); - assert_eq!(other.get("stop_token_ids"), Some(&json!(vec![1, 2, 3]))); - assert_eq!(other.get("lora_path"), Some(&json!("test_lora.bin"))); - assert_eq!( - other.get("session_params"), - Some(&serde_json::to_value(&session_params).unwrap()) - ); - - // Verify boolean extensions (only the ones CompletionRequest has) - assert_eq!(other.get("no_stop_trim"), Some(&json!(true))); - assert_eq!(other.get("ignore_eos"), Some(&json!(false))); - assert_eq!(other.get("skip_special_tokens"), Some(&json!(false))); - assert_eq!(other.get("return_hidden_states"), Some(&json!(true))); - } - - #[test] - fn test_sglang_extensions_none_values_not_passed_through() { - let messages = vec![ChatMessage::User { - role: "user".to_string(), - content: UserMessageContent::Text("Test".to_string()), - name: None, - }]; - - let req = ChatCompletionRequest { - messages, - model: "test-model".to_string(), - // All SGLang extensions as None/default - Optional fields won't appear, bools will use defaults - top_k: None, - min_p: None, - min_tokens: None, - repetition_penalty: None, - regex: None, - ebnf: None, - stop_token_ids: None, - lora_path: None, - session_params: None, - // Boolean fields use defaults (false for most, true for some with default_true) - no_stop_trim: false, - ignore_eos: false, - continue_final_message: false, - skip_special_tokens: true, // This has default_true - separate_reasoning: true, // This has default_true - stream_reasoning: true, // This has default_true - return_hidden_states: false, - ..default_chat_completion_request() - }; - - let pd_req = req.to_pd_request(); - let other = pd_req.other.as_object().unwrap(); - - // Verify None values are not included - assert!(!other.contains_key("top_k")); - assert!(!other.contains_key("min_p")); - assert!(!other.contains_key("min_tokens")); - assert!(!other.contains_key("repetition_penalty")); - assert!(!other.contains_key("regex")); - assert!(!other.contains_key("ebnf")); - assert!(!other.contains_key("stop_token_ids")); - assert!(!other.contains_key("lora_path")); - assert!(!other.contains_key("session_params")); - - // Boolean fields are always present with their values (can't be None) - assert_eq!(other.get("no_stop_trim"), Some(&json!(false))); - assert_eq!(other.get("ignore_eos"), Some(&json!(false))); - assert_eq!(other.get("continue_final_message"), Some(&json!(false))); - assert_eq!(other.get("skip_special_tokens"), Some(&json!(true))); // default_true - assert_eq!(other.get("separate_reasoning"), Some(&json!(true))); // default_true - assert_eq!(other.get("stream_reasoning"), Some(&json!(true))); // default_true - assert_eq!(other.get("return_hidden_states"), Some(&json!(false))); - } -} diff --git a/sgl-router/tests/benchmark_integration.rs b/sgl-router/tests/benchmark_integration.rs index 75c55986f..196e509ca 100644 --- a/sgl-router/tests/benchmark_integration.rs +++ b/sgl-router/tests/benchmark_integration.rs @@ -1,12 +1,15 @@ // Integration test to ensure benchmarks compile and basic functionality works // This prevents benchmarks from breaking in CI +// +// UPDATED: Removed deprecated ToPdRequest usage, now uses direct JSON serialization -use serde_json::{from_str, to_string}; +use serde_json::{from_str, to_string, to_value}; +use sglang_router_rs::core::{BasicWorker, WorkerType}; use sglang_router_rs::openai_api_types::{ ChatCompletionRequest, ChatMessage, CompletionRequest, GenerateParameters, GenerateRequest, SamplingParams, StringOrArray, UserMessageContent, }; -use sglang_router_rs::routers::request_adapter::{RouteableRequest, ToPdRequest}; +use sglang_router_rs::routers::bootstrap_injector::inject_bootstrap_fields; /// Create a default GenerateRequest for benchmarks with minimal fields set fn default_generate_request() -> GenerateRequest { @@ -114,6 +117,15 @@ fn default_completion_request() -> CompletionRequest { } } +fn create_test_worker() -> BasicWorker { + BasicWorker::new( + "http://test-server:8000".to_string(), + WorkerType::Prefill { + bootstrap_port: Some(5678), + }, + ) +} + #[test] fn test_benchmark_request_creation() { // Ensure all benchmark request types can be created without panicking @@ -197,8 +209,8 @@ fn test_benchmark_serialization_roundtrip() { } #[test] -fn test_benchmark_request_adaptation() { - // Test that PD request adaptation works for benchmark types +fn test_benchmark_bootstrap_injection() { + // Test that bootstrap injection works for benchmark types (replaces PD request adaptation) let generate_req = GenerateRequest { text: Some("Test prompt".to_string()), @@ -236,24 +248,40 @@ fn test_benchmark_request_adaptation() { ..default_completion_request() }; - // Test PD adaptation (should not panic) - let _pd_generate = generate_req.to_pd_request(); - let _pd_chat = chat_req.to_pd_request(); - let _pd_completion = completion_req.to_pd_request(); + let worker = create_test_worker(); + + // Test bootstrap injection (should not panic) + let mut generate_json = to_value(&generate_req).unwrap(); + let mut chat_json = to_value(&chat_req).unwrap(); + let mut completion_json = to_value(&completion_req).unwrap(); + + assert!(inject_bootstrap_fields(&mut generate_json, &worker).is_ok()); + assert!(inject_bootstrap_fields(&mut chat_json, &worker).is_ok()); + assert!(inject_bootstrap_fields(&mut completion_json, &worker).is_ok()); + + // Verify bootstrap fields were added + assert!(generate_json.get("bootstrap_host").is_some()); + assert!(generate_json.get("bootstrap_port").is_some()); + assert!(generate_json.get("bootstrap_room").is_some()); } #[test] -fn test_benchmark_regular_routing() { - // Test regular routing functionality for benchmark types +fn test_benchmark_direct_json_routing() { + // Test direct JSON routing functionality for benchmark types (replaces regular routing) let generate_req = GenerateRequest { text: Some("Test prompt".to_string()), ..default_generate_request() }; - // Test regular routing methods (should not panic) - let _json = generate_req.to_json(); - let _bytes = generate_req.to_bytes(); + // Test direct JSON conversion (replaces regular routing methods) + let json = to_value(&generate_req).unwrap(); + let json_string = to_string(&json).unwrap(); + let bytes = json_string.as_bytes(); + + // Verify conversions work + assert!(!json_string.is_empty()); + assert!(!bytes.is_empty()); } #[test] @@ -266,23 +294,36 @@ fn test_benchmark_performance_baseline() { ..default_generate_request() }; - // Serialization should be fast (< 1ms for simple requests) + // Test the actual simplified pipeline: to_value + bootstrap injection let start = Instant::now(); - let _json = to_string(&generate_req).unwrap(); - let serialize_duration = start.elapsed(); + let worker = create_test_worker(); + + // This mirrors the actual router pipeline + let mut json = to_value(&generate_req).unwrap(); + let _ = inject_bootstrap_fields(&mut json, &worker); + + let total_duration = start.elapsed(); assert!( - serialize_duration.as_millis() < 1, - "Serialization took too long: {:?}", - serialize_duration + total_duration.as_millis() < 5, + "Simplified pipeline took too long: {:?} (should be faster than old adapter approach)", + total_duration ); - // PD adaptation should be very fast (< 1ms) + // Individual components should also be fast let start = Instant::now(); - let _pd_req = generate_req.to_pd_request(); - let adapt_duration = start.elapsed(); + let _json = to_value(&generate_req).unwrap(); + let to_value_duration = start.elapsed(); + + let start = Instant::now(); + let mut json = to_value(&generate_req).unwrap(); + let _ = inject_bootstrap_fields(&mut json, &worker); + let inject_duration = start.elapsed(); + + // Bootstrap injection should be faster than the JSON conversion assert!( - adapt_duration.as_millis() < 1, - "PD adaptation took too long: {:?}", - adapt_duration + inject_duration <= to_value_duration * 3, + "Bootstrap injection ({:?}) should not be much slower than JSON conversion ({:?})", + inject_duration, + to_value_duration ); }