[router] refactor router and worker management 2.5/n (#10677)
This commit is contained in:
@@ -9,6 +9,7 @@ use super::{
|
||||
RoundRobinPolicy,
|
||||
};
|
||||
use crate::config::types::PolicyConfig;
|
||||
use crate::core::Worker;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use tracing::{debug, info, warn};
|
||||
@@ -255,6 +256,81 @@ impl PolicyRegistry {
|
||||
.map(Arc::clone)
|
||||
.unwrap_or_else(|| self.get_default_policy())
|
||||
}
|
||||
|
||||
/// Initialize cache-aware policy with workers if applicable
|
||||
/// This should be called after workers are registered for a model
|
||||
pub fn init_cache_aware_policy(&self, model_id: &str, workers: &[Arc<dyn Worker>]) {
|
||||
// Get the policy for this model
|
||||
if let Some(policy) = self.get_policy(model_id) {
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy.as_any().downcast_ref::<CacheAwarePolicy>() {
|
||||
debug!(
|
||||
"Initializing cache-aware policy with {} workers for model {}",
|
||||
workers.len(),
|
||||
model_id
|
||||
);
|
||||
cache_aware.init_workers(workers);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a worker from cache-aware policy if applicable
|
||||
/// This should be called when a worker is being removed
|
||||
pub fn remove_worker_from_cache_aware(&self, model_id: &str, worker_url: &str) {
|
||||
// Get the policy for this model
|
||||
if let Some(policy) = self.get_policy(model_id) {
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy.as_any().downcast_ref::<CacheAwarePolicy>() {
|
||||
cache_aware.remove_worker_by_url(worker_url);
|
||||
debug!(
|
||||
"Removed worker {} from cache-aware policy for model {}",
|
||||
worker_url, model_id
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Initialize cache-aware policies for PD mode (prefill and decode)
|
||||
pub fn init_pd_cache_aware_policies(
|
||||
&self,
|
||||
prefill_workers: &[Arc<dyn Worker>],
|
||||
decode_workers: &[Arc<dyn Worker>],
|
||||
) {
|
||||
// Initialize prefill policy if it's cache-aware
|
||||
if let Some(prefill_policy) = self.prefill_policy.read().unwrap().as_ref() {
|
||||
if prefill_policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) =
|
||||
prefill_policy.as_any().downcast_ref::<CacheAwarePolicy>()
|
||||
{
|
||||
if !prefill_workers.is_empty() {
|
||||
debug!(
|
||||
"Initializing prefill cache-aware policy with {} workers",
|
||||
prefill_workers.len()
|
||||
);
|
||||
cache_aware.init_workers(prefill_workers);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize decode policy if it's cache-aware
|
||||
if let Some(decode_policy) = self.decode_policy.read().unwrap().as_ref() {
|
||||
if decode_policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = decode_policy.as_any().downcast_ref::<CacheAwarePolicy>()
|
||||
{
|
||||
if !decode_workers.is_empty() {
|
||||
debug!(
|
||||
"Initializing decode cache-aware policy with {} workers",
|
||||
decode_workers.len()
|
||||
);
|
||||
cache_aware.init_workers(decode_workers);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Debug for PolicyRegistry {
|
||||
|
||||
@@ -232,18 +232,12 @@ impl PDRouter {
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = self.policy_registry.on_worker_added(model_id, None);
|
||||
self.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy, update it with all workers for this model
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
let model_workers = self.worker_registry.get_by_model_fast(model_id);
|
||||
cache_aware.init_workers(&model_workers);
|
||||
}
|
||||
}
|
||||
// Initialize cache-aware policy if applicable
|
||||
let model_workers = self.worker_registry.get_by_model_fast(model_id);
|
||||
self.policy_registry
|
||||
.init_cache_aware_policy(model_id, &model_workers);
|
||||
|
||||
info!("Added prefill server: {}", url);
|
||||
Ok(format!("Successfully added prefill server: {}", url))
|
||||
@@ -272,18 +266,12 @@ impl PDRouter {
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = self.policy_registry.on_worker_added(model_id, None);
|
||||
self.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy, update it with all workers for this model
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
let model_workers = self.worker_registry.get_by_model_fast(model_id);
|
||||
cache_aware.init_workers(&model_workers);
|
||||
}
|
||||
}
|
||||
// Initialize cache-aware policy if applicable
|
||||
let model_workers = self.worker_registry.get_by_model_fast(model_id);
|
||||
self.policy_registry
|
||||
.init_cache_aware_policy(model_id, &model_workers);
|
||||
|
||||
info!("Added decode server: {}", url);
|
||||
Ok(format!("Successfully added decode server: {}", url))
|
||||
@@ -307,17 +295,9 @@ impl PDRouter {
|
||||
// Notify PolicyRegistry about the removed worker
|
||||
self.policy_registry.on_worker_removed(&model_id);
|
||||
|
||||
// Get the policy for this model to update cache-aware if needed
|
||||
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker_by_url(url);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Remove from cache-aware policy if applicable
|
||||
self.policy_registry
|
||||
.remove_worker_from_cache_aware(&model_id, url);
|
||||
}
|
||||
|
||||
if removed.is_some() {
|
||||
@@ -348,17 +328,9 @@ impl PDRouter {
|
||||
// Notify PolicyRegistry about the removed worker
|
||||
self.policy_registry.on_worker_removed(&model_id);
|
||||
|
||||
// Get the policy for this model to update cache-aware if needed
|
||||
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker_by_url(url);
|
||||
}
|
||||
}
|
||||
}
|
||||
// Remove from cache-aware policy if applicable
|
||||
self.policy_registry
|
||||
.remove_worker_from_cache_aware(&model_id, url);
|
||||
}
|
||||
|
||||
if removed.is_some() {
|
||||
@@ -2226,15 +2198,6 @@ mod tests {
|
||||
assert_eq!(prefill_workers.len(), 1);
|
||||
}
|
||||
|
||||
// ============= Bootstrap Injection Tests =============
|
||||
// Note: These tests are commented out as we've moved to the optimized bootstrap injection
|
||||
// approach that doesn't use the Bootstrap trait on GenerateReqInput anymore.
|
||||
|
||||
// TODO: Add new tests for the optimized bootstrap injection approach using
|
||||
// RequestWithBootstrap and BatchRequestWithBootstrap wrappers
|
||||
|
||||
// ============= Worker Selection Tests =============
|
||||
|
||||
#[tokio::test]
|
||||
async fn test_select_healthy_prefill_worker() {
|
||||
let router = create_test_pd_router();
|
||||
|
||||
@@ -70,21 +70,15 @@ impl Router {
|
||||
window_duration: Duration::from_secs(circuit_breaker_config.window_duration_secs),
|
||||
};
|
||||
|
||||
// Initialize cache-aware policy with workers if needed
|
||||
let default_policy = ctx.policy_registry.get_default_policy();
|
||||
if default_policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = default_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.init_workers(&workers);
|
||||
}
|
||||
}
|
||||
// Cache-aware policies are initialized in WorkerInitializer
|
||||
|
||||
// Setup load monitoring for PowerOfTwo policy
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
let worker_loads = Arc::new(rx);
|
||||
|
||||
// Get default policy to check if we need load monitoring
|
||||
let default_policy = ctx.policy_registry.get_default_policy();
|
||||
|
||||
// Check if default policy is power_of_two for load monitoring
|
||||
let load_monitor_handle = if default_policy.name() == "power_of_two" {
|
||||
let monitor_urls = worker_urls.clone();
|
||||
@@ -964,19 +958,13 @@ impl Router {
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = self.policy_registry.on_worker_added(model_id, None);
|
||||
self.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy, update it with all workers for this model
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>(
|
||||
) {
|
||||
let model_workers =
|
||||
self.worker_registry.get_by_model_fast(model_id);
|
||||
cache_aware.init_workers(&model_workers);
|
||||
}
|
||||
}
|
||||
// Initialize cache-aware policy if applicable
|
||||
let model_workers =
|
||||
self.worker_registry.get_by_model_fast(model_id);
|
||||
self.policy_registry
|
||||
.init_cache_aware_policy(model_id, &model_workers);
|
||||
|
||||
worker_added = true;
|
||||
}
|
||||
@@ -1000,20 +988,12 @@ impl Router {
|
||||
|
||||
// Notify PolicyRegistry about the new worker
|
||||
let model_id = worker_arc.model_id();
|
||||
let policy = self.policy_registry.on_worker_added(model_id, None);
|
||||
self.policy_registry.on_worker_added(model_id, None);
|
||||
|
||||
// If this is a cache-aware policy, add this worker to it
|
||||
if policy.name() == "cache_aware" {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>(
|
||||
) {
|
||||
// Get all workers for this model
|
||||
let model_workers =
|
||||
self.worker_registry.get_by_model_fast(model_id);
|
||||
cache_aware.init_workers(&model_workers);
|
||||
}
|
||||
}
|
||||
// Initialize cache-aware policy if applicable
|
||||
let model_workers = self.worker_registry.get_by_model_fast(model_id);
|
||||
self.policy_registry
|
||||
.init_cache_aware_policy(model_id, &model_workers);
|
||||
}
|
||||
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
@@ -1084,20 +1064,11 @@ impl Router {
|
||||
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
|
||||
// If any models are using cache aware policy, remove the workers from the tree
|
||||
// Check each removed worker's model and get its policy
|
||||
for dp_url in removed_workers.iter() {
|
||||
if let Some(worker) = self.worker_registry.get_by_url(dp_url) {
|
||||
let model_id = worker.model_id();
|
||||
if let Some(policy) = self.policy_registry.get_policy(model_id) {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker_by_url(dp_url);
|
||||
info!("Removed worker from cache-aware tree: {}", dp_url);
|
||||
}
|
||||
}
|
||||
self.policy_registry
|
||||
.remove_worker_from_cache_aware(model_id, dp_url);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -1118,16 +1089,8 @@ impl Router {
|
||||
RouterMetrics::set_active_workers(self.worker_registry.get_all().len());
|
||||
}
|
||||
|
||||
// If the model is using cache aware policy, remove the worker from the tree
|
||||
if let Some(policy) = self.policy_registry.get_policy(&model_id) {
|
||||
if let Some(cache_aware) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker_by_url(worker_url);
|
||||
info!("Removed worker from cache-aware tree: {}", worker_url);
|
||||
}
|
||||
}
|
||||
self.policy_registry
|
||||
.remove_worker_from_cache_aware(&model_id, worker_url);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
|
||||
use crate::config::types::{ConnectionMode as ConfigConnectionMode, RouterConfig, RoutingMode};
|
||||
use crate::core::{
|
||||
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, HealthConfig, WorkerRegistry,
|
||||
BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode, HealthConfig, Worker, WorkerRegistry,
|
||||
WorkerType,
|
||||
};
|
||||
use crate::policies::PolicyRegistry;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tracing::{info, warn};
|
||||
@@ -19,6 +21,7 @@ impl WorkerInitializer {
|
||||
pub async fn initialize_workers(
|
||||
config: &RouterConfig,
|
||||
worker_registry: &Arc<WorkerRegistry>,
|
||||
policy_registry: Option<&Arc<PolicyRegistry>>,
|
||||
) -> Result<(), String> {
|
||||
info!("Initializing workers for routing mode: {:?}", config.mode);
|
||||
|
||||
@@ -29,6 +32,7 @@ impl WorkerInitializer {
|
||||
&config.connection_mode,
|
||||
config,
|
||||
worker_registry,
|
||||
policy_registry,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
@@ -42,6 +46,7 @@ impl WorkerInitializer {
|
||||
&config.connection_mode,
|
||||
config,
|
||||
worker_registry,
|
||||
policy_registry,
|
||||
)
|
||||
.await?;
|
||||
Self::create_decode_workers(
|
||||
@@ -49,6 +54,7 @@ impl WorkerInitializer {
|
||||
&config.connection_mode,
|
||||
config,
|
||||
worker_registry,
|
||||
policy_registry,
|
||||
)
|
||||
.await?;
|
||||
}
|
||||
@@ -76,6 +82,7 @@ impl WorkerInitializer {
|
||||
config_connection_mode: &ConfigConnectionMode,
|
||||
config: &RouterConfig,
|
||||
registry: &Arc<WorkerRegistry>,
|
||||
policy_registry: Option<&Arc<PolicyRegistry>>,
|
||||
) -> Result<(), String> {
|
||||
info!("Creating {} regular workers", urls.len());
|
||||
|
||||
@@ -100,6 +107,8 @@ impl WorkerInitializer {
|
||||
success_threshold: config.health_check.success_threshold,
|
||||
};
|
||||
|
||||
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
|
||||
|
||||
for url in urls {
|
||||
// TODO: Add DP-aware support when we have dp_rank/dp_size info
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
@@ -109,8 +118,28 @@ impl WorkerInitializer {
|
||||
.health_config(health_config.clone())
|
||||
.build();
|
||||
|
||||
let worker_id = registry.register(Arc::new(worker));
|
||||
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
|
||||
let model_id = worker_arc.model_id();
|
||||
let worker_id = registry.register(Arc::clone(&worker_arc));
|
||||
info!("Registered regular worker {} with ID {:?}", url, worker_id);
|
||||
|
||||
// Track workers by model for cache-aware policy initialization
|
||||
registered_workers
|
||||
.entry(model_id.to_string())
|
||||
.or_default()
|
||||
.push(Arc::clone(&worker_arc));
|
||||
|
||||
// Notify policy registry about the worker
|
||||
if let Some(policy_reg) = policy_registry {
|
||||
policy_reg.on_worker_added(model_id, None);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize cache-aware policies with all workers for each model
|
||||
if let Some(policy_reg) = policy_registry {
|
||||
for (model_id, workers) in registered_workers {
|
||||
policy_reg.init_cache_aware_policy(&model_id, &workers);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -122,6 +151,7 @@ impl WorkerInitializer {
|
||||
config_connection_mode: &ConfigConnectionMode,
|
||||
config: &RouterConfig,
|
||||
registry: &Arc<WorkerRegistry>,
|
||||
policy_registry: Option<&Arc<PolicyRegistry>>,
|
||||
) -> Result<(), String> {
|
||||
info!("Creating {} prefill workers", prefill_entries.len());
|
||||
|
||||
@@ -149,6 +179,8 @@ impl WorkerInitializer {
|
||||
success_threshold: config.health_check.success_threshold,
|
||||
};
|
||||
|
||||
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
|
||||
|
||||
for (url, bootstrap_port) in prefill_entries {
|
||||
// TODO: Add DP-aware support when we have dp_rank/dp_size info
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
@@ -160,8 +192,33 @@ impl WorkerInitializer {
|
||||
.health_config(health_config.clone())
|
||||
.build();
|
||||
|
||||
let worker_id = registry.register(Arc::new(worker));
|
||||
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
|
||||
let model_id = worker_arc.model_id();
|
||||
let worker_id = registry.register(Arc::clone(&worker_arc));
|
||||
info!("Registered prefill worker {} with ID {:?}", url, worker_id);
|
||||
|
||||
// Track workers by model for cache-aware policy initialization
|
||||
registered_workers
|
||||
.entry(model_id.to_string())
|
||||
.or_default()
|
||||
.push(Arc::clone(&worker_arc));
|
||||
|
||||
// Notify policy registry about the worker
|
||||
if let Some(policy_reg) = policy_registry {
|
||||
policy_reg.on_worker_added(model_id, None);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize cache-aware policies for PD mode
|
||||
if let Some(policy_reg) = policy_registry {
|
||||
// Collect all prefill workers
|
||||
let all_prefill_workers: Vec<Arc<dyn Worker>> = registered_workers
|
||||
.values()
|
||||
.flat_map(|workers| workers.iter().cloned())
|
||||
.collect();
|
||||
|
||||
// Initialize PD policies (will handle both prefill and decode, but we only have prefill here)
|
||||
policy_reg.init_pd_cache_aware_policies(&all_prefill_workers, &[]);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -173,6 +230,7 @@ impl WorkerInitializer {
|
||||
config_connection_mode: &ConfigConnectionMode,
|
||||
config: &RouterConfig,
|
||||
registry: &Arc<WorkerRegistry>,
|
||||
policy_registry: Option<&Arc<PolicyRegistry>>,
|
||||
) -> Result<(), String> {
|
||||
info!("Creating {} decode workers", urls.len());
|
||||
|
||||
@@ -197,6 +255,8 @@ impl WorkerInitializer {
|
||||
success_threshold: config.health_check.success_threshold,
|
||||
};
|
||||
|
||||
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
|
||||
|
||||
for url in urls {
|
||||
// TODO: Add DP-aware support when we have dp_rank/dp_size info
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
@@ -206,8 +266,33 @@ impl WorkerInitializer {
|
||||
.health_config(health_config.clone())
|
||||
.build();
|
||||
|
||||
let worker_id = registry.register(Arc::new(worker));
|
||||
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
|
||||
let model_id = worker_arc.model_id();
|
||||
let worker_id = registry.register(Arc::clone(&worker_arc));
|
||||
info!("Registered decode worker {} with ID {:?}", url, worker_id);
|
||||
|
||||
// Track workers by model for cache-aware policy initialization
|
||||
registered_workers
|
||||
.entry(model_id.to_string())
|
||||
.or_default()
|
||||
.push(Arc::clone(&worker_arc));
|
||||
|
||||
// Notify policy registry about the worker
|
||||
if let Some(policy_reg) = policy_registry {
|
||||
policy_reg.on_worker_added(model_id, None);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize cache-aware policies for PD mode
|
||||
if let Some(policy_reg) = policy_registry {
|
||||
// Collect all decode workers
|
||||
let all_decode_workers: Vec<Arc<dyn Worker>> = registered_workers
|
||||
.values()
|
||||
.flat_map(|workers| workers.iter().cloned())
|
||||
.collect();
|
||||
|
||||
// Initialize PD policies (will handle both prefill and decode, but we only have decode here)
|
||||
policy_reg.init_pd_cache_aware_policies(&[], &all_decode_workers);
|
||||
}
|
||||
|
||||
Ok(())
|
||||
@@ -281,7 +366,8 @@ impl WorkerInitializer {
|
||||
worker_type: WorkerType,
|
||||
config: &RouterConfig,
|
||||
registry: &Arc<WorkerRegistry>,
|
||||
grpc_clients: &mut std::collections::HashMap<String, crate::grpc::SglangSchedulerClient>,
|
||||
policy_registry: Option<&Arc<PolicyRegistry>>,
|
||||
grpc_clients: &mut HashMap<String, crate::grpc::SglangSchedulerClient>,
|
||||
) -> Result<(), String> {
|
||||
info!(
|
||||
"Creating {} gRPC workers of type {:?}",
|
||||
@@ -307,6 +393,8 @@ impl WorkerInitializer {
|
||||
success_threshold: config.health_check.success_threshold,
|
||||
};
|
||||
|
||||
let mut registered_workers: HashMap<String, Vec<Arc<dyn Worker>>> = HashMap::new();
|
||||
|
||||
for url in worker_urls {
|
||||
if let Some(client) = grpc_clients.remove(url) {
|
||||
let worker = BasicWorkerBuilder::new(url.clone())
|
||||
@@ -317,13 +405,33 @@ impl WorkerInitializer {
|
||||
.grpc_client(client)
|
||||
.build();
|
||||
|
||||
let worker_id = registry.register(Arc::new(worker));
|
||||
let worker_arc = Arc::new(worker) as Arc<dyn Worker>;
|
||||
let model_id = worker_arc.model_id();
|
||||
let worker_id = registry.register(Arc::clone(&worker_arc));
|
||||
info!("Registered gRPC worker {} with ID {:?}", url, worker_id);
|
||||
|
||||
// Track workers by model for cache-aware policy initialization
|
||||
registered_workers
|
||||
.entry(model_id.to_string())
|
||||
.or_default()
|
||||
.push(Arc::clone(&worker_arc));
|
||||
|
||||
// Notify policy registry about the worker
|
||||
if let Some(policy_reg) = policy_registry {
|
||||
policy_reg.on_worker_added(model_id, None);
|
||||
}
|
||||
} else {
|
||||
warn!("No gRPC client available for worker {}, skipping", url);
|
||||
}
|
||||
}
|
||||
|
||||
// Initialize cache-aware policies with all workers for each model
|
||||
if let Some(policy_reg) = policy_registry {
|
||||
for (model_id, workers) in registered_workers {
|
||||
policy_reg.init_cache_aware_policy(&model_id, &workers);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -595,15 +595,17 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
||||
|
||||
let app_context = Arc::new(app_context);
|
||||
|
||||
// Initialize workers before creating routers
|
||||
// This separates worker lifecycle from router lifecycle
|
||||
info!(
|
||||
"Initializing workers for routing mode: {:?}",
|
||||
config.router_config.mode
|
||||
);
|
||||
WorkerInitializer::initialize_workers(&config.router_config, &app_context.worker_registry)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to initialize workers: {}", e))?;
|
||||
WorkerInitializer::initialize_workers(
|
||||
&config.router_config,
|
||||
&app_context.worker_registry,
|
||||
Some(&app_context.policy_registry),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| format!("Failed to initialize workers: {}", e))?;
|
||||
|
||||
let worker_stats = app_context.worker_registry.stats();
|
||||
info!(
|
||||
|
||||
@@ -104,7 +104,7 @@ impl TestContext {
|
||||
// Initialize workers in the registry before creating router
|
||||
if !worker_urls.is_empty() {
|
||||
use sglang_router_rs::routers::WorkerInitializer;
|
||||
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
|
||||
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry, None)
|
||||
.await
|
||||
.expect("Failed to initialize workers");
|
||||
}
|
||||
|
||||
@@ -48,7 +48,7 @@ impl TestContext {
|
||||
// Initialize workers in the registry before creating router
|
||||
if !worker_urls.is_empty() {
|
||||
use sglang_router_rs::routers::WorkerInitializer;
|
||||
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
|
||||
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry, None)
|
||||
.await
|
||||
.expect("Failed to initialize workers");
|
||||
}
|
||||
|
||||
@@ -49,7 +49,7 @@ impl TestContext {
|
||||
// Initialize workers in the registry before creating router
|
||||
if !worker_urls.is_empty() {
|
||||
use sglang_router_rs::routers::WorkerInitializer;
|
||||
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry)
|
||||
WorkerInitializer::initialize_workers(&config, &app_context.worker_registry, None)
|
||||
.await
|
||||
.expect("Failed to initialize workers");
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user