[router] create worker removal step and clean up worker manager (#11921)

This commit is contained in:
Simo Lin
2025-10-22 13:26:06 -07:00
committed by GitHub
parent eec9e471ca
commit 5dccf69713
23 changed files with 758 additions and 1905 deletions

View File

@@ -14,7 +14,7 @@ use sglang_router_rs::{
config::{
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
},
core::WorkerManager,
core::Job,
routers::{RouterFactory, RouterTrait},
server::AppContext,
};
@@ -112,22 +112,51 @@ impl TestContext {
// Create app context
let app_context = common::create_test_context(config.clone());
// Initialize workers in the registry before creating router
// Submit worker initialization job (same as real server does)
if !worker_urls.is_empty() {
WorkerManager::initialize_workers(&config, &app_context.worker_registry, None)
let job_queue = app_context
.worker_job_queue
.get()
.expect("JobQueue should be initialized");
let job = Job::InitializeWorkersFromConfig {
router_config: Box::new(config.clone()),
};
job_queue
.submit(job)
.await
.expect("Failed to initialize workers");
.expect("Failed to submit worker initialization job");
// Poll until all workers are healthy (up to 10 seconds)
let expected_count = worker_urls.len();
let start = tokio::time::Instant::now();
let timeout_duration = tokio::time::Duration::from_secs(10);
loop {
let healthy_workers = app_context
.worker_registry
.get_all()
.iter()
.filter(|w| w.is_healthy())
.count();
if healthy_workers >= expected_count {
break;
}
if start.elapsed() > timeout_duration {
panic!(
"Timeout waiting for {} workers to become healthy (only {} ready)",
expected_count, healthy_workers
);
}
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
}
}
// Create router
let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router);
// Wait for router to discover workers
if !workers.is_empty() {
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
}
Self {
workers,
router,
@@ -711,221 +740,6 @@ mod model_info_tests {
}
}
#[cfg(test)]
mod worker_management_tests {
use super::*;
#[tokio::test]
async fn test_add_new_worker() {
let ctx = TestContext::new(vec![]).await;
let app = ctx.create_app().await;
// Start a mock worker
let mut worker = MockWorker::new(MockWorkerConfig {
port: 18301,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
});
let url = worker.start().await.unwrap();
// Add the worker
let req = Request::builder()
.method("POST")
.uri(format!("/add_worker?url={}", url))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
// List workers to verify
let req = Request::builder()
.method("GET")
.uri("/list_workers")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let workers = body_json["urls"].as_array().unwrap();
assert!(workers.iter().any(|w| w.as_str().unwrap() == url));
worker.stop().await;
ctx.shutdown().await;
}
#[tokio::test]
async fn test_remove_existing_worker() {
let ctx = TestContext::new(vec![MockWorkerConfig {
port: 18302,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
}])
.await;
let app = ctx.create_app().await;
// Get the worker URL
let req = Request::builder()
.method("GET")
.uri("/list_workers")
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let workers = body_json["urls"].as_array().unwrap();
let worker_url = workers[0].as_str().unwrap();
// Remove the worker
let req = Request::builder()
.method("POST")
.uri(format!("/remove_worker?url={}", worker_url))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
let req = Request::builder()
.method("GET")
.uri("/list_workers")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
.await
.unwrap();
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
let workers = body_json["urls"].as_array().unwrap();
assert!(workers.is_empty());
ctx.shutdown().await;
}
#[tokio::test]
async fn test_add_worker_invalid_url() {
let ctx = TestContext::new(vec![]).await;
let app = ctx.create_app().await;
// Invalid URL format
let req = Request::builder()
.method("POST")
.uri("/add_worker?url=not-a-valid-url")
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
// Missing URL parameter
let req = Request::builder()
.method("POST")
.uri("/add_worker")
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
// Empty URL
let req = Request::builder()
.method("POST")
.uri("/add_worker?url=")
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
ctx.shutdown().await;
}
#[tokio::test]
async fn test_add_duplicate_worker() {
// Start a mock worker
let mut worker = MockWorker::new(MockWorkerConfig {
port: 18303,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Healthy,
response_delay_ms: 0,
fail_rate: 0.0,
});
let url = worker.start().await.unwrap();
let ctx = TestContext::new(vec![]).await;
let app = ctx.create_app().await;
// Add worker first time
let req = Request::builder()
.method("POST")
.uri(format!("/add_worker?url={}", url))
.body(Body::empty())
.unwrap();
let resp = app.clone().oneshot(req).await.unwrap();
assert_eq!(resp.status(), StatusCode::OK);
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
// Try to add same worker again
let req = Request::builder()
.method("POST")
.uri(format!("/add_worker?url={}", url))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
// Should return error for duplicate
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
worker.stop().await;
ctx.shutdown().await;
}
#[tokio::test]
async fn test_add_unhealthy_worker() {
// Start unhealthy worker
let mut worker = MockWorker::new(MockWorkerConfig {
port: 18304,
worker_type: WorkerType::Regular,
health_status: HealthStatus::Unhealthy,
response_delay_ms: 0,
fail_rate: 0.0,
});
let url = worker.start().await.unwrap();
let ctx = TestContext::new(vec![]).await;
let app = ctx.create_app().await;
// Try to add unhealthy worker
let req = Request::builder()
.method("POST")
.uri(format!("/add_worker?url={}", url))
.body(Body::empty())
.unwrap();
let resp = app.oneshot(req).await.unwrap();
// Router should reject unhealthy workers
assert!(
resp.status() == StatusCode::BAD_REQUEST
|| resp.status() == StatusCode::SERVICE_UNAVAILABLE
);
worker.stop().await;
ctx.shutdown().await;
}
}
#[cfg(test)]
mod router_policy_tests {
use super::*;

View File

@@ -66,7 +66,7 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
let worker_job_queue = Arc::new(OnceLock::new());
let workflow_engine = Arc::new(OnceLock::new());
Arc::new(AppContext::new(
let app_context = Arc::new(AppContext::new(
config,
client,
rate_limiter,
@@ -81,7 +81,32 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
load_monitor,
worker_job_queue,
workflow_engine,
))
));
// Initialize JobQueue after AppContext is created
let weak_context = Arc::downgrade(&app_context);
let job_queue = sglang_router_rs::core::JobQueue::new(
sglang_router_rs::core::JobQueueConfig::default(),
weak_context,
);
app_context
.worker_job_queue
.set(job_queue)
.expect("JobQueue should only be initialized once");
// Initialize WorkflowEngine and register workflows
use sglang_router_rs::core::workflow::{
create_worker_registration_workflow, create_worker_removal_workflow, WorkflowEngine,
};
let engine = Arc::new(WorkflowEngine::new());
engine.register_workflow(create_worker_registration_workflow());
engine.register_workflow(create_worker_removal_workflow());
app_context
.workflow_engine
.set(engine)
.expect("WorkflowEngine should only be initialized once");
app_context
}
// Tokenizer download configuration

View File

@@ -7,7 +7,6 @@ use reqwest::Client;
use serde_json::json;
use sglang_router_rs::{
config::{RouterConfig, RoutingMode},
core::WorkerManager,
routers::{RouterFactory, RouterTrait},
};
@@ -51,13 +50,6 @@ impl TestContext {
let app_context = common::create_test_context(config.clone());
// Initialize workers in the registry before creating router
if !worker_urls.is_empty() {
WorkerManager::initialize_workers(&config, &app_context.worker_registry, None)
.await
.expect("Failed to initialize workers");
}
let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router);

View File

@@ -8,7 +8,6 @@ use reqwest::Client;
use serde_json::json;
use sglang_router_rs::{
config::{RouterConfig, RoutingMode},
core::WorkerManager,
routers::{RouterFactory, RouterTrait},
};
@@ -52,13 +51,6 @@ impl TestContext {
let app_context = common::create_test_context(config.clone());
// Initialize workers in the registry before creating router
if !worker_urls.is_empty() {
WorkerManager::initialize_workers(&config, &app_context.worker_registry, None)
.await
.expect("Failed to initialize workers");
}
let router = RouterFactory::create_router(&app_context).await.unwrap();
let router = Arc::from(router);