542 lines
18 KiB
Rust
542 lines
18 KiB
Rust
//! Worker Registry for multi-router support
|
|
//!
|
|
//! Provides centralized registry for workers with model-based indexing
|
|
|
|
use crate::core::{ConnectionMode, Worker, WorkerType};
|
|
use dashmap::DashMap;
|
|
use std::sync::{Arc, RwLock};
|
|
use uuid::Uuid;
|
|
|
|
/// Unique identifier for a worker
|
|
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
|
pub struct WorkerId(String);
|
|
|
|
impl WorkerId {
|
|
/// Create a new worker ID
|
|
pub fn new() -> Self {
|
|
Self(Uuid::new_v4().to_string())
|
|
}
|
|
|
|
/// Create a worker ID from a string
|
|
pub fn from_string(s: String) -> Self {
|
|
Self(s)
|
|
}
|
|
|
|
/// Get the ID as a string
|
|
pub fn as_str(&self) -> &str {
|
|
&self.0
|
|
}
|
|
}
|
|
|
|
impl Default for WorkerId {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
type ModelIndex = Arc<DashMap<String, Arc<RwLock<Vec<Arc<dyn Worker>>>>>>;
|
|
|
|
/// Worker registry with model-based indexing
|
|
#[derive(Debug)]
|
|
pub struct WorkerRegistry {
|
|
/// All workers indexed by ID
|
|
workers: Arc<DashMap<WorkerId, Arc<dyn Worker>>>,
|
|
|
|
/// Workers indexed by model ID (stores WorkerId for reference)
|
|
model_workers: Arc<DashMap<String, Vec<WorkerId>>>,
|
|
|
|
/// Optimized model index for O(1) lookups (stores Arc<dyn Worker> directly)
|
|
model_index: ModelIndex,
|
|
|
|
/// Workers indexed by worker type
|
|
type_workers: Arc<DashMap<WorkerType, Vec<WorkerId>>>,
|
|
|
|
/// Workers indexed by connection mode
|
|
connection_workers: Arc<DashMap<ConnectionMode, Vec<WorkerId>>>,
|
|
/// URL to worker ID mapping
|
|
url_to_id: Arc<DashMap<String, WorkerId>>,
|
|
}
|
|
|
|
impl WorkerRegistry {
|
|
/// Create a new worker registry
|
|
pub fn new() -> Self {
|
|
Self {
|
|
workers: Arc::new(DashMap::new()),
|
|
model_workers: Arc::new(DashMap::new()),
|
|
model_index: Arc::new(DashMap::new()),
|
|
type_workers: Arc::new(DashMap::new()),
|
|
connection_workers: Arc::new(DashMap::new()),
|
|
url_to_id: Arc::new(DashMap::new()),
|
|
}
|
|
}
|
|
|
|
/// Register a new worker
|
|
pub fn register(&self, worker: Arc<dyn Worker>) -> WorkerId {
|
|
let worker_id = if let Some(existing_id) = self.url_to_id.get(worker.url()) {
|
|
// Worker with this URL already exists, update it
|
|
existing_id.clone()
|
|
} else {
|
|
WorkerId::new()
|
|
};
|
|
|
|
// Store worker
|
|
self.workers.insert(worker_id.clone(), worker.clone());
|
|
|
|
// Update URL mapping
|
|
self.url_to_id
|
|
.insert(worker.url().to_string(), worker_id.clone());
|
|
|
|
// Update model index (both ID-based and optimized)
|
|
let model_id = worker.model_id().to_string();
|
|
self.model_workers
|
|
.entry(model_id.clone())
|
|
.or_default()
|
|
.push(worker_id.clone());
|
|
|
|
// Update optimized model index for O(1) lookups
|
|
self.model_index
|
|
.entry(model_id)
|
|
.or_insert_with(|| Arc::new(RwLock::new(Vec::new())))
|
|
.write()
|
|
.expect("RwLock for model_index is poisoned")
|
|
.push(worker.clone());
|
|
|
|
// Update type index
|
|
self.type_workers
|
|
.entry(worker.worker_type())
|
|
.or_default()
|
|
.push(worker_id.clone());
|
|
|
|
// Update connection mode index
|
|
self.connection_workers
|
|
.entry(worker.connection_mode())
|
|
.or_default()
|
|
.push(worker_id.clone());
|
|
|
|
worker_id
|
|
}
|
|
|
|
/// Remove a worker by ID
|
|
pub fn remove(&self, worker_id: &WorkerId) -> Option<Arc<dyn Worker>> {
|
|
if let Some((_, worker)) = self.workers.remove(worker_id) {
|
|
// Remove from URL mapping
|
|
self.url_to_id.remove(worker.url());
|
|
|
|
// Remove from model index (both ID-based and optimized)
|
|
if let Some(mut model_workers) = self.model_workers.get_mut(worker.model_id()) {
|
|
model_workers.retain(|id| id != worker_id);
|
|
}
|
|
|
|
// Remove from optimized model index
|
|
if let Some(model_index_entry) = self.model_index.get(worker.model_id()) {
|
|
let worker_url = worker.url();
|
|
model_index_entry
|
|
.write()
|
|
.expect("RwLock for model_index is poisoned")
|
|
.retain(|w| w.url() != worker_url);
|
|
}
|
|
|
|
// Remove from type index
|
|
if let Some(mut type_workers) = self.type_workers.get_mut(&worker.worker_type()) {
|
|
type_workers.retain(|id| id != worker_id);
|
|
}
|
|
|
|
// Remove from connection mode index
|
|
if let Some(mut conn_workers) =
|
|
self.connection_workers.get_mut(&worker.connection_mode())
|
|
{
|
|
conn_workers.retain(|id| id != worker_id);
|
|
}
|
|
|
|
Some(worker)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// Remove a worker by URL
|
|
pub fn remove_by_url(&self, url: &str) -> Option<Arc<dyn Worker>> {
|
|
if let Some((_, worker_id)) = self.url_to_id.remove(url) {
|
|
self.remove(&worker_id)
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
/// Get a worker by ID
|
|
pub fn get(&self, worker_id: &WorkerId) -> Option<Arc<dyn Worker>> {
|
|
self.workers.get(worker_id).map(|entry| entry.clone())
|
|
}
|
|
|
|
/// Get a worker by URL
|
|
pub fn get_by_url(&self, url: &str) -> Option<Arc<dyn Worker>> {
|
|
self.url_to_id.get(url).and_then(|id| self.get(&id))
|
|
}
|
|
|
|
/// Get all workers for a model
|
|
pub fn get_by_model(&self, model_id: &str) -> Vec<Arc<dyn Worker>> {
|
|
self.model_workers
|
|
.get(model_id)
|
|
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
/// Get all workers for a model (O(1) optimized version)
|
|
/// This method uses the pre-indexed model_index for fast lookups
|
|
pub fn get_by_model_fast(&self, model_id: &str) -> Vec<Arc<dyn Worker>> {
|
|
self.model_index
|
|
.get(model_id)
|
|
.map(|workers| {
|
|
workers
|
|
.read()
|
|
.expect("RwLock for model_index is poisoned")
|
|
.clone()
|
|
})
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
/// Get all workers by worker type
|
|
pub fn get_by_type(&self, worker_type: &WorkerType) -> Vec<Arc<dyn Worker>> {
|
|
self.type_workers
|
|
.get(worker_type)
|
|
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
/// Get all prefill workers (regardless of bootstrap_port)
|
|
pub fn get_prefill_workers(&self) -> Vec<Arc<dyn Worker>> {
|
|
self.workers
|
|
.iter()
|
|
.filter_map(|entry| {
|
|
let worker = entry.value();
|
|
match worker.worker_type() {
|
|
WorkerType::Prefill { .. } => Some(worker.clone()),
|
|
_ => None,
|
|
}
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
/// Get all decode workers
|
|
pub fn get_decode_workers(&self) -> Vec<Arc<dyn Worker>> {
|
|
self.get_by_type(&WorkerType::Decode)
|
|
}
|
|
|
|
/// Get all workers by connection mode
|
|
pub fn get_by_connection(&self, connection_mode: &ConnectionMode) -> Vec<Arc<dyn Worker>> {
|
|
self.connection_workers
|
|
.get(connection_mode)
|
|
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
|
|
.unwrap_or_default()
|
|
}
|
|
|
|
/// Get all workers
|
|
pub fn get_all(&self) -> Vec<Arc<dyn Worker>> {
|
|
self.workers
|
|
.iter()
|
|
.map(|entry| entry.value().clone())
|
|
.collect()
|
|
}
|
|
|
|
/// Get all workers with their IDs
|
|
pub fn get_all_with_ids(&self) -> Vec<(WorkerId, Arc<dyn Worker>)> {
|
|
self.workers
|
|
.iter()
|
|
.map(|entry| (entry.key().clone(), entry.value().clone()))
|
|
.collect()
|
|
}
|
|
|
|
/// Get all worker URLs
|
|
pub fn get_all_urls(&self) -> Vec<String> {
|
|
self.workers
|
|
.iter()
|
|
.map(|entry| entry.value().url().to_string())
|
|
.collect()
|
|
}
|
|
|
|
pub fn get_all_urls_with_api_key(&self) -> Vec<(String, Option<String>)> {
|
|
self.workers
|
|
.iter()
|
|
.map(|entry| {
|
|
(
|
|
entry.value().url().to_string(),
|
|
entry.value().api_key().clone(),
|
|
)
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
/// Get all model IDs with workers
|
|
pub fn get_models(&self) -> Vec<String> {
|
|
self.model_workers
|
|
.iter()
|
|
.filter(|entry| !entry.value().is_empty())
|
|
.map(|entry| entry.key().clone())
|
|
.collect()
|
|
}
|
|
|
|
/// Get workers filtered by multiple criteria
|
|
///
|
|
/// This method allows flexible filtering of workers based on:
|
|
/// - model_id: Filter by specific model
|
|
/// - worker_type: Filter by worker type (Regular, Prefill, Decode)
|
|
/// - connection_mode: Filter by connection mode (Http, Grpc)
|
|
/// - healthy_only: Only return healthy workers
|
|
pub fn get_workers_filtered(
|
|
&self,
|
|
model_id: Option<&str>,
|
|
worker_type: Option<WorkerType>,
|
|
connection_mode: Option<ConnectionMode>,
|
|
healthy_only: bool,
|
|
) -> Vec<Arc<dyn Worker>> {
|
|
// Start with the most efficient collection based on filters
|
|
// Use model index when possible as it's O(1) lookup
|
|
let workers = if let Some(model) = model_id {
|
|
self.get_by_model_fast(model)
|
|
} else {
|
|
self.get_all()
|
|
};
|
|
|
|
// Apply remaining filters
|
|
workers
|
|
.into_iter()
|
|
.filter(|w| {
|
|
// Check worker_type if specified
|
|
if let Some(ref wtype) = worker_type {
|
|
if w.worker_type() != *wtype {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Check connection_mode if specified
|
|
if let Some(ref conn) = connection_mode {
|
|
if w.connection_mode() != *conn {
|
|
return false;
|
|
}
|
|
}
|
|
|
|
// Check health if required
|
|
if healthy_only && !w.is_healthy() {
|
|
return false;
|
|
}
|
|
|
|
true
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
/// Get worker statistics
|
|
pub fn stats(&self) -> WorkerRegistryStats {
|
|
let total_workers = self.workers.len();
|
|
let total_models = self.get_models().len();
|
|
|
|
let mut healthy_count = 0;
|
|
let mut total_load = 0;
|
|
let mut regular_count = 0;
|
|
let mut prefill_count = 0;
|
|
let mut decode_count = 0;
|
|
|
|
for worker in self.get_all() {
|
|
if worker.is_healthy() {
|
|
healthy_count += 1;
|
|
}
|
|
total_load += worker.load();
|
|
|
|
match worker.worker_type() {
|
|
WorkerType::Regular => regular_count += 1,
|
|
WorkerType::Prefill { .. } => prefill_count += 1,
|
|
WorkerType::Decode => decode_count += 1,
|
|
}
|
|
}
|
|
|
|
WorkerRegistryStats {
|
|
total_workers,
|
|
total_models,
|
|
healthy_workers: healthy_count,
|
|
total_load,
|
|
regular_workers: regular_count,
|
|
prefill_workers: prefill_count,
|
|
decode_workers: decode_count,
|
|
}
|
|
}
|
|
|
|
/// Start a health checker for all workers in the registry
|
|
/// This should be called once after the registry is populated with workers
|
|
pub fn start_health_checker(&self, check_interval_secs: u64) -> crate::core::HealthChecker {
|
|
use std::sync::atomic::{AtomicBool, Ordering};
|
|
use std::sync::Arc;
|
|
|
|
let shutdown = Arc::new(AtomicBool::new(false));
|
|
let shutdown_clone = shutdown.clone();
|
|
let workers_ref = self.workers.clone();
|
|
|
|
let handle = tokio::spawn(async move {
|
|
let mut interval =
|
|
tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs));
|
|
|
|
// Counter for periodic load reset (every 10 health check cycles)
|
|
let mut check_count = 0u64;
|
|
const LOAD_RESET_INTERVAL: u64 = 10;
|
|
|
|
loop {
|
|
interval.tick().await;
|
|
|
|
// Check for shutdown signal
|
|
if shutdown_clone.load(Ordering::Acquire) {
|
|
tracing::debug!("Registry health checker shutting down");
|
|
break;
|
|
}
|
|
|
|
// Get all workers from registry
|
|
let workers: Vec<Arc<dyn crate::core::Worker>> = workers_ref
|
|
.iter()
|
|
.map(|entry| entry.value().clone())
|
|
.collect();
|
|
|
|
// Perform health checks
|
|
for worker in &workers {
|
|
let _ = worker.check_health_async().await; // Use async version directly
|
|
}
|
|
|
|
// Reset loads periodically
|
|
check_count += 1;
|
|
if check_count.is_multiple_of(LOAD_RESET_INTERVAL) {
|
|
tracing::debug!("Resetting worker loads (cycle {})", check_count);
|
|
for worker in &workers {
|
|
worker.reset_load();
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
crate::core::HealthChecker::new(handle, shutdown)
|
|
}
|
|
}
|
|
|
|
impl Default for WorkerRegistry {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
/// Statistics for the worker registry
|
|
#[derive(Debug, Clone)]
|
|
pub struct WorkerRegistryStats {
|
|
pub total_workers: usize,
|
|
pub total_models: usize,
|
|
pub healthy_workers: usize,
|
|
pub total_load: usize,
|
|
pub regular_workers: usize,
|
|
pub prefill_workers: usize,
|
|
pub decode_workers: usize,
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig};
|
|
use std::collections::HashMap;
|
|
|
|
#[test]
|
|
fn test_worker_registry() {
|
|
let registry = WorkerRegistry::new();
|
|
|
|
// Create a worker with labels
|
|
let mut labels = HashMap::new();
|
|
labels.insert("model_id".to_string(), "llama-3-8b".to_string());
|
|
labels.insert("priority".to_string(), "50".to_string());
|
|
labels.insert("cost".to_string(), "0.8".to_string());
|
|
|
|
let worker: Box<dyn Worker> = Box::new(
|
|
BasicWorkerBuilder::new("http://worker1:8080")
|
|
.worker_type(WorkerType::Regular)
|
|
.labels(labels)
|
|
.circuit_breaker_config(CircuitBreakerConfig::default())
|
|
.api_key("test_api_key")
|
|
.build(),
|
|
);
|
|
|
|
// Register worker (WorkerFactory returns Box<dyn Worker>, convert to Arc)
|
|
let worker_id = registry.register(Arc::from(worker));
|
|
|
|
assert!(registry.get(&worker_id).is_some());
|
|
assert!(registry.get_by_url("http://worker1:8080").is_some());
|
|
assert_eq!(registry.get_by_model("llama-3-8b").len(), 1);
|
|
assert_eq!(registry.get_by_type(&WorkerType::Regular).len(), 1);
|
|
assert_eq!(registry.get_by_connection(&ConnectionMode::Http).len(), 1);
|
|
|
|
let stats = registry.stats();
|
|
assert_eq!(stats.total_workers, 1);
|
|
assert_eq!(stats.total_models, 1);
|
|
|
|
// Remove worker
|
|
registry.remove(&worker_id);
|
|
assert!(registry.get(&worker_id).is_none());
|
|
}
|
|
|
|
#[test]
|
|
fn test_model_index_fast_lookup() {
|
|
let registry = WorkerRegistry::new();
|
|
|
|
// Create workers for different models
|
|
let mut labels1 = HashMap::new();
|
|
labels1.insert("model_id".to_string(), "llama-3".to_string());
|
|
let worker1: Box<dyn Worker> = Box::new(
|
|
BasicWorkerBuilder::new("http://worker1:8080")
|
|
.worker_type(WorkerType::Regular)
|
|
.labels(labels1)
|
|
.circuit_breaker_config(CircuitBreakerConfig::default())
|
|
.api_key("test_api_key")
|
|
.build(),
|
|
);
|
|
|
|
let mut labels2 = HashMap::new();
|
|
labels2.insert("model_id".to_string(), "llama-3".to_string());
|
|
let worker2: Box<dyn Worker> = Box::new(
|
|
BasicWorkerBuilder::new("http://worker2:8080")
|
|
.worker_type(WorkerType::Regular)
|
|
.labels(labels2)
|
|
.circuit_breaker_config(CircuitBreakerConfig::default())
|
|
.api_key("test_api_key")
|
|
.build(),
|
|
);
|
|
|
|
let mut labels3 = HashMap::new();
|
|
labels3.insert("model_id".to_string(), "gpt-4".to_string());
|
|
let worker3: Box<dyn Worker> = Box::new(
|
|
BasicWorkerBuilder::new("http://worker3:8080")
|
|
.worker_type(WorkerType::Regular)
|
|
.labels(labels3)
|
|
.circuit_breaker_config(CircuitBreakerConfig::default())
|
|
.api_key("test_api_key")
|
|
.build(),
|
|
);
|
|
|
|
// Register workers
|
|
registry.register(Arc::from(worker1));
|
|
registry.register(Arc::from(worker2));
|
|
registry.register(Arc::from(worker3));
|
|
|
|
let llama_workers = registry.get_by_model_fast("llama-3");
|
|
assert_eq!(llama_workers.len(), 2);
|
|
let urls: Vec<String> = llama_workers.iter().map(|w| w.url().to_string()).collect();
|
|
assert!(urls.contains(&"http://worker1:8080".to_string()));
|
|
assert!(urls.contains(&"http://worker2:8080".to_string()));
|
|
|
|
let gpt_workers = registry.get_by_model_fast("gpt-4");
|
|
assert_eq!(gpt_workers.len(), 1);
|
|
assert_eq!(gpt_workers[0].url(), "http://worker3:8080");
|
|
|
|
let unknown_workers = registry.get_by_model_fast("unknown-model");
|
|
assert_eq!(unknown_workers.len(), 0);
|
|
|
|
let llama_workers_slow = registry.get_by_model("llama-3");
|
|
assert_eq!(llama_workers.len(), llama_workers_slow.len());
|
|
|
|
registry.remove_by_url("http://worker1:8080");
|
|
let llama_workers_after = registry.get_by_model_fast("llama-3");
|
|
assert_eq!(llama_workers_after.len(), 1);
|
|
assert_eq!(llama_workers_after[0].url(), "http://worker2:8080");
|
|
}
|
|
}
|