[router] Add Rust Binary Entrypoint for SGLang Router (#9089)

This commit is contained in:
Simo Lin
2025-08-11 21:37:36 -07:00
committed by GitHub
parent a218490136
commit 9d68bdb240
12 changed files with 638 additions and 78 deletions

View File

@@ -11,29 +11,32 @@ pub struct RouterFactory;
impl RouterFactory {
/// Create a router instance from application context
pub fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
match &ctx.router_config.mode {
RoutingMode::Regular { worker_urls } => {
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await
}
RoutingMode::PrefillDecode {
prefill_urls,
decode_urls,
prefill_policy,
decode_policy,
} => Self::create_pd_router(
prefill_urls,
decode_urls,
prefill_policy.as_ref(),
decode_policy.as_ref(),
&ctx.router_config.policy,
ctx,
),
} => {
Self::create_pd_router(
prefill_urls,
decode_urls,
prefill_policy.as_ref(),
decode_policy.as_ref(),
&ctx.router_config.policy,
ctx,
)
.await
}
}
}
/// Create a regular router with injected policy
fn create_regular_router(
async fn create_regular_router(
worker_urls: &[String],
policy_config: &PolicyConfig,
ctx: &Arc<AppContext>,
@@ -52,13 +55,14 @@ impl RouterFactory {
ctx.router_config.api_key.clone(),
ctx.router_config.retry.clone(),
ctx.router_config.circuit_breaker.clone(),
)?;
)
.await?;
Ok(Box::new(router))
}
/// Create a PD router with injected policy
fn create_pd_router(
async fn create_pd_router(
prefill_urls: &[(String, Option<u16>)],
decode_urls: &[String],
prefill_policy_config: Option<&PolicyConfig>,
@@ -83,7 +87,8 @@ impl RouterFactory {
ctx.router_config.worker_startup_check_interval_secs,
ctx.router_config.retry.clone(),
ctx.router_config.circuit_breaker.clone(),
)?;
)
.await?;
Ok(Box::new(router))
}

View File

@@ -67,6 +67,7 @@ impl PDRouter {
self.timeout_secs,
self.interval_secs,
)
.await
.map_err(|_| PDRouterError::HealthCheckFailed {
url: url.to_string(),
})
@@ -349,7 +350,7 @@ impl PDRouter {
Ok(format!("Successfully removed decode server: {}", url))
}
pub fn new(
pub async fn new(
prefill_urls: Vec<(String, Option<u16>)>,
decode_urls: Vec<String>,
prefill_policy: Arc<dyn LoadBalancingPolicy>,
@@ -392,7 +393,8 @@ impl PDRouter {
&all_urls,
timeout_secs,
interval_secs,
)?;
)
.await?;
}
// Initialize cache-aware policies with workers

View File

@@ -17,7 +17,6 @@ use futures_util::StreamExt;
use reqwest::Client;
use std::collections::HashMap;
use std::sync::{Arc, RwLock};
use std::thread;
use std::time::{Duration, Instant};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tracing::{debug, error, info, warn};
@@ -52,7 +51,7 @@ pub struct Router {
impl Router {
/// Create a new router with injected policy and client
pub fn new(
pub async fn new(
worker_urls: Vec<String>,
policy: Arc<dyn LoadBalancingPolicy>,
client: Client,
@@ -68,7 +67,7 @@ impl Router {
// Wait for workers to be healthy (skip if empty - for service discovery mode)
if !worker_urls.is_empty() {
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs).await?;
}
let worker_urls = if dp_aware {
@@ -156,7 +155,7 @@ impl Router {
.collect()
}
pub fn wait_for_healthy_workers(
pub async fn wait_for_healthy_workers(
worker_urls: &[String],
timeout_secs: u64,
interval_secs: u64,
@@ -167,9 +166,24 @@ impl Router {
);
}
// Perform health check asynchronously
Self::wait_for_healthy_workers_async(worker_urls, timeout_secs, interval_secs).await
}
async fn wait_for_healthy_workers_async(
worker_urls: &[String],
timeout_secs: u64,
interval_secs: u64,
) -> Result<(), String> {
info!(
"Waiting for {} workers to become healthy (timeout: {}s)",
worker_urls.len(),
timeout_secs
);
let start_time = std::time::Instant::now();
let sync_client = reqwest::blocking::Client::builder()
.timeout(Duration::from_secs(timeout_secs))
let client = reqwest::Client::builder()
.timeout(Duration::from_secs(2))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
@@ -185,20 +199,48 @@ impl Router {
));
}
// Perform all health checks concurrently
let mut health_checks = Vec::new();
for url in worker_urls {
let client_clone = client.clone();
let url_clone = url.clone();
let check_health = tokio::spawn(async move {
let health_url = format!("{}/health", url_clone);
match client_clone.get(&health_url).send().await {
Ok(res) => {
if res.status().is_success() {
None
} else {
Some((url_clone, format!("status: {}", res.status())))
}
}
Err(_) => Some((url_clone, "not ready".to_string())),
}
});
health_checks.push(check_health);
}
// Wait for all health checks to complete
let results = futures::future::join_all(health_checks).await;
let mut all_healthy = true;
let mut unhealthy_workers = Vec::new();
for url in worker_urls {
match sync_client.get(&format!("{}/health", url)).send() {
Ok(res) => {
if !res.status().is_success() {
all_healthy = false;
unhealthy_workers.push((url, format!("status: {}", res.status())));
}
for result in results {
match result {
Ok(None) => {
// Worker is healthy
}
Err(_) => {
Ok(Some((url, reason))) => {
all_healthy = false;
unhealthy_workers.push((url, "not ready".to_string()));
unhealthy_workers.push((url, reason));
}
Err(e) => {
all_healthy = false;
unhealthy_workers
.push(("unknown".to_string(), format!("task error: {}", e)));
}
}
}
@@ -208,11 +250,12 @@ impl Router {
return Ok(());
} else {
debug!(
"Waiting for {} workers to become healthy ({} unhealthy)",
"Waiting for {} workers to become healthy ({} unhealthy: {:?})",
worker_urls.len(),
unhealthy_workers.len()
unhealthy_workers.len(),
unhealthy_workers
);
thread::sleep(Duration::from_secs(interval_secs));
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
}
}
}
@@ -1246,19 +1289,19 @@ mod tests {
assert_eq!(result.unwrap(), "http://worker1:8080");
}
#[test]
fn test_wait_for_healthy_workers_empty_list() {
// Empty list will timeout as there are no workers to check
let result = Router::wait_for_healthy_workers(&[], 1, 1);
#[tokio::test]
async fn test_wait_for_healthy_workers_empty_list() {
// Empty list will return error immediately
let result = Router::wait_for_healthy_workers(&[], 1, 1).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("Timeout"));
assert!(result.unwrap_err().contains("no workers provided"));
}
#[test]
fn test_wait_for_healthy_workers_invalid_urls() {
#[tokio::test]
async fn test_wait_for_healthy_workers_invalid_urls() {
// This test will timeout quickly since the URLs are invalid
let result =
Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1);
Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1).await;
assert!(result.is_err());
assert!(result.unwrap_err().contains("Timeout"));
}