[router] refactor router and worker management 4/n (#10756)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -41,6 +41,7 @@ pub struct PDRouter {
|
||||
pub prefill_client: Client,
|
||||
pub retry_config: RetryConfig,
|
||||
pub api_key: Option<String>,
|
||||
pub enable_igw: bool,
|
||||
prefill_drain_tx: mpsc::Sender<reqwest::Response>,
|
||||
}
|
||||
|
||||
@@ -317,6 +318,7 @@ impl PDRouter {
|
||||
prefill_drain_tx,
|
||||
retry_config: ctx.router_config.effective_retry_config(),
|
||||
api_key: ctx.router_config.api_key.clone(),
|
||||
enable_igw: ctx.router_config.enable_igw,
|
||||
})
|
||||
}
|
||||
|
||||
@@ -849,7 +851,14 @@ impl PDRouter {
|
||||
request_text: Option<&str>,
|
||||
model_id: Option<&str>,
|
||||
) -> Result<(Arc<dyn Worker>, Arc<dyn Worker>), String> {
|
||||
let prefill_workers = if let Some(model) = model_id {
|
||||
let effective_model_id = if !self.enable_igw { None } else { model_id };
|
||||
|
||||
debug!(
|
||||
"Selecting PD pair: enable_igw={}, model_id={:?}, effective_model_id={:?}",
|
||||
self.enable_igw, model_id, effective_model_id
|
||||
);
|
||||
|
||||
let prefill_workers = if let Some(model) = effective_model_id {
|
||||
self.worker_registry
|
||||
.get_by_model_fast(model)
|
||||
.into_iter()
|
||||
@@ -859,7 +868,7 @@ impl PDRouter {
|
||||
self.worker_registry.get_prefill_workers()
|
||||
};
|
||||
|
||||
let decode_workers = if let Some(model) = model_id {
|
||||
let decode_workers = if let Some(model) = effective_model_id {
|
||||
self.worker_registry
|
||||
.get_by_model_fast(model)
|
||||
.into_iter()
|
||||
@@ -1797,6 +1806,7 @@ mod tests {
|
||||
prefill_drain_tx: mpsc::channel(100).0,
|
||||
retry_config: RetryConfig::default(),
|
||||
api_key: Some("test_api_key".to_string()),
|
||||
enable_igw: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -35,6 +35,7 @@ pub struct Router {
|
||||
policy_registry: Arc<PolicyRegistry>,
|
||||
client: Client,
|
||||
dp_aware: bool,
|
||||
enable_igw: bool,
|
||||
retry_config: RetryConfig,
|
||||
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||
@@ -93,6 +94,7 @@ impl Router {
|
||||
policy_registry: ctx.policy_registry.clone(),
|
||||
client: ctx.client.clone(),
|
||||
dp_aware: ctx.router_config.dp_aware,
|
||||
enable_igw: ctx.router_config.enable_igw,
|
||||
retry_config: ctx.router_config.effective_retry_config(),
|
||||
_worker_loads: worker_loads,
|
||||
_load_monitor_handle: load_monitor_handle,
|
||||
@@ -162,9 +164,11 @@ impl Router {
|
||||
model_id: Option<&str>,
|
||||
text: Option<&str>,
|
||||
) -> Option<Arc<dyn Worker>> {
|
||||
let effective_model_id = if !self.enable_igw { None } else { model_id };
|
||||
|
||||
// Get workers for the specified model O(1), filtered by connection mode
|
||||
let workers = self.worker_registry.get_workers_filtered(
|
||||
model_id,
|
||||
effective_model_id,
|
||||
Some(WorkerType::Regular),
|
||||
Some(ConnectionMode::Http),
|
||||
false, // get all workers, we'll filter by is_available() next
|
||||
@@ -1106,6 +1110,7 @@ mod tests {
|
||||
retry_config: RetryConfig::default(),
|
||||
_worker_loads: Arc::new(rx),
|
||||
_load_monitor_handle: None,
|
||||
enable_igw: false,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,12 +4,14 @@
|
||||
//! - Single Router Mode (enable_igw=false): Router owns workers directly
|
||||
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
|
||||
|
||||
use crate::core::{Worker, WorkerRegistry, WorkerType};
|
||||
use crate::config::{ConnectionMode, RoutingMode};
|
||||
use crate::core::{WorkerRegistry, WorkerType};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||
ResponsesRequest,
|
||||
};
|
||||
use crate::routers::RouterTrait;
|
||||
use crate::server::{AppContext, ServerConfig};
|
||||
use async_trait::async_trait;
|
||||
use axum::{
|
||||
body::Body,
|
||||
@@ -19,9 +21,8 @@ use axum::{
|
||||
};
|
||||
use dashmap::DashMap;
|
||||
use std::sync::Arc;
|
||||
use tracing::info;
|
||||
use tracing::{debug, info, warn};
|
||||
|
||||
/// Router identifier
|
||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||
pub struct RouterId(String);
|
||||
|
||||
@@ -35,30 +36,120 @@ impl RouterId {
|
||||
}
|
||||
}
|
||||
|
||||
/// Router Manager - Central coordinator for routers and workers
|
||||
pub struct RouterManager {
|
||||
/// Worker registry (single source of truth in multi-router mode)
|
||||
worker_registry: Arc<WorkerRegistry>,
|
||||
|
||||
/// All routers managed by this manager
|
||||
/// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd"
|
||||
routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>,
|
||||
|
||||
/// Default router for requests without specific routing
|
||||
default_router: Arc<std::sync::RwLock<Option<RouterId>>>,
|
||||
enable_igw: bool,
|
||||
}
|
||||
|
||||
impl RouterManager {
|
||||
/// Create a new router manager with shared registries
|
||||
pub fn new(worker_registry: Arc<WorkerRegistry>) -> Self {
|
||||
Self {
|
||||
worker_registry,
|
||||
routers: Arc::new(DashMap::new()),
|
||||
default_router: Arc::new(std::sync::RwLock::new(None)),
|
||||
enable_igw: false, // Will be set properly in from_config
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn from_config(
|
||||
config: &ServerConfig,
|
||||
app_context: &Arc<AppContext>,
|
||||
) -> Result<Arc<Self>, String> {
|
||||
use crate::routers::RouterFactory;
|
||||
|
||||
let mut manager = Self::new(app_context.worker_registry.clone());
|
||||
manager.enable_igw = config.router_config.enable_igw;
|
||||
let manager = Arc::new(manager);
|
||||
|
||||
if config.router_config.enable_igw {
|
||||
info!("Initializing RouterManager in multi-router mode (IGW)");
|
||||
|
||||
match RouterFactory::create_regular_router(app_context).await {
|
||||
Ok(http_regular) => {
|
||||
info!("Created HTTP Regular router");
|
||||
manager.register_router(
|
||||
RouterId::new("http-regular".to_string()),
|
||||
Arc::from(http_regular),
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create HTTP Regular router: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
match RouterFactory::create_pd_router(
|
||||
None,
|
||||
None,
|
||||
&config.router_config.policy,
|
||||
app_context,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(http_pd) => {
|
||||
info!("Created HTTP PD router");
|
||||
manager
|
||||
.register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create HTTP PD router: {e}");
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add gRPC routers once we have dynamic tokenizer loading
|
||||
|
||||
info!(
|
||||
"RouterManager initialized with {} routers for multi-router mode",
|
||||
manager.router_count()
|
||||
);
|
||||
} else {
|
||||
info!("Initializing RouterManager in single-router mode");
|
||||
|
||||
let single_router = Arc::from(RouterFactory::create_router(app_context).await?);
|
||||
let router_id = Self::determine_router_id(
|
||||
&config.router_config.mode,
|
||||
&config.router_config.connection_mode,
|
||||
);
|
||||
|
||||
info!("Created single router with ID: {}", router_id.as_str());
|
||||
manager.register_router(router_id.clone(), single_router);
|
||||
manager.set_default_router(router_id);
|
||||
}
|
||||
|
||||
if manager.router_count() == 0 {
|
||||
return Err("No routers could be initialized".to_string());
|
||||
}
|
||||
|
||||
Ok(manager)
|
||||
}
|
||||
|
||||
pub fn determine_router_id(
|
||||
routing_mode: &RoutingMode,
|
||||
connection_mode: &ConnectionMode,
|
||||
) -> RouterId {
|
||||
match (connection_mode, routing_mode) {
|
||||
(ConnectionMode::Http, RoutingMode::Regular { .. }) => {
|
||||
RouterId::new("http-regular".to_string())
|
||||
}
|
||||
(ConnectionMode::Http, RoutingMode::PrefillDecode { .. }) => {
|
||||
RouterId::new("http-pd".to_string())
|
||||
}
|
||||
(ConnectionMode::Http, RoutingMode::OpenAI { .. }) => {
|
||||
RouterId::new("http-openai".to_string())
|
||||
}
|
||||
(ConnectionMode::Grpc, RoutingMode::Regular { .. }) => {
|
||||
RouterId::new("grpc-regular".to_string())
|
||||
}
|
||||
(ConnectionMode::Grpc, RoutingMode::PrefillDecode { .. }) => {
|
||||
RouterId::new("grpc-pd".to_string())
|
||||
}
|
||||
(ConnectionMode::Grpc, RoutingMode::OpenAI { .. }) => {
|
||||
RouterId::new("grpc-regular".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Register a router with the manager
|
||||
pub fn register_router(&self, id: RouterId, router: Arc<dyn RouterTrait>) {
|
||||
self.routers.insert(id.clone(), router);
|
||||
|
||||
@@ -69,18 +160,15 @@ impl RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Set the default router
|
||||
pub fn set_default_router(&self, id: RouterId) {
|
||||
let mut default_router = self.default_router.write().unwrap();
|
||||
*default_router = Some(id);
|
||||
}
|
||||
|
||||
/// Get the number of registered routers
|
||||
pub fn router_count(&self) -> usize {
|
||||
self.routers.len()
|
||||
}
|
||||
|
||||
/// Get router for a specific model based on worker types
|
||||
pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> {
|
||||
let workers = self.worker_registry.get_by_model(model_id);
|
||||
|
||||
@@ -111,21 +199,25 @@ impl RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get workers for routing decision
|
||||
pub fn get_workers_for_request(&self, model_id: Option<&str>) -> Vec<Arc<dyn Worker>> {
|
||||
if let Some(model) = model_id {
|
||||
self.worker_registry.get_by_model(model)
|
||||
} else {
|
||||
self.worker_registry.get_all()
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the appropriate router for a request based on headers and request content
|
||||
pub fn select_router_for_request(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
model_id: Option<&str>,
|
||||
) -> Option<Arc<dyn RouterTrait>> {
|
||||
// In single-router mode (enable_igw=false), always use the default router
|
||||
if !self.enable_igw {
|
||||
let default_router = self.default_router.read().unwrap();
|
||||
if let Some(ref default_id) = *default_router {
|
||||
debug!(
|
||||
"Single-router mode: using default router {} for model {:?}",
|
||||
default_id.as_str(),
|
||||
model_id
|
||||
);
|
||||
return self.routers.get(default_id).map(|r| r.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Multi-router mode logic follows
|
||||
let _priority_threshold = headers.and_then(|h| {
|
||||
h.get("x-worker-priority")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
@@ -176,10 +268,6 @@ impl RouterManager {
|
||||
score += 1.0;
|
||||
}
|
||||
|
||||
// Get workers for this router and evaluate based on priority/cost
|
||||
// Note: This would require routers to expose their workers or stats
|
||||
// For now, we'll use a simple selection based on router type
|
||||
|
||||
// TODO: Once routers expose worker stats, we can evaluate:
|
||||
// - Average worker priority vs priority_threshold
|
||||
// - Average worker cost vs max_cost
|
||||
@@ -201,16 +289,11 @@ impl RouterTrait for RouterManager {
|
||||
self
|
||||
}
|
||||
|
||||
/// Health check - return 503 if no routers available
|
||||
async fn health(&self, _req: Request<Body>) -> Response {
|
||||
// Health check should succeed if RouterManager exists, even without routers
|
||||
// Individual router health can be checked via specific endpoints
|
||||
(StatusCode::OK, "RouterManager is healthy").into_response()
|
||||
}
|
||||
|
||||
/// Health generate - check if any router can handle generate requests
|
||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||
// Return 503 since we have no routers with workers
|
||||
// TODO: Should check if any router has healthy workers
|
||||
(
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
@@ -219,10 +302,8 @@ impl RouterTrait for RouterManager {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Get server information - aggregate from all routers
|
||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||
// TODO: Aggregate info from all routers with healthy workers
|
||||
// For now, return basic info about the RouterManager
|
||||
(
|
||||
StatusCode::OK,
|
||||
serde_json::json!({
|
||||
@@ -235,9 +316,7 @@ impl RouterTrait for RouterManager {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Get available models - query from worker registry
|
||||
async fn get_models(&self, _req: Request<Body>) -> Response {
|
||||
// Get models from worker registry
|
||||
let models = self.worker_registry.get_models();
|
||||
|
||||
if models.is_empty() {
|
||||
@@ -254,10 +333,8 @@ impl RouterTrait for RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get model information
|
||||
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||
// TODO: Extract model from request and route to appropriate router
|
||||
// For now, return not implemented
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"Model info endpoint not yet implemented in RouterManager",
|
||||
@@ -265,22 +342,17 @@ impl RouterTrait for RouterManager {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Route a generate request
|
||||
async fn route_generate(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &GenerateRequest,
|
||||
_model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Select router based on headers
|
||||
// GenerateRequest doesn't have a model field
|
||||
let router = self.select_router_for_request(headers, None);
|
||||
|
||||
if let Some(router) = router {
|
||||
// In multi-model mode, pass None since GenerateRequest doesn't have model field
|
||||
router.route_generate(headers, body, None).await
|
||||
} else {
|
||||
// Return 404 when no router is available for the request
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
"No router available for this request",
|
||||
@@ -289,7 +361,6 @@ impl RouterTrait for RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Route a chat completion request
|
||||
async fn route_chat(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
@@ -299,10 +370,8 @@ impl RouterTrait for RouterManager {
|
||||
let router = self.select_router_for_request(headers, Some(&body.model));
|
||||
|
||||
if let Some(router) = router {
|
||||
// In multi-model mode, pass the model_id to the router
|
||||
router.route_chat(headers, body, Some(&body.model)).await
|
||||
} else {
|
||||
// Return 404 when the specified model is not found
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Model '{}' not found or no router available", body.model),
|
||||
@@ -311,7 +380,6 @@ impl RouterTrait for RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Route a completion request
|
||||
async fn route_completion(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
@@ -321,12 +389,10 @@ impl RouterTrait for RouterManager {
|
||||
let router = self.select_router_for_request(headers, Some(&body.model));
|
||||
|
||||
if let Some(router) = router {
|
||||
// In multi-model mode, pass the model_id to the router
|
||||
router
|
||||
.route_completion(headers, body, Some(&body.model))
|
||||
.await
|
||||
} else {
|
||||
// Return 404 when the specified model is not found
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Model '{}' not found or no router available", body.model),
|
||||
@@ -348,26 +414,6 @@ impl RouterTrait for RouterManager {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"responses api not yet implemented in inference gateway mode",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn list_response_input_items(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_response_id: &str,
|
||||
) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"responses api not yet implemented in inference gateway mode",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
||||
let router = self.select_router_for_request(headers, None);
|
||||
if let Some(router) = router {
|
||||
@@ -394,7 +440,26 @@ impl RouterTrait for RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Route embeddings request
|
||||
async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"responses api not yet implemented in inference gateway mode",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn list_response_input_items(
|
||||
&self,
|
||||
_headers: Option<&HeaderMap>,
|
||||
_response_id: &str,
|
||||
) -> Response {
|
||||
(
|
||||
StatusCode::NOT_IMPLEMENTED,
|
||||
"responses api not yet implemented in inference gateway mode",
|
||||
)
|
||||
.into_response()
|
||||
}
|
||||
|
||||
async fn route_embeddings(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
@@ -408,7 +473,6 @@ impl RouterTrait for RouterManager {
|
||||
.route_embeddings(headers, body, Some(&body.model))
|
||||
.await
|
||||
} else {
|
||||
// Return 404 when the specified model is not found
|
||||
(
|
||||
StatusCode::NOT_FOUND,
|
||||
format!("Model '{}' not found or no router available", body.model),
|
||||
@@ -417,14 +481,12 @@ impl RouterTrait for RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Route rerank request
|
||||
async fn route_rerank(
|
||||
&self,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &RerankRequest,
|
||||
model_id: Option<&str>,
|
||||
) -> Response {
|
||||
// Try to select a router based on headers
|
||||
let router = self.select_router_for_request(headers, None);
|
||||
|
||||
if let Some(router) = router {
|
||||
@@ -438,10 +500,8 @@ impl RouterTrait for RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Flush cache on all routers and workers
|
||||
async fn flush_cache(&self) -> Response {
|
||||
// TODO: Call flush_cache on all routers that have workers
|
||||
// For now, return success if we have any routers
|
||||
if self.routers.is_empty() {
|
||||
(StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response()
|
||||
} else {
|
||||
@@ -450,9 +510,7 @@ impl RouterTrait for RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get worker loads from all routers
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
// Return worker loads from the registry
|
||||
let workers = self.worker_registry.get_all();
|
||||
let loads: Vec<serde_json::Value> = workers
|
||||
.iter()
|
||||
@@ -476,12 +534,10 @@ impl RouterTrait for RouterManager {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Get router type name
|
||||
fn router_type(&self) -> &'static str {
|
||||
"manager"
|
||||
}
|
||||
|
||||
/// Server readiness check - check if any router is ready
|
||||
fn readiness(&self) -> Response {
|
||||
if self.routers.is_empty() {
|
||||
(StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response()
|
||||
@@ -492,9 +548,6 @@ impl RouterTrait for RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
// Note: get_first_available_router removed - we now properly handle
|
||||
// router selection based on model and worker availability
|
||||
|
||||
impl std::fmt::Debug for RouterManager {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
f.debug_struct("RouterManager")
|
||||
|
||||
Reference in New Issue
Block a user