From 40e0082d8d921ee57e267aaa306ba5d605040577 Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Tue, 14 Oct 2025 22:07:25 -0400 Subject: [PATCH] [router] add worker self discovery for metadata (#11638) --- sgl-router/src/core/worker_manager.rs | 417 ++++++++++++++++++++++++-- 1 file changed, 400 insertions(+), 17 deletions(-) diff --git a/sgl-router/src/core/worker_manager.rs b/sgl-router/src/core/worker_manager.rs index 8e92e9aaf..4019ea239 100644 --- a/sgl-router/src/core/worker_manager.rs +++ b/sgl-router/src/core/worker_manager.rs @@ -11,6 +11,7 @@ use crate::core::{ BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder, HealthConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType, }; +use crate::grpc_client::SglangSchedulerClient; use crate::policies::PolicyRegistry; use crate::protocols::worker_spec::{ FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult, @@ -21,6 +22,7 @@ use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::collections::HashMap; +use std::path::Path; use std::sync::Arc; use std::time::Duration; use tokio::sync::{watch, Mutex}; @@ -55,6 +57,21 @@ pub struct DpInfo { pub model_id: String, } +/// Worker discovery results gathered from backend endpoints +struct WorkerDiscovery { + labels: HashMap, + grpc_client: Option, +} + +impl WorkerDiscovery { + fn new() -> Self { + Self { + labels: HashMap::new(), + grpc_client: None, + } + } +} + /// Unified worker management pub struct WorkerManager; @@ -318,7 +335,8 @@ impl WorkerManager { None, circuit_breaker_config.clone(), health_config.clone(), - ); + ) + .await; Self::register_worker(worker, registry, &mut registered_workers, policy_registry); } } @@ -363,7 +381,8 @@ impl WorkerManager { None, circuit_breaker_config.clone(), health_config.clone(), - ); + ) + .await; Self::register_worker(worker, registry, &mut registered_workers, policy_registry); } @@ -408,7 +427,8 @@ impl WorkerManager { None, circuit_breaker_config.clone(), health_config.clone(), - ); + ) + .await; Self::register_worker(worker, registry, &mut registered_workers, policy_registry); } @@ -448,7 +468,8 @@ impl WorkerManager { None, circuit_breaker_config.clone(), health_config.clone(), - ); + ) + .await; Self::register_worker(worker, registry, &mut registered_workers, policy_registry); info!( "Registered gRPC worker at {} (will connect on first use)", @@ -497,7 +518,8 @@ impl WorkerManager { None, circuit_breaker_config.clone(), health_config.clone(), - ); + ) + .await; Self::register_worker( worker, registry, @@ -522,7 +544,8 @@ impl WorkerManager { None, circuit_breaker_config.clone(), health_config.clone(), - ); + ) + .await; Self::register_worker( worker, registry, @@ -563,12 +586,9 @@ impl WorkerManager { } let mut labels = config.labels.clone(); - // Use provided model_id or default to "unknown" - let model_id = config - .model_id - .clone() - .unwrap_or_else(|| "unknown".to_string()); - labels.insert("model_id".to_string(), model_id.clone()); + if let Some(model_id) = &config.model_id { + labels.insert("model_id".to_string(), model_id.clone()); + } if let Some(priority) = config.priority { labels.insert("priority".to_string(), priority.to_string()); } @@ -620,12 +640,14 @@ impl WorkerManager { Some(labels.clone()), circuit_breaker_config, health_config, - ); + ) + .await; worker.set_healthy(false); context.worker_registry.register(worker.clone()); let policy_hint = labels.get("policy").map(|s| s.as_str()); + let model_id = worker.model_id().to_string(); context .policy_registry .on_worker_added(&model_id, policy_hint); @@ -793,7 +815,8 @@ impl WorkerManager { labels, circuit_breaker_config, health_config, - ); + ) + .await; let model_id = worker.model_id().to_string(); context.worker_registry.register(worker.clone()); @@ -893,7 +916,7 @@ impl WorkerManager { } /// Create a basic worker - fn create_basic_worker( + async fn create_basic_worker( url: String, worker_type: WorkerType, connection_mode: ConnectionMode, @@ -902,6 +925,16 @@ impl WorkerManager { circuit_breaker_config: CircuitBreakerConfig, health_config: HealthConfig, ) -> Arc { + let discovery = + Self::discover_worker_metadata(&url, &connection_mode, api_key.as_deref()).await; + + let mut final_labels = discovery.labels; + if let Some(custom_labels) = labels { + for (key, value) in custom_labels { + final_labels.insert(key, value); + } + } + let mut builder = BasicWorkerBuilder::new(url) .worker_type(worker_type) .connection_mode(connection_mode) @@ -912,8 +945,12 @@ impl WorkerManager { builder = builder.api_key(key); } - if let Some(worker_labels) = labels { - builder = builder.labels(worker_labels); + if !final_labels.is_empty() { + builder = builder.labels(final_labels); + } + + if let Some(client) = discovery.grpc_client { + builder = builder.grpc_client(client); } let worker = builder.build(); @@ -1084,6 +1121,306 @@ impl WorkerManager { } } + /// Gather worker metadata directly from the backend before registration. + async fn discover_worker_metadata( + url: &str, + connection_mode: &ConnectionMode, + api_key: Option<&str>, + ) -> WorkerDiscovery { + match connection_mode { + ConnectionMode::Http => Self::discover_http_metadata(url, api_key).await, + ConnectionMode::Grpc { .. } => Self::discover_grpc_metadata(url).await, + } + } + + async fn discover_http_metadata(url: &str, api_key: Option<&str>) -> WorkerDiscovery { + let mut discovery = WorkerDiscovery::new(); + + match Self::get_model_info(url, api_key).await { + Ok(model_info) => { + if let Some(model_path) = model_info.get("model_path").and_then(|v| v.as_str()) { + if !model_path.is_empty() { + discovery + .labels + .insert("model_path".to_string(), model_path.to_string()); + } + } + if let Some(tokenizer_path) = + model_info.get("tokenizer_path").and_then(|v| v.as_str()) + { + if !tokenizer_path.is_empty() { + discovery + .labels + .insert("tokenizer_path".to_string(), tokenizer_path.to_string()); + } + } + if let Some(served_model_name) = + model_info.get("served_model_name").and_then(|v| v.as_str()) + { + if !served_model_name.is_empty() { + discovery.labels.insert( + "served_model_name".to_string(), + served_model_name.to_string(), + ); + } + } + if let Some(weight_version) = + model_info.get("weight_version").and_then(|v| v.as_str()) + { + if !weight_version.is_empty() { + discovery + .labels + .insert("weight_version".to_string(), weight_version.to_string()); + } + } + if let Some(model_type) = model_info.get("model_type").and_then(|v| v.as_str()) { + if !model_type.is_empty() { + discovery + .labels + .insert("model_type".to_string(), model_type.to_string()); + } + } + if let Some(is_generation) = + model_info.get("is_generation").and_then(|v| v.as_bool()) + { + discovery + .labels + .insert("is_generation".to_string(), is_generation.to_string()); + } + if let Some(preferred_sampling_params) = model_info + .get("preferred_sampling_params") + .and_then(|v| v.as_str()) + { + if !preferred_sampling_params.is_empty() { + discovery.labels.insert( + "preferred_sampling_params".to_string(), + preferred_sampling_params.to_string(), + ); + } + } + if let Some(max_context_length) = model_info + .get("max_context_length") + .and_then(|v| v.as_i64()) + { + discovery.labels.insert( + "max_context_length".to_string(), + max_context_length.to_string(), + ); + } + if let Some(max_req_input_len) = + model_info.get("max_req_input_len").and_then(|v| v.as_i64()) + { + discovery.labels.insert( + "max_req_input_len".to_string(), + max_req_input_len.to_string(), + ); + } + } + Err(e) => { + warn!( + "Worker discovery: failed to fetch HTTP model info from {}: {}", + url, e + ); + } + } + + match Self::get_server_info(url, api_key).await { + Ok(server_info) => { + if let Some(model_id) = server_info.model_id { + if !model_id.is_empty() { + let normalized = Self::normalize_model_identifier(&model_id); + discovery.labels.insert("model_id".to_string(), normalized); + } + } + if let Some(model_path) = server_info.model_path { + if !model_path.is_empty() { + discovery + .labels + .insert("model_path".to_string(), model_path); + } + } + if let Some(version) = server_info.version { + if !version.is_empty() { + discovery + .labels + .insert("server_version".to_string(), version); + } + } + if let Some(max_total_tokens) = server_info.max_total_tokens { + discovery + .labels + .insert("max_total_tokens".to_string(), max_total_tokens.to_string()); + } + if let Some(max_prefill_tokens) = server_info.max_prefill_tokens { + discovery.labels.insert( + "max_prefill_tokens".to_string(), + max_prefill_tokens.to_string(), + ); + } + if let Some(max_running_requests) = server_info.max_running_requests { + discovery.labels.insert( + "max_running_requests".to_string(), + max_running_requests.to_string(), + ); + } + } + Err(e) => { + warn!( + "Worker discovery: failed to fetch HTTP server info from {}: {}", + url, e + ); + } + } + + Self::finalize_model_id(&mut discovery.labels); + + discovery + } + + async fn discover_grpc_metadata(url: &str) -> WorkerDiscovery { + let mut discovery = WorkerDiscovery::new(); + + let client = match SglangSchedulerClient::connect(url).await { + Ok(client) => client, + Err(e) => { + warn!( + "Worker discovery: failed to connect to gRPC worker {}: {}", + url, e + ); + return discovery; + } + }; + + match client.get_model_info().await { + Ok(model_info) => { + if !model_info.model_path.is_empty() { + discovery + .labels + .insert("model_path".to_string(), model_info.model_path.clone()); + } + if !model_info.tokenizer_path.is_empty() { + discovery.labels.insert( + "tokenizer_path".to_string(), + model_info.tokenizer_path.clone(), + ); + } + if !model_info.served_model_name.is_empty() { + discovery.labels.insert( + "served_model_name".to_string(), + model_info.served_model_name.clone(), + ); + let normalized = + Self::normalize_model_identifier(&model_info.served_model_name); + discovery.labels.insert("model_id".to_string(), normalized); + } + if !model_info.weight_version.is_empty() { + discovery.labels.insert( + "weight_version".to_string(), + model_info.weight_version.clone(), + ); + } + if !model_info.model_type.is_empty() { + discovery + .labels + .insert("model_type".to_string(), model_info.model_type.clone()); + } + if !model_info.preferred_sampling_params.is_empty() { + discovery.labels.insert( + "preferred_sampling_params".to_string(), + model_info.preferred_sampling_params.clone(), + ); + } + discovery.labels.insert( + "is_generation".to_string(), + model_info.is_generation.to_string(), + ); + if model_info.max_context_length > 0 { + discovery.labels.insert( + "max_context_length".to_string(), + model_info.max_context_length.to_string(), + ); + } + if model_info.max_req_input_len > 0 { + discovery.labels.insert( + "max_req_input_len".to_string(), + model_info.max_req_input_len.to_string(), + ); + } + if model_info.vocab_size > 0 { + discovery + .labels + .insert("vocab_size".to_string(), model_info.vocab_size.to_string()); + } + } + Err(e) => { + warn!( + "Worker discovery: failed to fetch gRPC model info from {}: {}", + url, e + ); + } + } + + if !discovery.labels.contains_key("model_id") { + Self::finalize_model_id(&mut discovery.labels); + } + + discovery.grpc_client = Some(client); + discovery + } + + fn normalize_model_identifier(value: &str) -> String { + let trimmed = value.trim(); + if trimmed.contains('/') || trimmed.contains('\\') { + Self::derive_model_id_from_path(trimmed) + } else { + trimmed.to_string() + } + } + + fn finalize_model_id(labels: &mut HashMap) { + let has_model_id = labels + .get("model_id") + .map(|v| !v.trim().is_empty()) + .unwrap_or(false); + if has_model_id { + return; + } + + if let Some(served_name) = labels.get("served_model_name").cloned() { + if !served_name.trim().is_empty() { + let normalized = Self::normalize_model_identifier(&served_name); + labels.insert("model_id".to_string(), normalized); + return; + } + } + + if let Some(model_path) = labels.get("model_path").cloned() { + if !model_path.trim().is_empty() { + let derived = Self::derive_model_id_from_path(&model_path); + if !derived.is_empty() { + labels.insert("model_id".to_string(), derived); + } + } + } + } + + fn derive_model_id_from_path(path: &str) -> String { + let trimmed = path.trim_end_matches(['/', '\\']); + if trimmed.is_empty() { + return path.to_string(); + } + + let candidate = Path::new(trimmed) + .file_name() + .and_then(|p| p.to_str()) + .map(|s| s.to_string()); + + match candidate { + Some(name) if !name.is_empty() => name, + _ => trimmed.to_string(), + } + } + /// Parse server info from JSON response fn parse_server_info(json: Value) -> Result { Ok(ServerInfo { @@ -1499,6 +1836,7 @@ impl Drop for LoadMonitor { #[cfg(test)] mod tests { use super::*; + use std::collections::HashMap; #[test] fn test_parse_server_info() { @@ -1533,4 +1871,49 @@ mod tests { assert_eq!(info.model_id, None); assert_eq!(info.dp_size, None); } + + #[test] + fn test_derive_model_id_from_path() { + let path = "/raid/models/meta-llama/Llama-3.1-8B-Instruct"; + let derived = WorkerManager::derive_model_id_from_path(path); + assert_eq!(derived, "Llama-3.1-8B-Instruct"); + } + + #[test] + fn test_derive_model_id_trailing_slash() { + let path = "/models/foo/bar/"; + let derived = WorkerManager::derive_model_id_from_path(path); + assert_eq!(derived, "bar"); + } + + #[test] + fn test_finalize_model_id_prefers_existing() { + let mut labels = HashMap::new(); + labels.insert("model_id".to_string(), "manual-id".to_string()); + labels.insert("served_model_name".to_string(), "auto-id".to_string()); + WorkerManager::finalize_model_id(&mut labels); + assert_eq!(labels.get("model_id").unwrap(), "manual-id"); + } + + #[test] + fn test_finalize_model_id_prefers_served_name() { + let mut labels = HashMap::new(); + labels.insert("served_model_name".to_string(), "served-name".to_string()); + WorkerManager::finalize_model_id(&mut labels); + assert_eq!(labels.get("model_id").unwrap(), "served-name"); + } + + #[test] + fn test_finalize_model_id_falls_back_to_path() { + let mut labels = HashMap::new(); + labels.insert("model_path".to_string(), "/models/alpha".to_string()); + WorkerManager::finalize_model_id(&mut labels); + assert_eq!(labels.get("model_id").unwrap(), "alpha"); + } + + #[test] + fn test_normalize_model_identifier_from_path() { + let normalized = WorkerManager::normalize_model_identifier("/raid/models/foo/bar-model"); + assert_eq!(normalized, "bar-model"); + } }