[router] refactor router and worker management 4/n (#10756)
Co-authored-by: Chang Su <chang.s.su@oracle.com>
This commit is contained in:
@@ -41,6 +41,7 @@ pub struct PDRouter {
|
|||||||
pub prefill_client: Client,
|
pub prefill_client: Client,
|
||||||
pub retry_config: RetryConfig,
|
pub retry_config: RetryConfig,
|
||||||
pub api_key: Option<String>,
|
pub api_key: Option<String>,
|
||||||
|
pub enable_igw: bool,
|
||||||
prefill_drain_tx: mpsc::Sender<reqwest::Response>,
|
prefill_drain_tx: mpsc::Sender<reqwest::Response>,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -317,6 +318,7 @@ impl PDRouter {
|
|||||||
prefill_drain_tx,
|
prefill_drain_tx,
|
||||||
retry_config: ctx.router_config.effective_retry_config(),
|
retry_config: ctx.router_config.effective_retry_config(),
|
||||||
api_key: ctx.router_config.api_key.clone(),
|
api_key: ctx.router_config.api_key.clone(),
|
||||||
|
enable_igw: ctx.router_config.enable_igw,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -849,7 +851,14 @@ impl PDRouter {
|
|||||||
request_text: Option<&str>,
|
request_text: Option<&str>,
|
||||||
model_id: Option<&str>,
|
model_id: Option<&str>,
|
||||||
) -> Result<(Arc<dyn Worker>, Arc<dyn Worker>), String> {
|
) -> Result<(Arc<dyn Worker>, Arc<dyn Worker>), String> {
|
||||||
let prefill_workers = if let Some(model) = model_id {
|
let effective_model_id = if !self.enable_igw { None } else { model_id };
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Selecting PD pair: enable_igw={}, model_id={:?}, effective_model_id={:?}",
|
||||||
|
self.enable_igw, model_id, effective_model_id
|
||||||
|
);
|
||||||
|
|
||||||
|
let prefill_workers = if let Some(model) = effective_model_id {
|
||||||
self.worker_registry
|
self.worker_registry
|
||||||
.get_by_model_fast(model)
|
.get_by_model_fast(model)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@@ -859,7 +868,7 @@ impl PDRouter {
|
|||||||
self.worker_registry.get_prefill_workers()
|
self.worker_registry.get_prefill_workers()
|
||||||
};
|
};
|
||||||
|
|
||||||
let decode_workers = if let Some(model) = model_id {
|
let decode_workers = if let Some(model) = effective_model_id {
|
||||||
self.worker_registry
|
self.worker_registry
|
||||||
.get_by_model_fast(model)
|
.get_by_model_fast(model)
|
||||||
.into_iter()
|
.into_iter()
|
||||||
@@ -1797,6 +1806,7 @@ mod tests {
|
|||||||
prefill_drain_tx: mpsc::channel(100).0,
|
prefill_drain_tx: mpsc::channel(100).0,
|
||||||
retry_config: RetryConfig::default(),
|
retry_config: RetryConfig::default(),
|
||||||
api_key: Some("test_api_key".to_string()),
|
api_key: Some("test_api_key".to_string()),
|
||||||
|
enable_igw: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -35,6 +35,7 @@ pub struct Router {
|
|||||||
policy_registry: Arc<PolicyRegistry>,
|
policy_registry: Arc<PolicyRegistry>,
|
||||||
client: Client,
|
client: Client,
|
||||||
dp_aware: bool,
|
dp_aware: bool,
|
||||||
|
enable_igw: bool,
|
||||||
retry_config: RetryConfig,
|
retry_config: RetryConfig,
|
||||||
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||||
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||||
@@ -93,6 +94,7 @@ impl Router {
|
|||||||
policy_registry: ctx.policy_registry.clone(),
|
policy_registry: ctx.policy_registry.clone(),
|
||||||
client: ctx.client.clone(),
|
client: ctx.client.clone(),
|
||||||
dp_aware: ctx.router_config.dp_aware,
|
dp_aware: ctx.router_config.dp_aware,
|
||||||
|
enable_igw: ctx.router_config.enable_igw,
|
||||||
retry_config: ctx.router_config.effective_retry_config(),
|
retry_config: ctx.router_config.effective_retry_config(),
|
||||||
_worker_loads: worker_loads,
|
_worker_loads: worker_loads,
|
||||||
_load_monitor_handle: load_monitor_handle,
|
_load_monitor_handle: load_monitor_handle,
|
||||||
@@ -162,9 +164,11 @@ impl Router {
|
|||||||
model_id: Option<&str>,
|
model_id: Option<&str>,
|
||||||
text: Option<&str>,
|
text: Option<&str>,
|
||||||
) -> Option<Arc<dyn Worker>> {
|
) -> Option<Arc<dyn Worker>> {
|
||||||
|
let effective_model_id = if !self.enable_igw { None } else { model_id };
|
||||||
|
|
||||||
// Get workers for the specified model O(1), filtered by connection mode
|
// Get workers for the specified model O(1), filtered by connection mode
|
||||||
let workers = self.worker_registry.get_workers_filtered(
|
let workers = self.worker_registry.get_workers_filtered(
|
||||||
model_id,
|
effective_model_id,
|
||||||
Some(WorkerType::Regular),
|
Some(WorkerType::Regular),
|
||||||
Some(ConnectionMode::Http),
|
Some(ConnectionMode::Http),
|
||||||
false, // get all workers, we'll filter by is_available() next
|
false, // get all workers, we'll filter by is_available() next
|
||||||
@@ -1106,6 +1110,7 @@ mod tests {
|
|||||||
retry_config: RetryConfig::default(),
|
retry_config: RetryConfig::default(),
|
||||||
_worker_loads: Arc::new(rx),
|
_worker_loads: Arc::new(rx),
|
||||||
_load_monitor_handle: None,
|
_load_monitor_handle: None,
|
||||||
|
enable_igw: false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,14 @@
|
|||||||
//! - Single Router Mode (enable_igw=false): Router owns workers directly
|
//! - Single Router Mode (enable_igw=false): Router owns workers directly
|
||||||
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
|
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
|
||||||
|
|
||||||
use crate::core::{Worker, WorkerRegistry, WorkerType};
|
use crate::config::{ConnectionMode, RoutingMode};
|
||||||
|
use crate::core::{WorkerRegistry, WorkerType};
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
};
|
};
|
||||||
use crate::routers::RouterTrait;
|
use crate::routers::RouterTrait;
|
||||||
|
use crate::server::{AppContext, ServerConfig};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use axum::{
|
use axum::{
|
||||||
body::Body,
|
body::Body,
|
||||||
@@ -19,9 +21,8 @@ use axum::{
|
|||||||
};
|
};
|
||||||
use dashmap::DashMap;
|
use dashmap::DashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tracing::info;
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
/// Router identifier
|
|
||||||
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
#[derive(Debug, Clone, Hash, Eq, PartialEq)]
|
||||||
pub struct RouterId(String);
|
pub struct RouterId(String);
|
||||||
|
|
||||||
@@ -35,30 +36,120 @@ impl RouterId {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Router Manager - Central coordinator for routers and workers
|
|
||||||
pub struct RouterManager {
|
pub struct RouterManager {
|
||||||
/// Worker registry (single source of truth in multi-router mode)
|
|
||||||
worker_registry: Arc<WorkerRegistry>,
|
worker_registry: Arc<WorkerRegistry>,
|
||||||
|
|
||||||
/// All routers managed by this manager
|
|
||||||
/// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd"
|
|
||||||
routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>,
|
routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>,
|
||||||
|
|
||||||
/// Default router for requests without specific routing
|
|
||||||
default_router: Arc<std::sync::RwLock<Option<RouterId>>>,
|
default_router: Arc<std::sync::RwLock<Option<RouterId>>>,
|
||||||
|
enable_igw: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RouterManager {
|
impl RouterManager {
|
||||||
/// Create a new router manager with shared registries
|
|
||||||
pub fn new(worker_registry: Arc<WorkerRegistry>) -> Self {
|
pub fn new(worker_registry: Arc<WorkerRegistry>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
worker_registry,
|
worker_registry,
|
||||||
routers: Arc::new(DashMap::new()),
|
routers: Arc::new(DashMap::new()),
|
||||||
default_router: Arc::new(std::sync::RwLock::new(None)),
|
default_router: Arc::new(std::sync::RwLock::new(None)),
|
||||||
|
enable_igw: false, // Will be set properly in from_config
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn from_config(
|
||||||
|
config: &ServerConfig,
|
||||||
|
app_context: &Arc<AppContext>,
|
||||||
|
) -> Result<Arc<Self>, String> {
|
||||||
|
use crate::routers::RouterFactory;
|
||||||
|
|
||||||
|
let mut manager = Self::new(app_context.worker_registry.clone());
|
||||||
|
manager.enable_igw = config.router_config.enable_igw;
|
||||||
|
let manager = Arc::new(manager);
|
||||||
|
|
||||||
|
if config.router_config.enable_igw {
|
||||||
|
info!("Initializing RouterManager in multi-router mode (IGW)");
|
||||||
|
|
||||||
|
match RouterFactory::create_regular_router(app_context).await {
|
||||||
|
Ok(http_regular) => {
|
||||||
|
info!("Created HTTP Regular router");
|
||||||
|
manager.register_router(
|
||||||
|
RouterId::new("http-regular".to_string()),
|
||||||
|
Arc::from(http_regular),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to create HTTP Regular router: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
match RouterFactory::create_pd_router(
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
&config.router_config.policy,
|
||||||
|
app_context,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(http_pd) => {
|
||||||
|
info!("Created HTTP PD router");
|
||||||
|
manager
|
||||||
|
.register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd));
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!("Failed to create HTTP PD router: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Add gRPC routers once we have dynamic tokenizer loading
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"RouterManager initialized with {} routers for multi-router mode",
|
||||||
|
manager.router_count()
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
info!("Initializing RouterManager in single-router mode");
|
||||||
|
|
||||||
|
let single_router = Arc::from(RouterFactory::create_router(app_context).await?);
|
||||||
|
let router_id = Self::determine_router_id(
|
||||||
|
&config.router_config.mode,
|
||||||
|
&config.router_config.connection_mode,
|
||||||
|
);
|
||||||
|
|
||||||
|
info!("Created single router with ID: {}", router_id.as_str());
|
||||||
|
manager.register_router(router_id.clone(), single_router);
|
||||||
|
manager.set_default_router(router_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if manager.router_count() == 0 {
|
||||||
|
return Err("No routers could be initialized".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(manager)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn determine_router_id(
|
||||||
|
routing_mode: &RoutingMode,
|
||||||
|
connection_mode: &ConnectionMode,
|
||||||
|
) -> RouterId {
|
||||||
|
match (connection_mode, routing_mode) {
|
||||||
|
(ConnectionMode::Http, RoutingMode::Regular { .. }) => {
|
||||||
|
RouterId::new("http-regular".to_string())
|
||||||
|
}
|
||||||
|
(ConnectionMode::Http, RoutingMode::PrefillDecode { .. }) => {
|
||||||
|
RouterId::new("http-pd".to_string())
|
||||||
|
}
|
||||||
|
(ConnectionMode::Http, RoutingMode::OpenAI { .. }) => {
|
||||||
|
RouterId::new("http-openai".to_string())
|
||||||
|
}
|
||||||
|
(ConnectionMode::Grpc, RoutingMode::Regular { .. }) => {
|
||||||
|
RouterId::new("grpc-regular".to_string())
|
||||||
|
}
|
||||||
|
(ConnectionMode::Grpc, RoutingMode::PrefillDecode { .. }) => {
|
||||||
|
RouterId::new("grpc-pd".to_string())
|
||||||
|
}
|
||||||
|
(ConnectionMode::Grpc, RoutingMode::OpenAI { .. }) => {
|
||||||
|
RouterId::new("grpc-regular".to_string())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Register a router with the manager
|
|
||||||
pub fn register_router(&self, id: RouterId, router: Arc<dyn RouterTrait>) {
|
pub fn register_router(&self, id: RouterId, router: Arc<dyn RouterTrait>) {
|
||||||
self.routers.insert(id.clone(), router);
|
self.routers.insert(id.clone(), router);
|
||||||
|
|
||||||
@@ -69,18 +160,15 @@ impl RouterManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Set the default router
|
|
||||||
pub fn set_default_router(&self, id: RouterId) {
|
pub fn set_default_router(&self, id: RouterId) {
|
||||||
let mut default_router = self.default_router.write().unwrap();
|
let mut default_router = self.default_router.write().unwrap();
|
||||||
*default_router = Some(id);
|
*default_router = Some(id);
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the number of registered routers
|
|
||||||
pub fn router_count(&self) -> usize {
|
pub fn router_count(&self) -> usize {
|
||||||
self.routers.len()
|
self.routers.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get router for a specific model based on worker types
|
|
||||||
pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> {
|
pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> {
|
||||||
let workers = self.worker_registry.get_by_model(model_id);
|
let workers = self.worker_registry.get_by_model(model_id);
|
||||||
|
|
||||||
@@ -111,21 +199,25 @@ impl RouterManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get workers for routing decision
|
|
||||||
pub fn get_workers_for_request(&self, model_id: Option<&str>) -> Vec<Arc<dyn Worker>> {
|
|
||||||
if let Some(model) = model_id {
|
|
||||||
self.worker_registry.get_by_model(model)
|
|
||||||
} else {
|
|
||||||
self.worker_registry.get_all()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Get the appropriate router for a request based on headers and request content
|
|
||||||
pub fn select_router_for_request(
|
pub fn select_router_for_request(
|
||||||
&self,
|
&self,
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
model_id: Option<&str>,
|
model_id: Option<&str>,
|
||||||
) -> Option<Arc<dyn RouterTrait>> {
|
) -> Option<Arc<dyn RouterTrait>> {
|
||||||
|
// In single-router mode (enable_igw=false), always use the default router
|
||||||
|
if !self.enable_igw {
|
||||||
|
let default_router = self.default_router.read().unwrap();
|
||||||
|
if let Some(ref default_id) = *default_router {
|
||||||
|
debug!(
|
||||||
|
"Single-router mode: using default router {} for model {:?}",
|
||||||
|
default_id.as_str(),
|
||||||
|
model_id
|
||||||
|
);
|
||||||
|
return self.routers.get(default_id).map(|r| r.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Multi-router mode logic follows
|
||||||
let _priority_threshold = headers.and_then(|h| {
|
let _priority_threshold = headers.and_then(|h| {
|
||||||
h.get("x-worker-priority")
|
h.get("x-worker-priority")
|
||||||
.and_then(|v| v.to_str().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
@@ -176,10 +268,6 @@ impl RouterManager {
|
|||||||
score += 1.0;
|
score += 1.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get workers for this router and evaluate based on priority/cost
|
|
||||||
// Note: This would require routers to expose their workers or stats
|
|
||||||
// For now, we'll use a simple selection based on router type
|
|
||||||
|
|
||||||
// TODO: Once routers expose worker stats, we can evaluate:
|
// TODO: Once routers expose worker stats, we can evaluate:
|
||||||
// - Average worker priority vs priority_threshold
|
// - Average worker priority vs priority_threshold
|
||||||
// - Average worker cost vs max_cost
|
// - Average worker cost vs max_cost
|
||||||
@@ -201,16 +289,11 @@ impl RouterTrait for RouterManager {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Health check - return 503 if no routers available
|
|
||||||
async fn health(&self, _req: Request<Body>) -> Response {
|
async fn health(&self, _req: Request<Body>) -> Response {
|
||||||
// Health check should succeed if RouterManager exists, even without routers
|
|
||||||
// Individual router health can be checked via specific endpoints
|
|
||||||
(StatusCode::OK, "RouterManager is healthy").into_response()
|
(StatusCode::OK, "RouterManager is healthy").into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Health generate - check if any router can handle generate requests
|
|
||||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||||
// Return 503 since we have no routers with workers
|
|
||||||
// TODO: Should check if any router has healthy workers
|
// TODO: Should check if any router has healthy workers
|
||||||
(
|
(
|
||||||
StatusCode::SERVICE_UNAVAILABLE,
|
StatusCode::SERVICE_UNAVAILABLE,
|
||||||
@@ -219,10 +302,8 @@ impl RouterTrait for RouterManager {
|
|||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get server information - aggregate from all routers
|
|
||||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||||
// TODO: Aggregate info from all routers with healthy workers
|
// TODO: Aggregate info from all routers with healthy workers
|
||||||
// For now, return basic info about the RouterManager
|
|
||||||
(
|
(
|
||||||
StatusCode::OK,
|
StatusCode::OK,
|
||||||
serde_json::json!({
|
serde_json::json!({
|
||||||
@@ -235,9 +316,7 @@ impl RouterTrait for RouterManager {
|
|||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get available models - query from worker registry
|
|
||||||
async fn get_models(&self, _req: Request<Body>) -> Response {
|
async fn get_models(&self, _req: Request<Body>) -> Response {
|
||||||
// Get models from worker registry
|
|
||||||
let models = self.worker_registry.get_models();
|
let models = self.worker_registry.get_models();
|
||||||
|
|
||||||
if models.is_empty() {
|
if models.is_empty() {
|
||||||
@@ -254,10 +333,8 @@ impl RouterTrait for RouterManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get model information
|
|
||||||
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
async fn get_model_info(&self, _req: Request<Body>) -> Response {
|
||||||
// TODO: Extract model from request and route to appropriate router
|
// TODO: Extract model from request and route to appropriate router
|
||||||
// For now, return not implemented
|
|
||||||
(
|
(
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
"Model info endpoint not yet implemented in RouterManager",
|
"Model info endpoint not yet implemented in RouterManager",
|
||||||
@@ -265,22 +342,17 @@ impl RouterTrait for RouterManager {
|
|||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Route a generate request
|
|
||||||
async fn route_generate(
|
async fn route_generate(
|
||||||
&self,
|
&self,
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
body: &GenerateRequest,
|
body: &GenerateRequest,
|
||||||
_model_id: Option<&str>,
|
_model_id: Option<&str>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
// Select router based on headers
|
|
||||||
// GenerateRequest doesn't have a model field
|
|
||||||
let router = self.select_router_for_request(headers, None);
|
let router = self.select_router_for_request(headers, None);
|
||||||
|
|
||||||
if let Some(router) = router {
|
if let Some(router) = router {
|
||||||
// In multi-model mode, pass None since GenerateRequest doesn't have model field
|
|
||||||
router.route_generate(headers, body, None).await
|
router.route_generate(headers, body, None).await
|
||||||
} else {
|
} else {
|
||||||
// Return 404 when no router is available for the request
|
|
||||||
(
|
(
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
"No router available for this request",
|
"No router available for this request",
|
||||||
@@ -289,7 +361,6 @@ impl RouterTrait for RouterManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Route a chat completion request
|
|
||||||
async fn route_chat(
|
async fn route_chat(
|
||||||
&self,
|
&self,
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
@@ -299,10 +370,8 @@ impl RouterTrait for RouterManager {
|
|||||||
let router = self.select_router_for_request(headers, Some(&body.model));
|
let router = self.select_router_for_request(headers, Some(&body.model));
|
||||||
|
|
||||||
if let Some(router) = router {
|
if let Some(router) = router {
|
||||||
// In multi-model mode, pass the model_id to the router
|
|
||||||
router.route_chat(headers, body, Some(&body.model)).await
|
router.route_chat(headers, body, Some(&body.model)).await
|
||||||
} else {
|
} else {
|
||||||
// Return 404 when the specified model is not found
|
|
||||||
(
|
(
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
format!("Model '{}' not found or no router available", body.model),
|
format!("Model '{}' not found or no router available", body.model),
|
||||||
@@ -311,7 +380,6 @@ impl RouterTrait for RouterManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Route a completion request
|
|
||||||
async fn route_completion(
|
async fn route_completion(
|
||||||
&self,
|
&self,
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
@@ -321,12 +389,10 @@ impl RouterTrait for RouterManager {
|
|||||||
let router = self.select_router_for_request(headers, Some(&body.model));
|
let router = self.select_router_for_request(headers, Some(&body.model));
|
||||||
|
|
||||||
if let Some(router) = router {
|
if let Some(router) = router {
|
||||||
// In multi-model mode, pass the model_id to the router
|
|
||||||
router
|
router
|
||||||
.route_completion(headers, body, Some(&body.model))
|
.route_completion(headers, body, Some(&body.model))
|
||||||
.await
|
.await
|
||||||
} else {
|
} else {
|
||||||
// Return 404 when the specified model is not found
|
|
||||||
(
|
(
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
format!("Model '{}' not found or no router available", body.model),
|
format!("Model '{}' not found or no router available", body.model),
|
||||||
@@ -348,26 +414,6 @@ impl RouterTrait for RouterManager {
|
|||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
|
||||||
(
|
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
|
||||||
"responses api not yet implemented in inference gateway mode",
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn list_response_input_items(
|
|
||||||
&self,
|
|
||||||
_headers: Option<&HeaderMap>,
|
|
||||||
_response_id: &str,
|
|
||||||
) -> Response {
|
|
||||||
(
|
|
||||||
StatusCode::NOT_IMPLEMENTED,
|
|
||||||
"responses api not yet implemented in inference gateway mode",
|
|
||||||
)
|
|
||||||
.into_response()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
async fn get_response(&self, headers: Option<&HeaderMap>, response_id: &str) -> Response {
|
||||||
let router = self.select_router_for_request(headers, None);
|
let router = self.select_router_for_request(headers, None);
|
||||||
if let Some(router) = router {
|
if let Some(router) = router {
|
||||||
@@ -394,7 +440,26 @@ impl RouterTrait for RouterManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Route embeddings request
|
async fn delete_response(&self, _headers: Option<&HeaderMap>, _response_id: &str) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"responses api not yet implemented in inference gateway mode",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list_response_input_items(
|
||||||
|
&self,
|
||||||
|
_headers: Option<&HeaderMap>,
|
||||||
|
_response_id: &str,
|
||||||
|
) -> Response {
|
||||||
|
(
|
||||||
|
StatusCode::NOT_IMPLEMENTED,
|
||||||
|
"responses api not yet implemented in inference gateway mode",
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
}
|
||||||
|
|
||||||
async fn route_embeddings(
|
async fn route_embeddings(
|
||||||
&self,
|
&self,
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
@@ -408,7 +473,6 @@ impl RouterTrait for RouterManager {
|
|||||||
.route_embeddings(headers, body, Some(&body.model))
|
.route_embeddings(headers, body, Some(&body.model))
|
||||||
.await
|
.await
|
||||||
} else {
|
} else {
|
||||||
// Return 404 when the specified model is not found
|
|
||||||
(
|
(
|
||||||
StatusCode::NOT_FOUND,
|
StatusCode::NOT_FOUND,
|
||||||
format!("Model '{}' not found or no router available", body.model),
|
format!("Model '{}' not found or no router available", body.model),
|
||||||
@@ -417,14 +481,12 @@ impl RouterTrait for RouterManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Route rerank request
|
|
||||||
async fn route_rerank(
|
async fn route_rerank(
|
||||||
&self,
|
&self,
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
body: &RerankRequest,
|
body: &RerankRequest,
|
||||||
model_id: Option<&str>,
|
model_id: Option<&str>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
// Try to select a router based on headers
|
|
||||||
let router = self.select_router_for_request(headers, None);
|
let router = self.select_router_for_request(headers, None);
|
||||||
|
|
||||||
if let Some(router) = router {
|
if let Some(router) = router {
|
||||||
@@ -438,10 +500,8 @@ impl RouterTrait for RouterManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Flush cache on all routers and workers
|
|
||||||
async fn flush_cache(&self) -> Response {
|
async fn flush_cache(&self) -> Response {
|
||||||
// TODO: Call flush_cache on all routers that have workers
|
// TODO: Call flush_cache on all routers that have workers
|
||||||
// For now, return success if we have any routers
|
|
||||||
if self.routers.is_empty() {
|
if self.routers.is_empty() {
|
||||||
(StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response()
|
(StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response()
|
||||||
} else {
|
} else {
|
||||||
@@ -450,9 +510,7 @@ impl RouterTrait for RouterManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get worker loads from all routers
|
|
||||||
async fn get_worker_loads(&self) -> Response {
|
async fn get_worker_loads(&self) -> Response {
|
||||||
// Return worker loads from the registry
|
|
||||||
let workers = self.worker_registry.get_all();
|
let workers = self.worker_registry.get_all();
|
||||||
let loads: Vec<serde_json::Value> = workers
|
let loads: Vec<serde_json::Value> = workers
|
||||||
.iter()
|
.iter()
|
||||||
@@ -476,12 +534,10 @@ impl RouterTrait for RouterManager {
|
|||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get router type name
|
|
||||||
fn router_type(&self) -> &'static str {
|
fn router_type(&self) -> &'static str {
|
||||||
"manager"
|
"manager"
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Server readiness check - check if any router is ready
|
|
||||||
fn readiness(&self) -> Response {
|
fn readiness(&self) -> Response {
|
||||||
if self.routers.is_empty() {
|
if self.routers.is_empty() {
|
||||||
(StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response()
|
(StatusCode::SERVICE_UNAVAILABLE, "No routers configured").into_response()
|
||||||
@@ -492,9 +548,6 @@ impl RouterTrait for RouterManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: get_first_available_router removed - we now properly handle
|
|
||||||
// router selection based on model and worker availability
|
|
||||||
|
|
||||||
impl std::fmt::Debug for RouterManager {
|
impl std::fmt::Debug for RouterManager {
|
||||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||||
f.debug_struct("RouterManager")
|
f.debug_struct("RouterManager")
|
||||||
|
|||||||
@@ -14,10 +14,7 @@ use crate::{
|
|||||||
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
|
worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse},
|
||||||
},
|
},
|
||||||
reasoning_parser::ParserFactory,
|
reasoning_parser::ParserFactory,
|
||||||
routers::{
|
routers::{router_manager::RouterManager, RouterTrait},
|
||||||
router_manager::{RouterId, RouterManager},
|
|
||||||
RouterFactory, RouterTrait,
|
|
||||||
},
|
|
||||||
service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
|
service_discovery::{start_service_discovery, ServiceDiscoveryConfig},
|
||||||
tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
|
tokenizer::{factory as tokenizer_factory, traits::Tokenizer},
|
||||||
tool_parser::ParserRegistry,
|
tool_parser::ParserRegistry,
|
||||||
@@ -64,10 +61,8 @@ impl AppContext {
|
|||||||
let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
|
let rate_limit_tokens = rate_limit_tokens_per_second.unwrap_or(max_concurrent_requests);
|
||||||
let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
|
let rate_limiter = Arc::new(TokenBucket::new(max_concurrent_requests, rate_limit_tokens));
|
||||||
|
|
||||||
// Initialize gRPC-specific components only when in gRPC mode
|
|
||||||
let (tokenizer, reasoning_parser_factory, tool_parser_registry) =
|
let (tokenizer, reasoning_parser_factory, tool_parser_registry) =
|
||||||
if router_config.connection_mode == ConnectionMode::Grpc {
|
if router_config.connection_mode == ConnectionMode::Grpc {
|
||||||
// Get tokenizer path (required for gRPC mode)
|
|
||||||
let tokenizer_path = router_config
|
let tokenizer_path = router_config
|
||||||
.tokenizer_path
|
.tokenizer_path
|
||||||
.clone()
|
.clone()
|
||||||
@@ -77,7 +72,6 @@ impl AppContext {
|
|||||||
.to_string()
|
.to_string()
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
// Initialize all gRPC components
|
|
||||||
let tokenizer = Some(
|
let tokenizer = Some(
|
||||||
tokenizer_factory::create_tokenizer(&tokenizer_path)
|
tokenizer_factory::create_tokenizer(&tokenizer_path)
|
||||||
.map_err(|e| format!("Failed to create tokenizer: {e}"))?,
|
.map_err(|e| format!("Failed to create tokenizer: {e}"))?,
|
||||||
@@ -87,7 +81,6 @@ impl AppContext {
|
|||||||
|
|
||||||
(tokenizer, reasoning_parser_factory, tool_parser_registry)
|
(tokenizer, reasoning_parser_factory, tool_parser_registry)
|
||||||
} else {
|
} else {
|
||||||
// HTTP mode doesn't need these components
|
|
||||||
(None, None, None)
|
(None, None, None)
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -96,7 +89,6 @@ impl AppContext {
|
|||||||
|
|
||||||
let router_manager = None;
|
let router_manager = None;
|
||||||
|
|
||||||
// Initialize response storage based on configuration
|
|
||||||
let response_storage: SharedResponseStorage = match router_config.history_backend {
|
let response_storage: SharedResponseStorage = match router_config.history_backend {
|
||||||
HistoryBackend::Memory => Arc::new(MemoryResponseStorage::new()),
|
HistoryBackend::Memory => Arc::new(MemoryResponseStorage::new()),
|
||||||
HistoryBackend::None => Arc::new(NoOpResponseStorage::new()),
|
HistoryBackend::None => Arc::new(NoOpResponseStorage::new()),
|
||||||
@@ -125,12 +117,10 @@ pub struct AppState {
|
|||||||
pub router_manager: Option<Arc<RouterManager>>,
|
pub router_manager: Option<Arc<RouterManager>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fallback handler for unmatched routes
|
|
||||||
async fn sink_handler() -> Response {
|
async fn sink_handler() -> Response {
|
||||||
StatusCode::NOT_FOUND.into_response()
|
StatusCode::NOT_FOUND.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
// Health check endpoints
|
|
||||||
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
|
async fn liveness(State(state): State<Arc<AppState>>) -> Response {
|
||||||
state.router.liveness()
|
state.router.liveness()
|
||||||
}
|
}
|
||||||
@@ -257,7 +247,6 @@ async fn v1_responses_delete(
|
|||||||
Path(response_id): Path<String>,
|
Path(response_id): Path<String>,
|
||||||
headers: http::HeaderMap,
|
headers: http::HeaderMap,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
// Python server does not support this yet
|
|
||||||
state
|
state
|
||||||
.router
|
.router
|
||||||
.delete_response(Some(&headers), &response_id)
|
.delete_response(Some(&headers), &response_id)
|
||||||
@@ -269,15 +258,12 @@ async fn v1_responses_list_input_items(
|
|||||||
Path(response_id): Path<String>,
|
Path(response_id): Path<String>,
|
||||||
headers: http::HeaderMap,
|
headers: http::HeaderMap,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
// Python server does not support this yet
|
|
||||||
state
|
state
|
||||||
.router
|
.router
|
||||||
.list_response_input_items(Some(&headers), &response_id)
|
.list_response_input_items(Some(&headers), &response_id)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------- Worker management endpoints (Legacy) ----------
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct AddWorkerQuery {
|
struct AddWorkerQuery {
|
||||||
url: String,
|
url: String,
|
||||||
@@ -288,7 +274,6 @@ async fn add_worker(
|
|||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
|
Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
// Use centralized WorkerManager with full context
|
|
||||||
let result = WorkerManager::add_worker(&url, &api_key, &state.context).await;
|
let result = WorkerManager::add_worker(&url, &api_key, &state.context).await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
@@ -298,7 +283,6 @@ async fn add_worker(
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
|
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
|
||||||
// Use centralized WorkerManager instead of router's get_worker_urls
|
|
||||||
let worker_list = WorkerManager::get_worker_urls(&state.context.worker_registry);
|
let worker_list = WorkerManager::get_worker_urls(&state.context.worker_registry);
|
||||||
Json(json!({ "urls": worker_list })).into_response()
|
Json(json!({ "urls": worker_list })).into_response()
|
||||||
}
|
}
|
||||||
@@ -307,7 +291,6 @@ async fn remove_worker(
|
|||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
Query(AddWorkerQuery { url, .. }): Query<AddWorkerQuery>,
|
Query(AddWorkerQuery { url, .. }): Query<AddWorkerQuery>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
// Use centralized WorkerManager with full context
|
|
||||||
let result = WorkerManager::remove_worker(&url, &state.context);
|
let result = WorkerManager::remove_worker(&url, &state.context);
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
@@ -324,14 +307,10 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons
|
|||||||
state.router.get_worker_loads().await
|
state.router.get_worker_loads().await
|
||||||
}
|
}
|
||||||
|
|
||||||
// ---------- Worker management endpoints (RESTful) ----------
|
|
||||||
|
|
||||||
/// POST /workers - Add a new worker with full configuration
|
|
||||||
async fn create_worker(
|
async fn create_worker(
|
||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
Json(config): Json<WorkerConfigRequest>,
|
Json(config): Json<WorkerConfigRequest>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
// In single router mode, use centralized WorkerManager with full context
|
|
||||||
let result = WorkerManager::add_worker_from_config(&config, &state.context).await;
|
let result = WorkerManager::add_worker_from_config(&config, &state.context).await;
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
@@ -353,9 +332,7 @@ async fn create_worker(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /workers - List all workers with details
|
|
||||||
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
||||||
// In single router mode, get detailed worker info from registry
|
|
||||||
let workers = state.context.worker_registry.get_all();
|
let workers = state.context.worker_registry.get_all();
|
||||||
let response = serde_json::json!({
|
let response = serde_json::json!({
|
||||||
"workers": workers.iter().map(|worker| {
|
"workers": workers.iter().map(|worker| {
|
||||||
@@ -374,7 +351,6 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
|||||||
"cost": worker.cost(),
|
"cost": worker.cost(),
|
||||||
});
|
});
|
||||||
|
|
||||||
// Add bootstrap_port for Prefill workers
|
|
||||||
if let WorkerType::Prefill { bootstrap_port } = worker.worker_type() {
|
if let WorkerType::Prefill { bootstrap_port } = worker.worker_type() {
|
||||||
worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port);
|
worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port);
|
||||||
}
|
}
|
||||||
@@ -391,7 +367,6 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
|||||||
Json(response).into_response()
|
Json(response).into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// GET /workers/{url} - Get specific worker info
|
|
||||||
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
|
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
|
||||||
let workers = WorkerManager::get_worker_urls(&state.context.worker_registry);
|
let workers = WorkerManager::get_worker_urls(&state.context.worker_registry);
|
||||||
if workers.contains(&url) {
|
if workers.contains(&url) {
|
||||||
@@ -410,9 +385,7 @@ async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>)
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// DELETE /workers/{url} - Remove a worker
|
|
||||||
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
|
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
|
||||||
// In single router mode, use centralized WorkerManager with full context
|
|
||||||
let result = WorkerManager::remove_worker(&url, &state.context);
|
let result = WorkerManager::remove_worker(&url, &state.context);
|
||||||
|
|
||||||
match result {
|
match result {
|
||||||
@@ -447,14 +420,12 @@ pub struct ServerConfig {
|
|||||||
pub request_id_headers: Option<Vec<String>>,
|
pub request_id_headers: Option<Vec<String>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build the Axum application with all routes and middleware
|
|
||||||
pub fn build_app(
|
pub fn build_app(
|
||||||
app_state: Arc<AppState>,
|
app_state: Arc<AppState>,
|
||||||
max_payload_size: usize,
|
max_payload_size: usize,
|
||||||
request_id_headers: Vec<String>,
|
request_id_headers: Vec<String>,
|
||||||
cors_allowed_origins: Vec<String>,
|
cors_allowed_origins: Vec<String>,
|
||||||
) -> Router {
|
) -> Router {
|
||||||
// Create routes
|
|
||||||
let protected_routes = Router::new()
|
let protected_routes = Router::new()
|
||||||
.route("/generate", post(generate))
|
.route("/generate", post(generate))
|
||||||
.route("/v1/chat/completions", post(v1_chat_completions))
|
.route("/v1/chat/completions", post(v1_chat_completions))
|
||||||
@@ -494,20 +465,17 @@ pub fn build_app(
|
|||||||
.route("/flush_cache", post(flush_cache))
|
.route("/flush_cache", post(flush_cache))
|
||||||
.route("/get_loads", get(get_loads));
|
.route("/get_loads", get(get_loads));
|
||||||
|
|
||||||
// Worker management routes
|
|
||||||
let worker_routes = Router::new()
|
let worker_routes = Router::new()
|
||||||
.route("/workers", post(create_worker))
|
.route("/workers", post(create_worker))
|
||||||
.route("/workers", get(list_workers_rest))
|
.route("/workers", get(list_workers_rest))
|
||||||
.route("/workers/{url}", get(get_worker))
|
.route("/workers/{url}", get(get_worker))
|
||||||
.route("/workers/{url}", delete(delete_worker));
|
.route("/workers/{url}", delete(delete_worker));
|
||||||
|
|
||||||
// Build app with all routes and middleware
|
|
||||||
Router::new()
|
Router::new()
|
||||||
.merge(protected_routes)
|
.merge(protected_routes)
|
||||||
.merge(public_routes)
|
.merge(public_routes)
|
||||||
.merge(admin_routes)
|
.merge(admin_routes)
|
||||||
.merge(worker_routes)
|
.merge(worker_routes)
|
||||||
// Request body size limiting
|
|
||||||
.layer(tower_http::limit::RequestBodyLimitLayer::new(
|
.layer(tower_http::limit::RequestBodyLimitLayer::new(
|
||||||
max_payload_size,
|
max_payload_size,
|
||||||
))
|
))
|
||||||
@@ -519,7 +487,6 @@ pub fn build_app(
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
|
pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Error>> {
|
||||||
// Only initialize logging if not already done (for Python bindings support)
|
|
||||||
static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);
|
static LOGGING_INITIALIZED: AtomicBool = AtomicBool::new(false);
|
||||||
|
|
||||||
let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
|
let _log_guard = if !LOGGING_INITIALIZED.swap(true, Ordering::SeqCst) {
|
||||||
@@ -545,9 +512,8 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
None
|
None
|
||||||
};
|
};
|
||||||
|
|
||||||
// Initialize prometheus metrics exporter
|
if let Some(prometheus_config) = &config.prometheus_config {
|
||||||
if let Some(prometheus_config) = config.prometheus_config {
|
metrics::start_prometheus(prometheus_config.clone());
|
||||||
metrics::start_prometheus(prometheus_config);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
@@ -569,7 +535,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
.build()
|
.build()
|
||||||
.expect("Failed to create HTTP client");
|
.expect("Failed to create HTTP client");
|
||||||
|
|
||||||
// Create the application context with all dependencies
|
|
||||||
let app_context = AppContext::new(
|
let app_context = AppContext::new(
|
||||||
config.router_config.clone(),
|
config.router_config.clone(),
|
||||||
client.clone(),
|
client.clone(),
|
||||||
@@ -597,67 +562,9 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
worker_stats.total_workers, worker_stats.healthy_workers
|
worker_stats.total_workers, worker_stats.healthy_workers
|
||||||
);
|
);
|
||||||
|
|
||||||
// Create the appropriate router based on enable_igw flag
|
let router_manager = RouterManager::from_config(&config, &app_context).await?;
|
||||||
let (router, router_manager): (Arc<dyn RouterTrait>, Option<Arc<RouterManager>>) =
|
let router: Arc<dyn RouterTrait> = router_manager.clone();
|
||||||
if config.router_config.enable_igw {
|
|
||||||
info!("Multi-router mode enabled (enable_igw=true)");
|
|
||||||
|
|
||||||
// Create RouterManager with shared registries from AppContext
|
|
||||||
let router_manager = Arc::new(RouterManager::new(app_context.worker_registry.clone()));
|
|
||||||
|
|
||||||
// 1. HTTP Regular Router
|
|
||||||
match RouterFactory::create_regular_router(&app_context).await {
|
|
||||||
Ok(http_regular) => {
|
|
||||||
info!("Created HTTP Regular router");
|
|
||||||
router_manager.register_router(
|
|
||||||
RouterId::new("http-regular".to_string()),
|
|
||||||
Arc::from(http_regular),
|
|
||||||
);
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to create HTTP Regular router: {e}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. HTTP PD Router
|
|
||||||
match RouterFactory::create_pd_router(
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
&config.router_config.policy,
|
|
||||||
&app_context,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(http_pd) => {
|
|
||||||
info!("Created HTTP PD router");
|
|
||||||
router_manager
|
|
||||||
.register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd));
|
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
warn!("Failed to create HTTP PD router: {e}");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: Add gRPC routers once we have dynamic tokenizer loading
|
|
||||||
|
|
||||||
info!(
|
|
||||||
"RouterManager initialized with {} routers",
|
|
||||||
router_manager.router_count()
|
|
||||||
);
|
|
||||||
(
|
|
||||||
router_manager.clone() as Arc<dyn RouterTrait>,
|
|
||||||
Some(router_manager),
|
|
||||||
)
|
|
||||||
} else {
|
|
||||||
info!("Single router mode (enable_igw=false)");
|
|
||||||
// Create single router with the context
|
|
||||||
(
|
|
||||||
Arc::from(RouterFactory::create_router(&app_context).await?),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
// Start health checker for all workers in the registry
|
|
||||||
let _health_checker = app_context
|
let _health_checker = app_context
|
||||||
.worker_registry
|
.worker_registry
|
||||||
.start_health_checker(config.router_config.health_check.check_interval_secs);
|
.start_health_checker(config.router_config.health_check.check_interval_secs);
|
||||||
@@ -666,14 +573,12 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
config.router_config.health_check.check_interval_secs
|
config.router_config.health_check.check_interval_secs
|
||||||
);
|
);
|
||||||
|
|
||||||
// Set up concurrency limiter with queue if configured
|
|
||||||
let (limiter, processor) = middleware::ConcurrencyLimiter::new(
|
let (limiter, processor) = middleware::ConcurrencyLimiter::new(
|
||||||
app_context.rate_limiter.clone(),
|
app_context.rate_limiter.clone(),
|
||||||
config.router_config.queue_size,
|
config.router_config.queue_size,
|
||||||
Duration::from_secs(config.router_config.queue_timeout_secs),
|
Duration::from_secs(config.router_config.queue_timeout_secs),
|
||||||
);
|
);
|
||||||
|
|
||||||
// Start queue processor if enabled
|
|
||||||
if let Some(processor) = processor {
|
if let Some(processor) = processor {
|
||||||
spawn(processor.run());
|
spawn(processor.run());
|
||||||
info!(
|
info!(
|
||||||
@@ -682,21 +587,18 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create app state with router and context
|
|
||||||
let app_state = Arc::new(AppState {
|
let app_state = Arc::new(AppState {
|
||||||
router,
|
router,
|
||||||
context: app_context.clone(),
|
context: app_context.clone(),
|
||||||
concurrency_queue_tx: limiter.queue_tx.clone(),
|
concurrency_queue_tx: limiter.queue_tx.clone(),
|
||||||
router_manager,
|
router_manager: Some(router_manager),
|
||||||
});
|
});
|
||||||
// Start the service discovery if enabled
|
|
||||||
if let Some(service_discovery_config) = config.service_discovery_config {
|
if let Some(service_discovery_config) = config.service_discovery_config {
|
||||||
if service_discovery_config.enabled {
|
if service_discovery_config.enabled {
|
||||||
let app_context_arc = Arc::clone(&app_state.context);
|
let app_context_arc = Arc::clone(&app_state.context);
|
||||||
match start_service_discovery(service_discovery_config, app_context_arc).await {
|
match start_service_discovery(service_discovery_config, app_context_arc).await {
|
||||||
Ok(handle) => {
|
Ok(handle) => {
|
||||||
info!("Service discovery started");
|
info!("Service discovery started");
|
||||||
// Spawn a task to handle the service discovery thread
|
|
||||||
spawn(async move {
|
spawn(async move {
|
||||||
if let Err(e) = handle.await {
|
if let Err(e) = handle.await {
|
||||||
error!("Service discovery task failed: {:?}", e);
|
error!("Service discovery task failed: {:?}", e);
|
||||||
@@ -725,7 +627,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
]
|
]
|
||||||
});
|
});
|
||||||
|
|
||||||
// Build the application
|
|
||||||
let app = build_app(
|
let app = build_app(
|
||||||
app_state,
|
app_state,
|
||||||
config.max_payload_size,
|
config.max_payload_size,
|
||||||
@@ -744,7 +645,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Graceful shutdown handler
|
|
||||||
async fn shutdown_signal() {
|
async fn shutdown_signal() {
|
||||||
let ctrl_c = async {
|
let ctrl_c = async {
|
||||||
signal::ctrl_c()
|
signal::ctrl_c()
|
||||||
@@ -773,19 +673,16 @@ async fn shutdown_signal() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// CORS Layer Creation
|
|
||||||
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
|
fn create_cors_layer(allowed_origins: Vec<String>) -> tower_http::cors::CorsLayer {
|
||||||
use tower_http::cors::Any;
|
use tower_http::cors::Any;
|
||||||
|
|
||||||
let cors = if allowed_origins.is_empty() {
|
let cors = if allowed_origins.is_empty() {
|
||||||
// Allow all origins if none specified
|
|
||||||
tower_http::cors::CorsLayer::new()
|
tower_http::cors::CorsLayer::new()
|
||||||
.allow_origin(Any)
|
.allow_origin(Any)
|
||||||
.allow_methods(Any)
|
.allow_methods(Any)
|
||||||
.allow_headers(Any)
|
.allow_headers(Any)
|
||||||
.expose_headers(Any)
|
.expose_headers(Any)
|
||||||
} else {
|
} else {
|
||||||
// Restrict to specific origins
|
|
||||||
let origins: Vec<http::HeaderValue> = allowed_origins
|
let origins: Vec<http::HeaderValue> = allowed_origins
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.filter_map(|origin| origin.parse().ok())
|
.filter_map(|origin| origin.parse().ok())
|
||||||
|
|||||||
Reference in New Issue
Block a user