[router] allow one router to support different model families and serving mode (#10244)

This commit is contained in:
Simo Lin
2025-09-12 19:18:27 -04:00
committed by GitHub
parent 321fecab74
commit 2f173ea074
28 changed files with 3528 additions and 837 deletions

View File

@@ -155,6 +155,82 @@ pub trait Worker: Send + Sync + fmt::Debug {
fn can_handle(&self, _req: &serde_json::Value) -> bool {
true
}
// === Multi-router support ===
// TODO: - Enhanced Worker Discovery
// The Worker trait should handle async discovery of metadata from the worker itself
// rather than having service discovery or other components query /get_server_info.
// This keeps service discovery decoupled from worker-specific APIs.
//
// Proposed additions:
// - async fn discover_metadata(&mut self) -> Result<(), Error>
// Query /get_server_info and populate metadata labels with model_id, priority, cost, etc.
// - async fn validate_configuration(&self) -> Result<(), Error>
// Ensure worker has required configuration for its mode (e.g., tokenizer for gRPC)
// - Make worker creation async to allow metadata discovery during initialization
//
// This way service discovery just calls router.add_worker() and the worker
// handles its own metadata discovery internally.
/// Get the model ID this worker serves
fn model_id(&self) -> &str {
self.metadata()
.labels
.get("model_id")
.map(|s| s.as_str())
.unwrap_or("unknown")
}
/// Get the priority of this worker (higher value = higher priority)
fn priority(&self) -> u32 {
self.metadata()
.labels
.get("priority")
.and_then(|s| s.parse().ok())
.unwrap_or(50) // Default priority is 50 (mid-range)
}
/// Get the cost factor of this worker (1.0 = baseline)
fn cost(&self) -> f32 {
self.metadata()
.labels
.get("cost")
.and_then(|s| s.parse().ok())
.unwrap_or(1.0)
}
/// Get the tokenizer path for this worker (gRPC mode only)
fn tokenizer_path(&self) -> Option<&str> {
self.metadata()
.labels
.get("tokenizer_path")
.map(|s| s.as_str())
}
/// Get the reasoning parser type for this worker (gRPC mode only)
fn reasoning_parser(&self) -> Option<&str> {
self.metadata()
.labels
.get("reasoning_parser")
.map(|s| s.as_str())
}
/// Get the tool parser type for this worker (gRPC mode only)
fn tool_parser(&self) -> Option<&str> {
self.metadata()
.labels
.get("tool_parser")
.map(|s| s.as_str())
}
/// Get the chat template for this worker (gRPC mode only)
fn chat_template(&self) -> Option<&str> {
self.metadata()
.labels
.get("chat_template")
.map(|s| s.as_str())
}
}
/// Connection mode for worker communication
@@ -724,6 +800,21 @@ impl WorkerFactory {
)
}
/// Create a regular worker with custom labels (for multi-router support)
pub fn create_regular_with_labels(
url: String,
labels: std::collections::HashMap<String, String>,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
let mut worker = BasicWorker::new(url.clone(), WorkerType::Regular)
.with_circuit_breaker_config(circuit_breaker_config);
// Add labels to metadata
worker.metadata.labels = labels;
Box::new(worker)
}
/// Create a DP-aware worker of specified type
pub fn create_dp_aware(
base_url: String,
@@ -941,6 +1032,11 @@ impl fmt::Debug for HealthChecker {
}
impl HealthChecker {
/// Create a new HealthChecker
pub fn new(handle: tokio::task::JoinHandle<()>, shutdown: Arc<AtomicBool>) -> Self {
Self { handle, shutdown }
}
/// Shutdown the health checker gracefully
pub async fn shutdown(self) {
self.shutdown.store(true, Ordering::Release);
@@ -950,7 +1046,7 @@ impl HealthChecker {
/// Start an async background health checker for a collection of workers
pub fn start_health_checker(
workers: std::sync::Arc<std::sync::RwLock<Vec<Box<dyn Worker>>>>,
workers: std::sync::Arc<std::sync::RwLock<Vec<std::sync::Arc<dyn Worker>>>>,
check_interval_secs: u64,
) -> HealthChecker {
let shutdown = Arc::new(AtomicBool::new(false));
@@ -1602,9 +1698,11 @@ mod tests {
// Test HealthChecker background task
#[tokio::test]
async fn test_health_checker_startup() {
let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular(
let worker = Arc::new(BasicWorker::new(
"http://w1:8080".to_string(),
)]));
WorkerType::Regular,
)) as Arc<dyn Worker>;
let workers = Arc::new(RwLock::new(vec![worker]));
let checker = start_health_checker(workers.clone(), 60);
@@ -1617,9 +1715,11 @@ mod tests {
#[tokio::test]
async fn test_health_checker_shutdown() {
let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular(
let worker = Arc::new(BasicWorker::new(
"http://w1:8080".to_string(),
)]));
WorkerType::Regular,
)) as Arc<dyn Worker>;
let workers = Arc::new(RwLock::new(vec![worker]));
let checker = start_health_checker(workers.clone(), 60);