[router] allow one router to support different model families and serving mode (#10244)
This commit is contained in:
@@ -155,6 +155,82 @@ pub trait Worker: Send + Sync + fmt::Debug {
|
||||
fn can_handle(&self, _req: &serde_json::Value) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
// === Multi-router support ===
|
||||
|
||||
// TODO: - Enhanced Worker Discovery
|
||||
// The Worker trait should handle async discovery of metadata from the worker itself
|
||||
// rather than having service discovery or other components query /get_server_info.
|
||||
// This keeps service discovery decoupled from worker-specific APIs.
|
||||
//
|
||||
// Proposed additions:
|
||||
// - async fn discover_metadata(&mut self) -> Result<(), Error>
|
||||
// Query /get_server_info and populate metadata labels with model_id, priority, cost, etc.
|
||||
// - async fn validate_configuration(&self) -> Result<(), Error>
|
||||
// Ensure worker has required configuration for its mode (e.g., tokenizer for gRPC)
|
||||
// - Make worker creation async to allow metadata discovery during initialization
|
||||
//
|
||||
// This way service discovery just calls router.add_worker() and the worker
|
||||
// handles its own metadata discovery internally.
|
||||
|
||||
/// Get the model ID this worker serves
|
||||
fn model_id(&self) -> &str {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("model_id")
|
||||
.map(|s| s.as_str())
|
||||
.unwrap_or("unknown")
|
||||
}
|
||||
|
||||
/// Get the priority of this worker (higher value = higher priority)
|
||||
fn priority(&self) -> u32 {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("priority")
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(50) // Default priority is 50 (mid-range)
|
||||
}
|
||||
|
||||
/// Get the cost factor of this worker (1.0 = baseline)
|
||||
fn cost(&self) -> f32 {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("cost")
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(1.0)
|
||||
}
|
||||
|
||||
/// Get the tokenizer path for this worker (gRPC mode only)
|
||||
fn tokenizer_path(&self) -> Option<&str> {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("tokenizer_path")
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
|
||||
/// Get the reasoning parser type for this worker (gRPC mode only)
|
||||
fn reasoning_parser(&self) -> Option<&str> {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("reasoning_parser")
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
|
||||
/// Get the tool parser type for this worker (gRPC mode only)
|
||||
fn tool_parser(&self) -> Option<&str> {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("tool_parser")
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
|
||||
/// Get the chat template for this worker (gRPC mode only)
|
||||
fn chat_template(&self) -> Option<&str> {
|
||||
self.metadata()
|
||||
.labels
|
||||
.get("chat_template")
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
/// Connection mode for worker communication
|
||||
@@ -724,6 +800,21 @@ impl WorkerFactory {
|
||||
)
|
||||
}
|
||||
|
||||
/// Create a regular worker with custom labels (for multi-router support)
|
||||
pub fn create_regular_with_labels(
|
||||
url: String,
|
||||
labels: std::collections::HashMap<String, String>,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
) -> Box<dyn Worker> {
|
||||
let mut worker = BasicWorker::new(url.clone(), WorkerType::Regular)
|
||||
.with_circuit_breaker_config(circuit_breaker_config);
|
||||
|
||||
// Add labels to metadata
|
||||
worker.metadata.labels = labels;
|
||||
|
||||
Box::new(worker)
|
||||
}
|
||||
|
||||
/// Create a DP-aware worker of specified type
|
||||
pub fn create_dp_aware(
|
||||
base_url: String,
|
||||
@@ -941,6 +1032,11 @@ impl fmt::Debug for HealthChecker {
|
||||
}
|
||||
|
||||
impl HealthChecker {
|
||||
/// Create a new HealthChecker
|
||||
pub fn new(handle: tokio::task::JoinHandle<()>, shutdown: Arc<AtomicBool>) -> Self {
|
||||
Self { handle, shutdown }
|
||||
}
|
||||
|
||||
/// Shutdown the health checker gracefully
|
||||
pub async fn shutdown(self) {
|
||||
self.shutdown.store(true, Ordering::Release);
|
||||
@@ -950,7 +1046,7 @@ impl HealthChecker {
|
||||
|
||||
/// Start an async background health checker for a collection of workers
|
||||
pub fn start_health_checker(
|
||||
workers: std::sync::Arc<std::sync::RwLock<Vec<Box<dyn Worker>>>>,
|
||||
workers: std::sync::Arc<std::sync::RwLock<Vec<std::sync::Arc<dyn Worker>>>>,
|
||||
check_interval_secs: u64,
|
||||
) -> HealthChecker {
|
||||
let shutdown = Arc::new(AtomicBool::new(false));
|
||||
@@ -1602,9 +1698,11 @@ mod tests {
|
||||
// Test HealthChecker background task
|
||||
#[tokio::test]
|
||||
async fn test_health_checker_startup() {
|
||||
let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular(
|
||||
let worker = Arc::new(BasicWorker::new(
|
||||
"http://w1:8080".to_string(),
|
||||
)]));
|
||||
WorkerType::Regular,
|
||||
)) as Arc<dyn Worker>;
|
||||
let workers = Arc::new(RwLock::new(vec![worker]));
|
||||
|
||||
let checker = start_health_checker(workers.clone(), 60);
|
||||
|
||||
@@ -1617,9 +1715,11 @@ mod tests {
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_health_checker_shutdown() {
|
||||
let workers = Arc::new(RwLock::new(vec![WorkerFactory::create_regular(
|
||||
let worker = Arc::new(BasicWorker::new(
|
||||
"http://w1:8080".to_string(),
|
||||
)]));
|
||||
WorkerType::Regular,
|
||||
)) as Arc<dyn Worker>;
|
||||
let workers = Arc::new(RwLock::new(vec![worker]));
|
||||
|
||||
let checker = start_health_checker(workers.clone(), 60);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user