[router] multi model registration fix (#10481)
This commit is contained in:
@@ -5,7 +5,7 @@
|
||||
//! - Multi-Router Mode (enable_igw=true): RouterManager coordinates everything
|
||||
|
||||
use crate::config::RouterConfig;
|
||||
use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry};
|
||||
use crate::core::{CircuitBreakerConfig, Worker, WorkerFactory, WorkerRegistry, WorkerType};
|
||||
use crate::protocols::spec::{
|
||||
ChatCompletionRequest, CompletionRequest, EmbeddingRequest, GenerateRequest, RerankRequest,
|
||||
ResponsesRequest,
|
||||
@@ -56,10 +56,6 @@ pub struct RouterManager {
|
||||
/// Default router for requests without specific routing
|
||||
default_router: Option<RouterId>,
|
||||
|
||||
/// Model to router mapping for model-aware routing
|
||||
/// Multiple models can be served by the same router
|
||||
model_routers: Arc<DashMap<String, Vec<RouterId>>>,
|
||||
|
||||
/// HTTP client for querying worker info
|
||||
client: reqwest::Client,
|
||||
|
||||
@@ -81,7 +77,6 @@ impl RouterManager {
|
||||
policy_registry,
|
||||
routers: Arc::new(DashMap::new()),
|
||||
default_router: None,
|
||||
model_routers: Arc::new(DashMap::new()),
|
||||
client,
|
||||
config,
|
||||
}
|
||||
@@ -92,19 +87,11 @@ impl RouterManager {
|
||||
&mut self,
|
||||
id: RouterId,
|
||||
router: Arc<dyn RouterTrait>,
|
||||
models: Vec<String>,
|
||||
_models: Vec<String>, // Keep parameter for backward compatibility but ignore it
|
||||
) {
|
||||
// Store router
|
||||
self.routers.insert(id.clone(), router);
|
||||
|
||||
// Update model mappings
|
||||
for model in models {
|
||||
self.model_routers
|
||||
.entry(model)
|
||||
.or_default()
|
||||
.push(id.clone());
|
||||
}
|
||||
|
||||
// Set as default if first router
|
||||
if self.default_router.is_none() {
|
||||
self.default_router = Some(id.clone());
|
||||
@@ -122,14 +109,29 @@ impl RouterManager {
|
||||
self.routers.len()
|
||||
}
|
||||
|
||||
/// Get router for a specific model
|
||||
/// Get router for a specific model based on worker types
|
||||
pub fn get_router_for_model(&self, model_id: &str) -> Option<Arc<dyn RouterTrait>> {
|
||||
// First try model-specific routers
|
||||
if let Some(router_ids) = self.model_routers.get(model_id) {
|
||||
if let Some(router_id) = router_ids.first() {
|
||||
if let Some(router) = self.routers.get(router_id) {
|
||||
return Some(router.clone());
|
||||
}
|
||||
// Query workers for this model from registry
|
||||
let workers = self.worker_registry.get_by_model(model_id);
|
||||
|
||||
if !workers.is_empty() {
|
||||
// Determine router based on worker types
|
||||
let has_pd_workers = workers.iter().any(|w| {
|
||||
matches!(
|
||||
w.worker_type(),
|
||||
WorkerType::Prefill { .. } | WorkerType::Decode
|
||||
)
|
||||
});
|
||||
|
||||
let router_id = if has_pd_workers {
|
||||
RouterId::new("http-pd".to_string())
|
||||
} else {
|
||||
RouterId::new("http-regular".to_string())
|
||||
};
|
||||
|
||||
// Return the router if it exists
|
||||
if let Some(router) = self.routers.get(&router_id) {
|
||||
return Some(router.clone());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -240,6 +242,17 @@ impl RouterManager {
|
||||
let policy_hint = labels.get("policy").map(|s| s.as_str());
|
||||
let policy = self.policy_registry.on_worker_added(&model_id, policy_hint);
|
||||
|
||||
// Log which type of router would handle this worker (for debugging)
|
||||
let expected_router = match config.worker_type.as_deref() {
|
||||
Some("prefill") | Some("decode") => "http-pd",
|
||||
_ => "http-regular",
|
||||
};
|
||||
|
||||
info!(
|
||||
"Worker for model '{}' would be handled by '{}' router based on type",
|
||||
model_id, expected_router
|
||||
);
|
||||
|
||||
info!(
|
||||
"Added worker {} with URL {} for model {} using policy {}",
|
||||
worker_id.as_str(),
|
||||
@@ -272,8 +285,9 @@ impl RouterManager {
|
||||
|
||||
if let Some(_worker) = self.worker_registry.remove_by_url(url) {
|
||||
// Notify PolicyRegistry about worker removal
|
||||
if let Some(model_id) = model_id {
|
||||
self.policy_registry.on_worker_removed(&model_id);
|
||||
if let Some(ref model_id) = model_id {
|
||||
self.policy_registry.on_worker_removed(model_id);
|
||||
|
||||
info!("Removed worker with URL {} for model {}", url, model_id);
|
||||
} else {
|
||||
info!("Removed worker with URL {}", url);
|
||||
@@ -406,14 +420,10 @@ impl RouterManager {
|
||||
})
|
||||
.unwrap_or(false);
|
||||
|
||||
// If model specified, find routers serving that model
|
||||
// If model specified, use get_router_for_model
|
||||
let candidate_routers = if let Some(model) = model_id {
|
||||
// Get routers for specific model
|
||||
if let Some(router_ids) = self.model_routers.get(model) {
|
||||
router_ids
|
||||
.iter()
|
||||
.filter_map(|id| self.routers.get(id).map(|r| r.clone()))
|
||||
.collect::<Vec<_>>()
|
||||
if let Some(router) = self.get_router_for_model(model) {
|
||||
vec![router]
|
||||
} else {
|
||||
Vec::new()
|
||||
}
|
||||
@@ -547,14 +557,10 @@ impl RouterTrait for RouterManager {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
/// Get available models - aggregate from all routers
|
||||
/// Get available models - query from worker registry
|
||||
async fn get_models(&self, _req: Request<Body>) -> Response {
|
||||
// Return models that have registered routers
|
||||
let models = self
|
||||
.model_routers
|
||||
.iter()
|
||||
.map(|entry| entry.key().clone())
|
||||
.collect::<Vec<_>>();
|
||||
// Get models from worker registry
|
||||
let models = self.worker_registry.get_models();
|
||||
|
||||
if models.is_empty() {
|
||||
(StatusCode::SERVICE_UNAVAILABLE, "No models available").into_response()
|
||||
|
||||
Reference in New Issue
Block a user