[router] refactor router and worker management 3/n (#10727)
This commit is contained in:
@@ -4,17 +4,12 @@
|
||||
//! - Single Router Mode (enable_igw=false): Router owns workers directly
|
||||
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
|
||||
|
||||
use crate::config::RouterConfig;
|
||||
use crate::core::{BasicWorkerBuilder, CircuitBreakerConfig, Worker, WorkerRegistry, WorkerType};
|
||||
use crate::core::{Worker, WorkerRegistry, WorkerType};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||
ResponsesRequest,
|
||||
};
|
||||
use crate::protocols::worker_spec::{
|
||||
ServerInfo, WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse, WorkerInfo,
|
||||
WorkerListResponse, WorkerStats, WorkerTypeStats,
|
||||
};
|
||||
use crate::routers::{RouterTrait, WorkerManagement};
|
||||
use crate::routers::RouterTrait;
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
@@ -24,7 +19,7 @@ use axum::{
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing::{info, warn};
|
||||
use tracing::info;
|
||||
|
||||
/// Router identifier
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
@@ -45,48 +40,28 @@ pub struct RouterManager {
|
||||
/// Worker registry (single source of truth in multi-router mode)
|
||||
worker_registry: Arc<WorkerRegistry>,
|
||||
|
||||
/// Policy registry for managing model-to-policy mappings
|
||||
policy_registry: Arc<crate::policies::PolicyRegistry>,
|
||||
|
||||
/// All routers managed by this manager
|
||||
/// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd"
|
||||
routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>,
|
||||
|
||||
/// Default router for requests without specific routing
|
||||
default_router: Arc<std::sync::RwLock<Option<RouterId>>>,
|
||||
|
||||
/// HTTP client for querying worker info
|
||||
client: reqwest::Client,
|
||||
|
||||
/// Configuration
|
||||
#[allow(dead_code)] // May be used in future enhancements
|
||||
config: RouterConfig,
|
||||
}
|
||||
|
||||
impl RouterManager {
|
||||
/// Create a new router manager with shared registries
|
||||
pub fn new(
|
||||
config: RouterConfig,
|
||||
client: reqwest::Client,
|
||||
worker_registry: Arc<WorkerRegistry>,
|
||||
policy_registry: Arc<crate::policies::PolicyRegistry>,
|
||||
) -> Self {
|
||||
pub fn new(worker_registry: Arc<WorkerRegistry>) -> Self {
|
||||
Self {
|
||||
worker_registry,
|
||||
policy_registry,
|
||||
routers: Arc::new(DashMap::new()),
|
||||
default_router: Arc::new(std::sync::RwLock::new(None)),
|
||||
client,
|
||||
config,
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a router with the manager
|
||||
pub fn register_router(&self, id: RouterId, router: Arc<dyn RouterTrait>) {
|
||||
// Store router
|
||||
self.routers.insert(id.clone(), router);
|
||||
|
||||
// Set as default if first router
|
||||
let mut default_router = self.default_router.write().unwrap();
|
||||
if default_router.is_none() {
|
||||
*default_router = Some(id.clone());
|
||||
@@ -107,11 +82,9 @@ impl RouterManager {
|
||||
|
||||
/// Get router for a specific model based on worker types
|
||||
pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> {
|
||||
// Query workers for this model from registry
|
||||
let workers = self.worker_registry.get_by_model(model_id);
|
||||
|
||||
if !workers.is_empty() {
|
||||
// Determine router based on worker types
|
||||
let has_pd_workers = workers.iter().any(|w| {
|
||||
matches!(
|
||||
w.worker_type(),
|
||||
@@ -125,13 +98,11 @@ impl RouterManager {
|
||||
RouterId::new("http-regular".to_string())
|
||||
};
|
||||
|
||||
// Return the router if it exists
|
||||
if let Some(router) = self.routers.get(&router_id) {
|
||||
return Some(router.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Fall back to default router
|
||||
let default_router = self.default_router.read().unwrap();
|
||||
if let Some(ref default_id) = *default_router {
|
||||
self.routers.get(default_id).map(|r| r.clone())
|
||||
@@ -149,277 +120,12 @@ impl RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Add a worker to the registry
|
||||
pub async fn add_worker(
|
||||
&self,
|
||||
config: WorkerConfigRequest,
|
||||
) -> Result<WorkerApiResponse, WorkerErrorResponse> {
|
||||
// Build labels from configuration
|
||||
let mut labels = config.labels.clone();
|
||||
|
||||
// Query server info if model_id not provided
|
||||
let model_id = if let Some(model_id) = config.model_id {
|
||||
model_id
|
||||
} else {
|
||||
match self.query_server_info(&config.url, &config.api_key).await {
|
||||
Ok(info) => {
|
||||
// Extract model_id from server info
|
||||
info.model_id
|
||||
.or_else(|| {
|
||||
info.model_path
|
||||
.as_ref()
|
||||
.and_then(|path| path.split('/').next_back().map(|s| s.to_string()))
|
||||
})
|
||||
.unwrap_or_else(|| "unknown".to_string())
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to query server info from {}: {}", config.url, e);
|
||||
"unknown".to_string()
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
// Add configuration to labels
|
||||
labels.insert("model_id".to_string(), model_id.clone());
|
||||
|
||||
if let Some(priority) = config.priority {
|
||||
labels.insert("priority".to_string(), priority.to_string());
|
||||
}
|
||||
|
||||
if let Some(cost) = config.cost {
|
||||
labels.insert("cost".to_string(), cost.to_string());
|
||||
}
|
||||
|
||||
// Add gRPC-specific configuration if provided
|
||||
if let Some(tokenizer_path) = config.tokenizer_path {
|
||||
labels.insert("tokenizer_path".to_string(), tokenizer_path);
|
||||
}
|
||||
|
||||
if let Some(reasoning_parser) = config.reasoning_parser {
|
||||
labels.insert("reasoning_parser".to_string(), reasoning_parser);
|
||||
}
|
||||
|
||||
if let Some(tool_parser) = config.tool_parser {
|
||||
labels.insert("tool_parser".to_string(), tool_parser);
|
||||
}
|
||||
|
||||
if let Some(chat_template) = config.chat_template {
|
||||
labels.insert("chat_template".to_string(), chat_template);
|
||||
}
|
||||
|
||||
let worker = match config.worker_type.as_deref() {
|
||||
Some("prefill") => {
|
||||
let mut builder = BasicWorkerBuilder::new(config.url.clone())
|
||||
.worker_type(WorkerType::Prefill {
|
||||
bootstrap_port: config.bootstrap_port,
|
||||
})
|
||||
.labels(labels.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default());
|
||||
|
||||
if let Some(api_key) = config.api_key.clone() {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
|
||||
Box::new(builder.build()) as Box<dyn Worker>
|
||||
}
|
||||
Some("decode") => {
|
||||
let mut builder = BasicWorkerBuilder::new(config.url.clone())
|
||||
.worker_type(WorkerType::Decode)
|
||||
.labels(labels.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default());
|
||||
|
||||
if let Some(api_key) = config.api_key.clone() {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
|
||||
Box::new(builder.build()) as Box<dyn Worker>
|
||||
}
|
||||
_ => {
|
||||
let mut builder = BasicWorkerBuilder::new(config.url.clone())
|
||||
.worker_type(WorkerType::Regular)
|
||||
.labels(labels.clone())
|
||||
.circuit_breaker_config(CircuitBreakerConfig::default());
|
||||
|
||||
if let Some(api_key) = config.api_key.clone() {
|
||||
builder = builder.api_key(api_key);
|
||||
}
|
||||
|
||||
Box::new(builder.build()) as Box<dyn Worker>
|
||||
}
|
||||
};
|
||||
|
||||
// Register worker
|
||||
let worker_arc: Arc<dyn Worker> = Arc::from(worker);
|
||||
let worker_id = self.worker_registry.register(worker_arc.clone());
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
// Extract policy hint from labels if provided
|
||||
let policy_hint = labels.get("policy").map(|s| s.as_str());
|
||||
let policy = self.policy_registry.on_worker_added(&model_id, policy_hint);
|
||||
|
||||
// Log which type of router would handle this worker (for debugging)
|
||||
let expected_router = match config.worker_type.as_deref() {
|
||||
Some("prefill") | Some("decode") => "http-pd",
|
||||
_ => "http-regular",
|
||||
};
|
||||
|
||||
info!(
|
||||
"Worker for model '{}' would be handled by '{}' router based on type",
|
||||
model_id, expected_router
|
||||
);
|
||||
|
||||
info!(
|
||||
"Added worker {} with URL {} for model {} using policy {}",
|
||||
worker_id.as_str(),
|
||||
config.url,
|
||||
model_id,
|
||||
policy.name()
|
||||
);
|
||||
|
||||
// Return worker info
|
||||
let worker_info = self.worker_to_info(worker_id.as_str(), &worker_arc);
|
||||
|
||||
Ok(WorkerApiResponse {
|
||||
success: true,
|
||||
message: format!("Worker {} added successfully", worker_id.as_str()),
|
||||
worker: Some(worker_info),
|
||||
})
|
||||
}
|
||||
|
||||
/// Remove a worker from the registry
|
||||
pub fn remove_worker_from_registry(
|
||||
&self,
|
||||
url: &str,
|
||||
) -> Result<WorkerApiResponse, WorkerErrorResponse> {
|
||||
// Get worker to extract model_id before removing
|
||||
let model_id = self
|
||||
.worker_registry
|
||||
.get_by_url(url)
|
||||
.map(|worker| worker.model_id().to_string());
|
||||
|
||||
if let Some(_worker) = self.worker_registry.remove_by_url(url) {
|
||||
// Notify PolicyRegistry about worker removal
|
||||
if let Some(ref model_id) = model_id {
|
||||
self.policy_registry.on_worker_removed(model_id);
|
||||
|
||||
info!("Removed worker with URL {} for model {}", url, model_id);
|
||||
} else {
|
||||
info!("Removed worker with URL {}", url);
|
||||
}
|
||||
|
||||
Ok(WorkerApiResponse {
|
||||
success: true,
|
||||
message: format!("Worker {} removed successfully", url),
|
||||
worker: None,
|
||||
})
|
||||
} else {
|
||||
Err(WorkerErrorResponse {
|
||||
error: format!("Worker with URL {} not found", url),
|
||||
code: "WORKER_NOT_FOUND".to_string(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// List all workers
|
||||
pub fn list_workers(&self) -> WorkerListResponse {
|
||||
let workers = self.worker_registry.get_all_with_ids();
|
||||
let worker_infos: Vec<WorkerInfo> = workers
|
||||
.iter()
|
||||
.map(|(id, w)| self.worker_to_info(id.as_str(), w))
|
||||
.collect();
|
||||
|
||||
let total = worker_infos.len();
|
||||
|
||||
// Get stats from the worker registry
|
||||
let registry_stats = self.worker_registry.stats();
|
||||
|
||||
// Convert WorkerRegistryStats to WorkerStats
|
||||
let stats = WorkerStats {
|
||||
total_workers: registry_stats.total_workers,
|
||||
healthy_workers: registry_stats.healthy_workers,
|
||||
total_models: registry_stats.total_models,
|
||||
total_load: registry_stats.total_load,
|
||||
by_type: WorkerTypeStats {
|
||||
regular: registry_stats.regular_workers,
|
||||
prefill: registry_stats.prefill_workers,
|
||||
decode: registry_stats.decode_workers,
|
||||
},
|
||||
};
|
||||
|
||||
WorkerListResponse {
|
||||
workers: worker_infos,
|
||||
total,
|
||||
stats,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get worker by URL
|
||||
pub fn get_worker(&self, url: &str) -> Option<WorkerInfo> {
|
||||
self.worker_registry
|
||||
.get_by_url(url)
|
||||
.map(|w| self.worker_to_info("unknown", &w))
|
||||
}
|
||||
|
||||
/// Query server info from a worker URL
|
||||
async fn query_server_info(
|
||||
&self,
|
||||
url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<ServerInfo, String> {
|
||||
let info_url = format!("{}/get_server_info", url.trim_end_matches('/'));
|
||||
|
||||
let mut req_builder = self.client.get(&info_url);
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
match req_builder.send().await {
|
||||
Ok(response) => {
|
||||
if response.status().is_success() {
|
||||
response
|
||||
.json::<ServerInfo>()
|
||||
.await
|
||||
.map_err(|e| format!("Failed to parse server info: {}", e))
|
||||
} else {
|
||||
Err(format!("Server returned status: {}", response.status()))
|
||||
}
|
||||
}
|
||||
Err(e) => Err(format!("Failed to connect to server: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
/// Convert Worker to WorkerInfo
|
||||
fn worker_to_info(&self, id: &str, worker: &Arc<dyn Worker>) -> WorkerInfo {
|
||||
let metadata = worker.metadata();
|
||||
|
||||
WorkerInfo {
|
||||
id: id.to_string(),
|
||||
url: worker.url().to_string(),
|
||||
model_id: worker.model_id().to_string(),
|
||||
priority: worker.priority(),
|
||||
cost: worker.cost(),
|
||||
worker_type: match worker.worker_type() {
|
||||
WorkerType::Regular => "regular".to_string(),
|
||||
WorkerType::Prefill { .. } => "prefill".to_string(),
|
||||
WorkerType::Decode => "decode".to_string(),
|
||||
},
|
||||
is_healthy: worker.is_healthy(),
|
||||
load: worker.load(),
|
||||
connection_mode: format!("{:?}", worker.connection_mode()),
|
||||
tokenizer_path: worker.tokenizer_path().map(|s| s.to_string()),
|
||||
reasoning_parser: worker.reasoning_parser().map(|s| s.to_string()),
|
||||
tool_parser: worker.tool_parser().map(|s| s.to_string()),
|
||||
chat_template: worker.chat_template().map(|s| s.to_string()),
|
||||
metadata: metadata.labels.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the appropriate router for a request based on headers and request content
|
||||
pub fn select_router_for_request(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
model_id: Option<&str>,
|
||||
) -> Option<Arc<dyn RouterTrait>> {
|
||||
// Extract priority and cost preferences from headers if available
|
||||
let _priority_threshold = headers.and_then(|h| {
|
||||
h.get("x-worker-priority")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
@@ -432,7 +138,6 @@ impl RouterManager {
|
||||
.and_then(|s| s.parse::<f32>().ok())
|
||||
});
|
||||
|
||||
// Check if PD (prefill-decode) mode is preferred from headers
|
||||
let prefer_pd = headers
|
||||
.and_then(|h| {
|
||||
h.get("x-prefer-pd")
|
||||
@@ -441,7 +146,6 @@ impl RouterManager {
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
// If model specified, use get_router_for_model
|
||||
let candidate_routers = if let Some(model) = model_id {
|
||||
if let Some(router) = self.get_router_for_model(model) {
|
||||
vec![router]
|
||||
@@ -449,7 +153,6 @@ impl RouterManager {
|
||||
Vec::new()
|
||||
}
|
||||
} else {
|
||||
// No model specified, consider all routers
|
||||
self.routers
|
||||
.iter()
|
||||
.map(|entry| entry.value().clone())
|
||||
@@ -457,23 +160,20 @@ impl RouterManager {
|
||||
};
|
||||
|
||||
if candidate_routers.is_empty() {
|
||||
// No routers found for the specified model
|
||||
return None;
|
||||
}
|
||||
|
||||
// Score routers based on worker attributes and request preferences
|
||||
let mut best_router = None;
|
||||
let mut best_score = 0.0;
|
||||
|
||||
for router in candidate_routers {
|
||||
let mut score = 1.0;
|
||||
|
||||
// Check if this is a PD router
|
||||
let is_pd = router.is_pd_mode();
|
||||
if prefer_pd && is_pd {
|
||||
score += 2.0; // Bonus for matching PD preference
|
||||
score += 2.0;
|
||||
} else if !prefer_pd && !is_pd {
|
||||
score += 1.0; // Bonus for matching regular preference
|
||||
score += 1.0;
|
||||
}
|
||||
|
||||
// Get workers for this router and evaluate based on priority/cost
|
||||
@@ -495,49 +195,6 @@ impl RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// RouterManager implements RouterTrait to act as a meta-router
|
||||
/// that delegates requests to the appropriate underlying router
|
||||
#[async_trait]
|
||||
impl WorkerManagement for RouterManager {
|
||||
/// Add a worker - in multi-router mode, this adds to the registry
|
||||
async fn add_worker(
|
||||
&self,
|
||||
worker_url: &str,
|
||||
api_key: &Option<String>,
|
||||
) -> Result<String, String> {
|
||||
// Create a basic worker config request
|
||||
let config = WorkerConfigRequest {
|
||||
url: worker_url.to_string(),
|
||||
api_key: api_key.clone(),
|
||||
model_id: None,
|
||||
worker_type: None,
|
||||
priority: None,
|
||||
cost: None,
|
||||
labels: std::collections::HashMap::new(),
|
||||
bootstrap_port: None,
|
||||
tokenizer_path: None,
|
||||
reasoning_parser: None,
|
||||
tool_parser: None,
|
||||
chat_template: None,
|
||||
};
|
||||
|
||||
match self.add_worker(config).await {
|
||||
Ok(response) => Ok(response.message),
|
||||
Err(e) => Err(e.error),
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker from the registry
|
||||
fn remove_worker(&self, worker_url: &str) {
|
||||
let _ = self.remove_worker_from_registry(worker_url);
|
||||
}
|
||||
|
||||
/// Get all worker URLs from the registry
|
||||
fn get_worker_urls(&self) -> Vec<String> {
|
||||
self.worker_registry.get_all_urls()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl RouterTrait for RouterManager {
|
||||
fn as_any(&self) -> &dyn std::any::Any {
|
||||
@@ -639,7 +296,6 @@ impl RouterTrait for RouterManager {
|
||||
body: &ChatCompletionRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Select router based on headers and model
|
||||
let router = self.select_router_for_request(headers, Some(&body.model));
|
||||
|
||||
if let Some(router) = router {
|
||||
@@ -662,7 +318,6 @@ impl RouterTrait for RouterManager {
|
||||
body: &CompletionRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Select router based on headers and model
|
||||
let router = self.select_router_for_request(headers, Some(&body.model));
|
||||
|
||||
if let Some(router) = router {
|
||||
@@ -746,7 +401,6 @@ impl RouterTrait for RouterManager {
|
||||
body: &EmbeddingRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Select router based on headers and model
|
||||
let router = self.select_router_for_request(headers, Some(&body.model));
|
||||
|
||||
if let Some(router) = router {
|
||||
|
||||
Reference in New Issue
Block a user