[router] allow one router to support different model families and serving mode (#10244)
This commit is contained in:
@@ -11,6 +11,7 @@ pub mod error;
|
||||
pub mod retry;
|
||||
pub mod token_bucket;
|
||||
pub mod worker;
|
||||
pub mod worker_registry;
|
||||
|
||||
// Re-export commonly used types at the module level
|
||||
pub use circuit_breaker::{
|
||||
@@ -22,3 +23,4 @@ pub use worker::{
|
||||
start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig,
|
||||
Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType,
|
||||
};
|
||||
pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats};
|
||||
|
||||
@@ -155,6 +155,82 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
||||
fn can_handle(&self, _req: &serde_json::Value) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
// === Multi-router support ===
|
||||
|
||||
// TODO: - Enhanced Worker Discovery
|
||||
// The Worker trait should handle async discovery of metadata from the worker itself
|
||||
// rather than having service discovery or other components query /get_server_info.
|
||||
// This keeps service discovery decoupled from worker-specific APIs.
|
||||
//
|
||||
// Proposed additions:
|
||||
// - async fn discover_metadata(&mut self) -> Result<(), Error>
|
||||
// Query /get_server_info and populate metadata labels with model_id, priority, cost, etc.
|
||||
// - async fn validate_configuration(&self) -> Result<(), Error>
|
||||
// Ensure worker has required configuration for its mode (e.g., tokenizer for gRPC)
|
||||
// - Make worker creation async to allow metadata discovery during initialization
|
||||
//
|
||||
// This way service discovery just calls router.add_worker() and the worker
|
||||
// handles its own metadata discovery internally.
|
||||
|
||||
/// Get the model ID this worker serves
|
||||
fn model_id(&self) -> &str {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("model_id")
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("unknown")
|
||||
}
|
||||
|
||||
/// Get the priority of this worker (higher value = higher priority)
|
||||
fn priority(&self) -> u32 {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("priority")
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(50) // Default priority is 50 (mid-range)
|
||||
}
|
||||
|
||||
/// Get the cost factor of this worker (1.0 = baseline)
|
||||
fn cost(&self) -> f32 {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("cost")
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1.0)
|
||||
}
|
||||
|
||||
/// Get the tokenizer path for this worker (gRPC mode only)
|
||||
fn tokenizer_path(&self) -> Option<&str> {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("tokenizer_path")
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
|
||||
/// Get the reasoning parser type for this worker (gRPC mode only)
|
||||
fn reasoning_parser(&self) -> Option<&str> {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("reasoning_parser")
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
|
||||
/// Get the tool parser type for this worker (gRPC mode only)
|
||||
fn tool_parser(&self) -> Option<&str> {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("tool_parser")
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
|
||||
/// Get the chat template for this worker (gRPC mode only)
|
||||
fn chat_template(&self) -> Option<&str> {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("chat_template")
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection mode for worker communication
|
||||
@@ -724,6 +800,21 @@ impl WorkerFactory {
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a regular worker with custom labels (for multi-router support)
|
||||
pub fn create_regular_with_labels(
|
||||
url: String,
|
||||
labels: std::collections::HashMap<String, String>,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
) -> Box<dyn Worker> {
|
||||
let mut worker = BasicWorker::new(url.clone(), WorkerType::Regular)
|
||||
.with_circuit_breaker_config(circuit_breaker_config);
|
||||
|
||||
// Add labels to metadata
|
||||
worker.metadata.labels = labels;
|
||||
|
||||
Box::new(worker)
|
||||
}
|
||||
|
||||
/// Create a DP-aware worker of specified type
|
||||
pub fn create_dp_aware(
|
||||
base_url: String,
|
||||
@@ -941,6 +1032,11 @@ impl fmt::Debug for HealthChecker {
|
||||
}
|
||||
|
||||
impl HealthChecker {
|
||||
/// Create a new HealthChecker
|
||||
pub fn new(handle: tokio::task::JoinHandle<()>, shutdown: Arc<AtomicBool>) -> Self {
|
||||
Self { handle, shutdown }
|
||||
}
|
||||
|
||||
/// Shutdown the health checker gracefully
|
||||
pub async fn shutdown(self) {
|
||||
self.shutdown.store(true, Ordering::Release);
|
||||
@@ -950,7 +1046,7 @@ impl HealthChecker {
|
||||
|
||||
/// Start an async background health checker for a collection of workers
|
||||
pub fn start_health_checker(
|
||||
workers: std::sync::Arc<std::sync::RwLock<Vec<Box<dyn Worker>>>>,
|
||||
workers: std::sync::Arc<std::sync::RwLock<Vec<std::sync::Arc<dyn Worker>>>>,
|
||||
check_interval_secs: u64,
|
||||
) -> HealthChecker {
|
||||
let shutdown = Arc::new(AtomicBool::new(false));
|
||||
@@ -1602,9 +1698,11 @@ mod tests {
|
||||
// Test HealthChecker background task
|
||||
#[tokio::test]
|
||||
async fn test_health_checker_startup() {
|
||||
let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular(
|
||||
let worker = Arc::new(BasicWorker::new(
|
||||
"http://w1:8080".to_string(),
|
||||
)]));
|
||||
WorkerType::Regular,
|
||||
)) as Arc<dyn Worker>;
|
||||
let workers = Arc::new(RwLock::new(vec![worker]));
|
||||
|
||||
let checker = start_health_checker(workers.clone(), 60);
|
||||
|
||||
@@ -1617,9 +1715,11 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_checker_shutdown() {
|
||||
let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular(
|
||||
let worker = Arc::new(BasicWorker::new(
|
||||
"http://w1:8080".to_string(),
|
||||
)]));
|
||||
WorkerType::Regular,
|
||||
)) as Arc<dyn Worker>;
|
||||
let workers = Arc::new(RwLock::new(vec![worker]));
|
||||
|
||||
let checker = start_health_checker(workers.clone(), 60);
|
||||
|
||||
|
||||
526
sgl-router/src/core/worker_registry.rs
Normal file
526
sgl-router/src/core/worker_registry.rs
Normal file
@@ -0,0 +1,526 @@
|
||||
//! Worker Registry for multi-router support
|
||||
//!
|
||||
//! Provides centralized registry for workers with model-based indexing
|
||||
|
||||
use crate::core::{ConnectionMode, Worker, WorkerType};
|
||||
use dashmap::DashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Unique identifier for a worker
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct WorkerId(String);
|
||||
|
||||
impl WorkerId {
|
||||
/// Create a new worker ID
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4().to_string())
|
||||
}
|
||||
|
||||
/// Create a worker ID from a string
|
||||
pub fn from_string(s: String) -> Self {
|
||||
Self(s)
|
||||
}
|
||||
|
||||
/// Get the ID as a string
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WorkerId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Type alias for the model index to reduce complexity
|
||||
type ModelIndex = Arc<DashMap<String, Arc<RwLock<Vec<Arc<dyn Worker>>>>>>;
|
||||
|
||||
/// Worker registry with model-based indexing
|
||||
#[derive(Debug)]
|
||||
pub struct WorkerRegistry {
|
||||
/// All workers indexed by ID
|
||||
workers: Arc<DashMap<WorkerId, Arc<dyn Worker>>>,
|
||||
|
||||
/// Workers indexed by model ID (stores WorkerId for reference)
|
||||
model_workers: Arc<DashMap<String, Vec<WorkerId>>>,
|
||||
|
||||
/// Optimized model index for O(1) lookups (stores Arc<dyn Worker> directly)
|
||||
model_index: ModelIndex,
|
||||
|
||||
/// Workers indexed by worker type
|
||||
type_workers: Arc<DashMap<WorkerType, Vec<WorkerId>>>,
|
||||
|
||||
/// Workers indexed by connection mode
|
||||
connection_workers: Arc<DashMap<ConnectionMode, Vec<WorkerId>>>,
|
||||
|
||||
/// URL to worker ID mapping (for backward compatibility)
|
||||
url_to_id: Arc<DashMap<String, WorkerId>>,
|
||||
}
|
||||
|
||||
impl WorkerRegistry {
|
||||
/// Create a new worker registry
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
workers: Arc::new(DashMap::new()),
|
||||
model_workers: Arc::new(DashMap::new()),
|
||||
model_index: Arc::new(DashMap::new()),
|
||||
type_workers: Arc::new(DashMap::new()),
|
||||
connection_workers: Arc::new(DashMap::new()),
|
||||
url_to_id: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new worker
|
||||
pub fn register(&self, worker: Arc<dyn Worker>) -> WorkerId {
|
||||
let worker_id = if let Some(existing_id) = self.url_to_id.get(worker.url()) {
|
||||
// Worker with this URL already exists, update it
|
||||
existing_id.clone()
|
||||
} else {
|
||||
WorkerId::new()
|
||||
};
|
||||
|
||||
// Store worker
|
||||
self.workers.insert(worker_id.clone(), worker.clone());
|
||||
|
||||
// Update URL mapping
|
||||
self.url_to_id
|
||||
.insert(worker.url().to_string(), worker_id.clone());
|
||||
|
||||
// Update model index (both ID-based and optimized)
|
||||
let model_id = worker.model_id().to_string();
|
||||
self.model_workers
|
||||
.entry(model_id.clone())
|
||||
.or_default()
|
||||
.push(worker_id.clone());
|
||||
|
||||
// Update optimized model index for O(1) lookups
|
||||
self.model_index
|
||||
.entry(model_id)
|
||||
.or_insert_with(|| Arc::new(RwLock::new(Vec::new())))
|
||||
.write()
|
||||
.expect("RwLock for model_index is poisoned")
|
||||
.push(worker.clone());
|
||||
|
||||
// Update type index
|
||||
self.type_workers
|
||||
.entry(worker.worker_type())
|
||||
.or_default()
|
||||
.push(worker_id.clone());
|
||||
|
||||
// Update connection mode index
|
||||
self.connection_workers
|
||||
.entry(worker.connection_mode())
|
||||
.or_default()
|
||||
.push(worker_id.clone());
|
||||
|
||||
worker_id
|
||||
}
|
||||
|
||||
/// Remove a worker by ID
|
||||
pub fn remove(&self, worker_id: &WorkerId) -> Option<Arc<dyn Worker>> {
|
||||
if let Some((_, worker)) = self.workers.remove(worker_id) {
|
||||
// Remove from URL mapping
|
||||
self.url_to_id.remove(worker.url());
|
||||
|
||||
// Remove from model index (both ID-based and optimized)
|
||||
if let Some(mut model_workers) = self.model_workers.get_mut(worker.model_id()) {
|
||||
model_workers.retain(|id| id != worker_id);
|
||||
}
|
||||
|
||||
// Remove from optimized model index
|
||||
if let Some(model_index_entry) = self.model_index.get(worker.model_id()) {
|
||||
let worker_url = worker.url();
|
||||
model_index_entry
|
||||
.write()
|
||||
.expect("RwLock for model_index is poisoned")
|
||||
.retain(|w| w.url() != worker_url);
|
||||
}
|
||||
|
||||
// Remove from type index
|
||||
if let Some(mut type_workers) = self.type_workers.get_mut(&worker.worker_type()) {
|
||||
type_workers.retain(|id| id != worker_id);
|
||||
}
|
||||
|
||||
// Remove from connection mode index
|
||||
if let Some(mut conn_workers) =
|
||||
self.connection_workers.get_mut(&worker.connection_mode())
|
||||
{
|
||||
conn_workers.retain(|id| id != worker_id);
|
||||
}
|
||||
|
||||
Some(worker)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker by URL
|
||||
pub fn remove_by_url(&self, url: &str) -> Option<Arc<dyn Worker>> {
|
||||
if let Some((_, worker_id)) = self.url_to_id.remove(url) {
|
||||
self.remove(&worker_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a worker by ID
|
||||
pub fn get(&self, worker_id: &WorkerId) -> Option<Arc<dyn Worker>> {
|
||||
self.workers.get(worker_id).map(|entry| entry.clone())
|
||||
}
|
||||
|
||||
/// Get a worker by URL
|
||||
pub fn get_by_url(&self, url: &str) -> Option<Arc<dyn Worker>> {
|
||||
self.url_to_id.get(url).and_then(|id| self.get(&id))
|
||||
}
|
||||
|
||||
/// Get all workers for a model
|
||||
pub fn get_by_model(&self, model_id: &str) -> Vec<Arc<dyn Worker>> {
|
||||
self.model_workers
|
||||
.get(model_id)
|
||||
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all workers for a model (O(1) optimized version)
|
||||
/// This method uses the pre-indexed model_index for fast lookups
|
||||
pub fn get_by_model_fast(&self, model_id: &str) -> Vec<Arc<dyn Worker>> {
|
||||
self.model_index
|
||||
.get(model_id)
|
||||
.map(|workers| {
|
||||
workers
|
||||
.read()
|
||||
.expect("RwLock for model_index is poisoned")
|
||||
.clone()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all workers by worker type
|
||||
pub fn get_by_type(&self, worker_type: &WorkerType) -> Vec<Arc<dyn Worker>> {
|
||||
self.type_workers
|
||||
.get(worker_type)
|
||||
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all prefill workers (regardless of bootstrap_port)
|
||||
pub fn get_prefill_workers(&self) -> Vec<Arc<dyn Worker>> {
|
||||
self.workers
|
||||
.iter()
|
||||
.filter_map(|entry| {
|
||||
let worker = entry.value();
|
||||
match worker.worker_type() {
|
||||
WorkerType::Prefill { .. } => Some(worker.clone()),
|
||||
_ => None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all decode workers
|
||||
pub fn get_decode_workers(&self) -> Vec<Arc<dyn Worker>> {
|
||||
self.get_by_type(&WorkerType::Decode)
|
||||
}
|
||||
|
||||
/// Get all workers by connection mode
|
||||
pub fn get_by_connection(&self, connection_mode: &ConnectionMode) -> Vec<Arc<dyn Worker>> {
|
||||
self.connection_workers
|
||||
.get(connection_mode)
|
||||
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all workers
|
||||
pub fn get_all(&self) -> Vec<Arc<dyn Worker>> {
|
||||
self.workers
|
||||
.iter()
|
||||
.map(|entry| entry.value().clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all workers with their IDs
|
||||
pub fn get_all_with_ids(&self) -> Vec<(WorkerId, Arc<dyn Worker>)> {
|
||||
self.workers
|
||||
.iter()
|
||||
.map(|entry| (entry.key().clone(), entry.value().clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all worker URLs
|
||||
pub fn get_all_urls(&self) -> Vec<String> {
|
||||
self.workers
|
||||
.iter()
|
||||
.map(|entry| entry.value().url().to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all model IDs with workers
|
||||
pub fn get_models(&self) -> Vec<String> {
|
||||
self.model_workers
|
||||
.iter()
|
||||
.filter(|entry| !entry.value().is_empty())
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get workers filtered by multiple criteria
|
||||
///
|
||||
/// This method allows flexible filtering of workers based on:
|
||||
/// - model_id: Filter by specific model
|
||||
/// - worker_type: Filter by worker type (Regular, Prefill, Decode)
|
||||
/// - connection_mode: Filter by connection mode (Http, Grpc)
|
||||
/// - healthy_only: Only return healthy workers
|
||||
pub fn get_workers_filtered(
|
||||
&self,
|
||||
model_id: Option<&str>,
|
||||
worker_type: Option<WorkerType>,
|
||||
connection_mode: Option<ConnectionMode>,
|
||||
healthy_only: bool,
|
||||
) -> Vec<Arc<dyn Worker>> {
|
||||
// Start with the most efficient collection based on filters
|
||||
// Use model index when possible as it's O(1) lookup
|
||||
let workers = if let Some(model) = model_id {
|
||||
self.get_by_model_fast(model)
|
||||
} else {
|
||||
self.get_all()
|
||||
};
|
||||
|
||||
// Apply remaining filters
|
||||
workers
|
||||
.into_iter()
|
||||
.filter(|w| {
|
||||
// Check worker_type if specified
|
||||
if let Some(ref wtype) = worker_type {
|
||||
if w.worker_type() != *wtype {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check connection_mode if specified
|
||||
if let Some(ref conn) = connection_mode {
|
||||
if w.connection_mode() != *conn {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check health if required
|
||||
if healthy_only && !w.is_healthy() {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get worker statistics
|
||||
pub fn stats(&self) -> WorkerRegistryStats {
|
||||
let total_workers = self.workers.len();
|
||||
let total_models = self.get_models().len();
|
||||
|
||||
let mut healthy_count = 0;
|
||||
let mut total_load = 0;
|
||||
let mut regular_count = 0;
|
||||
let mut prefill_count = 0;
|
||||
let mut decode_count = 0;
|
||||
|
||||
for worker in self.get_all() {
|
||||
if worker.is_healthy() {
|
||||
healthy_count += 1;
|
||||
}
|
||||
total_load += worker.load();
|
||||
|
||||
match worker.worker_type() {
|
||||
WorkerType::Regular => regular_count += 1,
|
||||
WorkerType::Prefill { .. } => prefill_count += 1,
|
||||
WorkerType::Decode => decode_count += 1,
|
||||
}
|
||||
}
|
||||
|
||||
WorkerRegistryStats {
|
||||
total_workers,
|
||||
total_models,
|
||||
healthy_workers: healthy_count,
|
||||
total_load,
|
||||
regular_workers: regular_count,
|
||||
prefill_workers: prefill_count,
|
||||
decode_workers: decode_count,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a health checker for all workers in the registry
|
||||
/// This should be called once after the registry is populated with workers
|
||||
pub fn start_health_checker(&self, check_interval_secs: u64) -> crate::core::HealthChecker {
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
let shutdown = Arc::new(AtomicBool::new(false));
|
||||
let shutdown_clone = shutdown.clone();
|
||||
let workers_ref = self.workers.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let mut interval =
|
||||
tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs));
|
||||
|
||||
// Counter for periodic load reset (every 10 health check cycles)
|
||||
let mut check_count = 0u64;
|
||||
const LOAD_RESET_INTERVAL: u64 = 10;
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
// Check for shutdown signal
|
||||
if shutdown_clone.load(Ordering::Acquire) {
|
||||
tracing::debug!("Registry health checker shutting down");
|
||||
break;
|
||||
}
|
||||
|
||||
// Get all workers from registry
|
||||
let workers: Vec<Arc<dyn crate::core::Worker>> = workers_ref
|
||||
.iter()
|
||||
.map(|entry| entry.value().clone())
|
||||
.collect();
|
||||
|
||||
// Perform health checks
|
||||
for worker in &workers {
|
||||
let _ = worker.check_health_async().await; // Use async version directly
|
||||
}
|
||||
|
||||
// Reset loads periodically
|
||||
check_count += 1;
|
||||
if check_count % LOAD_RESET_INTERVAL == 0 {
|
||||
tracing::debug!("Resetting worker loads (cycle {})", check_count);
|
||||
for worker in &workers {
|
||||
worker.reset_load();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
crate::core::HealthChecker::new(handle, shutdown)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WorkerRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics for the worker registry
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WorkerRegistryStats {
|
||||
pub total_workers: usize,
|
||||
pub total_models: usize,
|
||||
pub healthy_workers: usize,
|
||||
pub total_load: usize,
|
||||
pub regular_workers: usize,
|
||||
pub prefill_workers: usize,
|
||||
pub decode_workers: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{CircuitBreakerConfig, WorkerFactory};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn test_worker_registry() {
|
||||
let registry = WorkerRegistry::new();
|
||||
|
||||
// Create a worker with labels
|
||||
let mut labels = HashMap::new();
|
||||
labels.insert("model_id".to_string(), "llama-3-8b".to_string());
|
||||
labels.insert("priority".to_string(), "50".to_string());
|
||||
labels.insert("cost".to_string(), "0.8".to_string());
|
||||
|
||||
let worker = WorkerFactory::create_regular_with_labels(
|
||||
"http://worker1:8080".to_string(),
|
||||
labels,
|
||||
CircuitBreakerConfig::default(),
|
||||
);
|
||||
|
||||
// Register worker (WorkerFactory returns Box<dyn Worker>, convert to Arc)
|
||||
let worker_id = registry.register(Arc::from(worker));
|
||||
|
||||
// Verify registration
|
||||
assert!(registry.get(&worker_id).is_some());
|
||||
assert!(registry.get_by_url("http://worker1:8080").is_some());
|
||||
assert_eq!(registry.get_by_model("llama-3-8b").len(), 1);
|
||||
assert_eq!(registry.get_by_type(&WorkerType::Regular).len(), 1);
|
||||
assert_eq!(registry.get_by_connection(&ConnectionMode::Http).len(), 1);
|
||||
|
||||
// Test stats
|
||||
let stats = registry.stats();
|
||||
assert_eq!(stats.total_workers, 1);
|
||||
assert_eq!(stats.total_models, 1);
|
||||
|
||||
// Remove worker
|
||||
registry.remove(&worker_id);
|
||||
assert!(registry.get(&worker_id).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_index_fast_lookup() {
|
||||
let registry = WorkerRegistry::new();
|
||||
|
||||
// Create workers for different models
|
||||
let mut labels1 = HashMap::new();
|
||||
labels1.insert("model_id".to_string(), "llama-3".to_string());
|
||||
let worker1 = WorkerFactory::create_regular_with_labels(
|
||||
"http://worker1:8080".to_string(),
|
||||
labels1,
|
||||
CircuitBreakerConfig::default(),
|
||||
);
|
||||
|
||||
let mut labels2 = HashMap::new();
|
||||
labels2.insert("model_id".to_string(), "llama-3".to_string());
|
||||
let worker2 = WorkerFactory::create_regular_with_labels(
|
||||
"http://worker2:8080".to_string(),
|
||||
labels2,
|
||||
CircuitBreakerConfig::default(),
|
||||
);
|
||||
|
||||
let mut labels3 = HashMap::new();
|
||||
labels3.insert("model_id".to_string(), "gpt-4".to_string());
|
||||
let worker3 = WorkerFactory::create_regular_with_labels(
|
||||
"http://worker3:8080".to_string(),
|
||||
labels3,
|
||||
CircuitBreakerConfig::default(),
|
||||
);
|
||||
|
||||
// Register workers
|
||||
registry.register(Arc::from(worker1));
|
||||
registry.register(Arc::from(worker2));
|
||||
registry.register(Arc::from(worker3));
|
||||
|
||||
// Test get_by_model_fast for llama-3
|
||||
let llama_workers = registry.get_by_model_fast("llama-3");
|
||||
assert_eq!(llama_workers.len(), 2);
|
||||
let urls: Vec<String> = llama_workers.iter().map(|w| w.url().to_string()).collect();
|
||||
assert!(urls.contains(&"http://worker1:8080".to_string()));
|
||||
assert!(urls.contains(&"http://worker2:8080".to_string()));
|
||||
|
||||
// Test get_by_model_fast for gpt-4
|
||||
let gpt_workers = registry.get_by_model_fast("gpt-4");
|
||||
assert_eq!(gpt_workers.len(), 1);
|
||||
assert_eq!(gpt_workers[0].url(), "http://worker3:8080");
|
||||
|
||||
// Test get_by_model_fast for non-existent model
|
||||
let unknown_workers = registry.get_by_model_fast("unknown-model");
|
||||
assert_eq!(unknown_workers.len(), 0);
|
||||
|
||||
// Test that both get_by_model and get_by_model_fast return same results
|
||||
let llama_workers_slow = registry.get_by_model("llama-3");
|
||||
assert_eq!(llama_workers.len(), llama_workers_slow.len());
|
||||
|
||||
// Test removal updates the model index
|
||||
registry.remove_by_url("http://worker1:8080");
|
||||
let llama_workers_after = registry.get_by_model_fast("llama-3");
|
||||
assert_eq!(llama_workers_after.len(), 1);
|
||||
assert_eq!(llama_workers_after[0].url(), "http://worker2:8080");
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user