[router] allow one router to support different model families and serving mode (#10244)
This commit is contained in:
@@ -46,6 +46,9 @@ class Router:
|
||||
max_payload_size: Maximum payload size in bytes. Default: 256MB
|
||||
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
|
||||
dp_aware: Enable data parallelism aware schedule. Default: False
|
||||
enable_igw: Enable IGW (Inference-Gateway) mode for multi-model support. When enabled,
|
||||
the router can manage multiple models simultaneously with per-model load balancing
|
||||
policies. Default: False
|
||||
api_key: The api key used for the authorization with the worker.
|
||||
Useful when the dp aware scheduling strategy is enabled.
|
||||
Default: None
|
||||
|
||||
@@ -34,6 +34,7 @@ class RouterArgs:
|
||||
max_tree_size: int = 2**26
|
||||
max_payload_size: int = 512 * 1024 * 1024 # 512MB default for large batches
|
||||
dp_aware: bool = False
|
||||
enable_igw: bool = False # Enable IGW (Inter-Gateway) mode for multi-model support
|
||||
api_key: Optional[str] = None
|
||||
log_dir: Optional[str] = None
|
||||
log_level: Optional[str] = None
|
||||
@@ -227,6 +228,11 @@ class RouterArgs:
|
||||
action="store_true",
|
||||
help="Enable data parallelism aware schedule",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}enable-igw",
|
||||
action="store_true",
|
||||
help="Enable IGW (Inference-Gateway) mode for multi-model support",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}api-key",
|
||||
type=str,
|
||||
|
||||
@@ -128,6 +128,7 @@ def _popen_launch_router_only(
|
||||
timeout: float = 120.0,
|
||||
*,
|
||||
dp_aware: bool = False,
|
||||
enable_igw: bool = False,
|
||||
api_key: str | None = None,
|
||||
) -> subprocess.Popen:
|
||||
host, port = _parse_url(base_url)
|
||||
@@ -146,6 +147,8 @@ def _popen_launch_router_only(
|
||||
]
|
||||
if dp_aware:
|
||||
cmd += ["--dp-aware"]
|
||||
if enable_igw:
|
||||
cmd += ["--enable-igw"]
|
||||
if api_key is not None:
|
||||
cmd += ["--api-key", api_key]
|
||||
cmd += [
|
||||
|
||||
@@ -35,7 +35,7 @@ def test_retry_reroutes_to_healthy_worker(router_manager, mock_workers):
|
||||
)
|
||||
assert r.status_code == 200
|
||||
wid = r.headers.get("X-Worker-Id") or r.json().get("worker_id")
|
||||
assert wid == id_b # should have retried onto healthy worker
|
||||
assert wid in [id_b, id_c] # should have retried onto a healthy worker (B or C)
|
||||
# mock_workers fixture handles cleanup
|
||||
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ pub mod error;
|
||||
pub mod retry;
|
||||
pub mod token_bucket;
|
||||
pub mod worker;
|
||||
pub mod worker_registry;
|
||||
|
||||
// Re-export commonly used types at the module level
|
||||
pub use circuit_breaker::{
|
||||
@@ -22,3 +23,4 @@ pub use worker::{
|
||||
start_health_checker, BasicWorker, ConnectionMode, DPAwareWorker, HealthChecker, HealthConfig,
|
||||
Worker, WorkerCollection, WorkerFactory, WorkerLoadGuard, WorkerType,
|
||||
};
|
||||
pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats};
|
||||
|
||||
@@ -155,6 +155,82 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
||||
fn can_handle(&self, _req: &serde_json::Value) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
// === Multi-router support ===
|
||||
|
||||
// TODO: - Enhanced Worker Discovery
|
||||
// The Worker trait should handle async discovery of metadata from the worker itself
|
||||
// rather than having service discovery or other components query /get_server_info.
|
||||
// This keeps service discovery decoupled from worker-specific APIs.
|
||||
//
|
||||
// Proposed additions:
|
||||
// - async fn discover_metadata(&mut self) -> Result<(), Error>
|
||||
// Query /get_server_info and populate metadata labels with model_id, priority, cost, etc.
|
||||
// - async fn validate_configuration(&self) -> Result<(), Error>
|
||||
// Ensure worker has required configuration for its mode (e.g., tokenizer for gRPC)
|
||||
// - Make worker creation async to allow metadata discovery during initialization
|
||||
//
|
||||
// This way service discovery just calls router.add_worker() and the worker
|
||||
// handles its own metadata discovery internally.
|
||||
|
||||
/// Get the model ID this worker serves
|
||||
fn model_id(&self) -> &str {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("model_id")
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("unknown")
|
||||
}
|
||||
|
||||
/// Get the priority of this worker (higher value = higher priority)
|
||||
fn priority(&self) -> u32 {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("priority")
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(50) // Default priority is 50 (mid-range)
|
||||
}
|
||||
|
||||
/// Get the cost factor of this worker (1.0 = baseline)
|
||||
fn cost(&self) -> f32 {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("cost")
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1.0)
|
||||
}
|
||||
|
||||
/// Get the tokenizer path for this worker (gRPC mode only)
|
||||
fn tokenizer_path(&self) -> Option<&str> {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("tokenizer_path")
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
|
||||
/// Get the reasoning parser type for this worker (gRPC mode only)
|
||||
fn reasoning_parser(&self) -> Option<&str> {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("reasoning_parser")
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
|
||||
/// Get the tool parser type for this worker (gRPC mode only)
|
||||
fn tool_parser(&self) -> Option<&str> {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("tool_parser")
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
|
||||
/// Get the chat template for this worker (gRPC mode only)
|
||||
fn chat_template(&self) -> Option<&str> {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("chat_template")
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection mode for worker communication
|
||||
@@ -724,6 +800,21 @@ impl WorkerFactory {
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a regular worker with custom labels (for multi-router support)
|
||||
pub fn create_regular_with_labels(
|
||||
url: String,
|
||||
labels: std::collections::HashMap<String, String>,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
) -> Box<dyn Worker> {
|
||||
let mut worker = BasicWorker::new(url.clone(), WorkerType::Regular)
|
||||
.with_circuit_breaker_config(circuit_breaker_config);
|
||||
|
||||
// Add labels to metadata
|
||||
worker.metadata.labels = labels;
|
||||
|
||||
Box::new(worker)
|
||||
}
|
||||
|
||||
/// Create a DP-aware worker of specified type
|
||||
pub fn create_dp_aware(
|
||||
base_url: String,
|
||||
@@ -941,6 +1032,11 @@ impl fmt::Debug for HealthChecker {
|
||||
}
|
||||
|
||||
impl HealthChecker {
|
||||
/// Create a new HealthChecker
|
||||
pub fn new(handle: tokio::task::JoinHandle<()>, shutdown: Arc<AtomicBool>) -> Self {
|
||||
Self { handle, shutdown }
|
||||
}
|
||||
|
||||
/// Shutdown the health checker gracefully
|
||||
pub async fn shutdown(self) {
|
||||
self.shutdown.store(true, Ordering::Release);
|
||||
@@ -950,7 +1046,7 @@ impl HealthChecker {
|
||||
|
||||
/// Start an async background health checker for a collection of workers
|
||||
pub fn start_health_checker(
|
||||
workers: std::sync::Arc<std::sync::RwLock<Vec<Box<dyn Worker>>>>,
|
||||
workers: std::sync::Arc<std::sync::RwLock<Vec<std::sync::Arc<dyn Worker>>>>,
|
||||
check_interval_secs: u64,
|
||||
) -> HealthChecker {
|
||||
let shutdown = Arc::new(AtomicBool::new(false));
|
||||
@@ -1602,9 +1698,11 @@ mod tests {
|
||||
// Test HealthChecker background task
|
||||
#[tokio::test]
|
||||
async fn test_health_checker_startup() {
|
||||
let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular(
|
||||
let worker = Arc::new(BasicWorker::new(
|
||||
"http://w1:8080".to_string(),
|
||||
)]));
|
||||
WorkerType::Regular,
|
||||
)) as Arc<dyn Worker>;
|
||||
let workers = Arc::new(RwLock::new(vec![worker]));
|
||||
|
||||
let checker = start_health_checker(workers.clone(), 60);
|
||||
|
||||
@@ -1617,9 +1715,11 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_checker_shutdown() {
|
||||
let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular(
|
||||
let worker = Arc::new(BasicWorker::new(
|
||||
"http://w1:8080".to_string(),
|
||||
)]));
|
||||
WorkerType::Regular,
|
||||
)) as Arc<dyn Worker>;
|
||||
let workers = Arc::new(RwLock::new(vec![worker]));
|
||||
|
||||
let checker = start_health_checker(workers.clone(), 60);
|
||||
|
||||
|
||||
526
sgl-router/src/core/worker_registry.rs
Normal file
526
sgl-router/src/core/worker_registry.rs
Normal file
@@ -0,0 +1,526 @@
|
||||
//! Worker Registry for multi-router support
|
||||
//!
|
||||
//! Provides centralized registry for workers with model-based indexing
|
||||
|
||||
use crate::core::{ConnectionMode, Worker, WorkerType};
|
||||
use dashmap::DashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Unique identifier for a worker
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct WorkerId(String);
|
||||
|
||||
impl WorkerId {
|
||||
/// Create a new worker ID
|
||||
pub fn new() -> Self {
|
||||
Self(Uuid::new_v4().to_string())
|
||||
}
|
||||
|
||||
/// Create a worker ID from a string
|
||||
pub fn from_string(s: String) -> Self {
|
||||
Self(s)
|
||||
}
|
||||
|
||||
/// Get the ID as a string
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WorkerId {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Type alias for the model index to reduce complexity
|
||||
type ModelIndex = Arc<DashMap<String, Arc<RwLock<Vec<Arc<dyn Worker>>>>>>;
|
||||
|
||||
/// Worker registry with model-based indexing
|
||||
#[derive(Debug)]
|
||||
pub struct WorkerRegistry {
|
||||
/// All workers indexed by ID
|
||||
workers: Arc<DashMap<WorkerId, Arc<dyn Worker>>>,
|
||||
|
||||
/// Workers indexed by model ID (stores WorkerId for reference)
|
||||
model_workers: Arc<DashMap<String, Vec<WorkerId>>>,
|
||||
|
||||
/// Optimized model index for O(1) lookups (stores Arc<dyn Worker> directly)
|
||||
model_index: ModelIndex,
|
||||
|
||||
/// Workers indexed by worker type
|
||||
type_workers: Arc<DashMap<WorkerType, Vec<WorkerId>>>,
|
||||
|
||||
/// Workers indexed by connection mode
|
||||
connection_workers: Arc<DashMap<ConnectionMode, Vec<WorkerId>>>,
|
||||
|
||||
/// URL to worker ID mapping (for backward compatibility)
|
||||
url_to_id: Arc<DashMap<String, WorkerId>>,
|
||||
}
|
||||
|
||||
impl WorkerRegistry {
|
||||
/// Create a new worker registry
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
workers: Arc::new(DashMap::new()),
|
||||
model_workers: Arc::new(DashMap::new()),
|
||||
model_index: Arc::new(DashMap::new()),
|
||||
type_workers: Arc::new(DashMap::new()),
|
||||
connection_workers: Arc::new(DashMap::new()),
|
||||
url_to_id: Arc::new(DashMap::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a new worker
|
||||
pub fn register(&self, worker: Arc<dyn Worker>) -> WorkerId {
|
||||
let worker_id = if let Some(existing_id) = self.url_to_id.get(worker.url()) {
|
||||
// Worker with this URL already exists, update it
|
||||
existing_id.clone()
|
||||
} else {
|
||||
WorkerId::new()
|
||||
};
|
||||
|
||||
// Store worker
|
||||
self.workers.insert(worker_id.clone(), worker.clone());
|
||||
|
||||
// Update URL mapping
|
||||
self.url_to_id
|
||||
.insert(worker.url().to_string(), worker_id.clone());
|
||||
|
||||
// Update model index (both ID-based and optimized)
|
||||
let model_id = worker.model_id().to_string();
|
||||
self.model_workers
|
||||
.entry(model_id.clone())
|
||||
.or_default()
|
||||
.push(worker_id.clone());
|
||||
|
||||
// Update optimized model index for O(1) lookups
|
||||
self.model_index
|
||||
.entry(model_id)
|
||||
.or_insert_with(|| Arc::new(RwLock::new(Vec::new())))
|
||||
.write()
|
||||
.expect("RwLock for model_index is poisoned")
|
||||
.push(worker.clone());
|
||||
|
||||
// Update type index
|
||||
self.type_workers
|
||||
.entry(worker.worker_type())
|
||||
.or_default()
|
||||
.push(worker_id.clone());
|
||||
|
||||
// Update connection mode index
|
||||
self.connection_workers
|
||||
.entry(worker.connection_mode())
|
||||
.or_default()
|
||||
.push(worker_id.clone());
|
||||
|
||||
worker_id
|
||||
}
|
||||
|
||||
/// Remove a worker by ID
|
||||
pub fn remove(&self, worker_id: &WorkerId) -> Option<Arc<dyn Worker>> {
|
||||
if let Some((_, worker)) = self.workers.remove(worker_id) {
|
||||
// Remove from URL mapping
|
||||
self.url_to_id.remove(worker.url());
|
||||
|
||||
// Remove from model index (both ID-based and optimized)
|
||||
if let Some(mut model_workers) = self.model_workers.get_mut(worker.model_id()) {
|
||||
model_workers.retain(|id| id != worker_id);
|
||||
}
|
||||
|
||||
// Remove from optimized model index
|
||||
if let Some(model_index_entry) = self.model_index.get(worker.model_id()) {
|
||||
let worker_url = worker.url();
|
||||
model_index_entry
|
||||
.write()
|
||||
.expect("RwLock for model_index is poisoned")
|
||||
.retain(|w| w.url() != worker_url);
|
||||
}
|
||||
|
||||
// Remove from type index
|
||||
if let Some(mut type_workers) = self.type_workers.get_mut(&worker.worker_type()) {
|
||||
type_workers.retain(|id| id != worker_id);
|
||||
}
|
||||
|
||||
// Remove from connection mode index
|
||||
if let Some(mut conn_workers) =
|
||||
self.connection_workers.get_mut(&worker.connection_mode())
|
||||
{
|
||||
conn_workers.retain(|id| id != worker_id);
|
||||
}
|
||||
|
||||
Some(worker)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker by URL
|
||||
pub fn remove_by_url(&self, url: &str) -> Option<Arc<dyn Worker>> {
|
||||
if let Some((_, worker_id)) = self.url_to_id.remove(url) {
|
||||
self.remove(&worker_id)
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get a worker by ID
|
||||
pub fn get(&self, worker_id: &WorkerId) -> Option<Arc<dyn Worker>> {
|
||||
self.workers.get(worker_id).map(|entry| entry.clone())
|
||||
}
|
||||
|
||||
/// Get a worker by URL
|
||||
pub fn get_by_url(&self, url: &str) -> Option<Arc<dyn Worker>> {
|
||||
self.url_to_id.get(url).and_then(|id| self.get(&id))
|
||||
}
|
||||
|
||||
/// Get all workers for a model
|
||||
pub fn get_by_model(&self, model_id: &str) -> Vec<Arc<dyn Worker>> {
|
||||
self.model_workers
|
||||
.get(model_id)
|
||||
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all workers for a model (O(1) optimized version)
|
||||
/// This method uses the pre-indexed model_index for fast lookups
|
||||
pub fn get_by_model_fast(&self, model_id: &str) -> Vec<Arc<dyn Worker>> {
|
||||
self.model_index
|
||||
.get(model_id)
|
||||
.map(|workers| {
|
||||
workers
|
||||
.read()
|
||||
.expect("RwLock for model_index is poisoned")
|
||||
.clone()
|
||||
})
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all workers by worker type
|
||||
pub fn get_by_type(&self, worker_type: &WorkerType) -> Vec<Arc<dyn Worker>> {
|
||||
self.type_workers
|
||||
.get(worker_type)
|
||||
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all prefill workers (regardless of bootstrap_port)
|
||||
pub fn get_prefill_workers(&self) -> Vec<Arc<dyn Worker>> {
|
||||
self.workers
|
||||
.iter()
|
||||
.filter_map(|entry| {
|
||||
let worker = entry.value();
|
||||
match worker.worker_type() {
|
||||
WorkerType::Prefill { .. } => Some(worker.clone()),
|
||||
_ => None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all decode workers
|
||||
pub fn get_decode_workers(&self) -> Vec<Arc<dyn Worker>> {
|
||||
self.get_by_type(&WorkerType::Decode)
|
||||
}
|
||||
|
||||
/// Get all workers by connection mode
|
||||
pub fn get_by_connection(&self, connection_mode: &ConnectionMode) -> Vec<Arc<dyn Worker>> {
|
||||
self.connection_workers
|
||||
.get(connection_mode)
|
||||
.map(|ids| ids.iter().filter_map(|id| self.get(id)).collect())
|
||||
.unwrap_or_default()
|
||||
}
|
||||
|
||||
/// Get all workers
|
||||
pub fn get_all(&self) -> Vec<Arc<dyn Worker>> {
|
||||
self.workers
|
||||
.iter()
|
||||
.map(|entry| entry.value().clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all workers with their IDs
|
||||
pub fn get_all_with_ids(&self) -> Vec<(WorkerId, Arc<dyn Worker>)> {
|
||||
self.workers
|
||||
.iter()
|
||||
.map(|entry| (entry.key().clone(), entry.value().clone()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all worker URLs
|
||||
pub fn get_all_urls(&self) -> Vec<String> {
|
||||
self.workers
|
||||
.iter()
|
||||
.map(|entry| entry.value().url().to_string())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get all model IDs with workers
|
||||
pub fn get_models(&self) -> Vec<String> {
|
||||
self.model_workers
|
||||
.iter()
|
||||
.filter(|entry| !entry.value().is_empty())
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get workers filtered by multiple criteria
|
||||
///
|
||||
/// This method allows flexible filtering of workers based on:
|
||||
/// - model_id: Filter by specific model
|
||||
/// - worker_type: Filter by worker type (Regular, Prefill, Decode)
|
||||
/// - connection_mode: Filter by connection mode (Http, Grpc)
|
||||
/// - healthy_only: Only return healthy workers
|
||||
pub fn get_workers_filtered(
|
||||
&self,
|
||||
model_id: Option<&str>,
|
||||
worker_type: Option<WorkerType>,
|
||||
connection_mode: Option<ConnectionMode>,
|
||||
healthy_only: bool,
|
||||
) -> Vec<Arc<dyn Worker>> {
|
||||
// Start with the most efficient collection based on filters
|
||||
// Use model index when possible as it's O(1) lookup
|
||||
let workers = if let Some(model) = model_id {
|
||||
self.get_by_model_fast(model)
|
||||
} else {
|
||||
self.get_all()
|
||||
};
|
||||
|
||||
// Apply remaining filters
|
||||
workers
|
||||
.into_iter()
|
||||
.filter(|w| {
|
||||
// Check worker_type if specified
|
||||
if let Some(ref wtype) = worker_type {
|
||||
if w.worker_type() != *wtype {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check connection_mode if specified
|
||||
if let Some(ref conn) = connection_mode {
|
||||
if w.connection_mode() != *conn {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// Check health if required
|
||||
if healthy_only && !w.is_healthy() {
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get worker statistics
|
||||
pub fn stats(&self) -> WorkerRegistryStats {
|
||||
let total_workers = self.workers.len();
|
||||
let total_models = self.get_models().len();
|
||||
|
||||
let mut healthy_count = 0;
|
||||
let mut total_load = 0;
|
||||
let mut regular_count = 0;
|
||||
let mut prefill_count = 0;
|
||||
let mut decode_count = 0;
|
||||
|
||||
for worker in self.get_all() {
|
||||
if worker.is_healthy() {
|
||||
healthy_count += 1;
|
||||
}
|
||||
total_load += worker.load();
|
||||
|
||||
match worker.worker_type() {
|
||||
WorkerType::Regular => regular_count += 1,
|
||||
WorkerType::Prefill { .. } => prefill_count += 1,
|
||||
WorkerType::Decode => decode_count += 1,
|
||||
}
|
||||
}
|
||||
|
||||
WorkerRegistryStats {
|
||||
total_workers,
|
||||
total_models,
|
||||
healthy_workers: healthy_count,
|
||||
total_load,
|
||||
regular_workers: regular_count,
|
||||
prefill_workers: prefill_count,
|
||||
decode_workers: decode_count,
|
||||
}
|
||||
}
|
||||
|
||||
/// Start a health checker for all workers in the registry
|
||||
/// This should be called once after the registry is populated with workers
|
||||
pub fn start_health_checker(&self, check_interval_secs: u64) -> crate::core::HealthChecker {
|
||||
use std::sync::atomic::{AtomicBool, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
let shutdown = Arc::new(AtomicBool::new(false));
|
||||
let shutdown_clone = shutdown.clone();
|
||||
let workers_ref = self.workers.clone();
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
let mut interval =
|
||||
tokio::time::interval(tokio::time::Duration::from_secs(check_interval_secs));
|
||||
|
||||
// Counter for periodic load reset (every 10 health check cycles)
|
||||
let mut check_count = 0u64;
|
||||
const LOAD_RESET_INTERVAL: u64 = 10;
|
||||
|
||||
loop {
|
||||
interval.tick().await;
|
||||
|
||||
// Check for shutdown signal
|
||||
if shutdown_clone.load(Ordering::Acquire) {
|
||||
tracing::debug!("Registry health checker shutting down");
|
||||
break;
|
||||
}
|
||||
|
||||
// Get all workers from registry
|
||||
let workers: Vec<Arc<dyn crate::core::Worker>> = workers_ref
|
||||
.iter()
|
||||
.map(|entry| entry.value().clone())
|
||||
.collect();
|
||||
|
||||
// Perform health checks
|
||||
for worker in &workers {
|
||||
let _ = worker.check_health_async().await; // Use async version directly
|
||||
}
|
||||
|
||||
// Reset loads periodically
|
||||
check_count += 1;
|
||||
if check_count % LOAD_RESET_INTERVAL == 0 {
|
||||
tracing::debug!("Resetting worker loads (cycle {})", check_count);
|
||||
for worker in &workers {
|
||||
worker.reset_load();
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
crate::core::HealthChecker::new(handle, shutdown)
|
||||
}
|
||||
}
|
||||
|
||||
impl Default for WorkerRegistry {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
}
|
||||
|
||||
/// Statistics for the worker registry
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct WorkerRegistryStats {
|
||||
pub total_workers: usize,
|
||||
pub total_models: usize,
|
||||
pub healthy_workers: usize,
|
||||
pub total_load: usize,
|
||||
pub regular_workers: usize,
|
||||
pub prefill_workers: usize,
|
||||
pub decode_workers: usize,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::core::{CircuitBreakerConfig, WorkerFactory};
|
||||
use std::collections::HashMap;
|
||||
|
||||
#[test]
|
||||
fn test_worker_registry() {
|
||||
let registry = WorkerRegistry::new();
|
||||
|
||||
// Create a worker with labels
|
||||
let mut labels = HashMap::new();
|
||||
labels.insert("model_id".to_string(), "llama-3-8b".to_string());
|
||||
labels.insert("priority".to_string(), "50".to_string());
|
||||
labels.insert("cost".to_string(), "0.8".to_string());
|
||||
|
||||
let worker = WorkerFactory::create_regular_with_labels(
|
||||
"http://worker1:8080".to_string(),
|
||||
labels,
|
||||
CircuitBreakerConfig::default(),
|
||||
);
|
||||
|
||||
// Register worker (WorkerFactory returns Box<dyn Worker>, convert to Arc)
|
||||
let worker_id = registry.register(Arc::from(worker));
|
||||
|
||||
// Verify registration
|
||||
assert!(registry.get(&worker_id).is_some());
|
||||
assert!(registry.get_by_url("http://worker1:8080").is_some());
|
||||
assert_eq!(registry.get_by_model("llama-3-8b").len(), 1);
|
||||
assert_eq!(registry.get_by_type(&WorkerType::Regular).len(), 1);
|
||||
assert_eq!(registry.get_by_connection(&ConnectionMode::Http).len(), 1);
|
||||
|
||||
// Test stats
|
||||
let stats = registry.stats();
|
||||
assert_eq!(stats.total_workers, 1);
|
||||
assert_eq!(stats.total_models, 1);
|
||||
|
||||
// Remove worker
|
||||
registry.remove(&worker_id);
|
||||
assert!(registry.get(&worker_id).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_model_index_fast_lookup() {
|
||||
let registry = WorkerRegistry::new();
|
||||
|
||||
// Create workers for different models
|
||||
let mut labels1 = HashMap::new();
|
||||
labels1.insert("model_id".to_string(), "llama-3".to_string());
|
||||
let worker1 = WorkerFactory::create_regular_with_labels(
|
||||
"http://worker1:8080".to_string(),
|
||||
labels1,
|
||||
CircuitBreakerConfig::default(),
|
||||
);
|
||||
|
||||
let mut labels2 = HashMap::new();
|
||||
labels2.insert("model_id".to_string(), "llama-3".to_string());
|
||||
let worker2 = WorkerFactory::create_regular_with_labels(
|
||||
"http://worker2:8080".to_string(),
|
||||
labels2,
|
||||
CircuitBreakerConfig::default(),
|
||||
);
|
||||
|
||||
let mut labels3 = HashMap::new();
|
||||
labels3.insert("model_id".to_string(), "gpt-4".to_string());
|
||||
let worker3 = WorkerFactory::create_regular_with_labels(
|
||||
"http://worker3:8080".to_string(),
|
||||
labels3,
|
||||
CircuitBreakerConfig::default(),
|
||||
);
|
||||
|
||||
// Register workers
|
||||
registry.register(Arc::from(worker1));
|
||||
registry.register(Arc::from(worker2));
|
||||
registry.register(Arc::from(worker3));
|
||||
|
||||
// Test get_by_model_fast for llama-3
|
||||
let llama_workers = registry.get_by_model_fast("llama-3");
|
||||
assert_eq!(llama_workers.len(), 2);
|
||||
let urls: Vec<String> = llama_workers.iter().map(|w| w.url().to_string()).collect();
|
||||
assert!(urls.contains(&"http://worker1:8080".to_string()));
|
||||
assert!(urls.contains(&"http://worker2:8080".to_string()));
|
||||
|
||||
// Test get_by_model_fast for gpt-4
|
||||
let gpt_workers = registry.get_by_model_fast("gpt-4");
|
||||
assert_eq!(gpt_workers.len(), 1);
|
||||
assert_eq!(gpt_workers[0].url(), "http://worker3:8080");
|
||||
|
||||
// Test get_by_model_fast for non-existent model
|
||||
let unknown_workers = registry.get_by_model_fast("unknown-model");
|
||||
assert_eq!(unknown_workers.len(), 0);
|
||||
|
||||
// Test that both get_by_model and get_by_model_fast return same results
|
||||
let llama_workers_slow = registry.get_by_model("llama-3");
|
||||
assert_eq!(llama_workers.len(), llama_workers_slow.len());
|
||||
|
||||
// Test removal updates the model index
|
||||
registry.remove_by_url("http://worker1:8080");
|
||||
let llama_workers_after = registry.get_by_model_fast("llama-3");
|
||||
assert_eq!(llama_workers_after.len(), 1);
|
||||
assert_eq!(llama_workers_after[0].url(), "http://worker2:8080");
|
||||
}
|
||||
}
|
||||
@@ -63,6 +63,7 @@ use super::{get_healthy_worker_indices, CacheAwareConfig, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::tree::Tree;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, Mutex};
|
||||
use std::thread;
|
||||
use std::time::Duration;
|
||||
@@ -72,10 +73,11 @@ use tracing::debug;
|
||||
///
|
||||
/// Routes requests based on cache affinity when load is balanced,
|
||||
/// switches to shortest-queue routing when load is imbalanced.
|
||||
/// Maintains separate trees per model for multi-model support.
|
||||
#[derive(Debug)]
|
||||
pub struct CacheAwarePolicy {
|
||||
config: CacheAwareConfig,
|
||||
tree: Arc<Mutex<Tree>>,
|
||||
trees: Arc<Mutex<HashMap<String, Tree>>>, // model_id -> Tree
|
||||
eviction_handle: Option<thread::JoinHandle<()>>,
|
||||
}
|
||||
|
||||
@@ -85,20 +87,26 @@ impl CacheAwarePolicy {
|
||||
}
|
||||
|
||||
pub fn with_config(config: CacheAwareConfig) -> Self {
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
let trees = Arc::new(Mutex::new(HashMap::<String, Tree>::new()));
|
||||
|
||||
// Start background eviction thread if configured
|
||||
let eviction_handle = if config.eviction_interval_secs > 0 {
|
||||
let tree_clone = Arc::clone(&tree);
|
||||
let trees_clone = Arc::clone(&trees);
|
||||
let max_tree_size = config.max_tree_size;
|
||||
let interval = config.eviction_interval_secs;
|
||||
|
||||
Some(thread::spawn(move || loop {
|
||||
thread::sleep(Duration::from_secs(interval));
|
||||
|
||||
if let Ok(tree_guard) = tree_clone.lock() {
|
||||
tree_guard.evict_tenant_by_size(max_tree_size);
|
||||
debug!("Cache eviction completed, max_size: {}", max_tree_size);
|
||||
if let Ok(mut trees_guard) = trees_clone.lock() {
|
||||
// Evict for all model trees
|
||||
for (model_id, tree) in trees_guard.iter_mut() {
|
||||
tree.evict_tenant_by_size(max_tree_size);
|
||||
debug!(
|
||||
"Cache eviction completed for model {}, max_size: {}",
|
||||
model_id, max_tree_size
|
||||
);
|
||||
}
|
||||
}
|
||||
}))
|
||||
} else {
|
||||
@@ -107,38 +115,97 @@ impl CacheAwarePolicy {
|
||||
|
||||
Self {
|
||||
config,
|
||||
tree,
|
||||
trees,
|
||||
eviction_handle,
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize the tree with worker URLs (used only during initial setup)
|
||||
pub fn init_workers(&self, workers: &[Box<dyn Worker>]) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
pub fn init_workers(&self, workers: &[Arc<dyn Worker>]) {
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
// Group workers by model
|
||||
let mut model_workers: HashMap<String, Vec<&Arc<dyn Worker>>> = HashMap::new();
|
||||
for worker in workers {
|
||||
tree.insert("", worker.url());
|
||||
// Use "default" for unknown/empty model_ids for backward compatibility
|
||||
let model_id = worker.model_id();
|
||||
let tree_key = if model_id.is_empty() || model_id == "unknown" {
|
||||
"default".to_string()
|
||||
} else {
|
||||
model_id.to_string()
|
||||
};
|
||||
model_workers.entry(tree_key).or_default().push(worker);
|
||||
}
|
||||
|
||||
// Initialize tree for each model
|
||||
for (tree_key, model_workers) in model_workers {
|
||||
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
|
||||
for worker in model_workers {
|
||||
tree.insert("", worker.url());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a single worker to the tree (incremental update)
|
||||
pub fn add_worker(&self, url: &str) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
pub fn add_worker(&self, worker: &dyn Worker) {
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
// For backward compatibility: if model_id is "unknown" or empty,
|
||||
// use a default tree. This preserves existing behavior for single-model routers.
|
||||
let model_id = worker.model_id();
|
||||
let tree_key = if model_id.is_empty() || model_id == "unknown" {
|
||||
"default".to_string()
|
||||
} else {
|
||||
model_id.to_string()
|
||||
};
|
||||
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
|
||||
tree.insert("", worker.url());
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a worker by URL and model (for backward compatibility)
|
||||
pub fn add_worker_by_url(&self, url: &str, model_id: &str) {
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
let tree = trees.entry(model_id.to_string()).or_insert_with(Tree::new);
|
||||
tree.insert("", url);
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker from the tree
|
||||
pub fn remove_worker(&self, url: &str) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
tree.remove_tenant(url);
|
||||
pub fn remove_worker(&self, worker: &dyn Worker) {
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
// Use same logic as add_worker for consistency
|
||||
let model_id = worker.model_id();
|
||||
let tree_key = if model_id.is_empty() || model_id == "unknown" {
|
||||
"default".to_string()
|
||||
} else {
|
||||
model_id.to_string()
|
||||
};
|
||||
if let Some(tree) = trees.get_mut(&tree_key) {
|
||||
tree.remove_tenant(worker.url());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker by URL (removes from all model trees for backward compatibility)
|
||||
pub fn remove_worker_by_url(&self, url: &str) {
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
// Remove from all trees since we don't know which model it belongs to
|
||||
for (_model_id, tree) in trees.iter_mut() {
|
||||
tree.remove_tenant(url);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Run cache eviction to prevent unbounded growth
|
||||
pub fn evict_cache(&self, max_size: usize) {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
tree.evict_tenant_by_size(max_size);
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
for (model_id, tree) in trees.iter_mut() {
|
||||
tree.evict_tenant_by_size(max_size);
|
||||
debug!(
|
||||
"Cache eviction for model {}, max_size: {}",
|
||||
model_id, max_size
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -146,7 +213,7 @@ impl CacheAwarePolicy {
|
||||
impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
workers: &[Arc<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
@@ -155,6 +222,18 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Group workers by model (using "default" for unknown/empty model_ids)
|
||||
let mut model_workers: HashMap<String, Vec<usize>> = HashMap::new();
|
||||
for idx in &healthy_indices {
|
||||
let model_id = workers[*idx].model_id();
|
||||
let tree_key = if model_id.is_empty() || model_id == "unknown" {
|
||||
"default".to_string()
|
||||
} else {
|
||||
model_id.to_string()
|
||||
};
|
||||
model_workers.entry(tree_key).or_default().push(*idx);
|
||||
}
|
||||
|
||||
// Get current load statistics
|
||||
let loads: Vec<usize> = workers.iter().map(|w| w.load()).collect();
|
||||
let max_load = *loads.iter().max().unwrap_or(&0);
|
||||
@@ -187,7 +266,14 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
|
||||
// Even in imbalanced mode, update the tree to maintain cache state
|
||||
if let Some(text) = request_text {
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
let model_id = workers[min_load_idx].model_id();
|
||||
let tree_key = if model_id.is_empty() || model_id == "unknown" {
|
||||
"default".to_string()
|
||||
} else {
|
||||
model_id.to_string()
|
||||
};
|
||||
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
|
||||
tree.insert(text, workers[min_load_idx].url());
|
||||
}
|
||||
}
|
||||
@@ -203,43 +289,85 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
// Use cache-aware routing when balanced
|
||||
let text = request_text.unwrap_or("");
|
||||
|
||||
if let Ok(tree) = self.tree.lock() {
|
||||
let (matched_text, matched_worker) = tree.prefix_match(text);
|
||||
let match_rate = if text.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
matched_text.chars().count() as f32 / text.chars().count() as f32
|
||||
};
|
||||
if let Ok(mut trees) = self.trees.lock() {
|
||||
let mut best_match_idx: Option<usize> = None;
|
||||
let mut best_match_rate: f32 = 0.0;
|
||||
|
||||
let selected_url = if match_rate > self.config.cache_threshold {
|
||||
RouterMetrics::record_cache_hit();
|
||||
matched_worker.to_string()
|
||||
} else {
|
||||
RouterMetrics::record_cache_miss();
|
||||
tree.get_smallest_tenant()
|
||||
};
|
||||
// Find best match across all models
|
||||
for (model_id, worker_indices) in &model_workers {
|
||||
let tree = trees.entry(model_id.clone()).or_insert_with(Tree::new);
|
||||
|
||||
// Find the index of the selected worker
|
||||
if let Some(selected_idx) = workers.iter().position(|w| w.url() == selected_url) {
|
||||
// Only proceed if the worker is healthy
|
||||
if workers[selected_idx].is_healthy() {
|
||||
// Update the tree with this request
|
||||
tree.insert(text, &selected_url);
|
||||
let (matched_text, matched_worker) = tree.prefix_match(text);
|
||||
let match_rate = if text.is_empty() {
|
||||
0.0
|
||||
} else {
|
||||
matched_text.chars().count() as f32 / text.chars().count() as f32
|
||||
};
|
||||
|
||||
// Increment processed counter
|
||||
workers[selected_idx].increment_processed();
|
||||
RouterMetrics::record_processed_request(&selected_url);
|
||||
|
||||
return Some(selected_idx);
|
||||
// Check if this model has the best match
|
||||
if match_rate > best_match_rate {
|
||||
// Find the worker index for this URL
|
||||
if let Some(idx) = worker_indices
|
||||
.iter()
|
||||
.find(|&&idx| workers[idx].url() == matched_worker)
|
||||
{
|
||||
best_match_idx = Some(*idx);
|
||||
best_match_rate = match_rate;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Selected worker no longer exists, remove it from tree
|
||||
tree.remove_tenant(&selected_url);
|
||||
debug!("Removed stale worker {} from cache tree", selected_url);
|
||||
}
|
||||
|
||||
// Fallback to first healthy worker
|
||||
return healthy_indices.first().copied();
|
||||
// Select worker based on cache threshold
|
||||
let selected_idx = if let (Some(idx), true) = (
|
||||
best_match_idx,
|
||||
best_match_rate > self.config.cache_threshold,
|
||||
) {
|
||||
RouterMetrics::record_cache_hit();
|
||||
idx
|
||||
} else {
|
||||
RouterMetrics::record_cache_miss();
|
||||
|
||||
// Find model with smallest tree (most cache capacity)
|
||||
let mut smallest_tree_model = String::new();
|
||||
let mut smallest_tree_size = usize::MAX;
|
||||
|
||||
for model_id in model_workers.keys() {
|
||||
let tree = trees.entry(model_id.clone()).or_insert_with(Tree::new);
|
||||
let size = tree.get_used_size_per_tenant().values().sum::<usize>();
|
||||
if size < smallest_tree_size {
|
||||
smallest_tree_size = size;
|
||||
smallest_tree_model = model_id.clone();
|
||||
}
|
||||
}
|
||||
|
||||
// Select least loaded worker from model with most cache capacity
|
||||
if let Some(worker_indices) = model_workers.get(&smallest_tree_model) {
|
||||
worker_indices
|
||||
.iter()
|
||||
.min_by_key(|&&idx| workers[idx].load())
|
||||
.copied()
|
||||
.unwrap_or(healthy_indices[0])
|
||||
} else {
|
||||
healthy_indices[0]
|
||||
}
|
||||
};
|
||||
|
||||
// Update the tree with this request
|
||||
let model_id = workers[selected_idx].model_id();
|
||||
let tree_key = if model_id.is_empty() || model_id == "unknown" {
|
||||
"default".to_string()
|
||||
} else {
|
||||
model_id.to_string()
|
||||
};
|
||||
let tree = trees.entry(tree_key).or_insert_with(Tree::new);
|
||||
tree.insert(text, workers[selected_idx].url());
|
||||
|
||||
// Increment processed counter
|
||||
workers[selected_idx].increment_processed();
|
||||
RouterMetrics::record_processed_request(workers[selected_idx].url());
|
||||
RouterMetrics::record_policy_decision(self.name(), workers[selected_idx].url());
|
||||
|
||||
return Some(selected_idx);
|
||||
}
|
||||
|
||||
// Fallback to first healthy worker if tree operations fail
|
||||
@@ -272,8 +400,8 @@ impl LoadBalancingPolicy for CacheAwarePolicy {
|
||||
|
||||
fn select_worker_pair(
|
||||
&self,
|
||||
prefill_workers: &[Box<dyn Worker>],
|
||||
decode_workers: &[Box<dyn Worker>],
|
||||
prefill_workers: &[Arc<dyn Worker>],
|
||||
decode_workers: &[Arc<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<(usize, usize)> {
|
||||
// DEPRECATED: This method is no longer used when separate policies are configured.
|
||||
@@ -333,12 +461,12 @@ mod tests {
|
||||
..Default::default()
|
||||
};
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
@@ -378,7 +506,7 @@ mod tests {
|
||||
}
|
||||
// worker2 has load 0
|
||||
|
||||
let workers: Vec<Box<dyn Worker>> = vec![Box::new(worker1), Box::new(worker2)];
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker1), Arc::new(worker2)];
|
||||
policy.init_workers(&workers);
|
||||
|
||||
// Should select worker2 (lower load) despite cache affinity
|
||||
@@ -395,12 +523,12 @@ mod tests {
|
||||
..Default::default()
|
||||
};
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
@@ -413,7 +541,7 @@ mod tests {
|
||||
policy.select_worker(&workers, Some("test2"));
|
||||
|
||||
// Remove a worker
|
||||
policy.remove_worker("http://w1:8000");
|
||||
policy.remove_worker_by_url("http://w1:8000");
|
||||
workers[0].set_healthy(false);
|
||||
|
||||
// All requests should now go to worker2
|
||||
|
||||
@@ -5,17 +5,20 @@
|
||||
|
||||
use crate::core::Worker;
|
||||
use std::fmt::Debug;
|
||||
use std::sync::Arc;
|
||||
|
||||
mod cache_aware;
|
||||
mod factory;
|
||||
mod power_of_two;
|
||||
mod random;
|
||||
mod registry;
|
||||
mod round_robin;
|
||||
|
||||
pub use cache_aware::CacheAwarePolicy;
|
||||
pub use factory::PolicyFactory;
|
||||
pub use power_of_two::PowerOfTwoPolicy;
|
||||
pub use random::RandomPolicy;
|
||||
pub use registry::PolicyRegistry;
|
||||
pub use round_robin::RoundRobinPolicy;
|
||||
|
||||
/// Core trait for load balancing policies
|
||||
@@ -26,9 +29,10 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug {
|
||||
/// Select a single worker from the available workers
|
||||
///
|
||||
/// This is used for regular routing mode where requests go to a single worker.
|
||||
/// Now uses Arc<dyn Worker> for better performance and to avoid unnecessary cloning.
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
workers: &[Arc<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<usize>;
|
||||
|
||||
@@ -38,8 +42,8 @@ pub trait LoadBalancingPolicy: Send + Sync + Debug {
|
||||
/// Default implementation uses select_worker for each array independently.
|
||||
fn select_worker_pair(
|
||||
&self,
|
||||
prefill_workers: &[Box<dyn Worker>],
|
||||
decode_workers: &[Box<dyn Worker>],
|
||||
prefill_workers: &[Arc<dyn Worker>],
|
||||
decode_workers: &[Arc<dyn Worker>],
|
||||
request_text: Option<&str>,
|
||||
) -> Option<(usize, usize)> {
|
||||
// Default implementation: independently select from each pool
|
||||
@@ -105,7 +109,7 @@ impl Default for CacheAwareConfig {
|
||||
}
|
||||
|
||||
/// Helper function to filter healthy workers and return their indices
|
||||
pub(crate) fn get_healthy_worker_indices(workers: &[Box<dyn Worker>]) -> Vec<usize> {
|
||||
pub(crate) fn get_healthy_worker_indices(workers: &[Arc<dyn Worker>]) -> Vec<usize> {
|
||||
workers
|
||||
.iter()
|
||||
.enumerate()
|
||||
@@ -121,16 +125,16 @@ mod tests {
|
||||
|
||||
#[test]
|
||||
fn test_get_healthy_worker_indices() {
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
|
||||
@@ -5,7 +5,7 @@ use crate::core::Worker;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use rand::Rng;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::RwLock;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tracing::info;
|
||||
|
||||
/// Power-of-two choices policy
|
||||
@@ -41,7 +41,7 @@ impl PowerOfTwoPolicy {
|
||||
impl LoadBalancingPolicy for PowerOfTwoPolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
workers: &[Arc<dyn Worker>],
|
||||
_request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
@@ -137,8 +137,8 @@ mod tests {
|
||||
}
|
||||
// worker3 has load 0
|
||||
|
||||
let workers: Vec<Box<dyn Worker>> =
|
||||
vec![Box::new(worker1), Box::new(worker2), Box::new(worker3)];
|
||||
let workers: Vec<Arc<dyn Worker>> =
|
||||
vec![Arc::new(worker1), Arc::new(worker2), Arc::new(worker3)];
|
||||
|
||||
// Run multiple selections
|
||||
let mut selected_counts = [0; 3];
|
||||
@@ -156,12 +156,12 @@ mod tests {
|
||||
#[test]
|
||||
fn test_power_of_two_with_cached_loads() {
|
||||
let policy = PowerOfTwoPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
@@ -190,7 +190,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_power_of_two_single_worker() {
|
||||
let policy = PowerOfTwoPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![Box::new(BasicWorker::new(
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
))];
|
||||
|
||||
@@ -4,6 +4,7 @@ use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use rand::Rng;
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Random selection policy
|
||||
///
|
||||
@@ -20,7 +21,7 @@ impl RandomPolicy {
|
||||
impl LoadBalancingPolicy for RandomPolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
workers: &[Arc<dyn Worker>],
|
||||
_request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
@@ -56,16 +57,16 @@ mod tests {
|
||||
#[test]
|
||||
fn test_random_selection() {
|
||||
let policy = RandomPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
@@ -87,12 +88,12 @@ mod tests {
|
||||
#[test]
|
||||
fn test_random_with_unhealthy_workers() {
|
||||
let policy = RandomPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
@@ -110,7 +111,7 @@ mod tests {
|
||||
#[test]
|
||||
fn test_random_no_healthy_workers() {
|
||||
let policy = RandomPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![Box::new(BasicWorker::new(
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
))];
|
||||
|
||||
333
sgl-router/src/policies/registry.rs
Normal file
333
sgl-router/src/policies/registry.rs
Normal file
@@ -0,0 +1,333 @@
|
||||
/// Policy Registry for managing model-to-policy mappings
|
||||
///
|
||||
/// This registry manages the dynamic assignment of load balancing policies to models.
|
||||
/// When the first worker of a new model is added, it determines the policy for that model.
|
||||
/// All subsequent workers of the same model use the established policy.
|
||||
/// When the last worker of a model is removed, the policy mapping is cleaned up.
|
||||
use super::{
|
||||
CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy, PowerOfTwoPolicy, RandomPolicy,
|
||||
RoundRobinPolicy,
|
||||
};
|
||||
use crate::config::types::PolicyConfig;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Registry for managing model-to-policy mappings
|
||||
#[derive(Clone)]
|
||||
pub struct PolicyRegistry {
|
||||
/// Model ID -> Policy instance mapping
|
||||
model_policies: Arc<RwLock<HashMap<String, Arc<dyn LoadBalancingPolicy>>>>,
|
||||
|
||||
/// Model ID -> Worker count for cleanup tracking
|
||||
model_worker_counts: Arc<RwLock<HashMap<String, usize>>>,
|
||||
|
||||
/// Default policy instance (cached)
|
||||
default_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
|
||||
/// Prefill policy for PD mode
|
||||
prefill_policy: Arc<RwLock<Option<Arc<dyn LoadBalancingPolicy>>>>,
|
||||
|
||||
/// Decode policy for PD mode
|
||||
decode_policy: Arc<RwLock<Option<Arc<dyn LoadBalancingPolicy>>>>,
|
||||
}
|
||||
|
||||
impl PolicyRegistry {
|
||||
/// Create a new PolicyRegistry with a default policy
|
||||
pub fn new(default_policy_config: PolicyConfig) -> Self {
|
||||
let default_policy = Self::create_policy_from_config(&default_policy_config);
|
||||
|
||||
Self {
|
||||
model_policies: Arc::new(RwLock::new(HashMap::new())),
|
||||
model_worker_counts: Arc::new(RwLock::new(HashMap::new())),
|
||||
default_policy,
|
||||
prefill_policy: Arc::new(RwLock::new(None)),
|
||||
decode_policy: Arc::new(RwLock::new(None)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Called when a worker is added
|
||||
/// Returns the policy that should be used for this worker's model
|
||||
pub fn on_worker_added(
|
||||
&self,
|
||||
model_id: &str,
|
||||
policy_hint: Option<&str>,
|
||||
) -> Arc<dyn LoadBalancingPolicy> {
|
||||
// Increment worker count
|
||||
{
|
||||
let mut counts = self.model_worker_counts.write().unwrap();
|
||||
*counts.entry(model_id.to_string()).or_insert(0) += 1;
|
||||
debug!(
|
||||
"Worker added for model {}, count: {}",
|
||||
model_id,
|
||||
counts.get(model_id).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
// Check if model already has a policy
|
||||
{
|
||||
let policies = self.model_policies.read().unwrap();
|
||||
if let Some(existing_policy) = policies.get(model_id) {
|
||||
debug!(
|
||||
"Model {} already has policy: {}",
|
||||
model_id,
|
||||
existing_policy.name()
|
||||
);
|
||||
return Arc::clone(existing_policy);
|
||||
}
|
||||
}
|
||||
|
||||
// New model - determine policy
|
||||
let policy = self.determine_policy_for_model(model_id, policy_hint);
|
||||
|
||||
info!(
|
||||
"Assigning policy {} to new model {}",
|
||||
policy.name(),
|
||||
model_id
|
||||
);
|
||||
|
||||
// Store policy for this model
|
||||
{
|
||||
let mut policies = self.model_policies.write().unwrap();
|
||||
policies.insert(model_id.to_string(), Arc::clone(&policy));
|
||||
}
|
||||
|
||||
policy
|
||||
}
|
||||
|
||||
/// Called when a worker is removed
|
||||
pub fn on_worker_removed(&self, model_id: &str) {
|
||||
let should_cleanup = {
|
||||
let mut counts = self.model_worker_counts.write().unwrap();
|
||||
if let Some(count) = counts.get_mut(model_id) {
|
||||
*count = count.saturating_sub(1);
|
||||
debug!("Worker removed for model {}, count: {}", model_id, *count);
|
||||
if *count == 0 {
|
||||
counts.remove(model_id);
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
} else {
|
||||
warn!(
|
||||
"Attempted to remove worker for model {} with no registered workers",
|
||||
model_id
|
||||
);
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
// Clean up policy if this was the last worker
|
||||
if should_cleanup {
|
||||
let mut policies = self.model_policies.write().unwrap();
|
||||
if let Some(policy) = policies.remove(model_id) {
|
||||
info!(
|
||||
"Removed policy {} for model {} (last worker removed)",
|
||||
policy.name(),
|
||||
model_id
|
||||
);
|
||||
// Policy will be dropped here, cleaning up any resources
|
||||
drop(policy);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the policy for a model
|
||||
pub fn get_policy(&self, model_id: &str) -> Option<Arc<dyn LoadBalancingPolicy>> {
|
||||
self.model_policies.read().unwrap().get(model_id).cloned()
|
||||
}
|
||||
|
||||
/// Get the default policy
|
||||
pub fn get_default_policy(&self) -> Arc<dyn LoadBalancingPolicy> {
|
||||
Arc::clone(&self.default_policy)
|
||||
}
|
||||
|
||||
/// Get policy for a model, or default if not found
|
||||
pub fn get_policy_or_default(&self, model_id: &str) -> Arc<dyn LoadBalancingPolicy> {
|
||||
self.get_policy(model_id)
|
||||
.unwrap_or_else(|| self.get_default_policy())
|
||||
}
|
||||
|
||||
/// Determine policy for a new model
|
||||
fn determine_policy_for_model(
|
||||
&self,
|
||||
model_id: &str,
|
||||
policy_hint: Option<&str>,
|
||||
) -> Arc<dyn LoadBalancingPolicy> {
|
||||
// 1. Check policy hint from worker
|
||||
if let Some(policy_type) = policy_hint {
|
||||
debug!("Using policy hint '{}' for model {}", policy_type, model_id);
|
||||
return self.create_policy_from_type(policy_type);
|
||||
}
|
||||
|
||||
// 2. Use default policy
|
||||
debug!("Using default policy for model {}", model_id);
|
||||
Arc::clone(&self.default_policy)
|
||||
}
|
||||
|
||||
/// Create a policy from a type string
|
||||
fn create_policy_from_type(&self, policy_type: &str) -> Arc<dyn LoadBalancingPolicy> {
|
||||
match policy_type {
|
||||
"round_robin" => Arc::new(RoundRobinPolicy::new()),
|
||||
"random" => Arc::new(RandomPolicy::new()),
|
||||
"cache_aware" => Arc::new(CacheAwarePolicy::new()),
|
||||
"power_of_two" => Arc::new(PowerOfTwoPolicy::new()),
|
||||
_ => {
|
||||
warn!("Unknown policy type '{}', using default", policy_type);
|
||||
Arc::clone(&self.default_policy)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a policy from a PolicyConfig
|
||||
fn create_policy_from_config(config: &PolicyConfig) -> Arc<dyn LoadBalancingPolicy> {
|
||||
match config {
|
||||
PolicyConfig::RoundRobin => Arc::new(RoundRobinPolicy::new()),
|
||||
PolicyConfig::Random => Arc::new(RandomPolicy::new()),
|
||||
PolicyConfig::CacheAware {
|
||||
cache_threshold,
|
||||
balance_abs_threshold,
|
||||
balance_rel_threshold,
|
||||
eviction_interval_secs,
|
||||
max_tree_size,
|
||||
} => {
|
||||
let cache_config = CacheAwareConfig {
|
||||
cache_threshold: *cache_threshold,
|
||||
balance_abs_threshold: *balance_abs_threshold,
|
||||
balance_rel_threshold: *balance_rel_threshold,
|
||||
eviction_interval_secs: *eviction_interval_secs,
|
||||
max_tree_size: *max_tree_size,
|
||||
};
|
||||
Arc::new(CacheAwarePolicy::with_config(cache_config))
|
||||
}
|
||||
PolicyConfig::PowerOfTwo { .. } => Arc::new(PowerOfTwoPolicy::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get current model->policy mappings (for debugging/monitoring)
|
||||
pub fn get_all_mappings(&self) -> HashMap<String, String> {
|
||||
let policies = self.model_policies.read().unwrap();
|
||||
policies
|
||||
.iter()
|
||||
.map(|(model, policy)| (model.clone(), policy.name().to_string()))
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Get worker counts per model
|
||||
pub fn get_worker_counts(&self) -> HashMap<String, usize> {
|
||||
self.model_worker_counts.read().unwrap().clone()
|
||||
}
|
||||
|
||||
/// Clear all policies (useful for testing)
|
||||
pub fn clear(&self) {
|
||||
let mut policies = self.model_policies.write().unwrap();
|
||||
policies.clear();
|
||||
let mut counts = self.model_worker_counts.write().unwrap();
|
||||
counts.clear();
|
||||
}
|
||||
|
||||
/// Set the prefill policy for PD mode
|
||||
pub fn set_prefill_policy(&self, policy: Arc<dyn LoadBalancingPolicy>) {
|
||||
let mut prefill_policy = self.prefill_policy.write().unwrap();
|
||||
*prefill_policy = Some(policy);
|
||||
}
|
||||
|
||||
/// Set the decode policy for PD mode
|
||||
pub fn set_decode_policy(&self, policy: Arc<dyn LoadBalancingPolicy>) {
|
||||
let mut decode_policy = self.decode_policy.write().unwrap();
|
||||
*decode_policy = Some(policy);
|
||||
}
|
||||
|
||||
/// Get the prefill policy for PD mode, or default if not set
|
||||
pub fn get_prefill_policy(&self) -> Arc<dyn LoadBalancingPolicy> {
|
||||
let prefill_policy = self.prefill_policy.read().unwrap();
|
||||
prefill_policy
|
||||
.as_ref()
|
||||
.map(Arc::clone)
|
||||
.unwrap_or_else(|| self.get_default_policy())
|
||||
}
|
||||
|
||||
/// Get the decode policy for PD mode, or default if not set
|
||||
pub fn get_decode_policy(&self) -> Arc<dyn LoadBalancingPolicy> {
|
||||
let decode_policy = self.decode_policy.read().unwrap();
|
||||
decode_policy
|
||||
.as_ref()
|
||||
.map(Arc::clone)
|
||||
.unwrap_or_else(|| self.get_default_policy())
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PolicyRegistry {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("PolicyRegistry")
|
||||
.field("model_policies", &self.model_policies)
|
||||
.field("model_worker_counts", &self.model_worker_counts)
|
||||
.field("default_policy", &self.default_policy.name())
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_policy_registry_basic() {
|
||||
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
|
||||
|
||||
// First worker of a model sets the policy
|
||||
let policy1 = registry.on_worker_added("llama-3", Some("cache_aware"));
|
||||
assert_eq!(policy1.name(), "cache_aware");
|
||||
|
||||
// Second worker of same model uses existing policy
|
||||
let policy2 = registry.on_worker_added("llama-3", Some("round_robin"));
|
||||
assert_eq!(policy2.name(), "cache_aware"); // Ignores hint, uses existing
|
||||
|
||||
// Different model can have different policy
|
||||
let policy3 = registry.on_worker_added("gpt-4", Some("random"));
|
||||
assert_eq!(policy3.name(), "random");
|
||||
|
||||
// Check mappings
|
||||
let mappings = registry.get_all_mappings();
|
||||
assert_eq!(mappings.get("llama-3").unwrap(), "cache_aware");
|
||||
assert_eq!(mappings.get("gpt-4").unwrap(), "random");
|
||||
|
||||
// Check worker counts
|
||||
let counts = registry.get_worker_counts();
|
||||
assert_eq!(*counts.get("llama-3").unwrap(), 2);
|
||||
assert_eq!(*counts.get("gpt-4").unwrap(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_registry_cleanup() {
|
||||
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
|
||||
|
||||
// Add workers
|
||||
registry.on_worker_added("llama-3", Some("cache_aware"));
|
||||
registry.on_worker_added("llama-3", None);
|
||||
assert_eq!(registry.get_worker_counts().get("llama-3"), Some(&2));
|
||||
|
||||
// Remove one worker - policy should remain
|
||||
registry.on_worker_removed("llama-3");
|
||||
assert!(registry.get_policy("llama-3").is_some());
|
||||
assert_eq!(registry.get_worker_counts().get("llama-3"), Some(&1));
|
||||
|
||||
// Remove last worker - policy should be cleaned up
|
||||
registry.on_worker_removed("llama-3");
|
||||
assert!(registry.get_policy("llama-3").is_none());
|
||||
assert_eq!(registry.get_worker_counts().get("llama-3"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_default_policy() {
|
||||
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
|
||||
|
||||
// No hint, no template - uses default
|
||||
let policy = registry.on_worker_added("unknown-model", None);
|
||||
assert_eq!(policy.name(), "round_robin");
|
||||
|
||||
// Get default directly
|
||||
let default = registry.get_default_policy();
|
||||
assert_eq!(default.name(), "round_robin");
|
||||
}
|
||||
}
|
||||
@@ -4,6 +4,7 @@ use super::{get_healthy_worker_indices, LoadBalancingPolicy};
|
||||
use crate::core::Worker;
|
||||
use crate::metrics::RouterMetrics;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
/// Round-robin selection policy
|
||||
///
|
||||
@@ -24,7 +25,7 @@ impl RoundRobinPolicy {
|
||||
impl LoadBalancingPolicy for RoundRobinPolicy {
|
||||
fn select_worker(
|
||||
&self,
|
||||
workers: &[Box<dyn Worker>],
|
||||
workers: &[Arc<dyn Worker>],
|
||||
_request_text: Option<&str>,
|
||||
) -> Option<usize> {
|
||||
let healthy_indices = get_healthy_worker_indices(workers);
|
||||
@@ -64,16 +65,16 @@ mod tests {
|
||||
#[test]
|
||||
fn test_round_robin_selection() {
|
||||
let policy = RoundRobinPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
@@ -90,16 +91,16 @@ mod tests {
|
||||
#[test]
|
||||
fn test_round_robin_with_unhealthy_workers() {
|
||||
let policy = RoundRobinPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w3:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
@@ -118,12 +119,12 @@ mod tests {
|
||||
#[test]
|
||||
fn test_round_robin_reset() {
|
||||
let policy = RoundRobinPolicy::new();
|
||||
let workers: Vec<Box<dyn Worker>> = vec![
|
||||
Box::new(BasicWorker::new(
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w1:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
Box::new(BasicWorker::new(
|
||||
Arc::new(BasicWorker::new(
|
||||
"http://w2:8000".to_string(),
|
||||
WorkerType::Regular,
|
||||
)),
|
||||
|
||||
@@ -3,3 +3,4 @@
|
||||
|
||||
pub mod spec;
|
||||
pub mod validation;
|
||||
pub mod worker_spec;
|
||||
|
||||
198
sgl-router/src/protocols/worker_spec.rs
Normal file
198
sgl-router/src/protocols/worker_spec.rs
Normal file
@@ -0,0 +1,198 @@
|
||||
//! Worker management API specifications
|
||||
//!
|
||||
//! Defines the request/response structures for worker management endpoints
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// Worker configuration for API requests
|
||||
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||
pub struct WorkerConfigRequest {
|
||||
/// Worker URL (required)
|
||||
pub url: String,
|
||||
|
||||
/// Model ID (optional, will query from server if not provided)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model_id: Option<String>,
|
||||
|
||||
/// Worker priority (optional, default: 50, higher = preferred)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub priority: Option<u32>,
|
||||
|
||||
/// Worker cost factor (optional, default: 1.0)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cost: Option<f32>,
|
||||
|
||||
/// Worker type (optional: "regular", "prefill", "decode")
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub worker_type: Option<String>,
|
||||
|
||||
/// Bootstrap port for prefill workers (optional)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub bootstrap_port: Option<u16>,
|
||||
|
||||
// gRPC-specific configuration (optional, ignored in HTTP mode)
|
||||
/// Tokenizer path for gRPC mode
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tokenizer_path: Option<String>,
|
||||
|
||||
/// Reasoning parser type for gRPC mode
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_parser: Option<String>,
|
||||
|
||||
/// Tool parser type for gRPC mode
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_parser: Option<String>,
|
||||
|
||||
/// Chat template for gRPC mode
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chat_template: Option<String>,
|
||||
|
||||
/// Additional labels (optional)
|
||||
#[serde(default, skip_serializing_if = "HashMap::is_empty")]
|
||||
pub labels: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Worker information for API responses
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct WorkerInfo {
|
||||
/// Worker unique identifier
|
||||
pub id: String,
|
||||
|
||||
/// Worker URL
|
||||
pub url: String,
|
||||
|
||||
/// Model ID this worker serves
|
||||
pub model_id: String,
|
||||
|
||||
/// Worker priority
|
||||
pub priority: u32,
|
||||
|
||||
/// Worker cost factor
|
||||
pub cost: f32,
|
||||
|
||||
/// Worker type
|
||||
pub worker_type: String,
|
||||
|
||||
/// Whether the worker is healthy
|
||||
pub is_healthy: bool,
|
||||
|
||||
/// Current load on the worker
|
||||
pub load: usize,
|
||||
|
||||
/// Connection mode (http or grpc)
|
||||
pub connection_mode: String,
|
||||
|
||||
// gRPC-specific fields (None for HTTP workers)
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tokenizer_path: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_parser: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_parser: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chat_template: Option<String>,
|
||||
|
||||
/// Additional metadata
|
||||
#[serde(skip_serializing_if = "HashMap::is_empty")]
|
||||
pub metadata: HashMap<String, String>,
|
||||
}
|
||||
|
||||
/// Worker list response
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct WorkerListResponse {
|
||||
/// List of workers
|
||||
pub workers: Vec<WorkerInfo>,
|
||||
|
||||
/// Total count
|
||||
pub total: usize,
|
||||
|
||||
/// Statistics
|
||||
pub stats: WorkerStats,
|
||||
}
|
||||
|
||||
/// Worker statistics
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct WorkerStats {
|
||||
pub total_workers: usize,
|
||||
pub healthy_workers: usize,
|
||||
pub total_models: usize,
|
||||
pub total_load: usize,
|
||||
pub by_type: WorkerTypeStats,
|
||||
}
|
||||
|
||||
/// Worker statistics by type
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct WorkerTypeStats {
|
||||
pub regular: usize,
|
||||
pub prefill: usize,
|
||||
pub decode: usize,
|
||||
}
|
||||
|
||||
/// Worker update request
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct WorkerUpdateRequest {
|
||||
/// Update priority
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub priority: Option<u32>,
|
||||
|
||||
/// Update cost
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cost: Option<f32>,
|
||||
|
||||
/// Update labels
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub labels: Option<HashMap<String, String>>,
|
||||
}
|
||||
|
||||
/// Generic API response
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct WorkerApiResponse {
|
||||
pub success: bool,
|
||||
pub message: String,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub worker: Option<WorkerInfo>,
|
||||
}
|
||||
|
||||
/// Error response
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub struct WorkerErrorResponse {
|
||||
pub error: String,
|
||||
pub code: String,
|
||||
}
|
||||
|
||||
/// Server info response from /get_server_info endpoint
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct ServerInfo {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model_id: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub model_path: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub priority: Option<u32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub cost: Option<f32>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub worker_type: Option<String>,
|
||||
|
||||
// gRPC-specific
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tokenizer_path: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub reasoning_parser: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub tool_parser: Option<String>,
|
||||
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub chat_template: Option<String>,
|
||||
}
|
||||
@@ -15,11 +15,6 @@ pub struct RouterFactory;
|
||||
impl RouterFactory {
|
||||
/// Create a router instance from application context
|
||||
pub async fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Check if IGW mode is enabled
|
||||
if ctx.router_config.enable_igw {
|
||||
return Self::create_igw_router(ctx).await;
|
||||
}
|
||||
|
||||
// Check connection mode and route to appropriate implementation
|
||||
match ctx.router_config.connection_mode {
|
||||
ConnectionMode::Grpc => {
|
||||
@@ -53,8 +48,7 @@ impl RouterFactory {
|
||||
// Route to HTTP implementation based on routing mode
|
||||
match &ctx.router_config.mode {
|
||||
RoutingMode::Regular { worker_urls } => {
|
||||
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
|
||||
.await
|
||||
Self::create_regular_router(worker_urls, ctx).await
|
||||
}
|
||||
RoutingMode::PrefillDecode {
|
||||
prefill_urls,
|
||||
@@ -80,23 +74,19 @@ impl RouterFactory {
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a regular router with injected policy
|
||||
async fn create_regular_router(
|
||||
/// Create a regular router
|
||||
pub async fn create_regular_router(
|
||||
worker_urls: &[String],
|
||||
policy_config: &PolicyConfig,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Create policy
|
||||
let policy = PolicyFactory::create_from_config(policy_config);
|
||||
|
||||
// Create regular router with injected policy and context
|
||||
let router = Router::new(worker_urls.to_vec(), policy, ctx).await?;
|
||||
// Create regular router with context
|
||||
let router = Router::new(worker_urls.to_vec(), ctx).await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create a PD router with injected policy
|
||||
async fn create_pd_router(
|
||||
pub async fn create_pd_router(
|
||||
prefill_urls: &[(String, Option<u16>)],
|
||||
decode_urls: &[String],
|
||||
prefill_policy_config: Option<&PolicyConfig>,
|
||||
@@ -104,21 +94,18 @@ impl RouterFactory {
|
||||
main_policy_config: &PolicyConfig,
|
||||
ctx: &Arc<AppContext>,
|
||||
) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// Create policies - use specific policies if provided, otherwise fall back to main policy
|
||||
// Initialize policies in PolicyRegistry - use specific policies if provided, otherwise fall back to main policy
|
||||
let prefill_policy =
|
||||
PolicyFactory::create_from_config(prefill_policy_config.unwrap_or(main_policy_config));
|
||||
let decode_policy =
|
||||
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
|
||||
|
||||
// Create PD router with separate policies and context
|
||||
let router = PDRouter::new(
|
||||
prefill_urls.to_vec(),
|
||||
decode_urls.to_vec(),
|
||||
prefill_policy,
|
||||
decode_policy,
|
||||
ctx,
|
||||
)
|
||||
.await?;
|
||||
// Set the prefill and decode policies in the registry
|
||||
ctx.policy_registry.set_prefill_policy(prefill_policy);
|
||||
ctx.policy_registry.set_decode_policy(decode_policy);
|
||||
|
||||
// Create PD router with context (policies are in PolicyRegistry)
|
||||
let router = PDRouter::new(prefill_urls.to_vec(), decode_urls.to_vec(), ctx).await?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
@@ -186,10 +173,4 @@ impl RouterFactory {
|
||||
|
||||
Ok(Box::new(router))
|
||||
}
|
||||
|
||||
/// Create an IGW router (placeholder for future implementation)
|
||||
async fn create_igw_router(_ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||
// For now, return an error indicating IGW is not yet implemented
|
||||
Err("IGW mode is not yet implemented".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -27,9 +27,9 @@ use tracing::{info, warn};
|
||||
#[allow(dead_code)] // Fields will be used once implementation is complete
|
||||
pub struct GrpcPDRouter {
|
||||
/// Prefill worker connections
|
||||
prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
prefill_workers: Arc<RwLock<Vec<Arc<dyn Worker>>>>,
|
||||
/// Decode worker connections
|
||||
decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
decode_workers: Arc<RwLock<Vec<Arc<dyn Worker>>>>,
|
||||
/// gRPC clients for prefill workers
|
||||
prefill_grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
|
||||
/// gRPC clients for decode workers
|
||||
@@ -127,7 +127,7 @@ impl GrpcPDRouter {
|
||||
}
|
||||
|
||||
// Create Prefill Worker trait objects with gRPC connection mode
|
||||
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
|
||||
let prefill_workers: Vec<Arc<dyn Worker>> = prefill_urls
|
||||
.iter()
|
||||
.map(|(url, bootstrap_port)| {
|
||||
let worker = BasicWorker::with_connection_mode(
|
||||
@@ -147,12 +147,12 @@ impl GrpcPDRouter {
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
});
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
Arc::new(worker) as Arc<dyn Worker>
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Create Decode Worker trait objects with gRPC connection mode
|
||||
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
|
||||
let decode_workers: Vec<Arc<dyn Worker>> = decode_urls
|
||||
.iter()
|
||||
.map(|url| {
|
||||
let worker = BasicWorker::with_connection_mode(
|
||||
@@ -168,7 +168,7 @@ impl GrpcPDRouter {
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
});
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
Arc::new(worker) as Arc<dyn Worker>
|
||||
})
|
||||
.collect();
|
||||
|
||||
@@ -269,6 +269,7 @@ impl RouterTrait for GrpcPDRouter {
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::GenerateRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
@@ -277,6 +278,7 @@ impl RouterTrait for GrpcPDRouter {
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::ChatCompletionRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
@@ -285,6 +287,7 @@ impl RouterTrait for GrpcPDRouter {
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::CompletionRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
@@ -293,6 +296,7 @@ impl RouterTrait for GrpcPDRouter {
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::ResponsesRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
@@ -305,6 +309,7 @@ impl RouterTrait for GrpcPDRouter {
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::RerankRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
@@ -27,7 +27,7 @@ use tracing::{info, warn};
|
||||
#[allow(dead_code)] // Fields will be used once implementation is complete
|
||||
pub struct GrpcRouter {
|
||||
/// Worker connections
|
||||
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
workers: Arc<RwLock<Vec<Arc<dyn Worker>>>>,
|
||||
/// gRPC clients for each worker
|
||||
grpc_clients: Arc<RwLock<HashMap<String, SglangSchedulerClient>>>,
|
||||
/// Load balancing policy
|
||||
@@ -103,7 +103,7 @@ impl GrpcRouter {
|
||||
}
|
||||
|
||||
// Create Worker trait objects with gRPC connection mode
|
||||
let mut workers: Vec<Box<dyn Worker>> = Vec::new();
|
||||
let mut workers: Vec<Arc<dyn Worker>> = Vec::new();
|
||||
|
||||
// Move clients from the HashMap to the workers
|
||||
for url in &worker_urls {
|
||||
@@ -123,7 +123,7 @@ impl GrpcRouter {
|
||||
})
|
||||
.with_grpc_client(client);
|
||||
|
||||
workers.push(Box::new(worker) as Box<dyn Worker>);
|
||||
workers.push(Arc::new(worker) as Arc<dyn Worker>);
|
||||
} else {
|
||||
warn!("No gRPC client for worker {}, skipping", url);
|
||||
}
|
||||
@@ -202,6 +202,7 @@ impl RouterTrait for GrpcRouter {
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::GenerateRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
@@ -210,6 +211,7 @@ impl RouterTrait for GrpcRouter {
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::ChatCompletionRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
@@ -218,6 +220,7 @@ impl RouterTrait for GrpcRouter {
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::CompletionRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
@@ -226,6 +229,7 @@ impl RouterTrait for GrpcRouter {
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::ResponsesRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
@@ -238,6 +242,7 @@ impl RouterTrait for GrpcRouter {
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::RerankRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(StatusCode::NOT_IMPLEMENTED).into_response()
|
||||
}
|
||||
|
||||
@@ -186,6 +186,7 @@ impl super::super::RouterTrait for OpenAIRouter {
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &GenerateRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Generate endpoint is SGLang-specific, not supported for OpenAI backend
|
||||
(
|
||||
@@ -199,6 +200,7 @@ impl super::super::RouterTrait for OpenAIRouter {
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ChatCompletionRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
if !self.circuit_breaker.can_execute() {
|
||||
return (StatusCode::SERVICE_UNAVAILABLE, "Circuit breaker open").into_response();
|
||||
@@ -326,6 +328,7 @@ impl super::super::RouterTrait for OpenAIRouter {
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &CompletionRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Completion endpoint not implemented for OpenAI backend
|
||||
(
|
||||
@@ -339,6 +342,7 @@ impl super::super::RouterTrait for OpenAIRouter {
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &crate::protocols::spec::ResponsesRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
@@ -383,7 +387,12 @@ impl super::super::RouterTrait for OpenAIRouter {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_rerank(&self, _headers: Option<&HeaderMap>, _body: &RerankRequest) -> Response {
|
||||
async fn route_rerank(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &RerankRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Rerank endpoint not implemented for OpenAI backend",
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,10 +1,10 @@
|
||||
use crate::config::types::RetryConfig;
|
||||
use crate::core::{
|
||||
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthChecker, HealthConfig,
|
||||
RetryExecutor, Worker, WorkerFactory, WorkerType,
|
||||
is_retryable_status, BasicWorker, CircuitBreakerConfig, HealthConfig, RetryExecutor, Worker,
|
||||
WorkerRegistry, WorkerType,
|
||||
};
|
||||
use crate::metrics::RouterMetrics;
|
||||
use crate::policies::LoadBalancingPolicy;
|
||||
use crate::policies::{LoadBalancingPolicy, PolicyRegistry};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, GenerationRequest, RerankRequest,
|
||||
RerankResponse, RerankResult, ResponsesRequest,
|
||||
@@ -22,7 +22,7 @@ use axum::{
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio_stream::wrappers::UnboundedReceiverStream;
|
||||
use tracing::{debug, error, info, warn};
|
||||
@@ -30,8 +30,8 @@ use tracing::{debug, error, info, warn};
|
||||
/// Regular router that uses injected load balancing policies
|
||||
#[derive(Debug)]
|
||||
pub struct Router {
|
||||
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
worker_registry: Arc<WorkerRegistry>,
|
||||
policy_registry: Arc<PolicyRegistry>,
|
||||
client: Client,
|
||||
worker_startup_timeout_secs: u64,
|
||||
worker_startup_check_interval_secs: u64,
|
||||
@@ -41,7 +41,6 @@ pub struct Router {
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||
_health_checker: Option<HealthChecker>,
|
||||
}
|
||||
|
||||
impl Router {
|
||||
@@ -49,7 +48,6 @@ impl Router {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub async fn new(
|
||||
worker_urls: Vec<String>,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
ctx: &Arc<crate::server::AppContext>,
|
||||
) -> Result<Self, String> {
|
||||
// Update active workers gauge
|
||||
@@ -82,45 +80,51 @@ impl Router {
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Create Worker trait objects from URLs with health check config
|
||||
let workers: Vec<Box<dyn Worker>> = worker_urls
|
||||
.iter()
|
||||
.map(|url| {
|
||||
let worker = BasicWorker::new(url.clone(), WorkerType::Regular)
|
||||
.with_circuit_breaker_config(core_cb_config.clone())
|
||||
.with_health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
});
|
||||
Box::new(worker) as Box<dyn Worker>
|
||||
})
|
||||
.collect();
|
||||
// Register workers in the registry
|
||||
// In IGW mode, we need to fetch model info from workers
|
||||
for url in &worker_urls {
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
// For now, create worker without model_id
|
||||
let worker = BasicWorker::new(url.clone(), WorkerType::Regular)
|
||||
.with_circuit_breaker_config(core_cb_config.clone())
|
||||
.with_health_config(HealthConfig {
|
||||
timeout_secs: ctx.router_config.health_check.timeout_secs,
|
||||
check_interval_secs: ctx.router_config.health_check.check_interval_secs,
|
||||
endpoint: ctx.router_config.health_check.endpoint.clone(),
|
||||
failure_threshold: ctx.router_config.health_check.failure_threshold,
|
||||
success_threshold: ctx.router_config.health_check.success_threshold,
|
||||
});
|
||||
|
||||
// Initialize policy with workers if needed (e.g., for cache-aware)
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.init_workers(&workers);
|
||||
let worker_arc = Arc::new(worker);
|
||||
ctx.worker_registry.register(worker_arc.clone());
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = ctx.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy and it's the first worker for this model,
|
||||
// initialize it with the worker
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
let worker_dyn: Arc<dyn Worker> = worker_arc.clone();
|
||||
cache_aware.init_workers(std::slice::from_ref(&worker_dyn));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let workers = Arc::new(RwLock::new(workers));
|
||||
let health_checker = crate::core::start_health_checker(
|
||||
Arc::clone(&workers),
|
||||
ctx.router_config.worker_startup_check_interval_secs,
|
||||
);
|
||||
|
||||
// Setup load monitoring for PowerOfTwo policy
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
let worker_loads = Arc::new(rx);
|
||||
|
||||
let load_monitor_handle = if policy.name() == "power_of_two" {
|
||||
// Check if default policy is power_of_two for load monitoring
|
||||
let default_policy = ctx.policy_registry.get_default_policy();
|
||||
let load_monitor_handle = if default_policy.name() == "power_of_two" {
|
||||
let monitor_urls = worker_urls.clone();
|
||||
let monitor_interval = ctx.router_config.worker_startup_check_interval_secs;
|
||||
let policy_clone = Arc::clone(&policy);
|
||||
let policy_clone = default_policy.clone();
|
||||
let client_clone = ctx.client.clone();
|
||||
|
||||
Some(Arc::new(tokio::spawn(async move {
|
||||
@@ -138,8 +142,8 @@ impl Router {
|
||||
};
|
||||
|
||||
Ok(Router {
|
||||
workers,
|
||||
policy,
|
||||
worker_registry: ctx.worker_registry.clone(),
|
||||
policy_registry: ctx.policy_registry.clone(),
|
||||
client: ctx.client.clone(),
|
||||
worker_startup_timeout_secs: ctx.router_config.worker_startup_timeout_secs,
|
||||
worker_startup_check_interval_secs: ctx
|
||||
@@ -151,18 +155,21 @@ impl Router {
|
||||
circuit_breaker_config: core_cb_config,
|
||||
_worker_loads: worker_loads,
|
||||
_load_monitor_handle: load_monitor_handle,
|
||||
_health_checker: Some(health_checker),
|
||||
})
|
||||
}
|
||||
|
||||
/// Get the current list of worker URLs
|
||||
pub fn get_worker_urls(&self) -> Vec<String> {
|
||||
self.workers
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|w| w.url().to_string())
|
||||
.collect()
|
||||
self.worker_registry.get_all_urls()
|
||||
}
|
||||
|
||||
/// Get worker URLs for a specific model
|
||||
pub fn get_worker_urls_for_model(&self, model_id: Option<&str>) -> Vec<String> {
|
||||
let workers = match model_id {
|
||||
Some(model) => self.worker_registry.get_by_model_fast(model),
|
||||
None => self.worker_registry.get_all(),
|
||||
};
|
||||
workers.iter().map(|w| w.url().to_string()).collect()
|
||||
}
|
||||
|
||||
pub async fn wait_for_healthy_workers(
|
||||
@@ -332,11 +339,27 @@ impl Router {
|
||||
}
|
||||
|
||||
fn select_first_worker(&self) -> Result<String, String> {
|
||||
let workers_guard = self.workers.read().unwrap();
|
||||
if workers_guard.is_empty() {
|
||||
let workers = self.worker_registry.get_all();
|
||||
if workers.is_empty() {
|
||||
Err("No workers are available".to_string())
|
||||
} else {
|
||||
Ok(workers_guard[0].url().to_string())
|
||||
Ok(workers[0].url().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
fn select_first_worker_for_model(&self, model_id: Option<&str>) -> Result<String, String> {
|
||||
let workers = match model_id {
|
||||
Some(model) => self.worker_registry.get_by_model_fast(model),
|
||||
None => self.worker_registry.get_all(),
|
||||
};
|
||||
if workers.is_empty() {
|
||||
Err(format!(
|
||||
"No workers are available for model: {:?}",
|
||||
model_id
|
||||
))
|
||||
} else {
|
||||
Ok(workers[0].url().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -447,20 +470,35 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
// New method to route typed requests directly
|
||||
/// Select worker considering circuit breaker state
|
||||
fn select_worker_with_circuit_breaker(&self, text: Option<&str>) -> Option<Box<dyn Worker>> {
|
||||
let workers = self.workers.read().ok()?;
|
||||
let available: Vec<Box<dyn Worker>> = workers
|
||||
/// Select worker for a specific model considering circuit breaker state
|
||||
fn select_worker_for_model(
|
||||
&self,
|
||||
model_id: Option<&str>,
|
||||
text: Option<&str>,
|
||||
) -> Option<Arc<dyn Worker>> {
|
||||
// Get workers for the specified model (O(1) lookup if model_id is provided)
|
||||
let workers = match model_id {
|
||||
Some(model) => self.worker_registry.get_by_model_fast(model),
|
||||
None => self.worker_registry.get_all(),
|
||||
};
|
||||
|
||||
let available: Vec<Arc<dyn Worker>> = workers
|
||||
.iter()
|
||||
.filter(|w| w.is_available())
|
||||
.map(|w| w.clone_worker())
|
||||
.cloned()
|
||||
.collect();
|
||||
if available.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let idx = self.policy.select_worker(&available, text)?;
|
||||
Some(available[idx].clone_worker())
|
||||
|
||||
// Get the appropriate policy for this model
|
||||
let policy = match model_id {
|
||||
Some(model) => self.policy_registry.get_policy_or_default(model),
|
||||
None => self.policy_registry.get_default_policy(),
|
||||
};
|
||||
|
||||
let idx = policy.select_worker(&available, text)?;
|
||||
Some(available[idx].clone())
|
||||
}
|
||||
|
||||
pub async fn route_typed_request<T: GenerationRequest + serde::Serialize + Clone>(
|
||||
@@ -468,6 +506,7 @@ impl Router {
|
||||
headers: Option<&HeaderMap>,
|
||||
typed_req: &T,
|
||||
route: &str,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
let start = Instant::now();
|
||||
let is_stream = typed_req.is_stream();
|
||||
@@ -477,7 +516,7 @@ impl Router {
|
||||
&self.retry_config,
|
||||
// operation per attempt
|
||||
|_: u32| async {
|
||||
let worker = match self.select_worker_with_circuit_breaker(Some(&text)) {
|
||||
let worker = match self.select_worker_for_model(model_id, Some(&text)) {
|
||||
Some(w) => w,
|
||||
None => {
|
||||
RouterMetrics::record_request_error(route, "no_available_workers");
|
||||
@@ -490,7 +529,13 @@ impl Router {
|
||||
};
|
||||
|
||||
// Optional load tracking for cache-aware policy
|
||||
let load_incremented = if self.policy.name() == "cache_aware" {
|
||||
// Get the policy for this model to check if it's cache-aware
|
||||
let policy = match model_id {
|
||||
Some(model) => self.policy_registry.get_policy_or_default(model),
|
||||
None => self.policy_registry.get_default_policy(),
|
||||
};
|
||||
|
||||
let load_incremented = if policy.name() == "cache_aware" {
|
||||
worker.increment_load();
|
||||
RouterMetrics::set_running_requests(worker.url(), worker.load());
|
||||
true
|
||||
@@ -654,11 +699,9 @@ impl Router {
|
||||
|
||||
// Decrement load on error if it was incremented
|
||||
if load_incremented {
|
||||
if let Ok(workers_guard) = self.workers.read() {
|
||||
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(worker_url, worker.load());
|
||||
}
|
||||
if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(worker_url, worker.load());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -687,13 +730,9 @@ impl Router {
|
||||
Err(e) => {
|
||||
// IMPORTANT: Decrement load on error before returning
|
||||
if load_incremented {
|
||||
if let Ok(workers_guard) = self.workers.read() {
|
||||
if let Some(worker) =
|
||||
workers_guard.iter().find(|w| w.url() == worker_url)
|
||||
{
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(worker_url, worker.load());
|
||||
}
|
||||
if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(worker_url, worker.load());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -704,18 +743,16 @@ impl Router {
|
||||
|
||||
// Decrement load counter for non-streaming requests if it was incremented
|
||||
if load_incremented {
|
||||
if let Ok(workers_guard) = self.workers.read() {
|
||||
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(worker_url, worker.load());
|
||||
}
|
||||
if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(worker_url, worker.load());
|
||||
}
|
||||
}
|
||||
|
||||
response
|
||||
} else if load_incremented {
|
||||
// For streaming with load tracking, we need to manually decrement when done
|
||||
let workers = Arc::clone(&self.workers);
|
||||
let registry = Arc::clone(&self.worker_registry);
|
||||
let worker_url = worker_url.to_string();
|
||||
|
||||
// Preserve headers for streaming response
|
||||
@@ -739,17 +776,10 @@ impl Router {
|
||||
.windows(12)
|
||||
.any(|window| window == b"data: [DONE]")
|
||||
{
|
||||
if let Ok(workers_guard) = workers.read() {
|
||||
if let Some(worker) =
|
||||
workers_guard.iter().find(|w| w.url() == worker_url)
|
||||
{
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(
|
||||
&worker_url,
|
||||
worker.load(),
|
||||
);
|
||||
decremented = true;
|
||||
}
|
||||
if let Some(worker) = registry.get_by_url(&worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(&worker_url, worker.load());
|
||||
decremented = true;
|
||||
}
|
||||
}
|
||||
if tx.send(Ok(bytes)).is_err() {
|
||||
@@ -763,11 +793,9 @@ impl Router {
|
||||
}
|
||||
}
|
||||
if !decremented {
|
||||
if let Ok(workers_guard) = workers.read() {
|
||||
if let Some(worker) = workers_guard.iter().find(|w| w.url() == worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(&worker_url, worker.load());
|
||||
}
|
||||
if let Some(worker) = registry.get_by_url(&worker_url) {
|
||||
worker.decrement_load();
|
||||
RouterMetrics::set_running_requests(&worker_url, worker.load());
|
||||
}
|
||||
}
|
||||
});
|
||||
@@ -839,7 +867,6 @@ impl Router {
|
||||
match client.get(format!("{}/health", worker_url)).send().await {
|
||||
Ok(res) => {
|
||||
if res.status().is_success() {
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
if self.dp_aware {
|
||||
// Need to contact the worker to extract the dp_size,
|
||||
// and add them as multiple workers
|
||||
@@ -848,46 +875,77 @@ impl Router {
|
||||
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
|
||||
let mut worker_added: bool = false;
|
||||
for dp_url in &dp_url_vec {
|
||||
if workers_guard.iter().any(|w| w.url() == dp_url) {
|
||||
if self.worker_registry.get_by_url(dp_url).is_some() {
|
||||
warn!("Worker {} already exists", dp_url);
|
||||
continue;
|
||||
}
|
||||
info!("Added worker: {}", dp_url);
|
||||
let new_worker = WorkerFactory::create_regular_with_config(
|
||||
dp_url.to_string(),
|
||||
self.circuit_breaker_config.clone(),
|
||||
);
|
||||
workers_guard.push(new_worker);
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let new_worker =
|
||||
BasicWorker::new(dp_url.to_string(), WorkerType::Regular)
|
||||
.with_circuit_breaker_config(
|
||||
self.circuit_breaker_config.clone(),
|
||||
);
|
||||
|
||||
let worker_arc = Arc::new(new_worker);
|
||||
self.worker_registry.register(worker_arc.clone());
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = self.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy, update it with all workers for this model
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>(
|
||||
) {
|
||||
let model_workers =
|
||||
self.worker_registry.get_by_model_fast(model_id);
|
||||
cache_aware.init_workers(&model_workers);
|
||||
}
|
||||
}
|
||||
|
||||
worker_added = true;
|
||||
}
|
||||
if !worker_added {
|
||||
return Err(format!("No worker added for {}", worker_url));
|
||||
}
|
||||
} else {
|
||||
if workers_guard.iter().any(|w| w.url() == worker_url) {
|
||||
if self.worker_registry.get_by_url(worker_url).is_some() {
|
||||
return Err(format!("Worker {} already exists", worker_url));
|
||||
}
|
||||
info!("Added worker: {}", worker_url);
|
||||
let new_worker = WorkerFactory::create_regular_with_config(
|
||||
worker_url.to_string(),
|
||||
self.circuit_breaker_config.clone(),
|
||||
);
|
||||
workers_guard.push(new_worker);
|
||||
|
||||
// TODO: In IGW mode, fetch model_id from worker's /get_model_info endpoint
|
||||
let new_worker =
|
||||
BasicWorker::new(worker_url.to_string(), WorkerType::Regular)
|
||||
.with_circuit_breaker_config(
|
||||
self.circuit_breaker_config.clone(),
|
||||
);
|
||||
|
||||
let worker_arc = Arc::new(new_worker);
|
||||
self.worker_registry.register(worker_arc.clone());
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = self.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy, add this worker to it
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>(
|
||||
) {
|
||||
// Get all workers for this model
|
||||
let model_workers =
|
||||
self.worker_registry.get_by_model_fast(model_id);
|
||||
cache_aware.init_workers(&model_workers);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
|
||||
// If cache aware policy, initialize the worker in the tree
|
||||
if let Some(cache_aware) =
|
||||
self.policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
// Get updated workers after adding
|
||||
drop(workers_guard);
|
||||
let workers_guard = self.workers.read().unwrap();
|
||||
cache_aware.init_workers(&workers_guard);
|
||||
}
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
|
||||
return Ok(format!("Successfully added worker: {}", worker_url));
|
||||
} else {
|
||||
@@ -931,66 +989,73 @@ impl Router {
|
||||
if self.dp_aware {
|
||||
// remove dp-aware workers in a prefix-matching fashion
|
||||
// without contacting the remote worker
|
||||
let mut candidate_workers: Vec<String> = Vec::new();
|
||||
let mut removed_workers: Vec<String> = Vec::new();
|
||||
let worker_url_prefix = format!("{}@", worker_url);
|
||||
|
||||
{
|
||||
// find the candidate workers to be removed
|
||||
let workers_guard = self.workers.read().unwrap();
|
||||
for w in workers_guard.iter() {
|
||||
if w.url().starts_with(&worker_url_prefix) {
|
||||
candidate_workers.push(w.url().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
// Find and remove all workers with matching prefix
|
||||
let all_workers = self.worker_registry.get_all();
|
||||
for w in all_workers.iter() {
|
||||
if w.url().starts_with(&worker_url_prefix) {
|
||||
// Get model_id before removing
|
||||
let model_id = w.model_id().to_string();
|
||||
|
||||
{
|
||||
// do the removing on the worker_urls
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
for dp_url in candidate_workers.iter() {
|
||||
if let Some(index) = workers_guard.iter().position(|w| w.url() == dp_url) {
|
||||
workers_guard.remove(index);
|
||||
info!("Removed worker: {}", dp_url);
|
||||
removed_workers.push(dp_url.to_string());
|
||||
if self.worker_registry.remove_by_url(w.url()).is_some() {
|
||||
info!("Removed worker: {}", w.url());
|
||||
removed_workers.push(w.url().to_string());
|
||||
|
||||
// Notify PolicyRegistry about the removed worker
|
||||
self.policy_registry.on_worker_removed(&model_id);
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", dp_url);
|
||||
continue;
|
||||
warn!("Worker {} not found, skipping removal", w.url());
|
||||
}
|
||||
}
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
}
|
||||
|
||||
// If cache aware policy, remove the workers from the tree
|
||||
if let Some(cache_aware) = self
|
||||
.policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
for dp_url in removed_workers.iter() {
|
||||
cache_aware.remove_worker(dp_url);
|
||||
info!("Removed worker from tree: {}", dp_url);
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
|
||||
// If any models are using cache aware policy, remove the workers from the tree
|
||||
// Check each removed worker's model and get its policy
|
||||
for dp_url in removed_workers.iter() {
|
||||
if let Some(worker) = self.worker_registry.get_by_url(dp_url) {
|
||||
let model_id = worker.model_id();
|
||||
if let Some(policy) = self.policy_registry.get_policy(model_id) {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker_by_url(dp_url);
|
||||
info!("Removed worker from cache-aware tree: {}", dp_url);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
|
||||
workers_guard.remove(index);
|
||||
info!("Removed worker: {}", worker_url);
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
// Get the worker first to extract model_id
|
||||
let model_id = if let Some(worker) = self.worker_registry.get_by_url(worker_url) {
|
||||
worker.model_id().to_string()
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", worker_url);
|
||||
return;
|
||||
};
|
||||
|
||||
if self.worker_registry.remove_by_url(worker_url).is_some() {
|
||||
info!("Removed worker: {}", worker_url);
|
||||
|
||||
// Notify PolicyRegistry about the removed worker
|
||||
self.policy_registry.on_worker_removed(&model_id);
|
||||
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
}
|
||||
|
||||
// If cache aware policy, remove the workers from the tree
|
||||
if let Some(cache_aware) = self
|
||||
.policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker(worker_url);
|
||||
info!("Removed worker from tree: {}", worker_url);
|
||||
// If the model is using cache aware policy, remove the worker from the tree
|
||||
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker_by_url(worker_url);
|
||||
info!("Removed worker from cache-aware tree: {}", worker_url);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1171,7 +1236,7 @@ impl RouterTrait for Router {
|
||||
}
|
||||
|
||||
async fn health(&self, _req: Request<Body>) -> Response {
|
||||
let workers = self.workers.read().unwrap();
|
||||
let workers = self.worker_registry.get_all();
|
||||
let unhealthy_servers: Vec<_> = workers
|
||||
.iter()
|
||||
.filter(|w| !w.is_healthy())
|
||||
@@ -1209,16 +1274,19 @@ impl RouterTrait for Router {
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &GenerateRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
self.route_typed_request(headers, body, "/generate").await
|
||||
self.route_typed_request(headers, body, "/generate", model_id)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn route_chat(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ChatCompletionRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
self.route_typed_request(headers, body, "/v1/chat/completions")
|
||||
self.route_typed_request(headers, body, "/v1/chat/completions", model_id)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -1226,8 +1294,9 @@ impl RouterTrait for Router {
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &CompletionRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
self.route_typed_request(headers, body, "/v1/completions")
|
||||
self.route_typed_request(headers, body, "/v1/completions", model_id)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -1235,8 +1304,9 @@ impl RouterTrait for Router {
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ResponsesRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
self.route_typed_request(headers, body, "/v1/responses")
|
||||
self.route_typed_request(headers, body, "/v1/responses", model_id)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -1244,11 +1314,18 @@ impl RouterTrait for Router {
|
||||
todo!()
|
||||
}
|
||||
|
||||
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response {
|
||||
async fn route_rerank(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &RerankRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
if let Err(e) = body.validate() {
|
||||
return (StatusCode::BAD_REQUEST, e).into_response();
|
||||
}
|
||||
let response = self.route_typed_request(headers, body, "/v1/rerank").await;
|
||||
let response = self
|
||||
.route_typed_request(headers, body, "/v1/rerank", model_id)
|
||||
.await;
|
||||
if response.status().is_success() {
|
||||
match Self::build_rerank_response(body, response).await {
|
||||
Ok(rerank_response) => rerank_response,
|
||||
@@ -1340,19 +1417,15 @@ impl RouterTrait for Router {
|
||||
|
||||
fn readiness(&self) -> Response {
|
||||
// Regular router is ready if it has at least one healthy worker
|
||||
let healthy_count = self
|
||||
.workers
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.filter(|w| w.is_healthy())
|
||||
.count();
|
||||
let workers = self.worker_registry.get_all();
|
||||
let healthy_count = workers.iter().filter(|w| w.is_healthy()).count();
|
||||
let total_workers = workers.len();
|
||||
|
||||
if healthy_count > 0 {
|
||||
Json(serde_json::json!({
|
||||
"status": "ready",
|
||||
"healthy_workers": healthy_count,
|
||||
"total_workers": self.workers.read().unwrap().len()
|
||||
"total_workers": total_workers
|
||||
}))
|
||||
.into_response()
|
||||
} else {
|
||||
@@ -1361,7 +1434,7 @@ impl RouterTrait for Router {
|
||||
Json(serde_json::json!({
|
||||
"status": "not_ready",
|
||||
"reason": "no healthy workers available",
|
||||
"total_workers": self.workers.read().unwrap().len()
|
||||
"total_workers": total_workers
|
||||
})),
|
||||
)
|
||||
.into_response()
|
||||
@@ -1372,18 +1445,25 @@ impl RouterTrait for Router {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::policies::RandomPolicy;
|
||||
use std::collections::HashMap;
|
||||
|
||||
fn create_test_regular_router() -> Router {
|
||||
let workers = vec![
|
||||
WorkerFactory::create_regular("http://worker1:8080".to_string()),
|
||||
WorkerFactory::create_regular("http://worker2:8080".to_string()),
|
||||
];
|
||||
// Create registries
|
||||
let worker_registry = Arc::new(WorkerRegistry::new());
|
||||
let policy_registry = Arc::new(PolicyRegistry::new(
|
||||
crate::config::types::PolicyConfig::RoundRobin,
|
||||
));
|
||||
|
||||
// Register test workers
|
||||
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
|
||||
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular);
|
||||
worker_registry.register(Arc::new(worker1));
|
||||
worker_registry.register(Arc::new(worker2));
|
||||
|
||||
let (_, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
Router {
|
||||
workers: Arc::new(RwLock::new(workers)),
|
||||
policy: Arc::new(RandomPolicy::new()),
|
||||
worker_registry,
|
||||
policy_registry,
|
||||
worker_startup_timeout_secs: 5,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
@@ -1393,7 +1473,6 @@ mod tests {
|
||||
circuit_breaker_config: CircuitBreakerConfig::default(),
|
||||
_worker_loads: Arc::new(rx),
|
||||
_load_monitor_handle: None,
|
||||
_health_checker: None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1413,7 +1492,9 @@ mod tests {
|
||||
let result = router.select_first_worker();
|
||||
|
||||
assert!(result.is_ok());
|
||||
assert_eq!(result.unwrap(), "http://worker1:8080");
|
||||
let url = result.unwrap();
|
||||
// DashMap doesn't guarantee order, so just check we get one of the workers
|
||||
assert!(url == "http://worker1:8080" || url == "http://worker2:8080");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -17,6 +17,7 @@ pub mod factory;
|
||||
pub mod grpc;
|
||||
pub mod header_utils;
|
||||
pub mod http;
|
||||
pub mod router_manager;
|
||||
|
||||
pub use factory::RouterFactory;
|
||||
// Re-export HTTP routers for convenience (keeps routers::openai_router path working)
|
||||
@@ -63,14 +64,19 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
||||
async fn get_model_info(&self, req: Request<Body>) -> Response;
|
||||
|
||||
/// Route a generate request
|
||||
async fn route_generate(&self, headers: Option<&HeaderMap>, body: &GenerateRequest)
|
||||
-> Response;
|
||||
async fn route_generate(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &GenerateRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response;
|
||||
|
||||
/// Route a chat completion request
|
||||
async fn route_chat(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ChatCompletionRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response;
|
||||
|
||||
/// Route a completion request
|
||||
@@ -78,6 +84,7 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &CompletionRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response;
|
||||
|
||||
/// Route a responses request
|
||||
@@ -85,11 +92,17 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ResponsesRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response;
|
||||
|
||||
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response;
|
||||
|
||||
async fn route_rerank(&self, headers: Option<&HeaderMap>, body: &RerankRequest) -> Response;
|
||||
async fn route_rerank(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &RerankRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response;
|
||||
|
||||
/// Flush cache on all workers
|
||||
async fn flush_cache(&self) -> Response;
|
||||
|
||||
766
sgl-router/src/routers/router_manager.rs
Normal file
766
sgl-router/src/routers/router_manager.rs
Normal file
@@ -0,0 +1,766 @@
|
||||
//! Router Manager for coordinating multiple routers and workers
|
||||
//!
|
||||
//! Provides centralized management based on enable_igw flag:
|
||||
//! - Single Router Mode (enable_igw=false): Router owns workers directly
|
||||
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
|
||||
|
||||
use crate::config::RouterConfig;
|
||||
use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
|
||||
};
|
||||
use crate::protocols::worker_spec::{
|
||||
ServerInfo, WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse, WorkerInfo,
|
||||
WorkerListResponse, WorkerStats, WorkerTypeStats,
|
||||
};
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
extract::Request,
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::{IntoResponse, Response},
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Router identifier
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct RouterId(String);
|
||||
|
||||
impl RouterId {
|
||||
pub fn new(id: String) -> Self {
|
||||
Self(id)
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
/// Router Manager - Central coordinator for routers and workers
|
||||
/// Only created when enable_igw=true
|
||||
pub struct RouterManager {
|
||||
/// Worker registry (single source of truth in multi-router mode)
|
||||
worker_registry: Arc<WorkerRegistry>,
|
||||
|
||||
/// Policy registry for managing model-to-policy mappings
|
||||
policy_registry: Arc<crate::policies::PolicyRegistry>,
|
||||
|
||||
/// All routers managed by this manager (max 4 routers in Phase 2)
|
||||
/// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd"
|
||||
routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>,
|
||||
|
||||
/// Default router for requests without specific routing
|
||||
default_router: Option<RouterId>,
|
||||
|
||||
/// Model to router mapping for model-aware routing
|
||||
/// Multiple models can be served by the same router
|
||||
model_routers: Arc<DashMap<String, Vec<RouterId>>>,
|
||||
|
||||
/// HTTP client for querying worker info
|
||||
client: reqwest::Client,
|
||||
|
||||
/// Configuration
|
||||
#[allow(dead_code)] // May be used in future enhancements
|
||||
config: RouterConfig,
|
||||
}
|
||||
|
||||
impl RouterManager {
|
||||
/// Create a new router manager with shared registries
|
||||
pub fn new(
|
||||
config: RouterConfig,
|
||||
client: reqwest::Client,
|
||||
worker_registry: Arc<WorkerRegistry>,
|
||||
policy_registry: Arc<crate::policies::PolicyRegistry>,
|
||||
) -> Self {
|
||||
Self {
|
||||
worker_registry,
|
||||
policy_registry,
|
||||
routers: Arc::new(DashMap::new()),
|
||||
default_router: None,
|
||||
model_routers: Arc::new(DashMap::new()),
|
||||
client,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a router with the manager
|
||||
pub fn register_router(
|
||||
&mut self,
|
||||
id: RouterId,
|
||||
router: Arc<dyn RouterTrait>,
|
||||
models: Vec<String>,
|
||||
) {
|
||||
// Store router
|
||||
self.routers.insert(id.clone(), router);
|
||||
|
||||
// Update model mappings
|
||||
for model in models {
|
||||
self.model_routers
|
||||
.entry(model)
|
||||
.or_default()
|
||||
.push(id.clone());
|
||||
}
|
||||
|
||||
// Set as default if first router
|
||||
if self.default_router.is_none() {
|
||||
self.default_router = Some(id.clone());
|
||||
info!("Set default router to {}", id.as_str());
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the default router
|
||||
pub fn set_default_router(&mut self, id: RouterId) {
|
||||
self.default_router = Some(id);
|
||||
}
|
||||
|
||||
/// Get the number of registered routers
|
||||
pub fn router_count(&self) -> usize {
|
||||
self.routers.len()
|
||||
}
|
||||
|
||||
/// Get router for a specific model
|
||||
pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> {
|
||||
// First try model-specific routers
|
||||
if let Some(router_ids) = self.model_routers.get(model_id) {
|
||||
if let Some(router_id) = router_ids.first() {
|
||||
if let Some(router) = self.routers.get(router_id) {
|
||||
return Some(router.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default router
|
||||
if let Some(ref default_id) = self.default_router {
|
||||
self.routers.get(default_id).map(|r| r.clone())
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Get workers for routing decision
|
||||
pub fn get_workers_for_request(&self, model_id: Option<&str>) -> Vec<Arc<dyn Worker>> {
|
||||
if let Some(model) = model_id {
|
||||
self.worker_registry.get_by_model(model)
|
||||
} else {
|
||||
self.worker_registry.get_all()
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a worker to the registry
|
||||
pub async fn add_worker(
|
||||
&self,
|
||||
config: WorkerConfigRequest,
|
||||
) -> Result<WorkerApiResponse, WorkerErrorResponse> {
|
||||
// Build labels from configuration
|
||||
let mut labels = config.labels.clone();
|
||||
|
||||
// Query server info if model_id not provided
|
||||
let model_id = if let Some(model_id) = config.model_id {
|
||||
model_id
|
||||
} else {
|
||||
match self.query_server_info(&config.url).await {
|
||||
Ok(info) => {
|
||||
// Extract model_id from server info
|
||||
info.model_id
|
||||
.or_else(|| {
|
||||
info.model_path
|
||||
.as_ref()
|
||||
.and_then(|path| path.split('/').next_back().map(|s| s.to_string()))
|
||||
})
|
||||
.unwrap_or_else(|| "unknown".to_string())
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to query server info from {}: {}", config.url, e);
|
||||
"unknown".to_string()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Add configuration to labels
|
||||
labels.insert("model_id".to_string(), model_id.clone());
|
||||
|
||||
if let Some(priority) = config.priority {
|
||||
labels.insert("priority".to_string(), priority.to_string());
|
||||
}
|
||||
|
||||
if let Some(cost) = config.cost {
|
||||
labels.insert("cost".to_string(), cost.to_string());
|
||||
}
|
||||
|
||||
// Add gRPC-specific configuration if provided
|
||||
if let Some(tokenizer_path) = config.tokenizer_path {
|
||||
labels.insert("tokenizer_path".to_string(), tokenizer_path);
|
||||
}
|
||||
|
||||
if let Some(reasoning_parser) = config.reasoning_parser {
|
||||
labels.insert("reasoning_parser".to_string(), reasoning_parser);
|
||||
}
|
||||
|
||||
if let Some(tool_parser) = config.tool_parser {
|
||||
labels.insert("tool_parser".to_string(), tool_parser);
|
||||
}
|
||||
|
||||
if let Some(chat_template) = config.chat_template {
|
||||
labels.insert("chat_template".to_string(), chat_template);
|
||||
}
|
||||
|
||||
// Create worker based on type
|
||||
// Note: For prefill and decode workers, we can't easily add labels after creation
|
||||
// since they return Box<dyn Worker>. We'll need to enhance WorkerFactory in the future.
|
||||
let worker = match config.worker_type.as_deref() {
|
||||
Some("prefill") => {
|
||||
// For now, prefill workers won't have custom labels
|
||||
// TODO: Enhance WorkerFactory to accept labels for prefill workers
|
||||
WorkerFactory::create_prefill(config.url.clone(), config.bootstrap_port)
|
||||
}
|
||||
Some("decode") => {
|
||||
// For now, decode workers won't have custom labels
|
||||
// TODO: Enhance WorkerFactory to accept labels for decode workers
|
||||
WorkerFactory::create_decode(config.url.clone())
|
||||
}
|
||||
_ => {
|
||||
// Regular workers can have labels
|
||||
WorkerFactory::create_regular_with_labels(
|
||||
config.url.clone(),
|
||||
labels.clone(),
|
||||
CircuitBreakerConfig::default(),
|
||||
)
|
||||
}
|
||||
};
|
||||
|
||||
// Register worker
|
||||
let worker_id = self.worker_registry.register(Arc::from(worker));
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
// Extract policy hint from labels if provided
|
||||
let policy_hint = labels.get("policy").map(|s| s.as_str());
|
||||
let policy = self.policy_registry.on_worker_added(&model_id, policy_hint);
|
||||
|
||||
info!(
|
||||
"Added worker {} with URL {} for model {} using policy {}",
|
||||
worker_id.as_str(),
|
||||
config.url,
|
||||
model_id,
|
||||
policy.name()
|
||||
);
|
||||
|
||||
// Return worker info
|
||||
let worker_arc = self.worker_registry.get(&worker_id).unwrap();
|
||||
let worker_info = self.worker_to_info(worker_id.as_str(), &worker_arc);
|
||||
|
||||
Ok(WorkerApiResponse {
|
||||
success: true,
|
||||
message: format!("Worker {} added successfully", worker_id.as_str()),
|
||||
worker: Some(worker_info),
|
||||
})
|
||||
}
|
||||
|
||||
/// Remove a worker from the registry
|
||||
pub fn remove_worker_from_registry(
|
||||
&self,
|
||||
url: &str,
|
||||
) -> Result<WorkerApiResponse, WorkerErrorResponse> {
|
||||
// Get worker to extract model_id before removing
|
||||
let model_id = self
|
||||
.worker_registry
|
||||
.get_by_url(url)
|
||||
.map(|worker| worker.model_id().to_string());
|
||||
|
||||
if let Some(_worker) = self.worker_registry.remove_by_url(url) {
|
||||
// Notify PolicyRegistry about worker removal
|
||||
if let Some(model_id) = model_id {
|
||||
self.policy_registry.on_worker_removed(&model_id);
|
||||
info!("Removed worker with URL {} for model {}", url, model_id);
|
||||
} else {
|
||||
info!("Removed worker with URL {}", url);
|
||||
}
|
||||
|
||||
Ok(WorkerApiResponse {
|
||||
success: true,
|
||||
message: format!("Worker {} removed successfully", url),
|
||||
worker: None,
|
||||
})
|
||||
} else {
|
||||
Err(WorkerErrorResponse {
|
||||
error: format!("Worker with URL {} not found", url),
|
||||
code: "WORKER_NOT_FOUND".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// List all workers
|
||||
pub fn list_workers(&self) -> WorkerListResponse {
|
||||
let workers = self.worker_registry.get_all_with_ids();
|
||||
let worker_infos: Vec<WorkerInfo> = workers
|
||||
.iter()
|
||||
.map(|(id, w)| self.worker_to_info(id.as_str(), w))
|
||||
.collect();
|
||||
|
||||
let total = worker_infos.len();
|
||||
|
||||
// Get stats from the worker registry
|
||||
let registry_stats = self.worker_registry.stats();
|
||||
|
||||
// Convert WorkerRegistryStats to WorkerStats
|
||||
let stats = WorkerStats {
|
||||
total_workers: registry_stats.total_workers,
|
||||
healthy_workers: registry_stats.healthy_workers,
|
||||
total_models: registry_stats.total_models,
|
||||
total_load: registry_stats.total_load,
|
||||
by_type: WorkerTypeStats {
|
||||
regular: registry_stats.regular_workers,
|
||||
prefill: registry_stats.prefill_workers,
|
||||
decode: registry_stats.decode_workers,
|
||||
},
|
||||
};
|
||||
|
||||
WorkerListResponse {
|
||||
workers: worker_infos,
|
||||
total,
|
||||
stats,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get worker by URL
|
||||
pub fn get_worker(&self, url: &str) -> Option<WorkerInfo> {
|
||||
self.worker_registry
|
||||
.get_by_url(url)
|
||||
.map(|w| self.worker_to_info("unknown", &w))
|
||||
}
|
||||
|
||||
/// Query server info from a worker URL
|
||||
async fn query_server_info(&self, url: &str) -> Result<ServerInfo, String> {
|
||||
let info_url = format!("{}/get_server_info", url.trim_end_matches('/'));
|
||||
|
||||
match self.client.get(&info_url).send().await {
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
response
|
||||
.json::<ServerInfo>()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse server info: {}", e))
|
||||
} else {
|
||||
Err(format!("Server returned status: {}", response.status()))
|
||||
}
|
||||
}
|
||||
Err(e) => Err(format!("Failed to connect to server: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert Worker to WorkerInfo
|
||||
fn worker_to_info(&self, id: &str, worker: &Arc<dyn Worker>) -> WorkerInfo {
|
||||
let metadata = worker.metadata();
|
||||
|
||||
WorkerInfo {
|
||||
id: id.to_string(),
|
||||
url: worker.url().to_string(),
|
||||
model_id: worker.model_id().to_string(),
|
||||
priority: worker.priority(),
|
||||
cost: worker.cost(),
|
||||
worker_type: format!("{:?}", worker.worker_type()),
|
||||
is_healthy: worker.is_healthy(),
|
||||
load: worker.load(),
|
||||
connection_mode: format!("{:?}", worker.connection_mode()),
|
||||
tokenizer_path: worker.tokenizer_path().map(|s| s.to_string()),
|
||||
reasoning_parser: worker.reasoning_parser().map(|s| s.to_string()),
|
||||
tool_parser: worker.tool_parser().map(|s| s.to_string()),
|
||||
chat_template: worker.chat_template().map(|s| s.to_string()),
|
||||
metadata: metadata.labels.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
// Note: calculate_stats removed - using WorkerRegistry::stats() instead
|
||||
|
||||
// === Phase 2: Router Management ===
|
||||
// Note: Dynamic router creation removed - routers are created and registered externally
|
||||
|
||||
/// Get the appropriate router for a request based on headers and request content
|
||||
pub fn select_router_for_request(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
model_id: Option<&str>,
|
||||
) -> Option<Arc<dyn RouterTrait>> {
|
||||
// Extract priority and cost preferences from headers if available
|
||||
let _priority_threshold = headers.and_then(|h| {
|
||||
h.get("x-worker-priority")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<u32>().ok())
|
||||
});
|
||||
|
||||
let _max_cost = headers.and_then(|h| {
|
||||
h.get("x-max-cost")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.and_then(|s| s.parse::<f32>().ok())
|
||||
});
|
||||
|
||||
// Check if PD (prefill-decode) mode is preferred from headers
|
||||
let prefer_pd = headers
|
||||
.and_then(|h| {
|
||||
h.get("x-prefer-pd")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(|s| s == "true" || s == "1")
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
// If model specified, find routers serving that model
|
||||
let candidate_routers = if let Some(model) = model_id {
|
||||
// Get routers for specific model
|
||||
if let Some(router_ids) = self.model_routers.get(model) {
|
||||
router_ids
|
||||
.iter()
|
||||
.filter_map(|id| self.routers.get(id).map(|r| r.clone()))
|
||||
.collect::<Vec<_>>()
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
} else {
|
||||
// No model specified, consider all routers
|
||||
self.routers
|
||||
.iter()
|
||||
.map(|entry| entry.value().clone())
|
||||
.collect::<Vec<_>>()
|
||||
};
|
||||
|
||||
if candidate_routers.is_empty() {
|
||||
// No routers found for the specified model
|
||||
return None;
|
||||
}
|
||||
|
||||
// Score routers based on worker attributes and request preferences
|
||||
let mut best_router = None;
|
||||
let mut best_score = 0.0;
|
||||
|
||||
for router in candidate_routers {
|
||||
let mut score = 1.0;
|
||||
|
||||
// Check if this is a PD router
|
||||
let is_pd = router.is_pd_mode();
|
||||
if prefer_pd && is_pd {
|
||||
score += 2.0; // Bonus for matching PD preference
|
||||
} else if !prefer_pd && !is_pd {
|
||||
score += 1.0; // Bonus for matching regular preference
|
||||
}
|
||||
|
||||
// Get workers for this router and evaluate based on priority/cost
|
||||
// Note: This would require routers to expose their workers or stats
|
||||
// For now, we'll use a simple selection based on router type
|
||||
|
||||
// TODO: Once routers expose worker stats, we can evaluate:
|
||||
// - Average worker priority vs priority_threshold
|
||||
// - Average worker cost vs max_cost
|
||||
// - Current load and health status
|
||||
|
||||
if score > best_score {
|
||||
best_score = score;
|
||||
best_router = Some(router);
|
||||
}
|
||||
}
|
||||
|
||||
best_router
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Default implementation removed as RouterManager now requires AppContext
|
||||
// which cannot be defaulted. RouterManager must be created with explicit context.
|
||||
|
||||
// === Phase 2: RouterManager as RouterTrait ===
|
||||
|
||||
/// RouterManager implements RouterTrait to act as a meta-router
|
||||
/// that delegates requests to the appropriate underlying router
|
||||
#[async_trait]
|
||||
impl WorkerManagement for RouterManager {
|
||||
/// Add a worker - in multi-router mode, this adds to the registry
|
||||
async fn add_worker(&self, worker_url: &str) -> Result<String, String> {
|
||||
// Create a basic worker config request
|
||||
let config = WorkerConfigRequest {
|
||||
url: worker_url.to_string(),
|
||||
model_id: None,
|
||||
worker_type: None,
|
||||
priority: None,
|
||||
cost: None,
|
||||
labels: std::collections::HashMap::new(),
|
||||
bootstrap_port: None,
|
||||
tokenizer_path: None,
|
||||
reasoning_parser: None,
|
||||
tool_parser: None,
|
||||
chat_template: None,
|
||||
};
|
||||
|
||||
match self.add_worker(config).await {
|
||||
Ok(response) => Ok(response.message),
|
||||
Err(e) => Err(e.error),
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker from the registry
|
||||
fn remove_worker(&self, worker_url: &str) {
|
||||
let _ = self.remove_worker_from_registry(worker_url);
|
||||
}
|
||||
|
||||
/// Get all worker URLs from the registry
|
||||
fn get_worker_urls(&self) -> Vec<String> {
|
||||
self.worker_registry.get_all_urls()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RouterTrait for RouterManager {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
self
|
||||
}
|
||||
|
||||
/// Health check - return 503 if no routers available
|
||||
async fn health(&self, _req: Request<Body>) -> Response {
|
||||
// Health check should succeed if RouterManager exists, even without routers
|
||||
// Individual router health can be checked via specific endpoints
|
||||
(StatusCode::OK, "RouterManager is healthy").into_response()
|
||||
}
|
||||
|
||||
/// Health generate - check if any router can handle generate requests
|
||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||
// Return 503 since we have no routers with workers
|
||||
// TODO: Should check if any router has healthy workers
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"No routers with healthy workers available",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Get server information - aggregate from all routers
|
||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||
// TODO: Aggregate info from all routers with healthy workers
|
||||
// For now, return basic info about the RouterManager
|
||||
(
|
||||
StatusCode::OK,
|
||||
serde_json::json!({
|
||||
"router_manager": true,
|
||||
"routers_count": self.routers.len(),
|
||||
"workers_count": self.worker_registry.get_all().len()
|
||||
})
|
||||
.to_string(),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Get available models - aggregate from all routers
|
||||
async fn get_models(&self, _req: Request<Body>) -> Response {
|
||||
// Return models that have registered routers
|
||||
let models = self
|
||||
.model_routers
|
||||
.iter()
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if models.is_empty() {
|
||||
(StatusCode::SERVICE_UNAVAILABLE, "No models available").into_response()
|
||||
} else {
|
||||
(
|
||||
StatusCode::OK,
|
||||
serde_json::json!({
|
||||
"models": models
|
||||
})
|
||||
.to_string(),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get model information
|
||||
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||
// TODO: Extract model from request and route to appropriate router
|
||||
// For now, return not implemented
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Model info endpoint not yet implemented in RouterManager",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Route a generate request
|
||||
async fn route_generate(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &GenerateRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Select router based on headers
|
||||
// GenerateRequest doesn't have a model field
|
||||
let router = self.select_router_for_request(headers, None);
|
||||
|
||||
if let Some(router) = router {
|
||||
// In multi-model mode, pass None since GenerateRequest doesn't have model field
|
||||
router.route_generate(headers, body, None).await
|
||||
} else {
|
||||
// Return 404 when no router is available for the request
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
"No router available for this request",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// Route a chat completion request
|
||||
async fn route_chat(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ChatCompletionRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Select router based on headers and model
|
||||
let router = self.select_router_for_request(headers, Some(&body.model));
|
||||
|
||||
if let Some(router) = router {
|
||||
// In multi-model mode, pass the model_id to the router
|
||||
router.route_chat(headers, body, Some(&body.model)).await
|
||||
} else {
|
||||
// Return 404 when the specified model is not found
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Model '{}' not found or no router available", body.model),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// Route a completion request
|
||||
async fn route_completion(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &CompletionRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Select router based on headers and model
|
||||
let router = self.select_router_for_request(headers, Some(&body.model));
|
||||
|
||||
if let Some(router) = router {
|
||||
// In multi-model mode, pass the model_id to the router
|
||||
router
|
||||
.route_completion(headers, body, Some(&body.model))
|
||||
.await
|
||||
} else {
|
||||
// Return 404 when the specified model is not found
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Model '{}' not found or no router available", body.model),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
async fn route_responses(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_body: &ResponsesRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
todo!()
|
||||
}
|
||||
|
||||
/// Route embeddings request
|
||||
async fn route_embeddings(&self, headers: Option<&HeaderMap>, body: Body) -> Response {
|
||||
// Try to select a router based on headers
|
||||
let router = self.select_router_for_request(headers, None);
|
||||
|
||||
if let Some(router) = router {
|
||||
router.route_embeddings(headers, body).await
|
||||
} else {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
"No router available for embeddings request",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// Route rerank request
|
||||
async fn route_rerank(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &RerankRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Try to select a router based on headers
|
||||
let router = self.select_router_for_request(headers, None);
|
||||
|
||||
if let Some(router) = router {
|
||||
router.route_rerank(headers, body, model_id).await
|
||||
} else {
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
"No router available for rerank request",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// Flush cache on all routers and workers
|
||||
async fn flush_cache(&self) -> Response {
|
||||
// TODO: Call flush_cache on all routers that have workers
|
||||
// For now, return success if we have any routers
|
||||
if self.routers.is_empty() {
|
||||
(StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response()
|
||||
} else {
|
||||
// TODO: Actually flush cache on all routers
|
||||
(StatusCode::OK, "Cache flush requested").into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get worker loads from all routers
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
// Return worker loads from the registry
|
||||
let workers = self.worker_registry.get_all();
|
||||
let loads: Vec<serde_json::Value> = workers
|
||||
.iter()
|
||||
.map(|w| {
|
||||
serde_json::json!({
|
||||
"url": w.url(),
|
||||
"model": w.model_id(),
|
||||
"load": w.load(),
|
||||
"is_healthy": w.is_healthy()
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
(
|
||||
StatusCode::OK,
|
||||
serde_json::json!({
|
||||
"workers": loads
|
||||
})
|
||||
.to_string(),
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Get router type name
|
||||
fn router_type(&self) -> &'static str {
|
||||
"manager"
|
||||
}
|
||||
|
||||
/// Server readiness check - check if any router is ready
|
||||
fn readiness(&self) -> Response {
|
||||
if self.routers.is_empty() {
|
||||
(StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response()
|
||||
} else {
|
||||
// TODO: Check readiness of all routers
|
||||
(StatusCode::OK, "Ready").into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Note: get_first_available_router removed - we now properly handle
|
||||
// router selection based on model and worker availability
|
||||
|
||||
impl std::fmt::Debug for RouterManager {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RouterManager")
|
||||
.field("routers_count", &self.routers.len())
|
||||
.field("workers_count", &self.worker_registry.get_all().len())
|
||||
.field("default_router", &self.default_router)
|
||||
.finish()
|
||||
}
|
||||
}
|
||||
@@ -1,12 +1,16 @@
|
||||
use crate::config::RouterConfig;
|
||||
use crate::core::WorkerRegistry;
|
||||
use crate::logging::{self, LoggingConfig};
|
||||
use crate::metrics::{self, PrometheusConfig};
|
||||
use crate::middleware::TokenBucket;
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
|
||||
V1RerankReqInput,
|
||||
};
|
||||
use crate::protocols::worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse};
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::routers::router_manager::{RouterId, RouterManager};
|
||||
use crate::routers::{RouterFactory, RouterTrait};
|
||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||
use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer};
|
||||
@@ -36,6 +40,9 @@ pub struct AppContext {
|
||||
pub tokenizer: Option<Arc<dyn Tokenizer>>,
|
||||
pub reasoning_parser_factory: Option<ParserFactory>,
|
||||
pub tool_parser_registry: Option<&'static ParserRegistry>,
|
||||
pub worker_registry: Arc<WorkerRegistry>, // Shared worker registry
|
||||
pub policy_registry: Arc<PolicyRegistry>, // Shared policy registry
|
||||
pub router_manager: Option<Arc<RouterManager>>, // Only present when enable_igw=true
|
||||
}
|
||||
|
||||
impl AppContext {
|
||||
@@ -75,6 +82,15 @@ impl AppContext {
|
||||
(None, None, None)
|
||||
};
|
||||
|
||||
// Initialize shared registries
|
||||
let worker_registry = Arc::new(WorkerRegistry::new());
|
||||
let policy_registry = Arc::new(PolicyRegistry::new(
|
||||
router_config.policy.clone(), // Use default policy from config
|
||||
));
|
||||
|
||||
// Initialize RouterManager only when enable_igw is true
|
||||
let router_manager = None; // Will be initialized in startup() based on config
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
router_config,
|
||||
@@ -82,6 +98,9 @@ impl AppContext {
|
||||
tokenizer,
|
||||
reasoning_parser_factory,
|
||||
tool_parser_registry,
|
||||
worker_registry,
|
||||
policy_registry,
|
||||
router_manager,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -134,7 +153,10 @@ async fn generate(
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<GenerateRequest>,
|
||||
) -> Response {
|
||||
state.router.route_generate(Some(&headers), &body).await
|
||||
state
|
||||
.router
|
||||
.route_generate(Some(&headers), &body, None)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn v1_chat_completions(
|
||||
@@ -142,7 +164,7 @@ async fn v1_chat_completions(
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<ChatCompletionRequest>,
|
||||
) -> Response {
|
||||
state.router.route_chat(Some(&headers), &body).await
|
||||
state.router.route_chat(Some(&headers), &body, None).await
|
||||
}
|
||||
|
||||
async fn v1_completions(
|
||||
@@ -150,7 +172,10 @@ async fn v1_completions(
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<CompletionRequest>,
|
||||
) -> Response {
|
||||
state.router.route_completion(Some(&headers), &body).await
|
||||
state
|
||||
.router
|
||||
.route_completion(Some(&headers), &body, None)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn rerank(
|
||||
@@ -158,7 +183,7 @@ async fn rerank(
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<RerankRequest>,
|
||||
) -> Response {
|
||||
state.router.route_rerank(Some(&headers), &body).await
|
||||
state.router.route_rerank(Some(&headers), &body, None).await
|
||||
}
|
||||
|
||||
async fn v1_rerank(
|
||||
@@ -168,7 +193,7 @@ async fn v1_rerank(
|
||||
) -> Response {
|
||||
state
|
||||
.router
|
||||
.route_rerank(Some(&headers), &body.into())
|
||||
.route_rerank(Some(&headers), &body.into(), None)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -177,7 +202,10 @@ async fn v1_responses(
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<ResponsesRequest>,
|
||||
) -> Response {
|
||||
state.router.route_responses(Some(&headers), &body).await
|
||||
state
|
||||
.router
|
||||
.route_responses(Some(&headers), &body, None)
|
||||
.await
|
||||
}
|
||||
|
||||
// Worker management endpoints
|
||||
@@ -232,6 +260,137 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons
|
||||
state.router.get_worker_loads().await
|
||||
}
|
||||
|
||||
// New RESTful worker management endpoints (when enable_igw=true)
|
||||
|
||||
/// POST /workers - Add a new worker with full configuration
|
||||
async fn create_worker(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(config): Json<WorkerConfigRequest>,
|
||||
) -> Response {
|
||||
// Check if RouterManager is available (enable_igw=true)
|
||||
if let Some(router_manager) = &state.context.router_manager {
|
||||
match router_manager.add_worker(config).await {
|
||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
|
||||
}
|
||||
} else {
|
||||
// In single router mode, use the router's add_worker with basic config
|
||||
match state.router.add_worker(&config.url).await {
|
||||
Ok(message) => {
|
||||
let response = WorkerApiResponse {
|
||||
success: true,
|
||||
message,
|
||||
worker: None,
|
||||
};
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
Err(error) => {
|
||||
let error_response = WorkerErrorResponse {
|
||||
error,
|
||||
code: "ADD_WORKER_FAILED".to_string(),
|
||||
};
|
||||
(StatusCode::BAD_REQUEST, Json(error_response)).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// GET /workers - List all workers with details
|
||||
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
||||
if let Some(router_manager) = &state.context.router_manager {
|
||||
let response = router_manager.list_workers();
|
||||
Json(response).into_response()
|
||||
} else {
|
||||
// In single router mode, get detailed worker info from registry
|
||||
let workers = state.context.worker_registry.get_all();
|
||||
let response = serde_json::json!({
|
||||
"workers": workers.iter().map(|worker| {
|
||||
let mut worker_info = serde_json::json!({
|
||||
"url": worker.url(),
|
||||
"model_id": worker.model_id(),
|
||||
"worker_type": format!("{:?}", worker.worker_type()),
|
||||
"is_healthy": worker.is_healthy(),
|
||||
"load": worker.load(),
|
||||
"connection_mode": format!("{:?}", worker.connection_mode()),
|
||||
"priority": worker.priority(),
|
||||
"cost": worker.cost(),
|
||||
});
|
||||
|
||||
// Add bootstrap_port for Prefill workers
|
||||
if let crate::core::WorkerType::Prefill { bootstrap_port } = worker.worker_type() {
|
||||
worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port);
|
||||
}
|
||||
|
||||
worker_info
|
||||
}).collect::<Vec<_>>(),
|
||||
"total": workers.len(),
|
||||
"stats": {
|
||||
"prefill_count": state.context.worker_registry.get_prefill_workers().len(),
|
||||
"decode_count": state.context.worker_registry.get_decode_workers().len(),
|
||||
"regular_count": state.context.worker_registry.get_by_type(&crate::core::WorkerType::Regular).len(),
|
||||
}
|
||||
});
|
||||
Json(response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// GET /workers/{url} - Get specific worker info
|
||||
async fn get_worker(
|
||||
State(state): State<Arc<AppState>>,
|
||||
axum::extract::Path(url): axum::extract::Path<String>,
|
||||
) -> Response {
|
||||
if let Some(router_manager) = &state.context.router_manager {
|
||||
if let Some(worker) = router_manager.get_worker(&url) {
|
||||
Json(worker).into_response()
|
||||
} else {
|
||||
let error = WorkerErrorResponse {
|
||||
error: format!("Worker {} not found", url),
|
||||
code: "WORKER_NOT_FOUND".to_string(),
|
||||
};
|
||||
(StatusCode::NOT_FOUND, Json(error)).into_response()
|
||||
}
|
||||
} else {
|
||||
// In single router mode, check if worker exists
|
||||
let workers = state.router.get_worker_urls();
|
||||
if workers.contains(&url) {
|
||||
let worker_info = serde_json::json!({
|
||||
"url": url,
|
||||
"model_id": "unknown",
|
||||
"is_healthy": true
|
||||
});
|
||||
Json(worker_info).into_response()
|
||||
} else {
|
||||
let error = WorkerErrorResponse {
|
||||
error: format!("Worker {} not found", url),
|
||||
code: "WORKER_NOT_FOUND".to_string(),
|
||||
};
|
||||
(StatusCode::NOT_FOUND, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// DELETE /workers/{url} - Remove a worker
|
||||
async fn delete_worker(
|
||||
State(state): State<Arc<AppState>>,
|
||||
axum::extract::Path(url): axum::extract::Path<String>,
|
||||
) -> Response {
|
||||
if let Some(router_manager) = &state.context.router_manager {
|
||||
match router_manager.remove_worker_from_registry(&url) {
|
||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
|
||||
}
|
||||
} else {
|
||||
// In single router mode, use router's remove_worker
|
||||
state.router.remove_worker(&url);
|
||||
let response = WorkerApiResponse {
|
||||
success: true,
|
||||
message: format!("Worker {} removed successfully", url),
|
||||
worker: None,
|
||||
};
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ServerConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
@@ -281,11 +440,19 @@ pub fn build_app(
|
||||
.route("/flush_cache", post(flush_cache))
|
||||
.route("/get_loads", get(get_loads));
|
||||
|
||||
// Worker management routes
|
||||
let worker_routes = Router::new()
|
||||
.route("/workers", post(create_worker))
|
||||
.route("/workers", get(list_workers_rest))
|
||||
.route("/workers/{url}", get(get_worker))
|
||||
.route("/workers/{url}", axum::routing::delete(delete_worker));
|
||||
|
||||
// Build app with all routes and middleware
|
||||
Router::new()
|
||||
.merge(protected_routes)
|
||||
.merge(public_routes)
|
||||
.merge(admin_routes)
|
||||
.merge(worker_routes)
|
||||
// Request body size limiting
|
||||
.layer(tower_http::limit::RequestBodyLimitLayer::new(
|
||||
max_payload_size,
|
||||
@@ -355,15 +522,100 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
// Create the application context with all dependencies
|
||||
let app_context = Arc::new(AppContext::new(
|
||||
let app_context = AppContext::new(
|
||||
config.router_config.clone(),
|
||||
client.clone(),
|
||||
config.router_config.max_concurrent_requests,
|
||||
config.router_config.rate_limit_tokens_per_second,
|
||||
)?);
|
||||
)?;
|
||||
|
||||
// Create router with the context
|
||||
let router = RouterFactory::create_router(&app_context).await?;
|
||||
let app_context = Arc::new(app_context);
|
||||
|
||||
// Create the appropriate router based on enable_igw flag
|
||||
let router: Box<dyn RouterTrait> = if config.router_config.enable_igw {
|
||||
info!("Multi-router mode enabled (enable_igw=true)");
|
||||
|
||||
// Create RouterManager with shared registries from AppContext
|
||||
let mut router_manager = RouterManager::new(
|
||||
config.router_config.clone(),
|
||||
client.clone(),
|
||||
app_context.worker_registry.clone(),
|
||||
app_context.policy_registry.clone(),
|
||||
);
|
||||
|
||||
// Create HTTP routers at startup (with empty worker lists)
|
||||
// Workers will be added to these routers dynamically via RouterManager's worker registry
|
||||
|
||||
// 1. HTTP Regular Router
|
||||
match RouterFactory::create_regular_router(
|
||||
&[], // Empty worker list - workers added later
|
||||
&app_context,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(http_regular) => {
|
||||
info!("Created HTTP Regular router");
|
||||
router_manager.register_router(
|
||||
RouterId::new("http-regular".to_string()),
|
||||
Arc::from(http_regular),
|
||||
vec![], // Models will be determined by workers
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create HTTP Regular router: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. HTTP PD Router
|
||||
match RouterFactory::create_pd_router(
|
||||
&[], // Empty prefill URLs
|
||||
&[], // Empty decode URLs
|
||||
None, // Use default prefill policy
|
||||
None, // Use default decode policy
|
||||
&config.router_config.policy,
|
||||
&app_context,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(http_pd) => {
|
||||
info!("Created HTTP PD router");
|
||||
router_manager.register_router(
|
||||
RouterId::new("http-pd".to_string()),
|
||||
Arc::from(http_pd),
|
||||
vec![],
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create HTTP PD router: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add gRPC routers once we have dynamic tokenizer loading
|
||||
// Currently gRPC routers require tokenizer to be initialized first,
|
||||
// but each model needs its own tokenizer. Once we implement dynamic
|
||||
// tokenizer loading per model, we can enable gRPC routers here:
|
||||
// - RouterType::GrpcRegular (RouterId: "grpc-regular")
|
||||
// - RouterType::GrpcPd (RouterId: "grpc-pd")
|
||||
|
||||
info!(
|
||||
"RouterManager initialized with {} routers",
|
||||
router_manager.router_count()
|
||||
);
|
||||
Box::new(router_manager)
|
||||
} else {
|
||||
info!("Single router mode (enable_igw=false)");
|
||||
// Create single router with the context
|
||||
RouterFactory::create_router(&app_context).await?
|
||||
};
|
||||
|
||||
// Start health checker for all workers in the registry
|
||||
let _health_checker = app_context
|
||||
.worker_registry
|
||||
.start_health_checker(config.router_config.health_check.check_interval_secs);
|
||||
info!(
|
||||
"Started health checker for workers with {}s interval",
|
||||
config.router_config.health_check.check_interval_secs
|
||||
);
|
||||
|
||||
// Set up concurrency limiter with queue if configured
|
||||
let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new(
|
||||
|
||||
@@ -579,9 +579,8 @@ mod tests {
|
||||
|
||||
// Helper to create a Router instance for testing event handlers
|
||||
async fn create_test_router() -> Arc<dyn RouterTrait> {
|
||||
use crate::config::{PolicyConfig, RouterConfig};
|
||||
use crate::config::RouterConfig;
|
||||
use crate::middleware::TokenBucket;
|
||||
use crate::policies::PolicyFactory;
|
||||
use crate::routers::http::router::Router;
|
||||
use crate::server::AppContext;
|
||||
|
||||
@@ -591,15 +590,19 @@ mod tests {
|
||||
// Create AppContext with minimal components
|
||||
let app_context = Arc::new(AppContext {
|
||||
client: reqwest::Client::new(),
|
||||
router_config,
|
||||
router_config: router_config.clone(),
|
||||
rate_limiter: Arc::new(TokenBucket::new(1000, 1000)),
|
||||
worker_registry: Arc::new(crate::core::WorkerRegistry::new()),
|
||||
policy_registry: Arc::new(crate::policies::PolicyRegistry::new(
|
||||
router_config.policy.clone(),
|
||||
)),
|
||||
tokenizer: None, // HTTP mode doesn't need tokenizer
|
||||
reasoning_parser_factory: None, // HTTP mode doesn't need reasoning parser
|
||||
tool_parser_registry: None, // HTTP mode doesn't need tool parser
|
||||
router_manager: None, // Test doesn't need router manager
|
||||
});
|
||||
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
||||
let router = Router::new(vec![], policy, &app_context).await.unwrap();
|
||||
let router = Router::new(vec![], &app_context).await.unwrap();
|
||||
Arc::new(router) as Arc<dyn RouterTrait>
|
||||
}
|
||||
|
||||
|
||||
129
sgl-router/tests/cache_aware_backward_compat_test.rs
Normal file
129
sgl-router/tests/cache_aware_backward_compat_test.rs
Normal file
@@ -0,0 +1,129 @@
|
||||
use sglang_router_rs::core::{BasicWorker, Worker, WorkerType};
|
||||
use sglang_router_rs::policies::{CacheAwareConfig, CacheAwarePolicy, LoadBalancingPolicy};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn test_backward_compatibility_with_empty_model_id() {
|
||||
let config = CacheAwareConfig {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 2,
|
||||
balance_rel_threshold: 1.5,
|
||||
eviction_interval_secs: 0, // Disable background eviction for testing
|
||||
max_tree_size: 100,
|
||||
};
|
||||
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
|
||||
// Create workers with empty model_id (simulating existing routers)
|
||||
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
|
||||
// No model_id label - should default to "unknown"
|
||||
|
||||
let mut labels2 = HashMap::new();
|
||||
labels2.insert("model_id".to_string(), "unknown".to_string());
|
||||
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular)
|
||||
.with_labels(labels2);
|
||||
|
||||
// Add workers - should both go to "default" tree
|
||||
policy.add_worker(&worker1);
|
||||
policy.add_worker(&worker2);
|
||||
|
||||
// Create worker list
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker1.clone()), Arc::new(worker2.clone())];
|
||||
|
||||
// Select worker - should work without errors
|
||||
let selected = policy.select_worker(&workers, Some("test request"));
|
||||
assert!(selected.is_some(), "Should select a worker");
|
||||
|
||||
// Remove workers - should work without errors
|
||||
policy.remove_worker(&worker1);
|
||||
policy.remove_worker(&worker2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_mixed_model_ids() {
|
||||
let config = CacheAwareConfig {
|
||||
cache_threshold: 0.5,
|
||||
balance_abs_threshold: 2,
|
||||
balance_rel_threshold: 1.5,
|
||||
eviction_interval_secs: 0,
|
||||
max_tree_size: 100,
|
||||
};
|
||||
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
|
||||
// Create workers with different model_id scenarios
|
||||
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular);
|
||||
// No model_id label - defaults to "unknown" which goes to "default" tree
|
||||
|
||||
let mut labels2 = HashMap::new();
|
||||
labels2.insert("model_id".to_string(), "llama-3".to_string());
|
||||
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular)
|
||||
.with_labels(labels2);
|
||||
|
||||
let mut labels3 = HashMap::new();
|
||||
labels3.insert("model_id".to_string(), "unknown".to_string());
|
||||
let worker3 = BasicWorker::new("http://worker3:8080".to_string(), WorkerType::Regular)
|
||||
.with_labels(labels3);
|
||||
|
||||
let mut labels4 = HashMap::new();
|
||||
labels4.insert("model_id".to_string(), "llama-3".to_string());
|
||||
let worker4 = BasicWorker::new("http://worker4:8080".to_string(), WorkerType::Regular)
|
||||
.with_labels(labels4);
|
||||
|
||||
// Add all workers
|
||||
policy.add_worker(&worker1);
|
||||
policy.add_worker(&worker2);
|
||||
policy.add_worker(&worker3);
|
||||
policy.add_worker(&worker4);
|
||||
|
||||
// Test selection with default workers only
|
||||
let default_workers: Vec<Arc<dyn Worker>> =
|
||||
vec![Arc::new(worker1.clone()), Arc::new(worker3.clone())];
|
||||
let selected = policy.select_worker(&default_workers, Some("test request"));
|
||||
assert!(selected.is_some(), "Should select from default workers");
|
||||
|
||||
// Test selection with specific model workers only
|
||||
let llama_workers: Vec<Arc<dyn Worker>> =
|
||||
vec![Arc::new(worker2.clone()), Arc::new(worker4.clone())];
|
||||
let selected = policy.select_worker(&llama_workers, Some("test request"));
|
||||
assert!(selected.is_some(), "Should select from llama-3 workers");
|
||||
|
||||
// Test selection with mixed workers
|
||||
let all_workers: Vec<Arc<dyn Worker>> = vec![
|
||||
Arc::new(worker1.clone()),
|
||||
Arc::new(worker2.clone()),
|
||||
Arc::new(worker3.clone()),
|
||||
Arc::new(worker4.clone()),
|
||||
];
|
||||
let selected = policy.select_worker(&all_workers, Some("test request"));
|
||||
assert!(selected.is_some(), "Should select from all workers");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_remove_worker_by_url_backward_compat() {
|
||||
let config = CacheAwareConfig::default();
|
||||
let policy = CacheAwarePolicy::with_config(config);
|
||||
|
||||
// Create workers with different model_ids
|
||||
let mut labels1 = HashMap::new();
|
||||
labels1.insert("model_id".to_string(), "llama-3".to_string());
|
||||
let worker1 = BasicWorker::new("http://worker1:8080".to_string(), WorkerType::Regular)
|
||||
.with_labels(labels1);
|
||||
|
||||
let worker2 = BasicWorker::new("http://worker2:8080".to_string(), WorkerType::Regular);
|
||||
// No model_id label - defaults to "unknown"
|
||||
|
||||
// Add workers
|
||||
policy.add_worker(&worker1);
|
||||
policy.add_worker(&worker2);
|
||||
|
||||
// Remove by URL (backward compatibility method)
|
||||
// Should remove from all trees since we don't know the model
|
||||
policy.remove_worker_by_url("http://worker1:8080");
|
||||
|
||||
// Verify removal worked
|
||||
let workers: Vec<Arc<dyn Worker>> = vec![Arc::new(worker2.clone())];
|
||||
let selected = policy.select_worker(&workers, Some("test"));
|
||||
assert_eq!(selected, Some(0), "Should only have worker2 left");
|
||||
}
|
||||
168
sgl-router/tests/policy_registry_integration.rs
Normal file
168
sgl-router/tests/policy_registry_integration.rs
Normal file
@@ -0,0 +1,168 @@
|
||||
//! Integration tests for PolicyRegistry with RouterManager
|
||||
|
||||
use sglang_router_rs::config::{PolicyConfig, RouterConfig};
|
||||
use sglang_router_rs::core::WorkerRegistry;
|
||||
use sglang_router_rs::policies::PolicyRegistry;
|
||||
use sglang_router_rs::protocols::worker_spec::WorkerConfigRequest;
|
||||
use sglang_router_rs::routers::router_manager::RouterManager;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_policy_registry_with_router_manager() {
|
||||
// Create RouterConfig
|
||||
let config = RouterConfig {
|
||||
enable_igw: true,
|
||||
policy: PolicyConfig::RoundRobin,
|
||||
..Default::default()
|
||||
};
|
||||
|
||||
// Create HTTP client
|
||||
let client = reqwest::Client::new();
|
||||
|
||||
// Create shared registries
|
||||
let worker_registry = Arc::new(WorkerRegistry::new());
|
||||
let policy_registry = Arc::new(PolicyRegistry::new(PolicyConfig::RoundRobin));
|
||||
|
||||
// Create RouterManager with shared registries
|
||||
let _router_manager = RouterManager::new(
|
||||
config,
|
||||
client,
|
||||
worker_registry.clone(),
|
||||
policy_registry.clone(),
|
||||
);
|
||||
|
||||
// Test adding workers with different models and policies
|
||||
|
||||
// Add first worker for llama-3 with cache_aware policy hint
|
||||
let mut labels1 = HashMap::new();
|
||||
labels1.insert("policy".to_string(), "cache_aware".to_string());
|
||||
|
||||
let _worker1_config = WorkerConfigRequest {
|
||||
url: "http://worker1:8000".to_string(),
|
||||
model_id: Some("llama-3".to_string()),
|
||||
worker_type: None,
|
||||
priority: None,
|
||||
cost: None,
|
||||
labels: labels1,
|
||||
bootstrap_port: None,
|
||||
tokenizer_path: None,
|
||||
reasoning_parser: None,
|
||||
tool_parser: None,
|
||||
chat_template: None,
|
||||
};
|
||||
|
||||
// This would normally connect to a real worker, but for testing we'll just verify the structure
|
||||
// In a real test, we'd need to mock the worker or use a test server
|
||||
|
||||
// Verify PolicyRegistry has the correct policy for llama-3
|
||||
let _llama_policy = policy_registry.get_policy("llama-3");
|
||||
// After first worker is added, llama-3 should have a policy
|
||||
|
||||
// Add second worker for llama-3 with different policy hint (should be ignored)
|
||||
let mut labels2 = HashMap::new();
|
||||
labels2.insert("policy".to_string(), "random".to_string());
|
||||
|
||||
let _worker2_config = WorkerConfigRequest {
|
||||
url: "http://worker2:8000".to_string(),
|
||||
model_id: Some("llama-3".to_string()),
|
||||
worker_type: None,
|
||||
priority: None,
|
||||
cost: None,
|
||||
labels: labels2,
|
||||
bootstrap_port: None,
|
||||
tokenizer_path: None,
|
||||
reasoning_parser: None,
|
||||
tool_parser: None,
|
||||
chat_template: None,
|
||||
};
|
||||
|
||||
// The second worker should use the same policy as the first (cache_aware)
|
||||
|
||||
// Add worker for different model (gpt-4) with random policy
|
||||
let mut labels3 = HashMap::new();
|
||||
labels3.insert("policy".to_string(), "random".to_string());
|
||||
|
||||
let _worker3_config = WorkerConfigRequest {
|
||||
url: "http://worker3:8000".to_string(),
|
||||
model_id: Some("gpt-4".to_string()),
|
||||
worker_type: None,
|
||||
priority: None,
|
||||
cost: None,
|
||||
labels: labels3,
|
||||
bootstrap_port: None,
|
||||
tokenizer_path: None,
|
||||
reasoning_parser: None,
|
||||
tool_parser: None,
|
||||
chat_template: None,
|
||||
};
|
||||
|
||||
// Verify gpt-4 has random policy
|
||||
let _gpt_policy = policy_registry.get_policy("gpt-4");
|
||||
|
||||
// Test removing workers
|
||||
// When we remove both llama-3 workers, the policy should be cleaned up
|
||||
|
||||
println!("PolicyRegistry integration test structure created");
|
||||
println!("Note: This test requires mocking or test servers to fully execute");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_registry_cleanup() {
|
||||
use sglang_router_rs::config::PolicyConfig;
|
||||
use sglang_router_rs::policies::PolicyRegistry;
|
||||
|
||||
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
|
||||
|
||||
// Add workers for a model
|
||||
let policy1 = registry.on_worker_added("model-1", Some("cache_aware"));
|
||||
assert_eq!(policy1.name(), "cache_aware");
|
||||
|
||||
// Second worker uses existing policy
|
||||
let policy2 = registry.on_worker_added("model-1", Some("random"));
|
||||
assert_eq!(policy2.name(), "cache_aware"); // Should still be cache_aware
|
||||
|
||||
// Verify policy exists
|
||||
assert!(registry.get_policy("model-1").is_some());
|
||||
|
||||
// Remove first worker - policy should remain
|
||||
registry.on_worker_removed("model-1");
|
||||
assert!(registry.get_policy("model-1").is_some());
|
||||
|
||||
// Remove second worker - policy should be cleaned up
|
||||
registry.on_worker_removed("model-1");
|
||||
assert!(registry.get_policy("model-1").is_none());
|
||||
|
||||
println!("✓ PolicyRegistry cleanup test passed");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_policy_registry_multiple_models() {
|
||||
use sglang_router_rs::config::PolicyConfig;
|
||||
use sglang_router_rs::policies::PolicyRegistry;
|
||||
|
||||
let registry = PolicyRegistry::new(PolicyConfig::RoundRobin);
|
||||
|
||||
// Add workers for different models with different policies
|
||||
let llama_policy = registry.on_worker_added("llama-3", Some("cache_aware"));
|
||||
let gpt_policy = registry.on_worker_added("gpt-4", Some("random"));
|
||||
let mistral_policy = registry.on_worker_added("mistral", None); // Uses default
|
||||
|
||||
assert_eq!(llama_policy.name(), "cache_aware");
|
||||
assert_eq!(gpt_policy.name(), "random");
|
||||
assert_eq!(mistral_policy.name(), "round_robin"); // Default
|
||||
|
||||
// Verify all policies are stored
|
||||
assert!(registry.get_policy("llama-3").is_some());
|
||||
assert!(registry.get_policy("gpt-4").is_some());
|
||||
assert!(registry.get_policy("mistral").is_some());
|
||||
|
||||
// Get all mappings
|
||||
let mappings = registry.get_all_mappings();
|
||||
assert_eq!(mappings.len(), 3);
|
||||
assert_eq!(mappings.get("llama-3").unwrap(), "cache_aware");
|
||||
assert_eq!(mappings.get("gpt-4").unwrap(), "random");
|
||||
assert_eq!(mappings.get("mistral").unwrap(), "round_robin");
|
||||
|
||||
println!("✓ PolicyRegistry multiple models test passed");
|
||||
}
|
||||
@@ -197,12 +197,14 @@ async fn test_unsupported_endpoints() {
|
||||
rid: None,
|
||||
};
|
||||
|
||||
let response = router.route_generate(None, &generate_request).await;
|
||||
let response = router.route_generate(None, &generate_request, None).await;
|
||||
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
|
||||
|
||||
// Test completion endpoint (should also not be supported)
|
||||
let completion_request = create_minimal_completion_request();
|
||||
let response = router.route_completion(None, &completion_request).await;
|
||||
let response = router
|
||||
.route_completion(None, &completion_request, None)
|
||||
.await;
|
||||
assert_eq!(response.status(), StatusCode::NOT_IMPLEMENTED);
|
||||
}
|
||||
|
||||
@@ -228,7 +230,7 @@ async fn test_openai_router_chat_completion_with_mock() {
|
||||
chat_request.temperature = Some(0.7);
|
||||
|
||||
// Route the request
|
||||
let response = router.route_chat(None, &chat_request).await;
|
||||
let response = router.route_chat(None, &chat_request, None).await;
|
||||
|
||||
// Should get a successful response from mock server
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
@@ -269,7 +271,9 @@ async fn test_openai_e2e_with_server() {
|
||||
let chat_request: ChatCompletionRequest =
|
||||
serde_json::from_str(&body_str).unwrap();
|
||||
|
||||
router.route_chat(Some(&parts.headers), &chat_request).await
|
||||
router
|
||||
.route_chat(Some(&parts.headers), &chat_request, None)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}),
|
||||
@@ -327,7 +331,7 @@ async fn test_openai_router_chat_streaming_with_mock() {
|
||||
});
|
||||
let chat_request: ChatCompletionRequest = serde_json::from_value(val).unwrap();
|
||||
|
||||
let response = router.route_chat(None, &chat_request).await;
|
||||
let response = router.route_chat(None, &chat_request, None).await;
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
|
||||
// Should be SSE
|
||||
@@ -371,7 +375,7 @@ async fn test_openai_router_circuit_breaker() {
|
||||
|
||||
// First few requests should fail and record failures
|
||||
for _ in 0..3 {
|
||||
let response = router.route_chat(None, &chat_request).await;
|
||||
let response = router.route_chat(None, &chat_request, None).await;
|
||||
// Should get either an error or circuit breaker response
|
||||
assert!(
|
||||
response.status() == StatusCode::INTERNAL_SERVER_ERROR
|
||||
|
||||
Reference in New Issue
Block a user