[router] allow one router to support different model families and serving mode (#10244)
This commit is contained in:
@@ -1,12 +1,16 @@
|
||||
use crate::config::RouterConfig;
|
||||
use crate::core::WorkerRegistry;
|
||||
use crate::logging::{self, LoggingConfig};
|
||||
use crate::metrics::{self, PrometheusConfig};
|
||||
use crate::middleware::TokenBucket;
|
||||
use crate::policies::PolicyRegistry;
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, GenerateRequest, RerankRequest, ResponsesRequest,
|
||||
V1RerankReqInput,
|
||||
};
|
||||
use crate::protocols::worker_spec::{WorkerApiResponse, WorkerConfigRequest, WorkerErrorResponse};
|
||||
use crate::reasoning_parser::ParserFactory;
|
||||
use crate::routers::router_manager::{RouterId, RouterManager};
|
||||
use crate::routers::{RouterFactory, RouterTrait};
|
||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||
use crate::tokenizer::{factory as tokenizer_factory, traits::Tokenizer};
|
||||
@@ -36,6 +40,9 @@ pub struct AppContext {
|
||||
pub tokenizer: Option<Arc<dyn Tokenizer>>,
|
||||
pub reasoning_parser_factory: Option<ParserFactory>,
|
||||
pub tool_parser_registry: Option<&'static ParserRegistry>,
|
||||
pub worker_registry: Arc<WorkerRegistry>, // Shared worker registry
|
||||
pub policy_registry: Arc<PolicyRegistry>, // Shared policy registry
|
||||
pub router_manager: Option<Arc<RouterManager>>, // Only present when enable_igw=true
|
||||
}
|
||||
|
||||
impl AppContext {
|
||||
@@ -75,6 +82,15 @@ impl AppContext {
|
||||
(None, None, None)
|
||||
};
|
||||
|
||||
// Initialize shared registries
|
||||
let worker_registry = Arc::new(WorkerRegistry::new());
|
||||
let policy_registry = Arc::new(PolicyRegistry::new(
|
||||
router_config.policy.clone(), // Use default policy from config
|
||||
));
|
||||
|
||||
// Initialize RouterManager only when enable_igw is true
|
||||
let router_manager = None; // Will be initialized in startup() based on config
|
||||
|
||||
Ok(Self {
|
||||
client,
|
||||
router_config,
|
||||
@@ -82,6 +98,9 @@ impl AppContext {
|
||||
tokenizer,
|
||||
reasoning_parser_factory,
|
||||
tool_parser_registry,
|
||||
worker_registry,
|
||||
policy_registry,
|
||||
router_manager,
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -134,7 +153,10 @@ async fn generate(
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<GenerateRequest>,
|
||||
) -> Response {
|
||||
state.router.route_generate(Some(&headers), &body).await
|
||||
state
|
||||
.router
|
||||
.route_generate(Some(&headers), &body, None)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn v1_chat_completions(
|
||||
@@ -142,7 +164,7 @@ async fn v1_chat_completions(
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<ChatCompletionRequest>,
|
||||
) -> Response {
|
||||
state.router.route_chat(Some(&headers), &body).await
|
||||
state.router.route_chat(Some(&headers), &body, None).await
|
||||
}
|
||||
|
||||
async fn v1_completions(
|
||||
@@ -150,7 +172,10 @@ async fn v1_completions(
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<CompletionRequest>,
|
||||
) -> Response {
|
||||
state.router.route_completion(Some(&headers), &body).await
|
||||
state
|
||||
.router
|
||||
.route_completion(Some(&headers), &body, None)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn rerank(
|
||||
@@ -158,7 +183,7 @@ async fn rerank(
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<RerankRequest>,
|
||||
) -> Response {
|
||||
state.router.route_rerank(Some(&headers), &body).await
|
||||
state.router.route_rerank(Some(&headers), &body, None).await
|
||||
}
|
||||
|
||||
async fn v1_rerank(
|
||||
@@ -168,7 +193,7 @@ async fn v1_rerank(
|
||||
) -> Response {
|
||||
state
|
||||
.router
|
||||
.route_rerank(Some(&headers), &body.into())
|
||||
.route_rerank(Some(&headers), &body.into(), None)
|
||||
.await
|
||||
}
|
||||
|
||||
@@ -177,7 +202,10 @@ async fn v1_responses(
|
||||
headers: http::HeaderMap,
|
||||
Json(body): Json<ResponsesRequest>,
|
||||
) -> Response {
|
||||
state.router.route_responses(Some(&headers), &body).await
|
||||
state
|
||||
.router
|
||||
.route_responses(Some(&headers), &body, None)
|
||||
.await
|
||||
}
|
||||
|
||||
// Worker management endpoints
|
||||
@@ -232,6 +260,137 @@ async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Respons
|
||||
state.router.get_worker_loads().await
|
||||
}
|
||||
|
||||
// New RESTful worker management endpoints (when enable_igw=true)
|
||||
|
||||
/// POST /workers - Add a new worker with full configuration
|
||||
async fn create_worker(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(config): Json<WorkerConfigRequest>,
|
||||
) -> Response {
|
||||
// Check if RouterManager is available (enable_igw=true)
|
||||
if let Some(router_manager) = &state.context.router_manager {
|
||||
match router_manager.add_worker(config).await {
|
||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
|
||||
}
|
||||
} else {
|
||||
// In single router mode, use the router's add_worker with basic config
|
||||
match state.router.add_worker(&config.url).await {
|
||||
Ok(message) => {
|
||||
let response = WorkerApiResponse {
|
||||
success: true,
|
||||
message,
|
||||
worker: None,
|
||||
};
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
Err(error) => {
|
||||
let error_response = WorkerErrorResponse {
|
||||
error,
|
||||
code: "ADD_WORKER_FAILED".to_string(),
|
||||
};
|
||||
(StatusCode::BAD_REQUEST, Json(error_response)).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// GET /workers - List all workers with details
|
||||
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
||||
if let Some(router_manager) = &state.context.router_manager {
|
||||
let response = router_manager.list_workers();
|
||||
Json(response).into_response()
|
||||
} else {
|
||||
// In single router mode, get detailed worker info from registry
|
||||
let workers = state.context.worker_registry.get_all();
|
||||
let response = serde_json::json!({
|
||||
"workers": workers.iter().map(|worker| {
|
||||
let mut worker_info = serde_json::json!({
|
||||
"url": worker.url(),
|
||||
"model_id": worker.model_id(),
|
||||
"worker_type": format!("{:?}", worker.worker_type()),
|
||||
"is_healthy": worker.is_healthy(),
|
||||
"load": worker.load(),
|
||||
"connection_mode": format!("{:?}", worker.connection_mode()),
|
||||
"priority": worker.priority(),
|
||||
"cost": worker.cost(),
|
||||
});
|
||||
|
||||
// Add bootstrap_port for Prefill workers
|
||||
if let crate::core::WorkerType::Prefill { bootstrap_port } = worker.worker_type() {
|
||||
worker_info["bootstrap_port"] = serde_json::json!(bootstrap_port);
|
||||
}
|
||||
|
||||
worker_info
|
||||
}).collect::<Vec<_>>(),
|
||||
"total": workers.len(),
|
||||
"stats": {
|
||||
"prefill_count": state.context.worker_registry.get_prefill_workers().len(),
|
||||
"decode_count": state.context.worker_registry.get_decode_workers().len(),
|
||||
"regular_count": state.context.worker_registry.get_by_type(&crate::core::WorkerType::Regular).len(),
|
||||
}
|
||||
});
|
||||
Json(response).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
/// GET /workers/{url} - Get specific worker info
|
||||
async fn get_worker(
|
||||
State(state): State<Arc<AppState>>,
|
||||
axum::extract::Path(url): axum::extract::Path<String>,
|
||||
) -> Response {
|
||||
if let Some(router_manager) = &state.context.router_manager {
|
||||
if let Some(worker) = router_manager.get_worker(&url) {
|
||||
Json(worker).into_response()
|
||||
} else {
|
||||
let error = WorkerErrorResponse {
|
||||
error: format!("Worker {} not found", url),
|
||||
code: "WORKER_NOT_FOUND".to_string(),
|
||||
};
|
||||
(StatusCode::NOT_FOUND, Json(error)).into_response()
|
||||
}
|
||||
} else {
|
||||
// In single router mode, check if worker exists
|
||||
let workers = state.router.get_worker_urls();
|
||||
if workers.contains(&url) {
|
||||
let worker_info = serde_json::json!({
|
||||
"url": url,
|
||||
"model_id": "unknown",
|
||||
"is_healthy": true
|
||||
});
|
||||
Json(worker_info).into_response()
|
||||
} else {
|
||||
let error = WorkerErrorResponse {
|
||||
error: format!("Worker {} not found", url),
|
||||
code: "WORKER_NOT_FOUND".to_string(),
|
||||
};
|
||||
(StatusCode::NOT_FOUND, Json(error)).into_response()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// DELETE /workers/{url} - Remove a worker
|
||||
async fn delete_worker(
|
||||
State(state): State<Arc<AppState>>,
|
||||
axum::extract::Path(url): axum::extract::Path<String>,
|
||||
) -> Response {
|
||||
if let Some(router_manager) = &state.context.router_manager {
|
||||
match router_manager.remove_worker_from_registry(&url) {
|
||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
|
||||
}
|
||||
} else {
|
||||
// In single router mode, use router's remove_worker
|
||||
state.router.remove_worker(&url);
|
||||
let response = WorkerApiResponse {
|
||||
success: true,
|
||||
message: format!("Worker {} removed successfully", url),
|
||||
worker: None,
|
||||
};
|
||||
(StatusCode::OK, Json(response)).into_response()
|
||||
}
|
||||
}
|
||||
|
||||
pub struct ServerConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
@@ -281,11 +440,19 @@ pub fn build_app(
|
||||
.route("/flush_cache", post(flush_cache))
|
||||
.route("/get_loads", get(get_loads));
|
||||
|
||||
// Worker management routes
|
||||
let worker_routes = Router::new()
|
||||
.route("/workers", post(create_worker))
|
||||
.route("/workers", get(list_workers_rest))
|
||||
.route("/workers/{url}", get(get_worker))
|
||||
.route("/workers/{url}", axum::routing::delete(delete_worker));
|
||||
|
||||
// Build app with all routes and middleware
|
||||
Router::new()
|
||||
.merge(protected_routes)
|
||||
.merge(public_routes)
|
||||
.merge(admin_routes)
|
||||
.merge(worker_routes)
|
||||
// Request body size limiting
|
||||
.layer(tower_http::limit::RequestBodyLimitLayer::new(
|
||||
max_payload_size,
|
||||
@@ -355,15 +522,100 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
// Create the application context with all dependencies
|
||||
let app_context = Arc::new(AppContext::new(
|
||||
let app_context = AppContext::new(
|
||||
config.router_config.clone(),
|
||||
client.clone(),
|
||||
config.router_config.max_concurrent_requests,
|
||||
config.router_config.rate_limit_tokens_per_second,
|
||||
)?);
|
||||
)?;
|
||||
|
||||
// Create router with the context
|
||||
let router = RouterFactory::create_router(&app_context).await?;
|
||||
let app_context = Arc::new(app_context);
|
||||
|
||||
// Create the appropriate router based on enable_igw flag
|
||||
let router: Box<dyn RouterTrait> = if config.router_config.enable_igw {
|
||||
info!("Multi-router mode enabled (enable_igw=true)");
|
||||
|
||||
// Create RouterManager with shared registries from AppContext
|
||||
let mut router_manager = RouterManager::new(
|
||||
config.router_config.clone(),
|
||||
client.clone(),
|
||||
app_context.worker_registry.clone(),
|
||||
app_context.policy_registry.clone(),
|
||||
);
|
||||
|
||||
// Create HTTP routers at startup (with empty worker lists)
|
||||
// Workers will be added to these routers dynamically via RouterManager's worker registry
|
||||
|
||||
// 1. HTTP Regular Router
|
||||
match RouterFactory::create_regular_router(
|
||||
&[], // Empty worker list - workers added later
|
||||
&app_context,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(http_regular) => {
|
||||
info!("Created HTTP Regular router");
|
||||
router_manager.register_router(
|
||||
RouterId::new("http-regular".to_string()),
|
||||
Arc::from(http_regular),
|
||||
vec![], // Models will be determined by workers
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create HTTP Regular router: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// 2. HTTP PD Router
|
||||
match RouterFactory::create_pd_router(
|
||||
&[], // Empty prefill URLs
|
||||
&[], // Empty decode URLs
|
||||
None, // Use default prefill policy
|
||||
None, // Use default decode policy
|
||||
&config.router_config.policy,
|
||||
&app_context,
|
||||
)
|
||||
.await
|
||||
{
|
||||
Ok(http_pd) => {
|
||||
info!("Created HTTP PD router");
|
||||
router_manager.register_router(
|
||||
RouterId::new("http-pd".to_string()),
|
||||
Arc::from(http_pd),
|
||||
vec![],
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create HTTP PD router: {}", e);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: Add gRPC routers once we have dynamic tokenizer loading
|
||||
// Currently gRPC routers require tokenizer to be initialized first,
|
||||
// but each model needs its own tokenizer. Once we implement dynamic
|
||||
// tokenizer loading per model, we can enable gRPC routers here:
|
||||
// - RouterType::GrpcRegular (RouterId: "grpc-regular")
|
||||
// - RouterType::GrpcPd (RouterId: "grpc-pd")
|
||||
|
||||
info!(
|
||||
"RouterManager initialized with {} routers",
|
||||
router_manager.router_count()
|
||||
);
|
||||
Box::new(router_manager)
|
||||
} else {
|
||||
info!("Single router mode (enable_igw=false)");
|
||||
// Create single router with the context
|
||||
RouterFactory::create_router(&app_context).await?
|
||||
};
|
||||
|
||||
// Start health checker for all workers in the registry
|
||||
let _health_checker = app_context
|
||||
.worker_registry
|
||||
.start_health_checker(config.router_config.health_check.check_interval_secs);
|
||||
info!(
|
||||
"Started health checker for workers with {}s interval",
|
||||
config.router_config.health_check.check_interval_secs
|
||||
);
|
||||
|
||||
// Set up concurrency limiter with queue if configured
|
||||
let (limiter, processor) = crate::middleware::ConcurrencyLimiter::new(
|
||||
|
||||
Reference in New Issue
Block a user