[router] fix worker registration in multi model mode (#10486)
This commit is contained in:
@@ -804,6 +804,37 @@ impl WorkerFactory {
|
|||||||
Box::new(worker)
|
Box::new(worker)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create a prefill worker with labels
|
||||||
|
pub fn create_prefill_with_labels(
|
||||||
|
url: String,
|
||||||
|
bootstrap_port: Option<u16>,
|
||||||
|
labels: std::collections::HashMap<String, String>,
|
||||||
|
circuit_breaker_config: CircuitBreakerConfig,
|
||||||
|
) -> Box<dyn Worker> {
|
||||||
|
let mut worker = BasicWorker::new(url.clone(), WorkerType::Prefill { bootstrap_port })
|
||||||
|
.with_circuit_breaker_config(circuit_breaker_config);
|
||||||
|
|
||||||
|
// Add labels to metadata
|
||||||
|
worker.metadata.labels = labels;
|
||||||
|
|
||||||
|
Box::new(worker)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a decode worker with labels
|
||||||
|
pub fn create_decode_with_labels(
|
||||||
|
url: String,
|
||||||
|
labels: std::collections::HashMap<String, String>,
|
||||||
|
circuit_breaker_config: CircuitBreakerConfig,
|
||||||
|
) -> Box<dyn Worker> {
|
||||||
|
let mut worker = BasicWorker::new(url.clone(), WorkerType::Decode)
|
||||||
|
.with_circuit_breaker_config(circuit_breaker_config);
|
||||||
|
|
||||||
|
// Add labels to metadata
|
||||||
|
worker.metadata.labels = labels;
|
||||||
|
|
||||||
|
Box::new(worker)
|
||||||
|
}
|
||||||
|
|
||||||
/// Create a DP-aware worker of specified type
|
/// Create a DP-aware worker of specified type
|
||||||
pub fn create_dp_aware(
|
pub fn create_dp_aware(
|
||||||
base_url: String,
|
base_url: String,
|
||||||
|
|||||||
@@ -41,7 +41,6 @@ impl RouterId {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Router Manager - Central coordinator for routers and workers
|
/// Router Manager - Central coordinator for routers and workers
|
||||||
/// Only created when enable_igw=true
|
|
||||||
pub struct RouterManager {
|
pub struct RouterManager {
|
||||||
/// Worker registry (single source of truth in multi-router mode)
|
/// Worker registry (single source of truth in multi-router mode)
|
||||||
worker_registry: Arc<WorkerRegistry>,
|
worker_registry: Arc<WorkerRegistry>,
|
||||||
@@ -49,7 +48,7 @@ pub struct RouterManager {
|
|||||||
/// Policy registry for managing model-to-policy mappings
|
/// Policy registry for managing model-to-policy mappings
|
||||||
policy_registry: Arc<crate::policies::PolicyRegistry>,
|
policy_registry: Arc<crate::policies::PolicyRegistry>,
|
||||||
|
|
||||||
/// All routers managed by this manager (max 4 routers in Phase 2)
|
/// All routers managed by this manager
|
||||||
/// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd"
|
/// RouterId examples: "http-regular", "http-pd", "grpc-regular", "grpc-pd"
|
||||||
routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>,
|
routers: Arc<DashMap<RouterId, Arc<dyn RouterTrait>>>,
|
||||||
|
|
||||||
@@ -83,12 +82,7 @@ impl RouterManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/// Register a router with the manager
|
/// Register a router with the manager
|
||||||
pub fn register_router(
|
pub fn register_router(&mut self, id: RouterId, router: Arc<dyn RouterTrait>) {
|
||||||
&mut self,
|
|
||||||
id: RouterId,
|
|
||||||
router: Arc<dyn RouterTrait>,
|
|
||||||
_models: Vec<String>, // Keep parameter for backward compatibility but ignore it
|
|
||||||
) {
|
|
||||||
// Store router
|
// Store router
|
||||||
self.routers.insert(id.clone(), router);
|
self.routers.insert(id.clone(), router);
|
||||||
|
|
||||||
@@ -210,32 +204,28 @@ impl RouterManager {
|
|||||||
labels.insert("chat_template".to_string(), chat_template);
|
labels.insert("chat_template".to_string(), chat_template);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create worker based on type
|
|
||||||
// Note: For prefill and decode workers, we can't easily add labels after creation
|
|
||||||
// since they return Box<dyn Worker>. We'll need to enhance WorkerFactory in the future.
|
|
||||||
let worker = match config.worker_type.as_deref() {
|
let worker = match config.worker_type.as_deref() {
|
||||||
Some("prefill") => {
|
Some("prefill") => WorkerFactory::create_prefill_with_labels(
|
||||||
// For now, prefill workers won't have custom labels
|
config.url.clone(),
|
||||||
// TODO: Enhance WorkerFactory to accept labels for prefill workers
|
config.bootstrap_port,
|
||||||
WorkerFactory::create_prefill(config.url.clone(), config.bootstrap_port)
|
labels.clone(),
|
||||||
}
|
CircuitBreakerConfig::default(),
|
||||||
Some("decode") => {
|
),
|
||||||
// For now, decode workers won't have custom labels
|
Some("decode") => WorkerFactory::create_decode_with_labels(
|
||||||
// TODO: Enhance WorkerFactory to accept labels for decode workers
|
config.url.clone(),
|
||||||
WorkerFactory::create_decode(config.url.clone())
|
labels.clone(),
|
||||||
}
|
CircuitBreakerConfig::default(),
|
||||||
_ => {
|
),
|
||||||
// Regular workers can have labels
|
_ => WorkerFactory::create_regular_with_labels(
|
||||||
WorkerFactory::create_regular_with_labels(
|
config.url.clone(),
|
||||||
config.url.clone(),
|
labels.clone(),
|
||||||
labels.clone(),
|
CircuitBreakerConfig::default(),
|
||||||
CircuitBreakerConfig::default(),
|
),
|
||||||
)
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Register worker
|
// Register worker
|
||||||
let worker_id = self.worker_registry.register(Arc::from(worker));
|
let worker_arc: Arc<dyn Worker> = Arc::from(worker);
|
||||||
|
let worker_id = self.worker_registry.register(worker_arc.clone());
|
||||||
|
|
||||||
// Notify PolicyRegistry about the new worker
|
// Notify PolicyRegistry about the new worker
|
||||||
// Extract policy hint from labels if provided
|
// Extract policy hint from labels if provided
|
||||||
@@ -262,7 +252,6 @@ impl RouterManager {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Return worker info
|
// Return worker info
|
||||||
let worker_arc = self.worker_registry.get(&worker_id).unwrap();
|
|
||||||
let worker_info = self.worker_to_info(worker_id.as_str(), &worker_arc);
|
let worker_info = self.worker_to_info(worker_id.as_str(), &worker_arc);
|
||||||
|
|
||||||
Ok(WorkerApiResponse {
|
Ok(WorkerApiResponse {
|
||||||
@@ -375,7 +364,11 @@ impl RouterManager {
|
|||||||
model_id: worker.model_id().to_string(),
|
model_id: worker.model_id().to_string(),
|
||||||
priority: worker.priority(),
|
priority: worker.priority(),
|
||||||
cost: worker.cost(),
|
cost: worker.cost(),
|
||||||
worker_type: format!("{:?}", worker.worker_type()),
|
worker_type: match worker.worker_type() {
|
||||||
|
WorkerType::Regular => "regular".to_string(),
|
||||||
|
WorkerType::Prefill { .. } => "prefill".to_string(),
|
||||||
|
WorkerType::Decode => "decode".to_string(),
|
||||||
|
},
|
||||||
is_healthy: worker.is_healthy(),
|
is_healthy: worker.is_healthy(),
|
||||||
load: worker.load(),
|
load: worker.load(),
|
||||||
connection_mode: format!("{:?}", worker.connection_mode()),
|
connection_mode: format!("{:?}", worker.connection_mode()),
|
||||||
@@ -387,11 +380,6 @@ impl RouterManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: calculate_stats removed - using WorkerRegistry::stats() instead
|
|
||||||
|
|
||||||
// === Phase 2: Router Management ===
|
|
||||||
// Note: Dynamic router creation removed - routers are created and registered externally
|
|
||||||
|
|
||||||
/// Get the appropriate router for a request based on headers and request content
|
/// Get the appropriate router for a request based on headers and request content
|
||||||
pub fn select_router_for_request(
|
pub fn select_router_for_request(
|
||||||
&self,
|
&self,
|
||||||
@@ -474,11 +462,6 @@ impl RouterManager {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Note: Default implementation removed as RouterManager now requires AppContext
|
|
||||||
// which cannot be defaulted. RouterManager must be created with explicit context.
|
|
||||||
|
|
||||||
// === Phase 2: RouterManager as RouterTrait ===
|
|
||||||
|
|
||||||
/// RouterManager implements RouterTrait to act as a meta-router
|
/// RouterManager implements RouterTrait to act as a meta-router
|
||||||
/// that delegates requests to the appropriate underlying router
|
/// that delegates requests to the appropriate underlying router
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|||||||
@@ -317,8 +317,9 @@ async fn create_worker(
|
|||||||
State(state): State<Arc<AppState>>,
|
State(state): State<Arc<AppState>>,
|
||||||
Json(config): Json<WorkerConfigRequest>,
|
Json(config): Json<WorkerConfigRequest>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
// Check if RouterManager is available (enable_igw=true)
|
// Check if the router is actually a RouterManager (enable_igw=true)
|
||||||
if let Some(router_manager) = &state.context.router_manager {
|
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
|
||||||
|
// Call RouterManager's add_worker method directly with the full config
|
||||||
match router_manager.add_worker(config).await {
|
match router_manager.add_worker(config).await {
|
||||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||||
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
|
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
|
||||||
@@ -347,7 +348,7 @@ async fn create_worker(
|
|||||||
|
|
||||||
/// GET /workers - List all workers with details
|
/// GET /workers - List all workers with details
|
||||||
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
||||||
if let Some(router_manager) = &state.context.router_manager {
|
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
|
||||||
let response = router_manager.list_workers();
|
let response = router_manager.list_workers();
|
||||||
Json(response).into_response()
|
Json(response).into_response()
|
||||||
} else {
|
} else {
|
||||||
@@ -358,7 +359,11 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
|||||||
let mut worker_info = serde_json::json!({
|
let mut worker_info = serde_json::json!({
|
||||||
"url": worker.url(),
|
"url": worker.url(),
|
||||||
"model_id": worker.model_id(),
|
"model_id": worker.model_id(),
|
||||||
"worker_type": format!("{:?}", worker.worker_type()),
|
"worker_type": match worker.worker_type() {
|
||||||
|
WorkerType::Regular => "regular",
|
||||||
|
WorkerType::Prefill { .. } => "prefill",
|
||||||
|
WorkerType::Decode => "decode",
|
||||||
|
},
|
||||||
"is_healthy": worker.is_healthy(),
|
"is_healthy": worker.is_healthy(),
|
||||||
"load": worker.load(),
|
"load": worker.load(),
|
||||||
"connection_mode": format!("{:?}", worker.connection_mode()),
|
"connection_mode": format!("{:?}", worker.connection_mode()),
|
||||||
@@ -386,7 +391,7 @@ async fn list_workers_rest(State(state): State<Arc<AppState>>) -> Response {
|
|||||||
|
|
||||||
/// GET /workers/{url} - Get specific worker info
|
/// GET /workers/{url} - Get specific worker info
|
||||||
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
|
async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
|
||||||
if let Some(router_manager) = &state.context.router_manager {
|
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
|
||||||
if let Some(worker) = router_manager.get_worker(&url) {
|
if let Some(worker) = router_manager.get_worker(&url) {
|
||||||
Json(worker).into_response()
|
Json(worker).into_response()
|
||||||
} else {
|
} else {
|
||||||
@@ -417,7 +422,7 @@ async fn get_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>)
|
|||||||
|
|
||||||
/// DELETE /workers/{url} - Remove a worker
|
/// DELETE /workers/{url} - Remove a worker
|
||||||
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
|
async fn delete_worker(State(state): State<Arc<AppState>>, Path(url): Path<String>) -> Response {
|
||||||
if let Some(router_manager) = &state.context.router_manager {
|
if let Some(router_manager) = state.router.as_any().downcast_ref::<RouterManager>() {
|
||||||
match router_manager.remove_worker_from_registry(&url) {
|
match router_manager.remove_worker_from_registry(&url) {
|
||||||
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
|
||||||
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
|
Err(error) => (StatusCode::BAD_REQUEST, Json(error)).into_response(),
|
||||||
@@ -603,7 +608,6 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
router_manager.register_router(
|
router_manager.register_router(
|
||||||
RouterId::new("http-regular".to_string()),
|
RouterId::new("http-regular".to_string()),
|
||||||
Arc::from(http_regular),
|
Arc::from(http_regular),
|
||||||
vec![], // Models will be determined by workers
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
@@ -624,11 +628,8 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
{
|
{
|
||||||
Ok(http_pd) => {
|
Ok(http_pd) => {
|
||||||
info!("Created HTTP PD router");
|
info!("Created HTTP PD router");
|
||||||
router_manager.register_router(
|
router_manager
|
||||||
RouterId::new("http-pd".to_string()),
|
.register_router(RouterId::new("http-pd".to_string()), Arc::from(http_pd));
|
||||||
Arc::from(http_pd),
|
|
||||||
vec![],
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Failed to create HTTP PD router: {e}");
|
warn!("Failed to create HTTP PD router: {e}");
|
||||||
|
|||||||
Reference in New Issue
Block a user