diff --git a/sgl-router/benches/request_processing.rs b/sgl-router/benches/request_processing.rs index 576d07d2f..db5cdc901 100644 --- a/sgl-router/benches/request_processing.rs +++ b/sgl-router/benches/request_processing.rs @@ -97,6 +97,7 @@ fn create_sample_completion_request() -> CompletionRequest { logit_bias: None, user: None, seed: None, + other: serde_json::Map::new(), } } diff --git a/sgl-router/src/openai_api_types.rs b/sgl-router/src/openai_api_types.rs index 9870fd06b..d57e61767 100644 --- a/sgl-router/src/openai_api_types.rs +++ b/sgl-router/src/openai_api_types.rs @@ -91,6 +91,10 @@ pub struct CompletionRequest { /// If specified, our system will make a best effort to sample deterministically #[serde(skip_serializing_if = "Option::is_none")] pub seed: Option, + + /// Additional fields including bootstrap info for PD routing + #[serde(flatten)] + pub other: serde_json::Map, } impl GenerationRequest for CompletionRequest { diff --git a/sgl-router/src/routers/pd_router.rs b/sgl-router/src/routers/pd_router.rs index 7c70a3873..ab9927d24 100644 --- a/sgl-router/src/routers/pd_router.rs +++ b/sgl-router/src/routers/pd_router.rs @@ -420,6 +420,77 @@ impl PDRouter { .await } + // Route a completion request while preserving OpenAI format + pub async fn route_completion( + &self, + client: &reqwest::Client, + req: &HttpRequest, + mut typed_req: CompletionRequest, + route: &str, + ) -> HttpResponse { + let start = Instant::now(); + + // Get stream flag and return_logprob flag before moving the request + let is_stream = typed_req.stream; + let return_logprob = typed_req.logprobs.is_some(); + + // Extract text for cache-aware routing from the typed request + let request_text = match &typed_req.prompt { + crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()), + crate::openai_api_types::StringOrArray::Array(arr) => arr.first().map(|s| s.as_str()), + }; + + // Select servers + let (prefill, decode) = match self.select_pd_pair(client, request_text).await { + Ok(pair) => pair, + Err(e) => { + error!("Failed to select PD pair: {}", e); + RouterMetrics::record_pd_error("server_selection"); + return HttpResponse::ServiceUnavailable() + .body(format!("No available servers: {}", e)); + } + }; + + // Log routing decision + info!( + "PD routing: {} -> prefill={}, decode={}", + route, + prefill.url(), + decode.url() + ); + + // Add bootstrap info using the trait method + if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) { + error!("Failed to add bootstrap info: {}", e); + RouterMetrics::record_pd_error("bootstrap_injection"); + return HttpResponse::InternalServerError() + .body(format!("Bootstrap injection failed: {}", e)); + } + + // Convert to JSON after bootstrap injection + let json_with_bootstrap = match serde_json::to_value(&typed_req) { + Ok(json) => json, + Err(e) => { + error!("Failed to serialize request: {}", e); + return HttpResponse::InternalServerError().body("Failed to serialize request"); + } + }; + + // Execute dual dispatch + self.execute_dual_dispatch( + client, + req, + json_with_bootstrap, + route, + prefill.as_ref(), + decode.as_ref(), + is_stream, + return_logprob, + start, + ) + .await + } + // Execute the dual dispatch to prefill and decode servers #[allow(clippy::too_many_arguments)] async fn execute_dual_dispatch( @@ -1302,23 +1373,12 @@ impl RouterTrait for PDRouter { req: &HttpRequest, body: serde_json::Value, ) -> HttpResponse { - match serde_json::from_value::(body.clone()) { + match serde_json::from_value::(body) { Ok(openai_req) => { - // Convert OpenAI format to PD format (CompletionRequest -> GenerateReqInput) - let pd_req = openai_req.to_pd_request(); - PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await - } - Err(_) => { - // If that fails, try to deserialize directly as PD format (for backwards compatibility) - match serde_json::from_value::(body) { - Ok(pd_req) => { - PDRouter::route_generate(self, client, req, pd_req, "/v1/completions").await - } - Err(e) => { - HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)) - } - } + // Use the new method that preserves OpenAI format + PDRouter::route_completion(self, client, req, openai_req, "/v1/completions").await } + Err(e) => HttpResponse::BadRequest().body(format!("Invalid request format: {}", e)), } } diff --git a/sgl-router/src/routers/pd_types.rs b/sgl-router/src/routers/pd_types.rs index e83ab5b60..993f2bf3d 100644 --- a/sgl-router/src/routers/pd_types.rs +++ b/sgl-router/src/routers/pd_types.rs @@ -1,6 +1,7 @@ // Essential PDLB types extracted for PD routing use crate::core::{Worker, WorkerType}; +use crate::openai_api_types::{CompletionRequest, StringOrArray}; use serde::{Deserialize, Serialize}; use serde_json::Value; @@ -233,3 +234,235 @@ impl Bootstrap for ChatReqInput { self.bootstrap_room = Some(bootstrap_room); } } + +// Bootstrap implementation for CompletionRequest to preserve OpenAI format +impl Bootstrap for CompletionRequest { + fn is_stream(&self) -> bool { + self.stream + } + + fn get_batch_size(&self) -> Result, String> { + if let StringOrArray::Array(prompts) = &self.prompt { + if prompts.is_empty() { + return Err("Batch prompt array is empty".to_string()); + } + return Ok(Some(prompts.len())); + } + + // Single string prompt + Ok(None) + } + + fn set_bootstrap_info( + &mut self, + bootstrap_host: BootstrapHost, + bootstrap_port: BootstrapPort, + bootstrap_room: BootstrapRoom, + ) { + // Insert bootstrap_host - it serializes correctly whether Single or Batch + if let Ok(host_value) = serde_json::to_value(&bootstrap_host) { + self.other.insert("bootstrap_host".to_string(), host_value); + } + + // Insert bootstrap_port - it serializes correctly whether Single or Batch + if let Ok(port_value) = serde_json::to_value(&bootstrap_port) { + self.other.insert("bootstrap_port".to_string(), port_value); + } + + // Insert bootstrap_room - it serializes correctly whether Single or Batch + if let Ok(room_value) = serde_json::to_value(&bootstrap_room) { + self.other.insert("bootstrap_room".to_string(), room_value); + } + } +} + +#[cfg(test)] +mod bootstrap_tests { + use super::*; + use crate::openai_api_types::StringOrArray; + + #[test] + fn test_completion_batch_size_with_array_prompt() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), + n: None, + other: serde_json::Map::new(), + suffix: None, + max_tokens: None, + temperature: None, + top_p: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + }; + + // Should return batch size for array prompt + assert_eq!(req.get_batch_size().unwrap(), Some(2)); + } + + #[test] + fn test_completion_batch_size_with_single_prompt() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::String("single prompt".to_string()), + n: None, + other: serde_json::Map::new(), + suffix: None, + max_tokens: None, + temperature: None, + top_p: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + }; + + // Should return None for single prompt + assert_eq!(req.get_batch_size().unwrap(), None); + } + + #[test] + fn test_completion_batch_size_with_n_parameter() { + let req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::String("single prompt".to_string()), + n: Some(3), + other: serde_json::Map::new(), + suffix: None, + max_tokens: None, + temperature: None, + top_p: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + }; + + // Should return None for single string prompt, even with n > 1 + // SGLang handles n parameter differently than batch requests + assert_eq!(req.get_batch_size().unwrap(), None); + } + + #[test] + fn test_completion_bootstrap_single_values() { + let mut req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), + n: None, + other: serde_json::Map::new(), + suffix: None, + max_tokens: None, + temperature: None, + top_p: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + }; + + // Set bootstrap info - should always use single values + req.set_bootstrap_info( + BootstrapHost::Single("test-server".to_string()), + BootstrapPort::Single(Some(5678)), + BootstrapRoom::Single(12345), + ); + + // Verify single values were created + assert!(req.other.get("bootstrap_host").unwrap().is_string()); + assert!(req.other.get("bootstrap_port").unwrap().is_number()); + assert!(req.other.get("bootstrap_room").unwrap().is_number()); + + assert_eq!( + req.other.get("bootstrap_host").unwrap().as_str().unwrap(), + "test-server" + ); + assert_eq!( + req.other.get("bootstrap_port").unwrap().as_u64().unwrap(), + 5678 + ); + assert_eq!( + req.other.get("bootstrap_room").unwrap().as_u64().unwrap(), + 12345 + ); + } + + #[test] + fn test_completion_bootstrap_array_values() { + let mut req = CompletionRequest { + model: "test".to_string(), + prompt: StringOrArray::Array(vec!["prompt1".to_string(), "prompt2".to_string()]), + n: None, + other: serde_json::Map::new(), + suffix: None, + max_tokens: None, + temperature: None, + top_p: None, + stream: false, + stream_options: None, + logprobs: None, + echo: false, + stop: None, + presence_penalty: None, + frequency_penalty: None, + best_of: None, + logit_bias: None, + user: None, + seed: None, + }; + + // Set bootstrap info with arrays + req.set_bootstrap_info( + BootstrapHost::Batch(vec!["test-server".to_string(); 2]), + BootstrapPort::Batch(vec![Some(5678); 2]), + BootstrapRoom::Batch(vec![12345, 67890]), + ); + + // Verify arrays were created correctly + assert!(req.other.get("bootstrap_host").unwrap().is_array()); + assert!(req.other.get("bootstrap_port").unwrap().is_array()); + assert!(req.other.get("bootstrap_room").unwrap().is_array()); + + let hosts = req.other.get("bootstrap_host").unwrap().as_array().unwrap(); + assert_eq!(hosts.len(), 2); + assert_eq!(hosts[0].as_str().unwrap(), "test-server"); + + let ports = req.other.get("bootstrap_port").unwrap().as_array().unwrap(); + assert_eq!(ports.len(), 2); + assert_eq!(ports[0].as_u64().unwrap(), 5678); + + let rooms = req.other.get("bootstrap_room").unwrap().as_array().unwrap(); + assert_eq!(rooms.len(), 2); + assert_eq!(rooms[0].as_u64().unwrap(), 12345); + assert_eq!(rooms[1].as_u64().unwrap(), 67890); + } +} diff --git a/sgl-router/src/routers/request_adapter.rs b/sgl-router/src/routers/request_adapter.rs index 201c61aa5..f29bcecc9 100644 --- a/sgl-router/src/routers/request_adapter.rs +++ b/sgl-router/src/routers/request_adapter.rs @@ -648,6 +648,7 @@ mod tests { user: None, seed: None, suffix: None, + other: serde_json::Map::new(), }; let pd_req = req.to_pd_request(); @@ -687,6 +688,7 @@ mod tests { user: None, seed: None, suffix: None, + other: serde_json::Map::new(), }; let pd_req = req.to_pd_request(); @@ -725,6 +727,7 @@ mod tests { user: Some("user123".to_string()), seed: Some(42), suffix: Some("...".to_string()), + other: serde_json::Map::new(), }; let pd_req = req.to_pd_request(); @@ -768,6 +771,7 @@ mod tests { user: None, seed: None, suffix: None, + other: serde_json::Map::new(), }; let pd_req = req.to_pd_request(); @@ -799,6 +803,7 @@ mod tests { user: None, seed: None, suffix: None, + other: serde_json::Map::new(), }; let pd_req = req.to_pd_request(); diff --git a/sgl-router/tests/benchmark_integration.rs b/sgl-router/tests/benchmark_integration.rs index 317859000..b7876e223 100644 --- a/sgl-router/tests/benchmark_integration.rs +++ b/sgl-router/tests/benchmark_integration.rs @@ -86,6 +86,7 @@ fn test_benchmark_request_creation() { logit_bias: None, user: None, seed: None, + other: serde_json::Map::new(), }; // Test serialization works @@ -181,6 +182,7 @@ fn test_benchmark_request_adaptation() { logit_bias: None, user: None, seed: None, + other: serde_json::Map::new(), }; // Test PD adaptation (should not panic)