Files
sglang/sgl-router/src/routers/router_manager.rs

723 lines
22 KiB
Rust

//! Router Manager for coordinating multiple routers and workers
//!
//! Provides centralized management based on enable_igw flag:
//! - Single Router Mode (enable_igw=false): Router owns workers directly
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
use std::sync::Arc;
use async_trait::async_trait;
use axum::{
body::Body,
extract::Request,
http::{HeaderMap, StatusCode},
response::{IntoResponse, Response},
};
use dashmap::DashMap;
use serde_json::Value;
use tracing::{debug, info, warn};
use crate::{
config::{ConnectionMode, RoutingMode},
core::{WorkerRegistry, WorkerType},
protocols::{
chat::ChatCompletionRequest,
classify::ClassifyRequest,
completion::CompletionRequest,
embedding::EmbeddingRequest,
generate::GenerateRequest,
rerank::RerankRequest,
responses::{ResponsesGetParams, ResponsesRequest},
},
routers::RouterTrait,
server::{AppContext, ServerConfig},
};
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
pub struct RouterId(String);
impl RouterId {
pub fn new(id: String) -> Self {
Self(id)
}
pub fn as_str(&self) -> &str {
&self.0
}
}
pub struct RouterManager {
worker_registry: Arc<WorkerRegistry>,
routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>,
default_router: Arc<std::sync::RwLock<Option<RouterId>>>,
enable_igw: bool,
}
impl RouterManager {
pub fn new(worker_registry: Arc<WorkerRegistry>) -> Self {
Self {
worker_registry,
routers: Arc::new(DashMap::new()),
default_router: Arc::new(std::sync::RwLock::new(None)),
enable_igw: false, // Will be set properly in from_config
}
}
pub async fn from_config(
config: &ServerConfig,
app_context: &Arc<AppContext>,
) -> Result<Arc<Self>, String> {
use crate::routers::RouterFactory;
let mut manager = Self::new(app_context.worker_registry.clone());
manager.enable_igw = config.router_config.enable_igw;
let manager = Arc::new(manager);
if config.router_config.enable_igw {
info!("Initializing RouterManager in multi-router mode (IGW)");
match RouterFactory::create_regular_router(app_context).await {
Ok(http_regular) => {
info!("Created HTTP Regular router");
manager.register_router(
RouterId::new("http-regular".to_string()),
Arc::from(http_regular),
);
}
Err(e) => {
warn!("Failed to create HTTP Regular router: {e}");
}
}
match RouterFactory::create_pd_router(
None,
None,
&config.router_config.policy,
app_context,
)
.await
{
Ok(http_pd) => {
info!("Created HTTP PD router");
manager
.register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd));
}
Err(e) => {
warn!("Failed to create HTTP PD router: {e}");
}
}
// TODO: Add gRPC routers once we have dynamic tokenizer loading
info!(
"RouterManager initialized with {} routers for multi-router mode",
manager.router_count()
);
} else {
info!("Initializing RouterManager in single-router mode");
let single_router = Arc::from(RouterFactory::create_router(app_context).await?);
let router_id = Self::determine_router_id(
&config.router_config.mode,
&config.router_config.connection_mode,
);
info!("Created single router with ID: {}", router_id.as_str());
manager.register_router(router_id.clone(), single_router);
manager.set_default_router(router_id);
}
if manager.router_count() == 0 {
return Err("No routers could be initialized".to_string());
}
Ok(manager)
}
pub fn determine_router_id(
routing_mode: &RoutingMode,
connection_mode: &ConnectionMode,
) -> RouterId {
match (connection_mode, routing_mode) {
(ConnectionMode::Http, RoutingMode::Regular { .. }) => {
RouterId::new("http-regular".to_string())
}
(ConnectionMode::Http, RoutingMode::PrefillDecode { .. }) => {
RouterId::new("http-pd".to_string())
}
(ConnectionMode::Http, RoutingMode::OpenAI { .. }) => {
RouterId::new("http-openai".to_string())
}
(ConnectionMode::Grpc, RoutingMode::Regular { .. }) => {
RouterId::new("grpc-regular".to_string())
}
(ConnectionMode::Grpc, RoutingMode::PrefillDecode { .. }) => {
RouterId::new("grpc-pd".to_string())
}
(ConnectionMode::Grpc, RoutingMode::OpenAI { .. }) => {
RouterId::new("grpc-regular".to_string())
}
}
}
pub fn register_router(&self, id: RouterId, router: Arc<dyn RouterTrait>) {
self.routers.insert(id.clone(), router);
let mut default_router = self.default_router.write().unwrap();
if default_router.is_none() {
*default_router = Some(id.clone());
info!("Set default router to {}", id.as_str());
}
}
pub fn set_default_router(&self, id: RouterId) {
let mut default_router = self.default_router.write().unwrap();
*default_router = Some(id);
}
pub fn router_count(&self) -> usize {
self.routers.len()
}
pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> {
let workers = self.worker_registry.get_by_model(model_id);
if !workers.is_empty() {
let has_pd_workers = workers.iter().any(|w| {
matches!(
w.worker_type(),
WorkerType::Prefill { .. } | WorkerType::Decode
)
});
let router_id = if has_pd_workers {
RouterId::new("http-pd".to_string())
} else {
RouterId::new("http-regular".to_string())
};
if let Some(router) = self.routers.get(&router_id) {
return Some(router.clone());
}
}
let default_router = self.default_router.read().unwrap();
if let Some(ref default_id) = *default_router {
self.routers.get(default_id).map(|r| r.clone())
} else {
None
}
}
pub fn select_router_for_request(
&self,
headers: Option<&HeaderMap>,
model_id: Option<&str>,
) -> Option<Arc<dyn RouterTrait>> {
// In single-router mode (enable_igw=false), always use the default router
if !self.enable_igw {
let default_router = self.default_router.read().unwrap();
if let Some(ref default_id) = *default_router {
debug!(
"Single-router mode: using default router {} for model {:?}",
default_id.as_str(),
model_id
);
return self.routers.get(default_id).map(|r| r.clone());
}
}
// Multi-router mode logic follows
let _priority_threshold = headers.and_then(|h| {
h.get("x-worker-priority")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<u32>().ok())
});
let _max_cost = headers.and_then(|h| {
h.get("x-max-cost")
.and_then(|v| v.to_str().ok())
.and_then(|s| s.parse::<f32>().ok())
});
let prefer_pd = headers
.and_then(|h| {
h.get("x-prefer-pd")
.and_then(|v| v.to_str().ok())
.map(|s| s == "true" || s == "1")
})
.unwrap_or(false);
let candidate_routers = if let Some(model) = model_id {
if let Some(router) = self.get_router_for_model(model) {
vec![router]
} else {
Vec::new()
}
} else {
self.routers
.iter()
.map(|entry| entry.value().clone())
.collect::<Vec<_>>()
};
if candidate_routers.is_empty() {
return None;
}
let mut best_router = None;
let mut best_score = 0.0;
for router in candidate_routers {
let mut score = 1.0;
let is_pd = router.is_pd_mode();
if prefer_pd && is_pd {
score += 2.0;
} else if !prefer_pd && !is_pd {
score += 1.0;
}
// TODO: Once routers expose worker stats, we can evaluate:
// - Average worker priority vs priority_threshold
// - Average worker cost vs max_cost
// - Current load and health status
if score > best_score {
best_score = score;
best_router = Some(router);
}
}
best_router
}
}
#[async_trait]
impl RouterTrait for RouterManager {
fn as_any(&self) -> &dyn std::any::Any {
self
}
async fn health_generate(&self, _req: Request<Body>) -> Response {
// TODO: Should check if any router has healthy workers
(
StatusCode::SERVICE_UNAVAILABLE,
"No routers with healthy workers available",
)
.into_response()
}
async fn get_server_info(&self, _req: Request<Body>) -> Response {
// TODO: Aggregate info from all routers with healthy workers
(
StatusCode::OK,
serde_json::json!({
"router_manager": true,
"routers_count": self.routers.len(),
"workers_count": self.worker_registry.get_all().len()
})
.to_string(),
)
.into_response()
}
async fn get_models(&self, _req: Request<Body>) -> Response {
let models = self.worker_registry.get_models();
if models.is_empty() {
(StatusCode::SERVICE_UNAVAILABLE, "No models available").into_response()
} else {
(
StatusCode::OK,
serde_json::json!({ "models": models }).to_string(),
)
.into_response()
}
}
async fn get_model_info(&self, _req: Request<Body>) -> Response {
// TODO: Extract model from request and route to appropriate router
(
StatusCode::NOT_IMPLEMENTED,
"Model info endpoint not yet implemented in RouterManager",
)
.into_response()
}
async fn route_generate(
&self,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
_model_id: Option<&str>,
) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.route_generate(headers, body, None).await
} else {
(
StatusCode::NOT_FOUND,
"No router available for this request",
)
.into_response()
}
}
async fn route_chat(
&self,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
_model_id: Option<&str>,
) -> Response {
let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router {
router.route_chat(headers, body, Some(&body.model)).await
} else {
(
StatusCode::NOT_FOUND,
format!("Model '{}' not found or no router available", body.model),
)
.into_response()
}
}
async fn route_completion(
&self,
headers: Option<&HeaderMap>,
body: &CompletionRequest,
_model_id: Option<&str>,
) -> Response {
let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router {
router
.route_completion(headers, body, Some(&body.model))
.await
} else {
(
StatusCode::NOT_FOUND,
format!("Model '{}' not found or no router available", body.model),
)
.into_response()
}
}
async fn route_responses(
&self,
headers: Option<&HeaderMap>,
body: &ResponsesRequest,
model_id: Option<&str>,
) -> Response {
let selected_model = body.model.as_deref().or(model_id);
let router = self.select_router_for_request(headers, selected_model);
if let Some(router) = router {
router.route_responses(headers, body, selected_model).await
} else {
(
StatusCode::NOT_FOUND,
"No router available to handle responses request",
)
.into_response()
}
}
async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"responses api not yet implemented in inference gateway mode",
)
.into_response()
}
async fn list_response_input_items(
&self,
_headers: Option<&HeaderMap>,
_response_id: &str,
) -> Response {
(
StatusCode::NOT_IMPLEMENTED,
"responses api not yet implemented in inference gateway mode",
)
.into_response()
}
async fn get_response(
&self,
headers: Option<&HeaderMap>,
response_id: &str,
params: &ResponsesGetParams,
) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.get_response(headers, response_id, params).await
} else {
(
StatusCode::NOT_FOUND,
format!("No router available to get response '{}'", response_id),
)
.into_response()
}
}
async fn cancel_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.cancel_response(headers, response_id).await
} else {
(
StatusCode::NOT_FOUND,
format!("No router available to cancel response '{}'", response_id),
)
.into_response()
}
}
async fn route_embeddings(
&self,
headers: Option<&HeaderMap>,
body: &EmbeddingRequest,
_model_id: Option<&str>,
) -> Response {
let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router {
router
.route_embeddings(headers, body, Some(&body.model))
.await
} else {
(
StatusCode::NOT_FOUND,
format!("Model '{}' not found or no router available", body.model),
)
.into_response()
}
}
async fn route_rerank(
&self,
headers: Option<&HeaderMap>,
body: &RerankRequest,
model_id: Option<&str>,
) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.route_rerank(headers, body, model_id).await
} else {
(
StatusCode::NOT_FOUND,
"No router available for rerank request",
)
.into_response()
}
}
async fn route_classify(
&self,
headers: Option<&HeaderMap>,
body: &ClassifyRequest,
model_id: Option<&str>,
) -> Response {
let router = self.select_router_for_request(headers, Some(&body.model));
if let Some(router) = router {
router.route_classify(headers, body, model_id).await
} else {
(
StatusCode::NOT_FOUND,
format!("Model '{}' not found or no router available", body.model),
)
.into_response()
}
}
fn router_type(&self) -> &'static str {
"manager"
}
// Conversations API delegates
async fn create_conversation(&self, headers: Option<&HeaderMap>, body: &Value) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.create_conversation(headers, body).await
} else {
(
StatusCode::NOT_FOUND,
"No router available to create conversation",
)
.into_response()
}
}
async fn get_conversation(
&self,
headers: Option<&HeaderMap>,
conversation_id: &str,
) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.get_conversation(headers, conversation_id).await
} else {
(
StatusCode::NOT_FOUND,
format!(
"No router available to get conversation '{}'",
conversation_id
),
)
.into_response()
}
}
async fn update_conversation(
&self,
headers: Option<&HeaderMap>,
conversation_id: &str,
body: &Value,
) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router
.update_conversation(headers, conversation_id, body)
.await
} else {
(
StatusCode::NOT_FOUND,
format!(
"No router available to update conversation '{}'",
conversation_id
),
)
.into_response()
}
}
async fn delete_conversation(
&self,
headers: Option<&HeaderMap>,
conversation_id: &str,
) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router.delete_conversation(headers, conversation_id).await
} else {
(
StatusCode::NOT_FOUND,
format!(
"No router available to delete conversation '{}'",
conversation_id
),
)
.into_response()
}
}
async fn list_conversation_items(
&self,
headers: Option<&HeaderMap>,
conversation_id: &str,
limit: Option<usize>,
order: Option<String>,
after: Option<String>,
) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router
.list_conversation_items(headers, conversation_id, limit, order, after)
.await
} else {
(
StatusCode::NOT_FOUND,
format!(
"No router available to list conversation items for '{}'",
conversation_id
),
)
.into_response()
}
}
async fn create_conversation_items(
&self,
headers: Option<&HeaderMap>,
conversation_id: &str,
body: &Value,
) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router
.create_conversation_items(headers, conversation_id, body)
.await
} else {
(
StatusCode::NOT_FOUND,
format!(
"No router available to create conversation items for '{}'",
conversation_id
),
)
.into_response()
}
}
async fn get_conversation_item(
&self,
headers: Option<&HeaderMap>,
conversation_id: &str,
item_id: &str,
include: Option<Vec<String>>,
) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router
.get_conversation_item(headers, conversation_id, item_id, include)
.await
} else {
(
StatusCode::NOT_FOUND,
format!(
"No router available to get conversation item '{}' in '{}'",
item_id, conversation_id
),
)
.into_response()
}
}
async fn delete_conversation_item(
&self,
headers: Option<&HeaderMap>,
conversation_id: &str,
item_id: &str,
) -> Response {
let router = self.select_router_for_request(headers, None);
if let Some(router) = router {
router
.delete_conversation_item(headers, conversation_id, item_id)
.await
} else {
(
StatusCode::NOT_FOUND,
format!(
"No router available to delete conversation item '{}' in '{}'",
item_id, conversation_id
),
)
.into_response()
}
}
}
impl std::fmt::Debug for RouterManager {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RouterManager")
.field("routers_count", &self.routers.len())
.field("workers_count", &self.worker_registry.get_all().len())
.field("default_router", &*self.default_router.read().unwrap())
.finish()
}
}