[router] multi model registration fix (#10481)
This commit is contained in:
@@ -5,7 +5,7 @@
|
|||||||
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
|
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
|
||||||
|
|
||||||
use crate::config::RouterConfig;
|
use crate::config::RouterConfig;
|
||||||
use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry};
|
use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType};
|
||||||
use crate::protocols::spec::{
|
use crate::protocols::spec::{
|
||||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||||
ResponsesRequest,
|
ResponsesRequest,
|
||||||
@@ -56,10 +56,6 @@ pub struct RouterManager {
|
|||||||
/// Default router for requests without specific routing
|
/// Default router for requests without specific routing
|
||||||
default_router: Option<RouterId>,
|
default_router: Option<RouterId>,
|
||||||
|
|
||||||
/// Model to router mapping for model-aware routing
|
|
||||||
/// Multiple models can be served by the same router
|
|
||||||
model_routers: Arc<DashMap<String, Vec<RouterId>>>,
|
|
||||||
|
|
||||||
/// HTTP client for querying worker info
|
/// HTTP client for querying worker info
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
|
|
||||||
@@ -81,7 +77,6 @@ impl RouterManager {
|
|||||||
policy_registry,
|
policy_registry,
|
||||||
routers: Arc::new(DashMap::new()),
|
routers: Arc::new(DashMap::new()),
|
||||||
default_router: None,
|
default_router: None,
|
||||||
model_routers: Arc::new(DashMap::new()),
|
|
||||||
client,
|
client,
|
||||||
config,
|
config,
|
||||||
}
|
}
|
||||||
@@ -92,19 +87,11 @@ impl RouterManager {
|
|||||||
&mut self,
|
&mut self,
|
||||||
id: RouterId,
|
id: RouterId,
|
||||||
router: Arc<dyn RouterTrait>,
|
router: Arc<dyn RouterTrait>,
|
||||||
models: Vec<String>,
|
_models: Vec<String>, // Keep parameter for backward compatibility but ignore it
|
||||||
) {
|
) {
|
||||||
// Store router
|
// Store router
|
||||||
self.routers.insert(id.clone(), router);
|
self.routers.insert(id.clone(), router);
|
||||||
|
|
||||||
// Update model mappings
|
|
||||||
for model in models {
|
|
||||||
self.model_routers
|
|
||||||
.entry(model)
|
|
||||||
.or_default()
|
|
||||||
.push(id.clone());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Set as default if first router
|
// Set as default if first router
|
||||||
if self.default_router.is_none() {
|
if self.default_router.is_none() {
|
||||||
self.default_router = Some(id.clone());
|
self.default_router = Some(id.clone());
|
||||||
@@ -122,14 +109,29 @@ impl RouterManager {
|
|||||||
self.routers.len()
|
self.routers.len()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get router for a specific model
|
/// Get router for a specific model based on worker types
|
||||||
pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> {
|
pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> {
|
||||||
// First try model-specific routers
|
// Query workers for this model from registry
|
||||||
if let Some(router_ids) = self.model_routers.get(model_id) {
|
let workers = self.worker_registry.get_by_model(model_id);
|
||||||
if let Some(router_id) = router_ids.first() {
|
|
||||||
if let Some(router) = self.routers.get(router_id) {
|
if !workers.is_empty() {
|
||||||
return Some(router.clone());
|
// Determine router based on worker types
|
||||||
}
|
let has_pd_workers = workers.iter().any(|w| {
|
||||||
|
matches!(
|
||||||
|
w.worker_type(),
|
||||||
|
WorkerType::Prefill { .. } | WorkerType::Decode
|
||||||
|
)
|
||||||
|
});
|
||||||
|
|
||||||
|
let router_id = if has_pd_workers {
|
||||||
|
RouterId::new("http-pd".to_string())
|
||||||
|
} else {
|
||||||
|
RouterId::new("http-regular".to_string())
|
||||||
|
};
|
||||||
|
|
||||||
|
// Return the router if it exists
|
||||||
|
if let Some(router) = self.routers.get(&router_id) {
|
||||||
|
return Some(router.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -240,6 +242,17 @@ impl RouterManager {
|
|||||||
let policy_hint = labels.get("policy").map(|s| s.as_str());
|
let policy_hint = labels.get("policy").map(|s| s.as_str());
|
||||||
let policy = self.policy_registry.on_worker_added(&model_id, policy_hint);
|
let policy = self.policy_registry.on_worker_added(&model_id, policy_hint);
|
||||||
|
|
||||||
|
// Log which type of router would handle this worker (for debugging)
|
||||||
|
let expected_router = match config.worker_type.as_deref() {
|
||||||
|
Some("prefill") | Some("decode") => "http-pd",
|
||||||
|
_ => "http-regular",
|
||||||
|
};
|
||||||
|
|
||||||
|
info!(
|
||||||
|
"Worker for model '{}' would be handled by '{}' router based on type",
|
||||||
|
model_id, expected_router
|
||||||
|
);
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Added worker {} with URL {} for model {} using policy {}",
|
"Added worker {} with URL {} for model {} using policy {}",
|
||||||
worker_id.as_str(),
|
worker_id.as_str(),
|
||||||
@@ -272,8 +285,9 @@ impl RouterManager {
|
|||||||
|
|
||||||
if let Some(_worker) = self.worker_registry.remove_by_url(url) {
|
if let Some(_worker) = self.worker_registry.remove_by_url(url) {
|
||||||
// Notify PolicyRegistry about worker removal
|
// Notify PolicyRegistry about worker removal
|
||||||
if let Some(model_id) = model_id {
|
if let Some(ref model_id) = model_id {
|
||||||
self.policy_registry.on_worker_removed(&model_id);
|
self.policy_registry.on_worker_removed(model_id);
|
||||||
|
|
||||||
info!("Removed worker with URL {} for model {}", url, model_id);
|
info!("Removed worker with URL {} for model {}", url, model_id);
|
||||||
} else {
|
} else {
|
||||||
info!("Removed worker with URL {}", url);
|
info!("Removed worker with URL {}", url);
|
||||||
@@ -406,14 +420,10 @@ impl RouterManager {
|
|||||||
})
|
})
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
|
|
||||||
// If model specified, find routers serving that model
|
// If model specified, use get_router_for_model
|
||||||
let candidate_routers = if let Some(model) = model_id {
|
let candidate_routers = if let Some(model) = model_id {
|
||||||
// Get routers for specific model
|
if let Some(router) = self.get_router_for_model(model) {
|
||||||
if let Some(router_ids) = self.model_routers.get(model) {
|
vec![router]
|
||||||
router_ids
|
|
||||||
.iter()
|
|
||||||
.filter_map(|id| self.routers.get(id).map(|r| r.clone()))
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
} else {
|
} else {
|
||||||
Vec::new()
|
Vec::new()
|
||||||
}
|
}
|
||||||
@@ -547,14 +557,10 @@ impl RouterTrait for RouterManager {
|
|||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get available models - aggregate from all routers
|
/// Get available models - query from worker registry
|
||||||
async fn get_models(&self, _req: Request<Body>) -> Response {
|
async fn get_models(&self, _req: Request<Body>) -> Response {
|
||||||
// Return models that have registered routers
|
// Get models from worker registry
|
||||||
let models = self
|
let models = self.worker_registry.get_models();
|
||||||
.model_routers
|
|
||||||
.iter()
|
|
||||||
.map(|entry| entry.key().clone())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
if models.is_empty() {
|
if models.is_empty() {
|
||||||
(StatusCode::SERVICE_UNAVAILABLE, "No models available").into_response()
|
(StatusCode::SERVICE_UNAVAILABLE, "No models available").into_response()
|
||||||
|
|||||||
Reference in New Issue
Block a user