[router] Add Rust Binary Entrypoint for SGLang Router (#9089)
This commit is contained in:
@@ -11,29 +11,32 @@ pub struct RouterFactory;
|
||||
|
||||
impl RouterFactory {
|
||||
/// Create a router instance from application context
|
||||
pub fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||
match &ctx.router_config.mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
|
||||
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx).await
|
||||
}
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
} => Self::create_pd_router(
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy.as_ref(),
|
||||
decode_policy.as_ref(),
|
||||
&ctx.router_config.policy,
|
||||
ctx,
|
||||
),
|
||||
} => {
|
||||
Self::create_pd_router(
|
||||
prefill_urls,
|
||||
decode_urls,
|
||||
prefill_policy.as_ref(),
|
||||
decode_policy.as_ref(),
|
||||
&ctx.router_config.policy,
|
||||
ctx,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a regular router with injected policy
|
||||
fn create_regular_router(
|
||||
async fn create_regular_router(
|
||||
worker_urls: &[String],
|
||||
policy_config: &PolicyConfig,
|
||||
ctx: &Arc<AppContext>,
|
||||
@@ -52,13 +55,14 @@ impl RouterFactory {
|
||||
ctx.router_config.api_key.clone(),
|
||||
ctx.router_config.retry.clone(),
|
||||
ctx.router_config.circuit_breaker.clone(),
|
||||
)?;
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create a PD router with injected policy
|
||||
fn create_pd_router(
|
||||
async fn create_pd_router(
|
||||
prefill_urls: &[(String, Option<u16>)],
|
||||
decode_urls: &[String],
|
||||
prefill_policy_config: Option<&PolicyConfig>,
|
||||
@@ -83,7 +87,8 @@ impl RouterFactory {
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
ctx.router_config.retry.clone(),
|
||||
ctx.router_config.circuit_breaker.clone(),
|
||||
)?;
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
@@ -67,6 +67,7 @@ impl PDRouter {
|
||||
self.timeout_secs,
|
||||
self.interval_secs,
|
||||
)
|
||||
.await
|
||||
.map_err(|_| PDRouterError::HealthCheckFailed {
|
||||
url: url.to_string(),
|
||||
})
|
||||
@@ -349,7 +350,7 @@ impl PDRouter {
|
||||
Ok(format!("Successfully removed decode server: {}", url))
|
||||
}
|
||||
|
||||
pub fn new(
|
||||
pub async fn new(
|
||||
prefill_urls: Vec<(String, Option<u16>)>,
|
||||
decode_urls: Vec<String>,
|
||||
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
@@ -392,7 +393,8 @@ impl PDRouter {
|
||||
&all_urls,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
)?;
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
|
||||
// Initialize cache-aware policies with workers
|
||||
|
||||
@@ -17,7 +17,6 @@ use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::thread;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
@@ -52,7 +51,7 @@ pub struct Router {
|
||||
|
||||
impl Router {
|
||||
/// Create a new router with injected policy and client
|
||||
pub fn new(
|
||||
pub async fn new(
|
||||
worker_urls: Vec<String>,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
client: Client,
|
||||
@@ -68,7 +67,7 @@ impl Router {
|
||||
|
||||
// Wait for workers to be healthy (skip if empty - for service discovery mode)
|
||||
if !worker_urls.is_empty() {
|
||||
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
|
||||
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs).await?;
|
||||
}
|
||||
|
||||
let worker_urls = if dp_aware {
|
||||
@@ -156,7 +155,7 @@ impl Router {
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn wait_for_healthy_workers(
|
||||
pub async fn wait_for_healthy_workers(
|
||||
worker_urls: &[String],
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
@@ -167,9 +166,24 @@ impl Router {
|
||||
);
|
||||
}
|
||||
|
||||
// Perform health check asynchronously
|
||||
Self::wait_for_healthy_workers_async(worker_urls, timeout_secs, interval_secs).await
|
||||
}
|
||||
|
||||
async fn wait_for_healthy_workers_async(
|
||||
worker_urls: &[String],
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
) -> Result<(), String> {
|
||||
info!(
|
||||
"Waiting for {} workers to become healthy (timeout: {}s)",
|
||||
worker_urls.len(),
|
||||
timeout_secs
|
||||
);
|
||||
|
||||
let start_time = std::time::Instant::now();
|
||||
let sync_client = reqwest::blocking::Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
let client = reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(2))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
@@ -185,20 +199,48 @@ impl Router {
|
||||
));
|
||||
}
|
||||
|
||||
// Perform all health checks concurrently
|
||||
let mut health_checks = Vec::new();
|
||||
for url in worker_urls {
|
||||
let client_clone = client.clone();
|
||||
let url_clone = url.clone();
|
||||
|
||||
let check_health = tokio::spawn(async move {
|
||||
let health_url = format!("{}/health", url_clone);
|
||||
match client_clone.get(&health_url).send().await {
|
||||
Ok(res) => {
|
||||
if res.status().is_success() {
|
||||
None
|
||||
} else {
|
||||
Some((url_clone, format!("status: {}", res.status())))
|
||||
}
|
||||
}
|
||||
Err(_) => Some((url_clone, "not ready".to_string())),
|
||||
}
|
||||
});
|
||||
|
||||
health_checks.push(check_health);
|
||||
}
|
||||
|
||||
// Wait for all health checks to complete
|
||||
let results = futures::future::join_all(health_checks).await;
|
||||
|
||||
let mut all_healthy = true;
|
||||
let mut unhealthy_workers = Vec::new();
|
||||
|
||||
for url in worker_urls {
|
||||
match sync_client.get(&format!("{}/health", url)).send() {
|
||||
Ok(res) => {
|
||||
if !res.status().is_success() {
|
||||
all_healthy = false;
|
||||
unhealthy_workers.push((url, format!("status: {}", res.status())));
|
||||
}
|
||||
for result in results {
|
||||
match result {
|
||||
Ok(None) => {
|
||||
// Worker is healthy
|
||||
}
|
||||
Err(_) => {
|
||||
Ok(Some((url, reason))) => {
|
||||
all_healthy = false;
|
||||
unhealthy_workers.push((url, "not ready".to_string()));
|
||||
unhealthy_workers.push((url, reason));
|
||||
}
|
||||
Err(e) => {
|
||||
all_healthy = false;
|
||||
unhealthy_workers
|
||||
.push(("unknown".to_string(), format!("task error: {}", e)));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -208,11 +250,12 @@ impl Router {
|
||||
return Ok(());
|
||||
} else {
|
||||
debug!(
|
||||
"Waiting for {} workers to become healthy ({} unhealthy)",
|
||||
"Waiting for {} workers to become healthy ({} unhealthy: {:?})",
|
||||
worker_urls.len(),
|
||||
unhealthy_workers.len()
|
||||
unhealthy_workers.len(),
|
||||
unhealthy_workers
|
||||
);
|
||||
thread::sleep(Duration::from_secs(interval_secs));
|
||||
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1246,19 +1289,19 @@ mod tests {
|
||||
assert_eq!(result.unwrap(), "http://worker1:8080");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wait_for_healthy_workers_empty_list() {
|
||||
// Empty list will timeout as there are no workers to check
|
||||
let result = Router::wait_for_healthy_workers(&[], 1, 1);
|
||||
#[tokio::test]
|
||||
async fn test_wait_for_healthy_workers_empty_list() {
|
||||
// Empty list will return error immediately
|
||||
let result = Router::wait_for_healthy_workers(&[], 1, 1).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Timeout"));
|
||||
assert!(result.unwrap_err().contains("no workers provided"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_wait_for_healthy_workers_invalid_urls() {
|
||||
#[tokio::test]
|
||||
async fn test_wait_for_healthy_workers_invalid_urls() {
|
||||
// This test will timeout quickly since the URLs are invalid
|
||||
let result =
|
||||
Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1);
|
||||
Router::wait_for_healthy_workers(&["http://nonexistent:8080".to_string()], 1, 1).await;
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("Timeout"));
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user