[router] fix worker registration in multi model mode (#10486)

This commit is contained in:
Chang Su
2025-09-15 18:05:00 -07:00
committed by GitHub
parent 31fb19a0a2
commit 35ef3f2902
3 changed files with 69 additions and 54 deletions

View File

@@ -317,8 +317,9 @@ async fn create_worker(
State(state): State<Arc<AppState>>,
Json(config): Json<WorkerConfigRequest>,
) -> Response {
// Check if RouterManager is available (enable_igw=true)
if let Some(router_manager) = &state.context.router_manager {
// Check if the router is actually a RouterManager (enable_igw=true)
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
// Call RouterManager's add_worker method directly with the full config
match router_manager.add_worker(config).await {
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
@@ -347,7 +348,7 @@ async fn create_worker(
/// GET /workers - List all workers with details
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
if let Some(router_manager) = &state.context.router_manager {
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
let response = router_manager.list_workers();
Json(response).into_response()
} else {
@@ -358,7 +359,11 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
let mut worker_info = serde_json::json!({
"url": worker.url(),
"model_id": worker.model_id(),
"worker_type": format!("{:?}", worker.worker_type()),
"worker_type": match worker.worker_type() {
WorkerType::Regular => "regular",
WorkerType::Prefill { .. } => "prefill",
WorkerType::Decode => "decode",
},
"is_healthy": worker.is_healthy(),
"load": worker.load(),
"connection_mode": format!("{:?}", worker.connection_mode()),
@@ -386,7 +391,7 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
/// GET /workers/{url} - Get specific worker info
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
if let Some(router_manager) = &state.context.router_manager {
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
if let Some(worker) = router_manager.get_worker(&url) {
Json(worker).into_response()
} else {
@@ -417,7 +422,7 @@ async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>)
/// DELETE /workers/{url} - Remove a worker
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
if let Some(router_manager) = &state.context.router_manager {
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
match router_manager.remove_worker_from_registry(&url) {
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
@@ -603,7 +608,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
router_manager.register_router(
RouterId::new("http-regular".to_string()),
Arc::from(http_regular),
vec![], // Models will be determined by workers
);
}
Err(e) => {
@@ -624,11 +628,8 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
{
Ok(http_pd) => {
info!("Created HTTP PD router");
router_manager.register_router(
RouterId::new("http-pd".to_string()),
Arc::from(http_pd),
vec![],
);
router_manager
.register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd));
}
Err(e) => {
warn!("Failed to create HTTP PD router: {e}");