[router] fix worker registration in multi model mode (#10486)
This commit is contained in:
@@ -804,6 +804,37 @@ impl WorkerFactory {
|
||||
Box::new(worker)
|
||||
}
|
||||
|
||||
/// Create a prefill worker with labels
|
||||
pub fn create_prefill_with_labels(
|
||||
url: String,
|
||||
bootstrap_port: Option<u16>,
|
||||
labels: std::collections::HashMap<String, String>,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
) -> Box<dyn Worker> {
|
||||
let mut worker = BasicWorker::new(url.clone(), WorkerType::Prefill { bootstrap_port })
|
||||
.with_circuit_breaker_config(circuit_breaker_config);
|
||||
|
||||
// Add labels to metadata
|
||||
worker.metadata.labels = labels;
|
||||
|
||||
Box::new(worker)
|
||||
}
|
||||
|
||||
/// Create a decode worker with labels
|
||||
pub fn create_decode_with_labels(
|
||||
url: String,
|
||||
labels: std::collections::HashMap<String, String>,
|
||||
circuit_breaker_config: CircuitBreakerConfig,
|
||||
) -> Box<dyn Worker> {
|
||||
let mut worker = BasicWorker::new(url.clone(), WorkerType::Decode)
|
||||
.with_circuit_breaker_config(circuit_breaker_config);
|
||||
|
||||
// Add labels to metadata
|
||||
worker.metadata.labels = labels;
|
||||
|
||||
Box::new(worker)
|
||||
}
|
||||
|
||||
/// Create a DP-aware worker of specified type
|
||||
pub fn create_dp_aware(
|
||||
base_url: String,
|
||||
|
||||
@@ -41,7 +41,6 @@ impl RouterId {
|
||||
}
|
||||
|
||||
/// Router Manager - Central coordinator for routers and workers
|
||||
/// Only created when enable_igw=true
|
||||
pub struct RouterManager {
|
||||
/// Worker registry (single source of truth in multi-router mode)
|
||||
worker_registry: Arc<WorkerRegistry>,
|
||||
@@ -49,7 +48,7 @@ pub struct RouterManager {
|
||||
/// Policy registry for managing model-to-policy mappings
|
||||
policy_registry: Arc<crate::policies::PolicyRegistry>,
|
||||
|
||||
/// All routers managed by this manager (max 4 routers in Phase 2)
|
||||
/// All routers managed by this manager
|
||||
/// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd"
|
||||
routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>,
|
||||
|
||||
@@ -83,12 +82,7 @@ impl RouterManager {
|
||||
}
|
||||
|
||||
/// Register a router with the manager
|
||||
pub fn register_router(
|
||||
&mut self,
|
||||
id: RouterId,
|
||||
router: Arc<dyn RouterTrait>,
|
||||
_models: Vec<String>, // Keep parameter for backward compatibility but ignore it
|
||||
) {
|
||||
pub fn register_router(&mut self, id: RouterId, router: Arc<dyn RouterTrait>) {
|
||||
// Store router
|
||||
self.routers.insert(id.clone(), router);
|
||||
|
||||
@@ -210,32 +204,28 @@ impl RouterManager {
|
||||
labels.insert("chat_template".to_string(), chat_template);
|
||||
}
|
||||
|
||||
// Create worker based on type
|
||||
// Note: For prefill and decode workers, we can't easily add labels after creation
|
||||
// since they return Box<dyn Worker>. We'll need to enhance WorkerFactory in the future.
|
||||
let worker = match config.worker_type.as_deref() {
|
||||
Some("prefill") => {
|
||||
// For now, prefill workers won't have custom labels
|
||||
// TODO: Enhance WorkerFactory to accept labels for prefill workers
|
||||
WorkerFactory::create_prefill(config.url.clone(), config.bootstrap_port)
|
||||
}
|
||||
Some("decode") => {
|
||||
// For now, decode workers won't have custom labels
|
||||
// TODO: Enhance WorkerFactory to accept labels for decode workers
|
||||
WorkerFactory::create_decode(config.url.clone())
|
||||
}
|
||||
_ => {
|
||||
// Regular workers can have labels
|
||||
WorkerFactory::create_regular_with_labels(
|
||||
config.url.clone(),
|
||||
labels.clone(),
|
||||
CircuitBreakerConfig::default(),
|
||||
)
|
||||
}
|
||||
Some("prefill") => WorkerFactory::create_prefill_with_labels(
|
||||
config.url.clone(),
|
||||
config.bootstrap_port,
|
||||
labels.clone(),
|
||||
CircuitBreakerConfig::default(),
|
||||
),
|
||||
Some("decode") => WorkerFactory::create_decode_with_labels(
|
||||
config.url.clone(),
|
||||
labels.clone(),
|
||||
CircuitBreakerConfig::default(),
|
||||
),
|
||||
_ => WorkerFactory::create_regular_with_labels(
|
||||
config.url.clone(),
|
||||
labels.clone(),
|
||||
CircuitBreakerConfig::default(),
|
||||
),
|
||||
};
|
||||
|
||||
// Register worker
|
||||
let worker_id = self.worker_registry.register(Arc::from(worker));
|
||||
let worker_arc: Arc<dyn Worker> = Arc::from(worker);
|
||||
let worker_id = self.worker_registry.register(worker_arc.clone());
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
// Extract policy hint from labels if provided
|
||||
@@ -262,7 +252,6 @@ impl RouterManager {
|
||||
);
|
||||
|
||||
// Return worker info
|
||||
let worker_arc = self.worker_registry.get(&worker_id).unwrap();
|
||||
let worker_info = self.worker_to_info(worker_id.as_str(), &worker_arc);
|
||||
|
||||
Ok(WorkerApiResponse {
|
||||
@@ -375,7 +364,11 @@ impl RouterManager {
|
||||
model_id: worker.model_id().to_string(),
|
||||
priority: worker.priority(),
|
||||
cost: worker.cost(),
|
||||
worker_type: format!("{:?}", worker.worker_type()),
|
||||
worker_type: match worker.worker_type() {
|
||||
WorkerType::Regular => "regular".to_string(),
|
||||
WorkerType::Prefill { .. } => "prefill".to_string(),
|
||||
WorkerType::Decode => "decode".to_string(),
|
||||
},
|
||||
is_healthy: worker.is_healthy(),
|
||||
load: worker.load(),
|
||||
connection_mode: format!("{:?}", worker.connection_mode()),
|
||||
@@ -387,11 +380,6 @@ impl RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
// Note: calculate_stats removed - using WorkerRegistry::stats() instead
|
||||
|
||||
// === Phase 2: Router Management ===
|
||||
// Note: Dynamic router creation removed - routers are created and registered externally
|
||||
|
||||
/// Get the appropriate router for a request based on headers and request content
|
||||
pub fn select_router_for_request(
|
||||
&self,
|
||||
@@ -474,11 +462,6 @@ impl RouterManager {
|
||||
}
|
||||
}
|
||||
|
||||
// Note: Default implementation removed as RouterManager now requires AppContext
|
||||
// which cannot be defaulted. RouterManager must be created with explicit context.
|
||||
|
||||
// === Phase 2: RouterManager as RouterTrait ===
|
||||
|
||||
/// RouterManager implements RouterTrait to act as a meta-router
|
||||
/// that delegates requests to the appropriate underlying router
|
||||
#[async_trait]
|
||||
|
||||
@@ -317,8 +317,9 @@ async fn create_worker(
|
||||
State(state): State<Arc<AppState>>,
|
||||
Json(config): Json<WorkerConfigRequest>,
|
||||
) -> Response {
|
||||
// Check if RouterManager is available (enable_igw=true)
|
||||
if let Some(router_manager) = &state.context.router_manager {
|
||||
// Check if the router is actually a RouterManager (enable_igw=true)
|
||||
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
|
||||
// Call RouterManager's add_worker method directly with the full config
|
||||
match router_manager.add_worker(config).await {
|
||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
|
||||
@@ -347,7 +348,7 @@ async fn create_worker(
|
||||
|
||||
/// GET /workers - List all workers with details
|
||||
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
||||
if let Some(router_manager) = &state.context.router_manager {
|
||||
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
|
||||
let response = router_manager.list_workers();
|
||||
Json(response).into_response()
|
||||
} else {
|
||||
@@ -358,7 +359,11 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
||||
let mut worker_info = serde_json::json!({
|
||||
"url": worker.url(),
|
||||
"model_id": worker.model_id(),
|
||||
"worker_type": format!("{:?}", worker.worker_type()),
|
||||
"worker_type": match worker.worker_type() {
|
||||
WorkerType::Regular => "regular",
|
||||
WorkerType::Prefill { .. } => "prefill",
|
||||
WorkerType::Decode => "decode",
|
||||
},
|
||||
"is_healthy": worker.is_healthy(),
|
||||
"load": worker.load(),
|
||||
"connection_mode": format!("{:?}", worker.connection_mode()),
|
||||
@@ -386,7 +391,7 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
||||
|
||||
/// GET /workers/{url} - Get specific worker info
|
||||
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
|
||||
if let Some(router_manager) = &state.context.router_manager {
|
||||
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
|
||||
if let Some(worker) = router_manager.get_worker(&url) {
|
||||
Json(worker).into_response()
|
||||
} else {
|
||||
@@ -417,7 +422,7 @@ async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>)
|
||||
|
||||
/// DELETE /workers/{url} - Remove a worker
|
||||
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
|
||||
if let Some(router_manager) = &state.context.router_manager {
|
||||
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
|
||||
match router_manager.remove_worker_from_registry(&url) {
|
||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
|
||||
@@ -603,7 +608,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
||||
router_manager.register_router(
|
||||
RouterId::new("http-regular".to_string()),
|
||||
Arc::from(http_regular),
|
||||
vec![], // Models will be determined by workers
|
||||
);
|
||||
}
|
||||
Err(e) => {
|
||||
@@ -624,11 +628,8 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
||||
{
|
||||
Ok(http_pd) => {
|
||||
info!("Created HTTP PD router");
|
||||
router_manager.register_router(
|
||||
RouterId::new("http-pd".to_string()),
|
||||
Arc::from(http_pd),
|
||||
vec![],
|
||||
);
|
||||
router_manager
|
||||
.register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd));
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Failed to create HTTP PD router: {e}");
|
||||
|
||||
Reference in New Issue
Block a user