[pd-router] Add Configurable Retry Logic for reduce backend pressure (#8744)
This commit is contained in:
@@ -39,6 +39,8 @@ pub struct RouterConfig {
|
||||
pub max_concurrent_requests: usize,
|
||||
/// CORS allowed origins
|
||||
pub cors_allowed_origins: Vec<String>,
|
||||
/// Retry configuration
|
||||
pub retry: RetryConfig,
|
||||
}
|
||||
|
||||
/// Routing mode configuration
|
||||
@@ -182,6 +184,30 @@ impl Default for DiscoveryConfig {
|
||||
}
|
||||
}
|
||||
|
||||
/// Retry configuration for request handling
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RetryConfig {
|
||||
/// Maximum number of retry attempts
|
||||
pub max_retries: u32,
|
||||
/// Initial backoff delay in milliseconds
|
||||
pub initial_backoff_ms: u64,
|
||||
/// Maximum backoff delay in milliseconds
|
||||
pub max_backoff_ms: u64,
|
||||
/// Backoff multiplier for exponential backoff
|
||||
pub backoff_multiplier: f32,
|
||||
}
|
||||
|
||||
impl Default for RetryConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
max_retries: 3,
|
||||
initial_backoff_ms: 100,
|
||||
max_backoff_ms: 10000,
|
||||
backoff_multiplier: 2.0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Metrics configuration
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MetricsConfig {
|
||||
@@ -210,7 +236,7 @@ impl Default for RouterConfig {
|
||||
host: "127.0.0.1".to_string(),
|
||||
port: 3001,
|
||||
max_payload_size: 268_435_456, // 256MB
|
||||
request_timeout_secs: 600,
|
||||
request_timeout_secs: 3600, // 1 hour to match Python mini LB
|
||||
worker_startup_timeout_secs: 300,
|
||||
worker_startup_check_interval_secs: 10,
|
||||
dp_aware: false,
|
||||
@@ -222,6 +248,7 @@ impl Default for RouterConfig {
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -277,7 +304,7 @@ mod tests {
|
||||
assert_eq!(config.host, "127.0.0.1");
|
||||
assert_eq!(config.port, 3001);
|
||||
assert_eq!(config.max_payload_size, 268_435_456);
|
||||
assert_eq!(config.request_timeout_secs, 600);
|
||||
assert_eq!(config.request_timeout_secs, 3600);
|
||||
assert_eq!(config.worker_startup_timeout_secs, 300);
|
||||
assert_eq!(config.worker_startup_check_interval_secs, 10);
|
||||
assert!(config.discovery.is_none());
|
||||
@@ -332,6 +359,7 @@ mod tests {
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&config).unwrap();
|
||||
@@ -759,6 +787,7 @@ mod tests {
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
};
|
||||
|
||||
assert!(config.mode.is_pd_mode());
|
||||
@@ -810,6 +839,7 @@ mod tests {
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
};
|
||||
|
||||
assert!(!config.mode.is_pd_mode());
|
||||
@@ -857,6 +887,7 @@ mod tests {
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
};
|
||||
|
||||
assert!(config.has_service_discovery());
|
||||
|
||||
@@ -19,7 +19,7 @@ pub enum PolicyType {
|
||||
Random,
|
||||
RoundRobin,
|
||||
CacheAware,
|
||||
PowerOfTwo, // Moved from PD-specific, now shared
|
||||
PowerOfTwo,
|
||||
}
|
||||
|
||||
#[pyclass]
|
||||
@@ -45,7 +45,6 @@ struct Router {
|
||||
selector: HashMap<String, String>,
|
||||
service_discovery_port: u16,
|
||||
service_discovery_namespace: Option<String>,
|
||||
// PD service discovery fields
|
||||
prefill_selector: HashMap<String, String>,
|
||||
decode_selector: HashMap<String, String>,
|
||||
bootstrap_port_annotation: String,
|
||||
@@ -53,14 +52,11 @@ struct Router {
|
||||
prometheus_host: Option<String>,
|
||||
request_timeout_secs: u64,
|
||||
request_id_headers: Option<Vec<String>>,
|
||||
// PD mode flag
|
||||
pd_disaggregation: bool,
|
||||
// PD-specific fields (only used when pd_disaggregation is true)
|
||||
prefill_urls: Option<Vec<(String, Option<u16>)>>,
|
||||
decode_urls: Option<Vec<String>>,
|
||||
prefill_policy: Option<PolicyType>,
|
||||
decode_policy: Option<PolicyType>,
|
||||
// Additional server config fields
|
||||
max_concurrent_requests: usize,
|
||||
cors_allowed_origins: Vec<String>,
|
||||
}
|
||||
@@ -150,6 +146,7 @@ impl Router {
|
||||
request_id_headers: self.request_id_headers.clone(),
|
||||
max_concurrent_requests: self.max_concurrent_requests,
|
||||
cors_allowed_origins: self.cors_allowed_origins.clone(),
|
||||
retry: config::RetryConfig::default(),
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -289,7 +286,6 @@ impl Router {
|
||||
check_interval: std::time::Duration::from_secs(60),
|
||||
port: self.service_discovery_port,
|
||||
namespace: self.service_discovery_namespace.clone(),
|
||||
// PD mode configuration
|
||||
pd_mode: self.pd_disaggregation,
|
||||
prefill_selector: self.prefill_selector.clone(),
|
||||
decode_selector: self.decode_selector.clone(),
|
||||
|
||||
@@ -50,6 +50,7 @@ impl RouterFactory {
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
ctx.router_config.dp_aware,
|
||||
ctx.router_config.api_key.clone(),
|
||||
ctx.router_config.retry.clone(),
|
||||
)?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
@@ -79,6 +80,7 @@ impl RouterFactory {
|
||||
ctx.client.clone(),
|
||||
ctx.router_config.worker_startup_timeout_secs,
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
ctx.router_config.retry.clone(),
|
||||
)?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
use super::pd_types::{api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError};
|
||||
use super::request_adapter::ToPdRequest;
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
@@ -16,6 +17,8 @@ use axum::{
|
||||
Json,
|
||||
};
|
||||
use futures_util::StreamExt;
|
||||
use rand::Rng;
|
||||
use reqwest::Client;
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
@@ -36,6 +39,7 @@ pub struct PDRouter {
|
||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||
pub client: Client,
|
||||
pub retry_config: RetryConfig,
|
||||
_prefill_health_checker: Option<HealthChecker>,
|
||||
_decode_health_checker: Option<HealthChecker>,
|
||||
}
|
||||
@@ -180,6 +184,7 @@ impl PDRouter {
|
||||
client: Client,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
retry_config: RetryConfig,
|
||||
) -> Result<Self, String> {
|
||||
// Convert URLs to Worker trait objects
|
||||
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
|
||||
@@ -260,6 +265,7 @@ impl PDRouter {
|
||||
worker_loads,
|
||||
load_monitor_handle,
|
||||
client,
|
||||
retry_config,
|
||||
_prefill_health_checker: Some(prefill_health_checker),
|
||||
_decode_health_checker: Some(decode_health_checker),
|
||||
})
|
||||
@@ -294,6 +300,38 @@ impl PDRouter {
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to handle server selection errors
|
||||
fn handle_server_selection_error(error: String) -> Response {
|
||||
error!("Failed to select PD pair error={}", error);
|
||||
RouterMetrics::record_pd_error("server_selection");
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("No available servers: {}", error),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
// Helper to handle bootstrap injection errors
|
||||
fn handle_bootstrap_error(error: impl std::fmt::Display) -> Response {
|
||||
error!("Failed to add bootstrap info error={}", error);
|
||||
RouterMetrics::record_pd_error("bootstrap_injection");
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Bootstrap injection failed: {}", error),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
// Helper to handle serialization errors
|
||||
fn handle_serialization_error(error: impl std::fmt::Display) -> Response {
|
||||
error!("Failed to serialize request error={}", error);
|
||||
(
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to serialize request",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
// Route a typed generate request
|
||||
pub async fn route_generate(
|
||||
&self,
|
||||
@@ -320,15 +358,7 @@ impl PDRouter {
|
||||
// Select servers
|
||||
let (prefill, decode) = match self.select_pd_pair(request_text).await {
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
error!("Failed to select PD pair error={}", e);
|
||||
RouterMetrics::record_pd_error("server_selection");
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("No available servers: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => return Self::handle_server_selection_error(e),
|
||||
};
|
||||
|
||||
// Log routing decision
|
||||
@@ -341,26 +371,13 @@ impl PDRouter {
|
||||
|
||||
// Add bootstrap info using the trait method
|
||||
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||
error!("Failed to add bootstrap info error={}", e);
|
||||
RouterMetrics::record_pd_error("bootstrap_injection");
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Bootstrap injection failed: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
return Self::handle_bootstrap_error(e);
|
||||
}
|
||||
|
||||
// Convert to JSON after bootstrap injection
|
||||
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
|
||||
Ok(json) => json,
|
||||
Err(e) => {
|
||||
error!("Failed to serialize request error={}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to serialize request",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
|
||||
// Execute dual dispatch
|
||||
@@ -406,15 +423,7 @@ impl PDRouter {
|
||||
// Select servers
|
||||
let (prefill, decode) = match self.select_pd_pair(request_text).await {
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
error!("Failed to select PD pair error={}", e);
|
||||
RouterMetrics::record_pd_error("server_selection");
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("No available servers: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => return Self::handle_server_selection_error(e),
|
||||
};
|
||||
|
||||
// Log routing decision
|
||||
@@ -425,28 +434,14 @@ impl PDRouter {
|
||||
decode.url()
|
||||
);
|
||||
|
||||
// Add bootstrap info using the trait method
|
||||
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||
error!("Failed to add bootstrap info error={}", e);
|
||||
RouterMetrics::record_pd_error("bootstrap_injection");
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Bootstrap injection failed: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
return Self::handle_bootstrap_error(e);
|
||||
}
|
||||
|
||||
// Convert to JSON after bootstrap injection
|
||||
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
|
||||
Ok(json) => json,
|
||||
Err(e) => {
|
||||
error!("Failed to serialize request error={}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to serialize request",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
|
||||
// Execute dual dispatch
|
||||
@@ -485,15 +480,7 @@ impl PDRouter {
|
||||
// Select servers
|
||||
let (prefill, decode) = match self.select_pd_pair(request_text).await {
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
error!("Failed to select PD pair error={}", e);
|
||||
RouterMetrics::record_pd_error("server_selection");
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
format!("No available servers: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => return Self::handle_server_selection_error(e),
|
||||
};
|
||||
|
||||
// Log routing decision
|
||||
@@ -504,28 +491,14 @@ impl PDRouter {
|
||||
decode.url()
|
||||
);
|
||||
|
||||
// Add bootstrap info using the trait method
|
||||
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||
error!("Failed to add bootstrap info error={}", e);
|
||||
RouterMetrics::record_pd_error("bootstrap_injection");
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
format!("Bootstrap injection failed: {}", e),
|
||||
)
|
||||
.into_response();
|
||||
return Self::handle_bootstrap_error(e);
|
||||
}
|
||||
|
||||
// Convert to JSON after bootstrap injection
|
||||
let json_with_bootstrap = match serde_json::to_value(&typed_req) {
|
||||
Ok(json) => json,
|
||||
Err(e) => {
|
||||
error!("Failed to serialize request error={}", e);
|
||||
return (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
"Failed to serialize request",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
Err(e) => return Self::handle_serialization_error(e),
|
||||
};
|
||||
|
||||
// Execute dual dispatch
|
||||
@@ -542,7 +515,7 @@ impl PDRouter {
|
||||
.await
|
||||
}
|
||||
|
||||
// Execute the dual dispatch to prefill and decode servers
|
||||
// Execute the dual dispatch to prefill and decode servers with retry logic
|
||||
async fn execute_dual_dispatch(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
@@ -553,38 +526,128 @@ impl PDRouter {
|
||||
is_stream: bool,
|
||||
return_logprob: bool,
|
||||
start_time: Instant,
|
||||
) -> Response {
|
||||
for attempt in 0..self.retry_config.max_retries {
|
||||
if attempt > 0 {
|
||||
// Calculate backoff with exponential growth and jitter
|
||||
let base_backoff = self.retry_config.initial_backoff_ms as f64
|
||||
* self
|
||||
.retry_config
|
||||
.backoff_multiplier
|
||||
.powf((attempt - 1) as f32) as f64;
|
||||
let backoff_ms = base_backoff.min(self.retry_config.max_backoff_ms as f64) as u64;
|
||||
|
||||
// Add jitter to prevent thundering herd
|
||||
let jitter = {
|
||||
let mut rng = rand::thread_rng();
|
||||
rng.gen_range(0..backoff_ms / 2)
|
||||
};
|
||||
let total_backoff = Duration::from_millis(backoff_ms + jitter);
|
||||
|
||||
info!(
|
||||
"Retrying request (attempt {}/{}) after {:?} backoff",
|
||||
attempt + 1,
|
||||
self.retry_config.max_retries,
|
||||
total_backoff
|
||||
);
|
||||
|
||||
tokio::time::sleep(total_backoff).await;
|
||||
}
|
||||
|
||||
debug!(
|
||||
"Executing request attempt {}/{}",
|
||||
attempt + 1,
|
||||
self.retry_config.max_retries
|
||||
);
|
||||
let result = self
|
||||
.execute_dual_dispatch_inner(
|
||||
headers,
|
||||
json_request.clone(),
|
||||
route,
|
||||
prefill,
|
||||
decode,
|
||||
is_stream,
|
||||
return_logprob,
|
||||
start_time,
|
||||
)
|
||||
.await;
|
||||
|
||||
// Check if we should retry based on the response status
|
||||
let status = result.status();
|
||||
debug!(
|
||||
"Request attempt {} returned status: {}",
|
||||
attempt + 1,
|
||||
status
|
||||
);
|
||||
|
||||
// Don't retry client errors (4xx) or successful responses
|
||||
if status.is_client_error() || status.is_success() {
|
||||
debug!(
|
||||
"Returning response with status {} (no retry needed)",
|
||||
status
|
||||
);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Check if this is the last attempt
|
||||
if attempt == self.retry_config.max_retries - 1 {
|
||||
warn!("Final attempt failed with status {}", status);
|
||||
return result;
|
||||
}
|
||||
|
||||
// Log retry decision for retryable errors
|
||||
if status.is_server_error()
|
||||
|| status == StatusCode::BAD_GATEWAY
|
||||
|| status == StatusCode::GATEWAY_TIMEOUT
|
||||
{
|
||||
warn!(
|
||||
"Retryable error status: {} on attempt {}/{}. Will retry.",
|
||||
status,
|
||||
attempt + 1,
|
||||
self.retry_config.max_retries
|
||||
);
|
||||
} else {
|
||||
// Don't retry other statuses
|
||||
debug!("Status {} is not retryable, returning response", status);
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
// This should never be reached due to the loop logic, but just in case
|
||||
unreachable!("Retry loop completed without returning")
|
||||
}
|
||||
|
||||
// Inner implementation of dual dispatch (extracted for retry logic)
|
||||
async fn execute_dual_dispatch_inner(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
json_request: Value,
|
||||
route: &str,
|
||||
prefill: &dyn Worker,
|
||||
decode: &dyn Worker,
|
||||
is_stream: bool,
|
||||
return_logprob: bool,
|
||||
start_time: Instant,
|
||||
) -> Response {
|
||||
// Update load tracking for both workers
|
||||
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
|
||||
|
||||
// Build requests using .json() method
|
||||
let mut prefill_request = self
|
||||
.client
|
||||
.post(api_path(prefill.url(), route))
|
||||
.json(&json_request);
|
||||
// Build requests with headers
|
||||
let prefill_request =
|
||||
self.build_request_with_headers(prefill.url(), route, &json_request, headers);
|
||||
|
||||
let mut decode_request = self
|
||||
.client
|
||||
.post(api_path(decode.url(), route))
|
||||
.json(&json_request);
|
||||
|
||||
// Copy headers from original request (excluding content-type and content-length which are set by .json())
|
||||
if let Some(headers) = headers {
|
||||
for (name, value) in headers.iter() {
|
||||
let name_str = name.as_str();
|
||||
if name_str != "content-type" && name_str != "content-length" {
|
||||
// Skip headers with non-ASCII values
|
||||
if value.to_str().is_ok() {
|
||||
prefill_request = prefill_request.header(name, value);
|
||||
decode_request = decode_request.header(name, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let decode_request =
|
||||
self.build_request_with_headers(decode.url(), route, &json_request, headers);
|
||||
|
||||
// Send both requests concurrently
|
||||
debug!(
|
||||
"Sending concurrent requests to prefill={} decode={}",
|
||||
prefill.url(),
|
||||
decode.url()
|
||||
);
|
||||
let (prefill_result, decode_result) =
|
||||
tokio::join!(prefill_request.send(), decode_request.send());
|
||||
debug!("Received responses from both servers");
|
||||
|
||||
// Update metrics
|
||||
let duration = start_time.elapsed();
|
||||
@@ -593,11 +656,22 @@ impl PDRouter {
|
||||
RouterMetrics::record_pd_prefill_request(prefill.url());
|
||||
RouterMetrics::record_pd_decode_request(decode.url());
|
||||
|
||||
// Process prefill response
|
||||
let (_prefill_status, prefill_body) = match self
|
||||
.process_prefill_response(prefill_result, prefill.url(), return_logprob)
|
||||
.await
|
||||
{
|
||||
Ok(result) => result,
|
||||
Err(error_response) => return error_response,
|
||||
};
|
||||
|
||||
// Process decode response
|
||||
debug!("Processing decode response");
|
||||
match decode_result {
|
||||
Ok(res) => {
|
||||
let status = StatusCode::from_u16(res.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
debug!("Decode response status: {}", status);
|
||||
|
||||
if !status.is_success() {
|
||||
RouterMetrics::record_pd_decode_error(decode.url());
|
||||
@@ -618,128 +692,36 @@ impl PDRouter {
|
||||
}
|
||||
}
|
||||
|
||||
// Log prefill errors for debugging
|
||||
if let Err(e) = &prefill_result {
|
||||
error!(
|
||||
"Prefill server failed (non-critical) prefill_url={} error={}",
|
||||
prefill.url(),
|
||||
e
|
||||
);
|
||||
RouterMetrics::record_pd_prefill_error(prefill.url());
|
||||
}
|
||||
|
||||
if is_stream {
|
||||
// Streaming response
|
||||
if return_logprob {
|
||||
// Get prefill logprobs for merging
|
||||
let prefill_logprobs =
|
||||
match prefill_result {
|
||||
Ok(prefill_res) => match prefill_res.bytes().await {
|
||||
Ok(body) => serde_json::from_slice::<Value>(&body)
|
||||
.ok()
|
||||
.and_then(|json| {
|
||||
json.pointer("/meta_info/input_token_logprobs").cloned()
|
||||
}),
|
||||
Err(_) => None,
|
||||
},
|
||||
Err(_) => None,
|
||||
};
|
||||
|
||||
// Stream with logprob merging
|
||||
let stream = res.bytes_stream();
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut stream = stream;
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
// Try to merge logprobs
|
||||
if let Ok(merged) = Self::merge_streaming_logprobs(
|
||||
prefill_logprobs.clone(),
|
||||
&chunk,
|
||||
) {
|
||||
if tx.send(Ok(merged)).is_err() {
|
||||
break;
|
||||
}
|
||||
} else {
|
||||
if tx.send(Ok(chunk)).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
let _ = tx.send(Err(format!("Stream error: {}", e)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = UnboundedReceiverStream::new(rx);
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
let mut response = Response::new(body);
|
||||
*response.status_mut() = status;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
response
|
||||
let prefill_logprobs = if return_logprob {
|
||||
prefill_body
|
||||
.as_ref()
|
||||
.and_then(|body| serde_json::from_slice::<Value>(body).ok())
|
||||
.and_then(|json| {
|
||||
json.pointer("/meta_info/input_token_logprobs").cloned()
|
||||
})
|
||||
} else {
|
||||
// No logprob merging needed
|
||||
let stream = res.bytes_stream();
|
||||
let decode_url = decode.url().to_string();
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
None
|
||||
};
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut stream = stream;
|
||||
while let Some(chunk) = stream.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
if tx.send(Ok(bytes)).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Stream error from decode server {}: {}",
|
||||
decode_url, e
|
||||
);
|
||||
RouterMetrics::record_pd_stream_error(&decode_url);
|
||||
let _ = tx.send(Err(format!("Stream error: {}", e)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
let decode_url = if !return_logprob {
|
||||
Some(decode.url().to_string())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let stream = UnboundedReceiverStream::new(rx);
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
let mut response = Response::new(body);
|
||||
*response.status_mut() = status;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
response
|
||||
}
|
||||
Self::create_streaming_response(
|
||||
res.bytes_stream(),
|
||||
status,
|
||||
prefill_logprobs,
|
||||
return_logprob,
|
||||
decode_url,
|
||||
)
|
||||
} else {
|
||||
// Non-streaming response
|
||||
match res.bytes().await {
|
||||
Ok(decode_body) => {
|
||||
if return_logprob {
|
||||
self.merge_logprobs(prefill_result, decode_body, status)
|
||||
.await
|
||||
} else {
|
||||
(status, decode_body).into_response()
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to read decode response: {}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response")
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
// Non-streaming response - use helper
|
||||
self.process_non_streaming_response(res, status, return_logprob, prefill_body)
|
||||
.await
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -758,62 +740,6 @@ impl PDRouter {
|
||||
}
|
||||
}
|
||||
|
||||
// Merge logprobs from prefill and decode responses
|
||||
async fn merge_logprobs(
|
||||
&self,
|
||||
prefill_result: Result<reqwest::Response, reqwest::Error>,
|
||||
decode_body: bytes::Bytes,
|
||||
status: StatusCode,
|
||||
) -> Response {
|
||||
match prefill_result {
|
||||
Ok(prefill_res) => {
|
||||
match prefill_res.bytes().await {
|
||||
Ok(prefill_body) => {
|
||||
match (
|
||||
serde_json::from_slice::<Value>(&prefill_body),
|
||||
serde_json::from_slice::<Value>(&decode_body),
|
||||
) {
|
||||
(Ok(prefill_json), Ok(mut decode_json)) => {
|
||||
// Merge input_token_logprobs
|
||||
if let (Some(prefill_meta), Some(decode_meta)) = (
|
||||
prefill_json.get("meta_info"),
|
||||
decode_json.get_mut("meta_info"),
|
||||
) {
|
||||
if let (Some(prefill_logprobs), Some(decode_logprobs)) = (
|
||||
prefill_meta.get("input_token_logprobs"),
|
||||
decode_meta.get_mut("input_token_logprobs"),
|
||||
) {
|
||||
if let (Some(p_arr), Some(d_arr)) = (
|
||||
prefill_logprobs.as_array(),
|
||||
decode_logprobs.as_array(),
|
||||
) {
|
||||
let mut merged = p_arr.clone();
|
||||
merged.extend(d_arr.clone());
|
||||
decode_meta["input_token_logprobs"] =
|
||||
Value::Array(merged);
|
||||
}
|
||||
}
|
||||
}
|
||||
let mut response = Json(decode_json).into_response();
|
||||
*response.status_mut() = status;
|
||||
response
|
||||
}
|
||||
_ => {
|
||||
warn!("Failed to parse responses for logprob merging");
|
||||
(status, decode_body).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to read prefill response: {}", e);
|
||||
(status, decode_body).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
Err(_) => (status, decode_body).into_response(),
|
||||
}
|
||||
}
|
||||
|
||||
// Select a pair of prefill and decode servers
|
||||
async fn select_pd_pair(
|
||||
&self,
|
||||
@@ -900,6 +826,229 @@ impl PDRouter {
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to create a streaming response
|
||||
fn create_streaming_response(
|
||||
stream: impl futures_util::Stream<Item = Result<bytes::Bytes, reqwest::Error>> + Send + 'static,
|
||||
status: StatusCode,
|
||||
prefill_logprobs: Option<Value>,
|
||||
return_logprob: bool,
|
||||
decode_url: Option<String>,
|
||||
) -> Response {
|
||||
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||
|
||||
tokio::spawn(async move {
|
||||
futures_util::pin_mut!(stream);
|
||||
while let Some(chunk_result) = stream.next().await {
|
||||
match chunk_result {
|
||||
Ok(chunk) => {
|
||||
let result = if return_logprob && prefill_logprobs.is_some() {
|
||||
// Try to merge logprobs
|
||||
Self::merge_streaming_logprobs(prefill_logprobs.clone(), &chunk)
|
||||
.unwrap_or(chunk)
|
||||
} else {
|
||||
chunk
|
||||
};
|
||||
|
||||
if tx.send(Ok(result)).is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
if let Some(ref url) = decode_url {
|
||||
error!("Stream error from decode server {}: {}", url, e);
|
||||
RouterMetrics::record_pd_stream_error(url);
|
||||
}
|
||||
let _ = tx.send(Err(format!("Stream error: {}", e)));
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let stream = UnboundedReceiverStream::new(rx);
|
||||
let body = Body::from_stream(stream);
|
||||
|
||||
let mut response = Response::new(body);
|
||||
*response.status_mut() = status;
|
||||
response
|
||||
.headers_mut()
|
||||
.insert(CONTENT_TYPE, HeaderValue::from_static("text/event-stream"));
|
||||
response
|
||||
}
|
||||
|
||||
// Helper to process non-streaming decode response with logprob merging
|
||||
async fn process_non_streaming_response(
|
||||
&self,
|
||||
res: reqwest::Response,
|
||||
status: StatusCode,
|
||||
return_logprob: bool,
|
||||
prefill_body: Option<bytes::Bytes>,
|
||||
) -> Response {
|
||||
match res.bytes().await {
|
||||
Ok(decode_body) => {
|
||||
if return_logprob && prefill_body.is_some() {
|
||||
// Merge logprobs from prefill and decode
|
||||
let prefill_body = prefill_body.as_ref().unwrap();
|
||||
match (
|
||||
serde_json::from_slice::<Value>(prefill_body),
|
||||
serde_json::from_slice::<Value>(&decode_body),
|
||||
) {
|
||||
(Ok(prefill_json), Ok(mut decode_json)) => {
|
||||
// Use helper to merge logprobs
|
||||
Self::merge_logprobs_in_json(&prefill_json, &mut decode_json);
|
||||
|
||||
// Return merged response
|
||||
match serde_json::to_vec(&decode_json) {
|
||||
Ok(body) => (status, body).into_response(),
|
||||
Err(e) => {
|
||||
error!("Failed to serialize merged response: {}", e);
|
||||
(status, decode_body).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
// If parsing fails, just return decode response
|
||||
warn!("Failed to parse responses for logprob merging");
|
||||
(status, decode_body).into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
(status, decode_body).into_response()
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Failed to read decode response: {}", e);
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "Failed to read response").into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Helper to process prefill response and extract body if needed for logprobs
|
||||
async fn process_prefill_response(
|
||||
&self,
|
||||
prefill_result: Result<reqwest::Response, reqwest::Error>,
|
||||
prefill_url: &str,
|
||||
return_logprob: bool,
|
||||
) -> Result<(StatusCode, Option<bytes::Bytes>), Response> {
|
||||
// Check prefill result first - it's critical for disaggregated mode
|
||||
let prefill_response = match prefill_result {
|
||||
Ok(response) => response,
|
||||
Err(e) => {
|
||||
RouterMetrics::record_pd_prefill_error(prefill_url);
|
||||
error!(
|
||||
"Prefill server failed (CRITICAL) prefill_url={} error={}. Decode will timeout without prefill KV cache.",
|
||||
prefill_url,
|
||||
e
|
||||
);
|
||||
|
||||
// Return error immediately - don't wait for decode to timeout
|
||||
return Err((
|
||||
StatusCode::BAD_GATEWAY,
|
||||
format!(
|
||||
"Prefill server error: {}. This will cause decode timeout.",
|
||||
e
|
||||
),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
};
|
||||
|
||||
let prefill_status = StatusCode::from_u16(prefill_response.status().as_u16())
|
||||
.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
// Check if prefill succeeded
|
||||
if !prefill_status.is_success() {
|
||||
RouterMetrics::record_pd_prefill_error(prefill_url);
|
||||
|
||||
// Get error body from prefill
|
||||
let error_msg = prefill_response
|
||||
.text()
|
||||
.await
|
||||
.unwrap_or_else(|_| "Unknown prefill error".to_string());
|
||||
|
||||
error!(
|
||||
"Prefill server returned error status prefill_url={} status={} body={}",
|
||||
prefill_url, prefill_status, error_msg
|
||||
);
|
||||
|
||||
return Err((
|
||||
prefill_status,
|
||||
format!("Prefill server error ({}): {}", prefill_status, error_msg),
|
||||
)
|
||||
.into_response());
|
||||
}
|
||||
|
||||
// Read prefill body if needed for logprob merging
|
||||
let prefill_body = if return_logprob {
|
||||
match prefill_response.bytes().await {
|
||||
Ok(body) => Some(body),
|
||||
Err(e) => {
|
||||
warn!("Failed to read prefill response body for logprobs: {}", e);
|
||||
None
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// For non-logprob requests, just consume the response without storing
|
||||
debug!("Consuming prefill response body (non-logprob request)");
|
||||
match prefill_response.bytes().await {
|
||||
Ok(_) => debug!("Prefill response consumed successfully"),
|
||||
Err(e) => warn!("Error consuming prefill response: {}", e),
|
||||
}
|
||||
None
|
||||
};
|
||||
|
||||
Ok((prefill_status, prefill_body))
|
||||
}
|
||||
|
||||
// Helper to build a request with headers copied from the original request
|
||||
fn build_request_with_headers(
|
||||
&self,
|
||||
url: &str,
|
||||
route: &str,
|
||||
json_request: &Value,
|
||||
headers: Option<&HeaderMap>,
|
||||
) -> reqwest::RequestBuilder {
|
||||
let mut request = self.client.post(api_path(url, route)).json(json_request);
|
||||
|
||||
// Copy headers from original request (excluding content-type and content-length which are set by .json())
|
||||
if let Some(headers) = headers {
|
||||
for (name, value) in headers.iter() {
|
||||
let name_str = name.as_str();
|
||||
if name_str != "content-type" && name_str != "content-length" {
|
||||
// Skip headers with non-ASCII values
|
||||
if value.to_str().is_ok() {
|
||||
request = request.header(name, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
request
|
||||
}
|
||||
|
||||
// Helper to merge logprobs from prefill and decode responses
|
||||
fn merge_logprobs_in_json(prefill_json: &Value, decode_json: &mut Value) -> bool {
|
||||
if let (Some(prefill_meta), Some(decode_meta)) = (
|
||||
prefill_json.get("meta_info"),
|
||||
decode_json.get_mut("meta_info"),
|
||||
) {
|
||||
if let (Some(prefill_logprobs), Some(decode_logprobs)) = (
|
||||
prefill_meta.get("input_token_logprobs"),
|
||||
decode_meta.get_mut("input_token_logprobs"),
|
||||
) {
|
||||
if let (Some(prefill_arr), Some(decode_arr)) =
|
||||
(prefill_logprobs.as_array(), decode_logprobs.as_array_mut())
|
||||
{
|
||||
let mut merged = prefill_arr.clone();
|
||||
merged.extend(decode_arr.clone());
|
||||
decode_meta["input_token_logprobs"] = Value::Array(merged);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
// Simple helper to merge logprobs in streaming responses
|
||||
fn merge_streaming_logprobs(
|
||||
prefill_logprobs: Option<Value>,
|
||||
@@ -1316,7 +1465,6 @@ impl PDRouter {
|
||||
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for PDRouter {
|
||||
@@ -1558,6 +1706,7 @@ mod tests {
|
||||
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
||||
load_monitor_handle: None,
|
||||
client: Client::new(),
|
||||
retry_config: RetryConfig::default(),
|
||||
_prefill_health_checker: None,
|
||||
_decode_health_checker: None,
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{HealthChecker, Worker, WorkerFactory};
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
@@ -11,6 +12,7 @@ use axum::{
|
||||
Json,
|
||||
};
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::thread;
|
||||
@@ -39,6 +41,7 @@ pub struct Router {
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||
_health_checker: Option<HealthChecker>,
|
||||
@@ -54,6 +57,7 @@ impl Router {
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
retry_config: RetryConfig,
|
||||
) -> Result<Self, String> {
|
||||
// Update active workers gauge
|
||||
RouterMetrics::set_active_workers(worker_urls.len());
|
||||
@@ -120,6 +124,7 @@ impl Router {
|
||||
interval_secs,
|
||||
dp_aware,
|
||||
api_key,
|
||||
retry_config,
|
||||
_worker_loads: worker_loads,
|
||||
_load_monitor_handle: load_monitor_handle,
|
||||
_health_checker: Some(health_checker),
|
||||
@@ -141,6 +146,12 @@ impl Router {
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
) -> Result<(), String> {
|
||||
if worker_urls.is_empty() {
|
||||
return Err(
|
||||
"Timeout waiting for workers to become healthy: no workers provided".to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let sync_client = reqwest::blocking::Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
@@ -365,11 +376,13 @@ impl Router {
|
||||
) -> Response {
|
||||
// Handle retries like the original implementation
|
||||
let start = Instant::now();
|
||||
const MAX_REQUEST_RETRIES: u32 = 3;
|
||||
const MAX_TOTAL_RETRIES: u32 = 6;
|
||||
// Use retry config for per-worker retries
|
||||
let max_request_retries = self.retry_config.max_retries;
|
||||
// Total retries across all workers (2x to allow trying multiple workers)
|
||||
let max_total_retries = self.retry_config.max_retries * 2;
|
||||
let mut total_retries = 0;
|
||||
|
||||
while total_retries < MAX_TOTAL_RETRIES {
|
||||
while total_retries < max_total_retries {
|
||||
// Extract routing text directly from typed request
|
||||
let text = typed_req.extract_text_for_routing();
|
||||
let is_stream = typed_req.is_stream();
|
||||
@@ -379,7 +392,7 @@ impl Router {
|
||||
let mut request_retries = 0;
|
||||
|
||||
// Try the same worker multiple times
|
||||
while request_retries < MAX_REQUEST_RETRIES {
|
||||
while request_retries < max_request_retries {
|
||||
if total_retries >= 1 {
|
||||
info!("Retrying request after {} failed attempts", total_retries);
|
||||
RouterMetrics::record_retry(route);
|
||||
@@ -429,13 +442,13 @@ impl Router {
|
||||
route,
|
||||
worker_url,
|
||||
request_retries + 1,
|
||||
MAX_REQUEST_RETRIES
|
||||
max_request_retries
|
||||
);
|
||||
|
||||
request_retries += 1;
|
||||
total_retries += 1;
|
||||
|
||||
if request_retries == MAX_REQUEST_RETRIES {
|
||||
if request_retries == max_request_retries {
|
||||
warn!(
|
||||
"Removing failed worker after typed request failures worker_url={}",
|
||||
worker_url
|
||||
@@ -1003,7 +1016,6 @@ impl Router {
|
||||
}
|
||||
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
|
||||
#[async_trait]
|
||||
impl WorkerManagement for Router {
|
||||
@@ -1210,6 +1222,7 @@ mod tests {
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
client: Client::new(),
|
||||
retry_config: RetryConfig::default(),
|
||||
_worker_loads: Arc::new(rx),
|
||||
_load_monitor_handle: None,
|
||||
_health_checker: None,
|
||||
@@ -1237,8 +1250,10 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_wait_for_healthy_workers_empty_list() {
|
||||
// Empty list will timeout as there are no workers to check
|
||||
let result = Router::wait_for_healthy_workers(&[], 1, 1);
|
||||
assert!(result.is_ok());
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Timeout"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
||||
@@ -580,8 +580,17 @@ mod tests {
|
||||
use crate::routers::router::Router;
|
||||
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
||||
let router =
|
||||
Router::new(vec![], policy, reqwest::Client::new(), 5, 1, false, None).unwrap();
|
||||
let router = Router::new(
|
||||
vec![],
|
||||
policy,
|
||||
reqwest::Client::new(),
|
||||
5,
|
||||
1,
|
||||
false,
|
||||
None,
|
||||
crate::config::types::RetryConfig::default(),
|
||||
)
|
||||
.unwrap();
|
||||
Arc::new(router) as Arc<dyn RouterTrait>
|
||||
}
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ use axum::{
|
||||
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
||||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
|
||||
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
|
||||
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||
use std::sync::Arc;
|
||||
use tower::ServiceExt;
|
||||
@@ -44,6 +44,7 @@ impl TestContext {
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
};
|
||||
|
||||
Self::new_with_config(config, worker_configs).await
|
||||
@@ -1085,6 +1086,7 @@ mod error_tests {
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
};
|
||||
|
||||
let ctx = TestContext::new_with_config(
|
||||
@@ -1431,6 +1433,7 @@ mod pd_mode_tests {
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
};
|
||||
|
||||
// Create app context
|
||||
@@ -1584,6 +1587,7 @@ mod request_id_tests {
|
||||
request_id_headers: Some(vec!["custom-id".to_string(), "trace-id".to_string()]),
|
||||
max_concurrent_requests: 64,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
};
|
||||
|
||||
let ctx = TestContext::new_with_config(
|
||||
|
||||
@@ -3,7 +3,7 @@ mod common;
|
||||
use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType};
|
||||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
|
||||
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
|
||||
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -35,6 +35,7 @@ impl TestContext {
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
};
|
||||
|
||||
let mut workers = Vec::new();
|
||||
|
||||
@@ -4,7 +4,7 @@ use common::mock_worker::{HealthStatus, MockWorker, MockWorkerConfig, WorkerType
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
|
||||
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
|
||||
use sglang_router_rs::routers::{RouterFactory, RouterTrait};
|
||||
use std::sync::Arc;
|
||||
|
||||
@@ -36,6 +36,7 @@ impl TestContext {
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
};
|
||||
|
||||
let mut workers = Vec::new();
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
mod test_pd_routing {
|
||||
use rand::Rng;
|
||||
use serde_json::json;
|
||||
use sglang_router_rs::config::{PolicyConfig, RouterConfig, RoutingMode};
|
||||
use sglang_router_rs::config::{PolicyConfig, RetryConfig, RouterConfig, RoutingMode};
|
||||
use sglang_router_rs::core::{WorkerFactory, WorkerType};
|
||||
use sglang_router_rs::routers::pd_types::get_hostname;
|
||||
use sglang_router_rs::routers::pd_types::PDSelectionPolicy;
|
||||
@@ -178,6 +178,7 @@ mod test_pd_routing {
|
||||
request_id_headers: None,
|
||||
max_concurrent_requests: 64,
|
||||
cors_allowed_origins: vec![],
|
||||
retry: RetryConfig::default(),
|
||||
};
|
||||
|
||||
// Router creation will fail due to health checks, but config should be valid
|
||||
|
||||
Reference in New Issue
Block a user