[router] add worker self discovery for metadata (#11638)
This commit is contained in:
@@ -11,6 +11,7 @@ use crate::core::{
|
||||
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, DPAwareWorkerBuilder, HealthConfig,
|
||||
Worker, WorkerFactory, WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::grpc_client::SglangSchedulerClient;
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::worker_spec::{
|
||||
FlushCacheResult, WorkerConfigRequest, WorkerLoadInfo, WorkerLoadsResult,
|
||||
@@ -21,6 +22,7 @@ use once_cell::sync::Lazy;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::{watch, Mutex};
|
||||
@@ -55,6 +57,21 @@ pub struct DpInfo {
|
||||
pub model_id: String,
|
||||
}
|
||||
|
||||
/// Worker discovery results gathered from backend endpoints
|
||||
struct WorkerDiscovery {
|
||||
labels: HashMap<String, String>,
|
||||
grpc_client: Option<SglangSchedulerClient>,
|
||||
}
|
||||
|
||||
impl WorkerDiscovery {
|
||||
fn new() -> Self {
|
||||
Self {
|
||||
labels: HashMap::new(),
|
||||
grpc_client: None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Unified worker management
|
||||
pub struct WorkerManager;
|
||||
|
||||
@@ -318,7 +335,8 @@ impl WorkerManager {
|
||||
None,
|
||||
circuit_breaker_config.clone(),
|
||||
health_config.clone(),
|
||||
);
|
||||
)
|
||||
.await;
|
||||
Self::register_worker(worker, registry, &mut registered_workers, policy_registry);
|
||||
}
|
||||
}
|
||||
@@ -363,7 +381,8 @@ impl WorkerManager {
|
||||
None,
|
||||
circuit_breaker_config.clone(),
|
||||
health_config.clone(),
|
||||
);
|
||||
)
|
||||
.await;
|
||||
Self::register_worker(worker, registry, &mut registered_workers, policy_registry);
|
||||
}
|
||||
|
||||
@@ -408,7 +427,8 @@ impl WorkerManager {
|
||||
None,
|
||||
circuit_breaker_config.clone(),
|
||||
health_config.clone(),
|
||||
);
|
||||
)
|
||||
.await;
|
||||
Self::register_worker(worker, registry, &mut registered_workers, policy_registry);
|
||||
}
|
||||
|
||||
@@ -448,7 +468,8 @@ impl WorkerManager {
|
||||
None,
|
||||
circuit_breaker_config.clone(),
|
||||
health_config.clone(),
|
||||
);
|
||||
)
|
||||
.await;
|
||||
Self::register_worker(worker, registry, &mut registered_workers, policy_registry);
|
||||
info!(
|
||||
"Registered gRPC worker at {} (will connect on first use)",
|
||||
@@ -497,7 +518,8 @@ impl WorkerManager {
|
||||
None,
|
||||
circuit_breaker_config.clone(),
|
||||
health_config.clone(),
|
||||
);
|
||||
)
|
||||
.await;
|
||||
Self::register_worker(
|
||||
worker,
|
||||
registry,
|
||||
@@ -522,7 +544,8 @@ impl WorkerManager {
|
||||
None,
|
||||
circuit_breaker_config.clone(),
|
||||
health_config.clone(),
|
||||
);
|
||||
)
|
||||
.await;
|
||||
Self::register_worker(
|
||||
worker,
|
||||
registry,
|
||||
@@ -563,12 +586,9 @@ impl WorkerManager {
|
||||
}
|
||||
let mut labels = config.labels.clone();
|
||||
|
||||
// Use provided model_id or default to "unknown"
|
||||
let model_id = config
|
||||
.model_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| "unknown".to_string());
|
||||
labels.insert("model_id".to_string(), model_id.clone());
|
||||
if let Some(model_id) = &config.model_id {
|
||||
labels.insert("model_id".to_string(), model_id.clone());
|
||||
}
|
||||
if let Some(priority) = config.priority {
|
||||
labels.insert("priority".to_string(), priority.to_string());
|
||||
}
|
||||
@@ -620,12 +640,14 @@ impl WorkerManager {
|
||||
Some(labels.clone()),
|
||||
circuit_breaker_config,
|
||||
health_config,
|
||||
);
|
||||
)
|
||||
.await;
|
||||
|
||||
worker.set_healthy(false);
|
||||
context.worker_registry.register(worker.clone());
|
||||
|
||||
let policy_hint = labels.get("policy").map(|s| s.as_str());
|
||||
let model_id = worker.model_id().to_string();
|
||||
context
|
||||
.policy_registry
|
||||
.on_worker_added(&model_id, policy_hint);
|
||||
@@ -793,7 +815,8 @@ impl WorkerManager {
|
||||
labels,
|
||||
circuit_breaker_config,
|
||||
health_config,
|
||||
);
|
||||
)
|
||||
.await;
|
||||
|
||||
let model_id = worker.model_id().to_string();
|
||||
context.worker_registry.register(worker.clone());
|
||||
@@ -893,7 +916,7 @@ impl WorkerManager {
|
||||
}
|
||||
|
||||
/// Create a basic worker
|
||||
fn create_basic_worker(
|
||||
async fn create_basic_worker(
|
||||
url: String,
|
||||
worker_type: WorkerType,
|
||||
connection_mode: ConnectionMode,
|
||||
@@ -902,6 +925,16 @@ impl WorkerManager {
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
health_config: HealthConfig,
|
||||
) -> Arc<dyn Worker> {
|
||||
let discovery =
|
||||
Self::discover_worker_metadata(&url, &connection_mode, api_key.as_deref()).await;
|
||||
|
||||
let mut final_labels = discovery.labels;
|
||||
if let Some(custom_labels) = labels {
|
||||
for (key, value) in custom_labels {
|
||||
final_labels.insert(key, value);
|
||||
}
|
||||
}
|
||||
|
||||
let mut builder = BasicWorkerBuilder::new(url)
|
||||
.worker_type(worker_type)
|
||||
.connection_mode(connection_mode)
|
||||
@@ -912,8 +945,12 @@ impl WorkerManager {
|
||||
builder = builder.api_key(key);
|
||||
}
|
||||
|
||||
if let Some(worker_labels) = labels {
|
||||
builder = builder.labels(worker_labels);
|
||||
if !final_labels.is_empty() {
|
||||
builder = builder.labels(final_labels);
|
||||
}
|
||||
|
||||
if let Some(client) = discovery.grpc_client {
|
||||
builder = builder.grpc_client(client);
|
||||
}
|
||||
|
||||
let worker = builder.build();
|
||||
@@ -1084,6 +1121,306 @@ impl WorkerManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Gather worker metadata directly from the backend before registration.
|
||||
async fn discover_worker_metadata(
|
||||
url: &str,
|
||||
connection_mode: &ConnectionMode,
|
||||
api_key: Option<&str>,
|
||||
) -> WorkerDiscovery {
|
||||
match connection_mode {
|
||||
ConnectionMode::Http => Self::discover_http_metadata(url, api_key).await,
|
||||
ConnectionMode::Grpc { .. } => Self::discover_grpc_metadata(url).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn discover_http_metadata(url: &str, api_key: Option<&str>) -> WorkerDiscovery {
|
||||
let mut discovery = WorkerDiscovery::new();
|
||||
|
||||
match Self::get_model_info(url, api_key).await {
|
||||
Ok(model_info) => {
|
||||
if let Some(model_path) = model_info.get("model_path").and_then(|v| v.as_str()) {
|
||||
if !model_path.is_empty() {
|
||||
discovery
|
||||
.labels
|
||||
.insert("model_path".to_string(), model_path.to_string());
|
||||
}
|
||||
}
|
||||
if let Some(tokenizer_path) =
|
||||
model_info.get("tokenizer_path").and_then(|v| v.as_str())
|
||||
{
|
||||
if !tokenizer_path.is_empty() {
|
||||
discovery
|
||||
.labels
|
||||
.insert("tokenizer_path".to_string(), tokenizer_path.to_string());
|
||||
}
|
||||
}
|
||||
if let Some(served_model_name) =
|
||||
model_info.get("served_model_name").and_then(|v| v.as_str())
|
||||
{
|
||||
if !served_model_name.is_empty() {
|
||||
discovery.labels.insert(
|
||||
"served_model_name".to_string(),
|
||||
served_model_name.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
if let Some(weight_version) =
|
||||
model_info.get("weight_version").and_then(|v| v.as_str())
|
||||
{
|
||||
if !weight_version.is_empty() {
|
||||
discovery
|
||||
.labels
|
||||
.insert("weight_version".to_string(), weight_version.to_string());
|
||||
}
|
||||
}
|
||||
if let Some(model_type) = model_info.get("model_type").and_then(|v| v.as_str()) {
|
||||
if !model_type.is_empty() {
|
||||
discovery
|
||||
.labels
|
||||
.insert("model_type".to_string(), model_type.to_string());
|
||||
}
|
||||
}
|
||||
if let Some(is_generation) =
|
||||
model_info.get("is_generation").and_then(|v| v.as_bool())
|
||||
{
|
||||
discovery
|
||||
.labels
|
||||
.insert("is_generation".to_string(), is_generation.to_string());
|
||||
}
|
||||
if let Some(preferred_sampling_params) = model_info
|
||||
.get("preferred_sampling_params")
|
||||
.and_then(|v| v.as_str())
|
||||
{
|
||||
if !preferred_sampling_params.is_empty() {
|
||||
discovery.labels.insert(
|
||||
"preferred_sampling_params".to_string(),
|
||||
preferred_sampling_params.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
if let Some(max_context_length) = model_info
|
||||
.get("max_context_length")
|
||||
.and_then(|v| v.as_i64())
|
||||
{
|
||||
discovery.labels.insert(
|
||||
"max_context_length".to_string(),
|
||||
max_context_length.to_string(),
|
||||
);
|
||||
}
|
||||
if let Some(max_req_input_len) =
|
||||
model_info.get("max_req_input_len").and_then(|v| v.as_i64())
|
||||
{
|
||||
discovery.labels.insert(
|
||||
"max_req_input_len".to_string(),
|
||||
max_req_input_len.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Worker discovery: failed to fetch HTTP model info from {}: {}",
|
||||
url, e
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
match Self::get_server_info(url, api_key).await {
|
||||
Ok(server_info) => {
|
||||
if let Some(model_id) = server_info.model_id {
|
||||
if !model_id.is_empty() {
|
||||
let normalized = Self::normalize_model_identifier(&model_id);
|
||||
discovery.labels.insert("model_id".to_string(), normalized);
|
||||
}
|
||||
}
|
||||
if let Some(model_path) = server_info.model_path {
|
||||
if !model_path.is_empty() {
|
||||
discovery
|
||||
.labels
|
||||
.insert("model_path".to_string(), model_path);
|
||||
}
|
||||
}
|
||||
if let Some(version) = server_info.version {
|
||||
if !version.is_empty() {
|
||||
discovery
|
||||
.labels
|
||||
.insert("server_version".to_string(), version);
|
||||
}
|
||||
}
|
||||
if let Some(max_total_tokens) = server_info.max_total_tokens {
|
||||
discovery
|
||||
.labels
|
||||
.insert("max_total_tokens".to_string(), max_total_tokens.to_string());
|
||||
}
|
||||
if let Some(max_prefill_tokens) = server_info.max_prefill_tokens {
|
||||
discovery.labels.insert(
|
||||
"max_prefill_tokens".to_string(),
|
||||
max_prefill_tokens.to_string(),
|
||||
);
|
||||
}
|
||||
if let Some(max_running_requests) = server_info.max_running_requests {
|
||||
discovery.labels.insert(
|
||||
"max_running_requests".to_string(),
|
||||
max_running_requests.to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Worker discovery: failed to fetch HTTP server info from {}: {}",
|
||||
url, e
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Self::finalize_model_id(&mut discovery.labels);
|
||||
|
||||
discovery
|
||||
}
|
||||
|
||||
async fn discover_grpc_metadata(url: &str) -> WorkerDiscovery {
|
||||
let mut discovery = WorkerDiscovery::new();
|
||||
|
||||
let client = match SglangSchedulerClient::connect(url).await {
|
||||
Ok(client) => client,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Worker discovery: failed to connect to gRPC worker {}: {}",
|
||||
url, e
|
||||
);
|
||||
return discovery;
|
||||
}
|
||||
};
|
||||
|
||||
match client.get_model_info().await {
|
||||
Ok(model_info) => {
|
||||
if !model_info.model_path.is_empty() {
|
||||
discovery
|
||||
.labels
|
||||
.insert("model_path".to_string(), model_info.model_path.clone());
|
||||
}
|
||||
if !model_info.tokenizer_path.is_empty() {
|
||||
discovery.labels.insert(
|
||||
"tokenizer_path".to_string(),
|
||||
model_info.tokenizer_path.clone(),
|
||||
);
|
||||
}
|
||||
if !model_info.served_model_name.is_empty() {
|
||||
discovery.labels.insert(
|
||||
"served_model_name".to_string(),
|
||||
model_info.served_model_name.clone(),
|
||||
);
|
||||
let normalized =
|
||||
Self::normalize_model_identifier(&model_info.served_model_name);
|
||||
discovery.labels.insert("model_id".to_string(), normalized);
|
||||
}
|
||||
if !model_info.weight_version.is_empty() {
|
||||
discovery.labels.insert(
|
||||
"weight_version".to_string(),
|
||||
model_info.weight_version.clone(),
|
||||
);
|
||||
}
|
||||
if !model_info.model_type.is_empty() {
|
||||
discovery
|
||||
.labels
|
||||
.insert("model_type".to_string(), model_info.model_type.clone());
|
||||
}
|
||||
if !model_info.preferred_sampling_params.is_empty() {
|
||||
discovery.labels.insert(
|
||||
"preferred_sampling_params".to_string(),
|
||||
model_info.preferred_sampling_params.clone(),
|
||||
);
|
||||
}
|
||||
discovery.labels.insert(
|
||||
"is_generation".to_string(),
|
||||
model_info.is_generation.to_string(),
|
||||
);
|
||||
if model_info.max_context_length > 0 {
|
||||
discovery.labels.insert(
|
||||
"max_context_length".to_string(),
|
||||
model_info.max_context_length.to_string(),
|
||||
);
|
||||
}
|
||||
if model_info.max_req_input_len > 0 {
|
||||
discovery.labels.insert(
|
||||
"max_req_input_len".to_string(),
|
||||
model_info.max_req_input_len.to_string(),
|
||||
);
|
||||
}
|
||||
if model_info.vocab_size > 0 {
|
||||
discovery
|
||||
.labels
|
||||
.insert("vocab_size".to_string(), model_info.vocab_size.to_string());
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Worker discovery: failed to fetch gRPC model info from {}: {}",
|
||||
url, e
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if !discovery.labels.contains_key("model_id") {
|
||||
Self::finalize_model_id(&mut discovery.labels);
|
||||
}
|
||||
|
||||
discovery.grpc_client = Some(client);
|
||||
discovery
|
||||
}
|
||||
|
||||
fn normalize_model_identifier(value: &str) -> String {
|
||||
let trimmed = value.trim();
|
||||
if trimmed.contains('/') || trimmed.contains('\\') {
|
||||
Self::derive_model_id_from_path(trimmed)
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn finalize_model_id(labels: &mut HashMap<String, String>) {
|
||||
let has_model_id = labels
|
||||
.get("model_id")
|
||||
.map(|v| !v.trim().is_empty())
|
||||
.unwrap_or(false);
|
||||
if has_model_id {
|
||||
return;
|
||||
}
|
||||
|
||||
if let Some(served_name) = labels.get("served_model_name").cloned() {
|
||||
if !served_name.trim().is_empty() {
|
||||
let normalized = Self::normalize_model_identifier(&served_name);
|
||||
labels.insert("model_id".to_string(), normalized);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(model_path) = labels.get("model_path").cloned() {
|
||||
if !model_path.trim().is_empty() {
|
||||
let derived = Self::derive_model_id_from_path(&model_path);
|
||||
if !derived.is_empty() {
|
||||
labels.insert("model_id".to_string(), derived);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn derive_model_id_from_path(path: &str) -> String {
|
||||
let trimmed = path.trim_end_matches(['/', '\\']);
|
||||
if trimmed.is_empty() {
|
||||
return path.to_string();
|
||||
}
|
||||
|
||||
let candidate = Path::new(trimmed)
|
||||
.file_name()
|
||||
.and_then(|p| p.to_str())
|
||||
.map(|s| s.to_string());
|
||||
|
||||
match candidate {
|
||||
Some(name) if !name.is_empty() => name,
|
||||
_ => trimmed.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Parse server info from JSON response
|
||||
fn parse_server_info(json: Value) -> Result<ServerInfo, String> {
|
||||
Ok(ServerInfo {
|
||||
@@ -1499,6 +1836,7 @@ impl Drop for LoadMonitor {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn test_parse_server_info() {
|
||||
@@ -1533,4 +1871,49 @@ mod tests {
|
||||
assert_eq!(info.model_id, None);
|
||||
assert_eq!(info.dp_size, None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derive_model_id_from_path() {
|
||||
let path = "/raid/models/meta-llama/Llama-3.1-8B-Instruct";
|
||||
let derived = WorkerManager::derive_model_id_from_path(path);
|
||||
assert_eq!(derived, "Llama-3.1-8B-Instruct");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_derive_model_id_trailing_slash() {
|
||||
let path = "/models/foo/bar/";
|
||||
let derived = WorkerManager::derive_model_id_from_path(path);
|
||||
assert_eq!(derived, "bar");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_finalize_model_id_prefers_existing() {
|
||||
let mut labels = HashMap::new();
|
||||
labels.insert("model_id".to_string(), "manual-id".to_string());
|
||||
labels.insert("served_model_name".to_string(), "auto-id".to_string());
|
||||
WorkerManager::finalize_model_id(&mut labels);
|
||||
assert_eq!(labels.get("model_id").unwrap(), "manual-id");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_finalize_model_id_prefers_served_name() {
|
||||
let mut labels = HashMap::new();
|
||||
labels.insert("served_model_name".to_string(), "served-name".to_string());
|
||||
WorkerManager::finalize_model_id(&mut labels);
|
||||
assert_eq!(labels.get("model_id").unwrap(), "served-name");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_finalize_model_id_falls_back_to_path() {
|
||||
let mut labels = HashMap::new();
|
||||
labels.insert("model_path".to_string(), "/models/alpha".to_string());
|
||||
WorkerManager::finalize_model_id(&mut labels);
|
||||
assert_eq!(labels.get("model_id").unwrap(), "alpha");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_normalize_model_identifier_from_path() {
|
||||
let normalized = WorkerManager::normalize_model_identifier("/raid/models/foo/bar-model");
|
||||
assert_eq!(normalized, "bar-model");
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user