[router] fix req handling order, improve serialization, remove retry (#8888)

This commit is contained in:
Simo Lin
2025-08-06 23:24:39 -07:00
committed by GitHub
parent 2d120f8b18
commit a69b637014
10 changed files with 432 additions and 856 deletions

View File

@@ -1,7 +1,5 @@
// PD (Prefill-Decode) Router Implementation
// This module handles routing for disaggregated prefill-decode systems
use super::bootstrap_injector::inject_bootstrap_fields;
use super::pd_types::{api_path, PDRouterError};
use crate::config::types::RetryConfig;
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
@@ -19,7 +17,6 @@ use axum::{
Json,
};
use futures_util::StreamExt;
use rand::Rng;
use reqwest::Client;
use serde_json::Value;
use std::collections::HashMap;
@@ -316,17 +313,6 @@ impl PDRouter {
.into_response()
}
// Helper to handle bootstrap injection errors
fn handle_bootstrap_error(error: impl std::fmt::Display) -> Response {
error!("Failed to add bootstrap info error={}", error);
RouterMetrics::record_pd_error("bootstrap_injection");
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("Bootstrap injection failed: {}", error),
)
.into_response()
}
// Helper to handle serialization errors
fn handle_serialization_error(error: impl std::fmt::Display) -> Response {
error!("Failed to serialize request error={}", error);
@@ -337,110 +323,87 @@ impl PDRouter {
.into_response()
}
// Execute the dual dispatch to prefill and decode servers with retry logic
async fn execute_dual_dispatch(
&self,
headers: Option<&HeaderMap>,
json_request: Value,
route: &str,
prefill: &dyn Worker,
decode: &dyn Worker,
is_stream: bool,
return_logprob: bool,
start_time: Instant,
) -> Response {
for attempt in 0..self.retry_config.max_retries {
if attempt > 0 {
// Calculate backoff with exponential growth and jitter
let base_backoff = self.retry_config.initial_backoff_ms as f64
* self
.retry_config
.backoff_multiplier
.powf((attempt - 1) as f32) as f64;
let backoff_ms = base_backoff.min(self.retry_config.max_backoff_ms as f64) as u64;
// Add jitter to prevent thundering herd
let jitter = {
let mut rng = rand::thread_rng();
rng.gen_range(0..backoff_ms / 2)
};
let total_backoff = Duration::from_millis(backoff_ms + jitter);
info!(
"Retrying request (attempt {}/{}) after {:?} backoff",
attempt + 1,
self.retry_config.max_retries,
total_backoff
);
tokio::time::sleep(total_backoff).await;
}
debug!(
"Executing request attempt {}/{}",
attempt + 1,
self.retry_config.max_retries
);
let result = self
.execute_dual_dispatch_inner(
headers,
json_request.clone(),
route,
prefill,
decode,
is_stream,
return_logprob,
start_time,
)
.await;
// Check if we should retry based on the response status
let status = result.status();
debug!(
"Request attempt {} returned status: {}",
attempt + 1,
status
);
// Don't retry client errors (4xx) or successful responses
if status.is_client_error() || status.is_success() {
debug!(
"Returning response with status {} (no retry needed)",
status
);
return result;
}
// Check if this is the last attempt
if attempt == self.retry_config.max_retries - 1 {
warn!("Final attempt failed with status {}", status);
return result;
}
// Log retry decision for retryable errors
if status.is_server_error()
|| status == StatusCode::BAD_GATEWAY
|| status == StatusCode::GATEWAY_TIMEOUT
{
warn!(
"Retryable error status: {} on attempt {}/{}. Will retry.",
status,
attempt + 1,
self.retry_config.max_retries
);
} else {
// Don't retry other statuses
debug!("Status {} is not retryable, returning response", status);
return result;
// Helper to determine batch size from a GenerateRequest
fn get_generate_batch_size(req: &GenerateRequest) -> Option<usize> {
// Check prompt array
if let Some(prompt) = &req.prompt {
if let crate::openai_api_types::StringOrArray::Array(arr) = prompt {
if !arr.is_empty() {
return Some(arr.len());
}
}
}
// This should never be reached due to the loop logic, but just in case
unreachable!("Retry loop completed without returning")
// Check text array
if let Some(text) = &req.text {
if text.contains("[") && text.contains("]") {
// This is a simplified check - in reality we'd need to parse JSON
return None; // For now, fall back to non-batch
}
}
None
}
// Inner implementation of dual dispatch (extracted for retry logic)
async fn execute_dual_dispatch_inner(
// Helper to determine batch size from a ChatCompletionRequest
fn get_chat_batch_size(req: &ChatCompletionRequest) -> Option<usize> {
// Check 'n' parameter for multiple responses
if let Some(n) = req.n {
if n > 1 {
return Some(n as usize);
}
}
None
}
// Helper to determine batch size from a CompletionRequest
fn get_completion_batch_size(req: &CompletionRequest) -> Option<usize> {
// Check prompt array
if let crate::openai_api_types::StringOrArray::Array(arr) = &req.prompt {
if !arr.is_empty() {
return Some(arr.len());
}
}
None
}
// Helper to create request with bootstrap fields
fn create_request_with_bootstrap<T: serde::Serialize>(
request: &T,
prefill_worker: &dyn Worker,
batch_size: Option<usize>,
) -> Result<serde_json::Value, serde_json::Error> {
// Get bootstrap port from prefill worker
let bootstrap_port = match prefill_worker.worker_type() {
crate::core::WorkerType::Prefill { bootstrap_port } => bootstrap_port,
_ => None,
};
let hostname = super::pd_types::get_hostname(prefill_worker.url());
// Create optimized request with bootstrap fields
if let Some(batch_size) = batch_size {
// Batch request
let request_with_bootstrap = super::pd_types::BatchRequestWithBootstrap {
original: request,
bootstrap_host: vec![hostname; batch_size],
bootstrap_port: vec![bootstrap_port; batch_size],
bootstrap_room: (0..batch_size)
.map(|_| super::pd_types::generate_room_id())
.collect(),
};
serde_json::to_value(&request_with_bootstrap)
} else {
// Single request
let request_with_bootstrap = super::pd_types::RequestWithBootstrap {
original: request,
bootstrap_host: hostname,
bootstrap_port,
bootstrap_room: super::pd_types::generate_room_id(),
};
serde_json::to_value(&request_with_bootstrap)
}
}
// Execute the dual dispatch to prefill and decode servers
async fn execute_dual_dispatch(
&self,
headers: Option<&HeaderMap>,
json_request: Value,
@@ -467,101 +430,195 @@ impl PDRouter {
prefill.url(),
decode.url()
);
let (prefill_result, decode_result) =
tokio::join!(prefill_request.send(), decode_request.send());
debug!("Received responses from both servers");
// Update metrics
let duration = start_time.elapsed();
RouterMetrics::record_pd_request_duration(route, duration);
RouterMetrics::record_pd_request(route);
RouterMetrics::record_pd_prefill_request(prefill.url());
RouterMetrics::record_pd_decode_request(decode.url());
if return_logprob {
// When we need logprobs, wait for both responses
let (prefill_result, decode_result) =
tokio::join!(prefill_request.send(), decode_request.send());
debug!("Received responses from both servers");
// Process prefill response
let (_prefill_status, prefill_body) = match self
.process_prefill_response(prefill_result, prefill.url(), return_logprob)
.await
{
Ok(result) => result,
Err(error_response) => return error_response,
};
// Update metrics
let duration = start_time.elapsed();
RouterMetrics::record_pd_request_duration(route, duration);
RouterMetrics::record_pd_request(route);
RouterMetrics::record_pd_prefill_request(prefill.url());
RouterMetrics::record_pd_decode_request(decode.url());
// Process decode response
debug!("Processing decode response");
match decode_result {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
debug!("Decode response status: {}", status);
// Process decode response with prefill for logprobs
debug!("Processing decode response with logprobs");
match decode_result {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
debug!("Decode response status: {}", status);
if !status.is_success() {
RouterMetrics::record_pd_decode_error(decode.url());
error!(
"Decode server returned error status decode_url={} status={}",
decode.url(),
status
);
if !status.is_success() {
RouterMetrics::record_pd_decode_error(decode.url());
error!(
"Decode server returned error status decode_url={} status={}",
decode.url(),
status
);
// Return the error response from decode server
match res.bytes().await {
Ok(error_body) => {
return (status, error_body).into_response();
}
Err(e) => {
return (status, format!("Decode server error: {}", e)).into_response();
// Return the error response from decode server
match res.bytes().await {
Ok(error_body) => {
return (status, error_body).into_response();
}
Err(e) => {
return (status, format!("Decode server error: {}", e))
.into_response();
}
}
}
}
if is_stream {
// Streaming response
let prefill_logprobs = if return_logprob {
prefill_body
// Process prefill response for logprobs
let prefill_body = match self
.process_prefill_response(prefill_result, prefill.url(), return_logprob)
.await
{
Ok((_, body)) => body,
Err(error_response) => return error_response,
};
if is_stream {
// Streaming response with logprobs
let prefill_logprobs = prefill_body
.as_ref()
.and_then(|body| serde_json::from_slice::<Value>(body).ok())
.and_then(|json| {
json.pointer("/meta_info/input_token_logprobs").cloned()
})
} else {
None
};
});
let decode_url = if !return_logprob {
Some(decode.url().to_string())
Self::create_streaming_response(
res.bytes_stream(),
status,
prefill_logprobs,
return_logprob,
None,
)
} else {
None
};
Self::create_streaming_response(
res.bytes_stream(),
status,
prefill_logprobs,
return_logprob,
decode_url,
)
} else {
// Non-streaming response - use helper
self.process_non_streaming_response(res, status, return_logprob, prefill_body)
// Non-streaming response with logprobs
self.process_non_streaming_response(
res,
status,
return_logprob,
prefill_body,
)
.await
}
}
Err(e) => {
error!(
decode_url = %decode.url(),
error = %e,
"Decode request failed"
);
RouterMetrics::record_pd_decode_error(decode.url());
(
StatusCode::BAD_GATEWAY,
format!("Decode server error: {}", e),
)
.into_response()
}
}
Err(e) => {
error!(
decode_url = %decode.url(),
error = %e,
"Decode request failed"
);
RouterMetrics::record_pd_decode_error(decode.url());
(
StatusCode::BAD_GATEWAY,
format!("Decode server error: {}", e),
)
.into_response()
} else {
// When we don't need logprobs, only wait for decode response
// Send both requests concurrently but don't wait for prefill
// Add headers to minimize response size when we don't need the body
let prefill_future = prefill_request.header("Connection", "close").send();
let decode_future = decode_request.send();
tokio::spawn(async move {
if let Ok(response) = prefill_future.await {
// Consume with a short timeout to free connection quickly
let consume_future = async {
let _ = response.bytes().await;
};
// Give it 100ms to consume, then abandon
let _ = tokio::time::timeout(Duration::from_millis(100), consume_future).await;
}
});
// Wait only for decode response
let decode_result = decode_future.await;
debug!("Received decode response");
// Update metrics
let duration = start_time.elapsed();
RouterMetrics::record_pd_request_duration(route, duration);
RouterMetrics::record_pd_request(route);
RouterMetrics::record_pd_prefill_request(prefill.url());
RouterMetrics::record_pd_decode_request(decode.url());
// Process decode response immediately
debug!("Processing decode response (no logprobs)");
match decode_result {
Ok(res) => {
let status = StatusCode::from_u16(res.status().as_u16())
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
debug!("Decode response status: {}", status);
if !status.is_success() {
RouterMetrics::record_pd_decode_error(decode.url());
error!(
"Decode server returned error status decode_url={} status={}",
decode.url(),
status
);
// Return the error response from decode server
match res.bytes().await {
Ok(error_body) => (status, error_body).into_response(),
Err(e) => {
(status, format!("Decode server error: {}", e)).into_response()
}
}
} else if is_stream {
// Streaming response without logprobs - direct passthrough
let decode_url = decode.url().to_string();
Self::create_streaming_response(
res.bytes_stream(),
status,
None,
false,
Some(decode_url),
)
} else {
// Non-streaming response without logprobs - direct passthrough like fast version
match res.bytes().await {
Ok(decode_body) => (status, decode_body).into_response(),
Err(e) => {
error!("Failed to read decode response: {}", e);
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
.into_response()
}
}
}
}
Err(e) => {
error!(
decode_url = %decode.url(),
error = %e,
"Decode request failed"
);
RouterMetrics::record_pd_decode_error(decode.url());
(
StatusCode::BAD_GATEWAY,
format!("Decode server error: {}", e),
)
.into_response()
}
}
}
}
// Check if either prefill or decode policy needs request text
fn policies_need_request_text(&self) -> bool {
self.prefill_policy.needs_request_text() || self.decode_policy.needs_request_text()
}
// Select a pair of prefill and decode servers
async fn select_pd_pair(
&self,
@@ -1311,23 +1368,23 @@ impl RouterTrait for PDRouter {
) -> Response {
let start = Instant::now();
// Convert directly to JSON to preserve all fields automatically
let mut json = match serde_json::to_value(body) {
Ok(json) => json,
Err(e) => return Self::handle_serialization_error(e),
};
// Extract flags for routing logic
let is_stream = body.stream;
let return_logprob = body.return_logprob;
// Extract text for cache-aware routing
let request_text = body.text.as_deref().or_else(|| {
body.prompt.as_ref().and_then(|p| match p {
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()),
// Extract text for cache-aware routing only if needed
let request_text = if self.policies_need_request_text() {
body.text.as_deref().or_else(|| {
body.prompt.as_ref().and_then(|p| match p {
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
crate::openai_api_types::StringOrArray::Array(v) => {
v.first().map(|s| s.as_str())
}
})
})
});
} else {
None
};
// Select servers
let (prefill, decode) = match self.select_pd_pair(request_text).await {
@@ -1342,10 +1399,12 @@ impl RouterTrait for PDRouter {
decode.url()
);
// Inject bootstrap fields directly into JSON
if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) {
return Self::handle_bootstrap_error(e);
}
// Create optimized request with bootstrap fields
let batch_size = Self::get_generate_batch_size(body);
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
Ok(json) => json,
Err(e) => return Self::handle_serialization_error(e),
};
// Execute dual dispatch
self.execute_dual_dispatch(
@@ -1368,27 +1427,29 @@ impl RouterTrait for PDRouter {
) -> Response {
let start = Instant::now();
// Convert directly to JSON to preserve all fields automatically
let mut json = match serde_json::to_value(body) {
Ok(json) => json,
Err(e) => return Self::handle_serialization_error(e),
};
// Extract flags for routing logic
let is_stream = body.stream;
let return_logprob = body.logprobs;
// Extract text for cache-aware routing from chat messages
let request_text = body.messages.first().and_then(|msg| match msg {
crate::openai_api_types::ChatMessage::User { content, .. } => {
match content {
crate::openai_api_types::UserMessageContent::Text(text) => Some(text.as_str()),
crate::openai_api_types::UserMessageContent::Parts(_) => None, // Skip complex content
// Extract text for cache-aware routing from chat messages only if needed
let request_text = if self.policies_need_request_text() {
body.messages.first().and_then(|msg| match msg {
crate::openai_api_types::ChatMessage::User { content, .. } => {
match content {
crate::openai_api_types::UserMessageContent::Text(text) => {
Some(text.as_str())
}
crate::openai_api_types::UserMessageContent::Parts(_) => None, // Skip complex content
}
}
}
crate::openai_api_types::ChatMessage::System { content, .. } => Some(content.as_str()),
_ => None,
});
crate::openai_api_types::ChatMessage::System { content, .. } => {
Some(content.as_str())
}
_ => None,
})
} else {
None
};
// Select servers
let (prefill, decode) = match self.select_pd_pair(request_text).await {
@@ -1403,10 +1464,12 @@ impl RouterTrait for PDRouter {
decode.url()
);
// Inject bootstrap fields directly into JSON
if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) {
return Self::handle_bootstrap_error(e);
}
// Create optimized request with bootstrap fields
let batch_size = Self::get_chat_batch_size(body);
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
Ok(json) => json,
Err(e) => return Self::handle_serialization_error(e),
};
// Execute dual dispatch
self.execute_dual_dispatch(
@@ -1429,20 +1492,18 @@ impl RouterTrait for PDRouter {
) -> Response {
let start = Instant::now();
// Convert directly to JSON to preserve all fields automatically
let mut json = match serde_json::to_value(body) {
Ok(json) => json,
Err(e) => return Self::handle_serialization_error(e),
};
// Extract flags for routing logic
let is_stream = body.stream;
let return_logprob = body.logprobs.is_some();
// Extract text for cache-aware routing
let request_text = match &body.prompt {
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()),
// Extract text for cache-aware routing only if needed
let request_text = if self.policies_need_request_text() {
match &body.prompt {
crate::openai_api_types::StringOrArray::String(s) => Some(s.as_str()),
crate::openai_api_types::StringOrArray::Array(v) => v.first().map(|s| s.as_str()),
}
} else {
None
};
// Select servers
@@ -1458,10 +1519,12 @@ impl RouterTrait for PDRouter {
decode.url()
);
// Inject bootstrap fields directly into JSON
if let Err(e) = inject_bootstrap_fields(&mut json, prefill.as_ref()) {
return Self::handle_bootstrap_error(e);
}
// Create optimized request with bootstrap fields
let batch_size = Self::get_completion_batch_size(body);
let json = match Self::create_request_with_bootstrap(body, prefill.as_ref(), batch_size) {
Ok(json) => json,
Err(e) => return Self::handle_serialization_error(e),
};
// Execute dual dispatch
self.execute_dual_dispatch(
@@ -1937,6 +2000,13 @@ mod tests {
assert!(result.is_ok());
}
// ============= Bootstrap Injection Tests =============
// Note: These tests are commented out as we've moved to the optimized bootstrap injection
// approach that doesn't use the Bootstrap trait on GenerateReqInput anymore.
// TODO: Add new tests for the optimized bootstrap injection approach using
// RequestWithBootstrap and BatchRequestWithBootstrap wrappers
// ============= Worker Selection Tests =============
#[tokio::test]
@@ -2114,158 +2184,4 @@ mod tests {
let workers = router.prefill_workers.read().unwrap();
assert_eq!(workers.len(), 5);
}
#[tokio::test]
async fn test_simplified_routing_preserves_sglang_fields() {
use crate::openai_api_types::GenerateRequest;
use crate::routers::bootstrap_injector::inject_bootstrap_fields;
// Create a test worker
let worker = BasicWorker::new(
"http://test-server:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(5678),
},
);
// Create a GenerateRequest with SGLang extensions
let mut session_params = std::collections::HashMap::new();
session_params.insert("test_key".to_string(), serde_json::json!("test_value"));
let request = GenerateRequest {
text: Some("Test prompt".to_string()),
stream: false,
return_logprob: true,
// SGLang extensions
lora_path: Some(crate::openai_api_types::LoRAPath::Single(Some(
"test.bin".to_string(),
))),
session_params: Some(session_params.clone()),
return_hidden_states: true,
rid: Some("test-request-id".to_string()),
// Other fields default to None/false
prompt: None,
input_ids: None,
parameters: None,
sampling_params: None,
};
// Convert to JSON (simulating the simplified routing path)
let mut json = serde_json::to_value(&request).unwrap();
// Inject bootstrap fields
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify all SGLang fields are preserved
assert_eq!(json["text"], serde_json::json!("Test prompt"));
assert_eq!(json["stream"], serde_json::json!(false));
assert_eq!(json["return_logprob"], serde_json::json!(true));
assert_eq!(json["lora_path"], serde_json::json!("test.bin")); // LoRAPath::Single serializes as just the inner value
assert_eq!(
json["session_params"],
serde_json::to_value(&session_params).unwrap()
);
assert_eq!(json["return_hidden_states"], serde_json::json!(true));
assert_eq!(json["rid"], serde_json::json!("test-request-id"));
// Verify bootstrap fields were added
assert_eq!(json["bootstrap_host"], serde_json::json!("test-server"));
assert_eq!(json["bootstrap_port"], serde_json::json!(5678));
assert!(json["bootstrap_room"].is_number());
}
#[tokio::test]
async fn test_simplified_routing_chat_completion() {
use crate::openai_api_types::{ChatCompletionRequest, ChatMessage, UserMessageContent};
use crate::routers::bootstrap_injector::inject_bootstrap_fields;
// Create a test worker
let worker = BasicWorker::new(
"http://chat-server:8000".to_string(),
WorkerType::Prefill {
bootstrap_port: Some(9999),
},
);
// Create a ChatCompletionRequest with SGLang extensions
let request = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![ChatMessage::User {
role: "user".to_string(),
content: UserMessageContent::Text("Hello world!".to_string()),
name: None,
}],
stream: false,
n: Some(2), // This should create batch bootstrap
// SGLang extensions
top_k: Some(50),
separate_reasoning: false,
stream_reasoning: true,
// Set all other fields to defaults
temperature: None,
top_p: None,
stream_options: None,
stop: None,
max_tokens: None,
max_completion_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
seed: None,
logprobs: false,
top_logprobs: None,
response_format: None,
tools: None,
tool_choice: None,
parallel_tool_calls: None,
functions: None,
function_call: None,
min_p: None,
min_tokens: None,
repetition_penalty: None,
regex: None,
ebnf: None,
stop_token_ids: None,
no_stop_trim: false,
ignore_eos: false,
continue_final_message: false,
skip_special_tokens: true,
lora_path: None,
session_params: None,
return_hidden_states: false,
};
// Convert to JSON (simulating the simplified routing path)
let mut json = serde_json::to_value(&request).unwrap();
// Inject bootstrap fields
let result = inject_bootstrap_fields(&mut json, &worker);
assert!(result.is_ok());
// Verify original fields preserved
assert_eq!(json["model"], serde_json::json!("gpt-4"));
assert_eq!(json["stream"], serde_json::json!(false));
assert_eq!(json["n"], serde_json::json!(2));
assert_eq!(json["top_k"], serde_json::json!(50));
assert_eq!(json["separate_reasoning"], serde_json::json!(false));
assert_eq!(json["stream_reasoning"], serde_json::json!(true));
// Verify batch bootstrap fields for n=2
let bootstrap_hosts = json["bootstrap_host"].as_array().unwrap();
assert_eq!(bootstrap_hosts.len(), 2);
assert_eq!(bootstrap_hosts[0], serde_json::json!("chat-server"));
assert_eq!(bootstrap_hosts[1], serde_json::json!("chat-server"));
let bootstrap_ports = json["bootstrap_port"].as_array().unwrap();
assert_eq!(bootstrap_ports.len(), 2);
assert_eq!(bootstrap_ports[0], serde_json::json!(9999));
assert_eq!(bootstrap_ports[1], serde_json::json!(9999));
let bootstrap_rooms = json["bootstrap_room"].as_array().unwrap();
assert_eq!(bootstrap_rooms.len(), 2);
// Rooms should be different (randomness)
assert_ne!(bootstrap_rooms[0], bootstrap_rooms[1]);
}
}