[router] allow one router to support different model families and serving mode (#10244)

This commit is contained in:
Simo Lin
2025-09-12 19:18:27 -04:00
committed by GitHub
parent 321fecab74
commit 2f173ea074
28 changed files with 3528 additions and 837 deletions

View File

@@ -11,6 +11,7 @@ pub mod error;
pub mod retry;
pub mod token_bucket;
pub mod worker;
pub mod worker_registry;
// Re-export commonly used types at the module level
pub use circuit_breaker::{
@@ -22,3 +23,4 @@ pub use worker::{
start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig,
Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType,
};
pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats};

View File

@@ -155,6 +155,82 @@ pub trait Worker: Send + Sync + fmt::Debug {
fn can_handle(&self, _req: &serde_json::Value) -> bool {
true
}
// === Multi-router support ===
// TODO: - Enhanced Worker Discovery
// The Worker trait should handle async discovery of metadata from the worker itself
// rather than having service discovery or other components query /get_server_info.
// This keeps service discovery decoupled from worker-specific APIs.
//
// Proposed additions:
// - async fn discover_metadata(&mut self) -> Result<(), Error>
// Query /get_server_info and populate metadata labels with model_id, priority, cost, etc.
// - async fn validate_configuration(&self) -> Result<(), Error>
// Ensure worker has required configuration for its mode (e.g., tokenizer for gRPC)
// - Make worker creation async to allow metadata discovery during initialization
//
// This way service discovery just calls router.add_worker() and the worker
// handles its own metadata discovery internally.
/// Get the model ID this worker serves
fn model_id(&self) -> &str {
self.metadata()
.labels
.get("model_id")
.map(|s| s.as_str())
.unwrap_or("unknown")
}
/// Get the priority of this worker (higher value = higher priority)
fn priority(&self) -> u32 {
self.metadata()
.labels
.get("priority")
.and_then(|s| s.parse().ok())
.unwrap_or(50) // Default priority is 50 (mid-range)
}
/// Get the cost factor of this worker (1.0 = baseline)
fn cost(&self) -> f32 {
self.metadata()
.labels
.get("cost")
.and_then(|s| s.parse().ok())
.unwrap_or(1.0)
}
/// Get the tokenizer path for this worker (gRPC mode only)
fn tokenizer_path(&self) -> Option<&str> {
self.metadata()
.labels
.get("tokenizer_path")
.map(|s| s.as_str())
}
/// Get the reasoning parser type for this worker (gRPC mode only)
fn reasoning_parser(&self) -> Option<&str> {
self.metadata()
.labels
.get("reasoning_parser")
.map(|s| s.as_str())
}
/// Get the tool parser type for this worker (gRPC mode only)
fn tool_parser(&self) -> Option<&str> {
self.metadata()
.labels
.get("tool_parser")
.map(|s| s.as_str())
}
/// Get the chat template for this worker (gRPC mode only)
fn chat_template(&self) -> Option<&str> {
self.metadata()
.labels
.get("chat_template")
.map(|s| s.as_str())
}
}
/// Connection mode for worker communication
@@ -724,6 +800,21 @@ impl WorkerFactory {
)
}
/// Create a regular worker with custom labels (for multi-router support)
pub fn create_regular_with_labels(
url: String,
labels: std::collections::HashMap<String, String>,
circuit_breaker_config: CircuitBreakerConfig,
) -> Box<dyn Worker> {
let mut worker = BasicWorker::new(url.clone(), WorkerType::Regular)
.with_circuit_breaker_config(circuit_breaker_config);
// Add labels to metadata
worker.metadata.labels = labels;
Box::new(worker)
}
/// Create a DP-aware worker of specified type
pub fn create_dp_aware(
base_url: String,
@@ -941,6 +1032,11 @@ impl fmt::Debug for HealthChecker {
}
impl HealthChecker {
/// Create a new HealthChecker
pub fn new(handle: tokio::task::JoinHandle<()>, shutdown: Arc<AtomicBool>) -> Self {
Self { handle, shutdown }
}
/// Shutdown the health checker gracefully
pub async fn shutdown(self) {
self.shutdown.store(true, Ordering::Release);
@@ -950,7 +1046,7 @@ impl HealthChecker {
/// Start an async background health checker for a collection of workers
pub fn start_health_checker(
workers: std::sync::Arc<std::sync::RwLock<Vec<Box<dyn Worker>>>>,
workers: std::sync::Arc<std::sync::RwLock<Vec<std::sync::Arc<dyn Worker>>>>,
check_interval_secs: u64,
) -> HealthChecker {
let shutdown = Arc::new(AtomicBool::new(false));
@@ -1602,9 +1698,11 @@ mod tests {
// Test HealthChecker background task
#[tokio::test]
async fn test_health_checker_startup() {
let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular(
let worker = Arc::new(BasicWorker::new(
"http://w1:8080".to_string(),
)]));
WorkerType::Regular,
)) as Arc<dyn Worker>;
let workers = Arc::new(RwLock::new(vec![worker]));
let checker = start_health_checker(workers.clone(), 60);
@@ -1617,9 +1715,11 @@ mod tests {
#[tokio::test]
async fn test_health_checker_shutdown() {
let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular(
let worker = Arc::new(BasicWorker::new(
"http://w1:8080".to_string(),
)]));
WorkerType::Regular,
)) as Arc<dyn Worker>;
let workers = Arc::new(RwLock::new(vec![worker]));
let checker = start_health_checker(workers.clone(), 60);

View File

@@ -0,0 +1,526 @@
//! Worker Registry for multi-router support
//!
//! Provides centralized registry for workers with model-based indexing
use crate::core::{ConnectionMode, Worker, WorkerType};
use dashmap::DashMap;
use std::sync::{Arc, RwLock};
use uuid::Uuid;
/// Unique identifier for a worker
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct WorkerId(String);
impl WorkerId {
/// Create a new worker ID
pub fn new() -> Self {
Self(Uuid::new_v4().to_string())
}
/// Create a worker ID from a string
pub fn from_string(s: String) -> Self {
Self(s)
}
/// Get the ID as a string
pub fn as_str(&self) -> &str {
&self.0
}
}
impl Default for WorkerId {
fn default() -> Self {
Self::new()
}
}
/// Type alias for the model index to reduce complexity
type ModelIndex = Arc<DashMap<String, Arc<RwLock<Vec<Arc<dyn Worker>>>>>>;
/// Worker registry with model-based indexing
#[derive(Debug)]
pub struct WorkerRegistry {
/// All workers indexed by ID
workers: Arc<DashMap<WorkerId, Arc<dyn Worker>>>,
/// Workers indexed by model ID (stores WorkerId for reference)
model_workers: Arc<DashMap<String, Vec<WorkerId>>>,
/// Optimized model index for O(1) lookups (stores Arc<dyn Worker> directly)
model_index: ModelIndex,
/// Workers indexed by worker type
type_workers: Arc<DashMap<WorkerType, Vec<WorkerId>>>,
/// Workers indexed by connection mode
connection_workers: Arc<DashMap<ConnectionMode, Vec<WorkerId>>>,
/// URL to worker ID mapping (for backward compatibility)
url_to_id: Arc<DashMap<String, WorkerId>>,
}
impl WorkerRegistry {
/// Create a new worker registry
pub fn new() -> Self {
Self {
workers: Arc::new(DashMap::new()),
model_workers: Arc::new(DashMap::new()),
model_index: Arc::new(DashMap::new()),
type_workers: Arc::new(DashMap::new()),
connection_workers: Arc::new(DashMap::new()),
url_to_id: Arc::new(DashMap::new()),
}
}
/// Register a new worker
pub fn register(&self, worker: Arc<dyn Worker>) -> WorkerId {
let worker_id = if let Some(existing_id) = self.url_to_id.get(worker.url()) {
// Worker with this URL already exists, update it
existing_id.clone()
} else {
WorkerId::new()
};
// Store worker
self.workers.insert(worker_id.clone(), worker.clone());
// Update URL mapping
self.url_to_id
.insert(worker.url().to_string(), worker_id.clone());
// Update model index (both ID-based and optimized)
let model_id = worker.model_id().to_string();
self.model_workers
.entry(model_id.clone())
.or_default()
.push(worker_id.clone());
// Update optimized model index for O(1) lookups
self.model_index
.entry(model_id)
.or_insert_with(|| Arc::new(RwLock::new(Vec::new())))
.write()
.expect("RwLock for model_index is poisoned")
.push(worker.clone());
// Update type index
self.type_workers
.entry(worker.worker_type())
.or_default()
.push(worker_id.clone());
// Update connection mode index
self.connection_workers
.entry(worker.connection_mode())
.or_default()
.push(worker_id.clone());
worker_id
}
/// Remove a worker by ID
pub fn remove(&self, worker_id: &WorkerId) -> Option<Arc<dyn Worker>> {
if let Some((_, worker)) = self.workers.remove(worker_id) {
// Remove from URL mapping
self.url_to_id.remove(worker.url());
// Remove from model index (both ID-based and optimized)
if let Some(mut model_workers) = self.model_workers.get_mut(worker.model_id()) {
model_workers.retain(|id| id != worker_id);
}
// Remove from optimized model index
if let Some(model_index_entry) = self.model_index.get(worker.model_id()) {
let worker_url = worker.url();
model_index_entry
.write()
.expect("RwLock for model_index is poisoned")
.retain(|w| w.url() != worker_url);
}
// Remove from type index
if let Some(mut type_workers) = self.type_workers.get_mut(&worker.worker_type()) {
type_workers.retain(|id| id != worker_id);
}
// Remove from connection mode index
if let Some(mut conn_workers) =
self.connection_workers.get_mut(&worker.connection_mode())
{
conn_workers.retain(|id| id != worker_id);
}
Some(worker)
} else {
None
}
}
/// Remove a worker by URL
pub fn remove_by_url(&self, url: &str) -> Option<Arc<dyn Worker>> {
if let Some((_, worker_id)) = self.url_to_id.remove(url) {
self.remove(&worker_id)
} else {
None
}
}
/// Get a worker by ID
pub fn get(&self, worker_id: &WorkerId) -> Option<Arc<dyn Worker>> {
self.workers.get(worker_id).map(|entry| entry.clone())
}
/// Get a worker by URL
pub fn get_by_url(&self, url: &str) -> Option<Arc<dyn Worker>> {
self.url_to_id.get(url).and_then(|id| self.get(&id))
}
/// Get all workers for a model
pub fn get_by_model(&self, model_id: &str) -> Vec<Arc<dyn Worker>> {
self.model_workers
.get(model_id)
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
.unwrap_or_default()
}
/// Get all workers for a model (O(1) optimized version)
/// This method uses the pre-indexed model_index for fast lookups
pub fn get_by_model_fast(&self, model_id: &str) -> Vec<Arc<dyn Worker>> {
self.model_index
.get(model_id)
.map(|workers| {
workers
.read()
.expect("RwLock for model_index is poisoned")
.clone()
})
.unwrap_or_default()
}
/// Get all workers by worker type
pub fn get_by_type(&self, worker_type: &WorkerType) -> Vec<Arc<dyn Worker>> {
self.type_workers
.get(worker_type)
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
.unwrap_or_default()
}
/// Get all prefill workers (regardless of bootstrap_port)
pub fn get_prefill_workers(&self) -> Vec<Arc<dyn Worker>> {
self.workers
.iter()
.filter_map(|entry| {
let worker = entry.value();
match worker.worker_type() {
WorkerType::Prefill { .. } => Some(worker.clone()),
_ => None,
}
})
.collect()
}
/// Get all decode workers
pub fn get_decode_workers(&self) -> Vec<Arc<dyn Worker>> {
self.get_by_type(&WorkerType::Decode)
}
/// Get all workers by connection mode
pub fn get_by_connection(&self, connection_mode: &ConnectionMode) -> Vec<Arc<dyn Worker>> {
self.connection_workers
.get(connection_mode)
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
.unwrap_or_default()
}
/// Get all workers
pub fn get_all(&self) -> Vec<Arc<dyn Worker>> {
self.workers
.iter()
.map(|entry| entry.value().clone())
.collect()
}
/// Get all workers with their IDs
pub fn get_all_with_ids(&self) -> Vec<(WorkerId, Arc<dyn Worker>)> {
self.workers
.iter()
.map(|entry| (entry.key().clone(), entry.value().clone()))
.collect()
}
/// Get all worker URLs
pub fn get_all_urls(&self) -> Vec<String> {
self.workers
.iter()
.map(|entry| entry.value().url().to_string())
.collect()
}
/// Get all model IDs with workers
pub fn get_models(&self) -> Vec<String> {
self.model_workers
.iter()
.filter(|entry| !entry.value().is_empty())
.map(|entry| entry.key().clone())
.collect()
}
/// Get workers filtered by multiple criteria
///
/// This method allows flexible filtering of workers based on:
/// - model_id: Filter by specific model
/// - worker_type: Filter by worker type (Regular, Prefill, Decode)
/// - connection_mode: Filter by connection mode (Http, Grpc)
/// - healthy_only: Only return healthy workers
pub fn get_workers_filtered(
&self,
model_id: Option<&str>,
worker_type: Option<WorkerType>,
connection_mode: Option<ConnectionMode>,
healthy_only: bool,
) -> Vec<Arc<dyn Worker>> {
// Start with the most efficient collection based on filters
// Use model index when possible as it's O(1) lookup
let workers = if let Some(model) = model_id {
self.get_by_model_fast(model)
} else {
self.get_all()
};
// Apply remaining filters
workers
.into_iter()
.filter(|w| {
// Check worker_type if specified
if let Some(ref wtype) = worker_type {
if w.worker_type() != *wtype {
return false;
}
}
// Check connection_mode if specified
if let Some(ref conn) = connection_mode {
if w.connection_mode() != *conn {
return false;
}
}
// Check health if required
if healthy_only && !w.is_healthy() {
return false;
}
true
})
.collect()
}
/// Get worker statistics
pub fn stats(&self) -> WorkerRegistryStats {
let total_workers = self.workers.len();
let total_models = self.get_models().len();
let mut healthy_count = 0;
let mut total_load = 0;
let mut regular_count = 0;
let mut prefill_count = 0;
let mut decode_count = 0;
for worker in self.get_all() {
if worker.is_healthy() {
healthy_count += 1;
}
total_load += worker.load();
match worker.worker_type() {
WorkerType::Regular => regular_count += 1,
WorkerType::Prefill { .. } => prefill_count += 1,
WorkerType::Decode => decode_count += 1,
}
}
WorkerRegistryStats {
total_workers,
total_models,
healthy_workers: healthy_count,
total_load,
regular_workers: regular_count,
prefill_workers: prefill_count,
decode_workers: decode_count,
}
}
/// Start a health checker for all workers in the registry
/// This should be called once after the registry is populated with workers
pub fn start_health_checker(&self, check_interval_secs: u64) -> crate::core::HealthChecker {
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
let shutdown = Arc::new(AtomicBool::new(false));
let shutdown_clone = shutdown.clone();
let workers_ref = self.workers.clone();
let handle = tokio::spawn(async move {
let mut interval =
tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs));
// Counter for periodic load reset (every 10 health check cycles)
let mut check_count = 0u64;
const LOAD_RESET_INTERVAL: u64 = 10;
loop {
interval.tick().await;
// Check for shutdown signal
if shutdown_clone.load(Ordering::Acquire) {
tracing::debug!("Registry health checker shutting down");
break;
}
// Get all workers from registry
let workers: Vec<Arc<dyn crate::core::Worker>> = workers_ref
.iter()
.map(|entry| entry.value().clone())
.collect();
// Perform health checks
for worker in &workers {
let _ = worker.check_health_async().await; // Use async version directly
}
// Reset loads periodically
check_count += 1;
if check_count % LOAD_RESET_INTERVAL == 0 {
tracing::debug!("Resetting worker loads (cycle {})", check_count);
for worker in &workers {
worker.reset_load();
}
}
}
});
crate::core::HealthChecker::new(handle, shutdown)
}
}
impl Default for WorkerRegistry {
fn default() -> Self {
Self::new()
}
}
/// Statistics for the worker registry
#[derive(Debug, Clone)]
pub struct WorkerRegistryStats {
pub total_workers: usize,
pub total_models: usize,
pub healthy_workers: usize,
pub total_load: usize,
pub regular_workers: usize,
pub prefill_workers: usize,
pub decode_workers: usize,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::{CircuitBreakerConfig, WorkerFactory};
use std::collections::HashMap;
#[test]
fn test_worker_registry() {
let registry = WorkerRegistry::new();
// Create a worker with labels
let mut labels = HashMap::new();
labels.insert("model_id".to_string(), "llama-3-8b".to_string());
labels.insert("priority".to_string(), "50".to_string());
labels.insert("cost".to_string(), "0.8".to_string());
let worker = WorkerFactory::create_regular_with_labels(
"http://worker1:8080".to_string(),
labels,
CircuitBreakerConfig::default(),
);
// Register worker (WorkerFactory returns Box<dyn Worker>, convert to Arc)
let worker_id = registry.register(Arc::from(worker));
// Verify registration
assert!(registry.get(&worker_id).is_some());
assert!(registry.get_by_url("http://worker1:8080").is_some());
assert_eq!(registry.get_by_model("llama-3-8b").len(), 1);
assert_eq!(registry.get_by_type(&WorkerType::Regular).len(), 1);
assert_eq!(registry.get_by_connection(&ConnectionMode::Http).len(), 1);
// Test stats
let stats = registry.stats();
assert_eq!(stats.total_workers, 1);
assert_eq!(stats.total_models, 1);
// Remove worker
registry.remove(&worker_id);
assert!(registry.get(&worker_id).is_none());
}
#[test]
fn test_model_index_fast_lookup() {
let registry = WorkerRegistry::new();
// Create workers for different models
let mut labels1 = HashMap::new();
labels1.insert("model_id".to_string(), "llama-3".to_string());
let worker1 = WorkerFactory::create_regular_with_labels(
"http://worker1:8080".to_string(),
labels1,
CircuitBreakerConfig::default(),
);
let mut labels2 = HashMap::new();
labels2.insert("model_id".to_string(), "llama-3".to_string());
let worker2 = WorkerFactory::create_regular_with_labels(
"http://worker2:8080".to_string(),
labels2,
CircuitBreakerConfig::default(),
);
let mut labels3 = HashMap::new();
labels3.insert("model_id".to_string(), "gpt-4".to_string());
let worker3 = WorkerFactory::create_regular_with_labels(
"http://worker3:8080".to_string(),
labels3,
CircuitBreakerConfig::default(),
);
// Register workers
registry.register(Arc::from(worker1));
registry.register(Arc::from(worker2));
registry.register(Arc::from(worker3));
// Test get_by_model_fast for llama-3
let llama_workers = registry.get_by_model_fast("llama-3");
assert_eq!(llama_workers.len(), 2);
let urls: Vec<String> = llama_workers.iter().map(|w| w.url().to_string()).collect();
assert!(urls.contains(&"http://worker1:8080".to_string()));
assert!(urls.contains(&"http://worker2:8080".to_string()));
// Test get_by_model_fast for gpt-4
let gpt_workers = registry.get_by_model_fast("gpt-4");
assert_eq!(gpt_workers.len(), 1);
assert_eq!(gpt_workers[0].url(), "http://worker3:8080");
// Test get_by_model_fast for non-existent model
let unknown_workers = registry.get_by_model_fast("unknown-model");
assert_eq!(unknown_workers.len(), 0);
// Test that both get_by_model and get_by_model_fast return same results
let llama_workers_slow = registry.get_by_model("llama-3");
assert_eq!(llama_workers.len(), llama_workers_slow.len());
// Test removal updates the model index
registry.remove_by_url("http://worker1:8080");
let llama_workers_after = registry.get_by_model_fast("llama-3");
assert_eq!(llama_workers_after.len(), 1);
assert_eq!(llama_workers_after[0].url(), "http://worker2:8080");
}
}