[router] Refactor router and policy traits with dependency injection (#7987)

Co-authored-by: Jin Pan <jpan236@wisc.edu>
Co-authored-by: Keru Yang <rukeyang@gmail.com>
Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com>
Co-authored-by: Philip Zhu <phlipzhux@gmail.com>
This commit is contained in:
Simo Lin
2025-07-18 14:24:24 -07:00
committed by GitHub
parent 1f76fc8747
commit c8f31042a8
24 changed files with 3190 additions and 1944 deletions

View File

@@ -1,4 +1,4 @@
use crate::router::Router;
use crate::routers::RouterTrait;
use futures::{StreamExt, TryStreamExt};
use k8s_openapi::api::core::v1::Pod;
@@ -176,7 +176,7 @@ impl PodInfo {
pub async fn start_service_discovery(
config: ServiceDiscoveryConfig,
router: Arc<Router>,
router: Arc<dyn RouterTrait>,
) -> Result<task::JoinHandle<()>, kube::Error> {
// Don't initialize anything if service discovery is disabled
if !config.enabled {
@@ -346,7 +346,7 @@ pub async fn start_service_discovery(
async fn handle_pod_event(
pod_info: &PodInfo,
tracked_pods: Arc<Mutex<HashSet<PodInfo>>>,
router: Arc<Router>,
router: Arc<dyn RouterTrait>,
port: u16,
pd_mode: bool,
) {
@@ -379,17 +379,32 @@ async fn handle_pod_event(
pod_info.name, pod_info.pod_type, worker_url
);
// Handle PD mode with specific pod types
let result = if pd_mode && pod_info.pod_type.is_some() {
// Use PD-aware worker management
if let Some(pod_type) = &pod_info.pod_type {
router
.add_pd_worker(&worker_url, pod_type.clone(), pod_info.bootstrap_port)
.await
// Need to import PDRouter type
use crate::routers::pd_router::PDRouter;
// Try to downcast to PDRouter
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
match &pod_info.pod_type {
Some(PodType::Prefill) => pd_router
.add_prefill_server(worker_url.clone(), pod_info.bootstrap_port)
.await
.map_err(|e| e.to_string()),
Some(PodType::Decode) => pd_router
.add_decode_server(worker_url.clone())
.await
.map_err(|e| e.to_string()),
Some(PodType::Regular) | None => {
// Fall back to regular add_worker for regular pods
router.add_worker(&worker_url).await
}
}
} else {
Err("Pod type is None in PD mode".to_string())
Err("PD mode enabled but router is not a PDRouter".to_string())
}
} else {
// Fallback to regular worker management
// Regular mode or no pod type specified
router.add_worker(&worker_url).await
};
@@ -412,7 +427,7 @@ async fn handle_pod_event(
async fn handle_pod_deletion(
pod_info: &PodInfo,
tracked_pods: Arc<Mutex<HashSet<PodInfo>>>,
router: Arc<Router>,
router: Arc<dyn RouterTrait>,
port: u16,
pd_mode: bool,
) {
@@ -435,18 +450,34 @@ async fn handle_pod_deletion(
pod_info.name, pod_info.pod_type, worker_url
);
// Handle PD mode removal
if pd_mode && pod_info.pod_type.is_some() {
// Use PD-aware worker removal
if let Some(pod_type) = &pod_info.pod_type {
if let Err(e) = router.remove_pd_worker(&worker_url, pod_type.clone()).await {
error!(
"Failed to remove PD worker {} from router: {}",
worker_url, e
);
use crate::routers::pd_router::PDRouter;
// Try to downcast to PDRouter for PD-specific removal
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
match &pod_info.pod_type {
Some(PodType::Prefill) => {
if let Err(e) = pd_router.remove_prefill_server(&worker_url).await {
error!("Failed to remove prefill server {}: {}", worker_url, e);
}
}
Some(PodType::Decode) => {
if let Err(e) = pd_router.remove_decode_server(&worker_url).await {
error!("Failed to remove decode server {}: {}", worker_url, e);
}
}
Some(PodType::Regular) | None => {
// Fall back to regular remove_worker
router.remove_worker(&worker_url);
}
}
} else {
// PD mode but not a PDRouter, use generic removal
router.remove_worker(&worker_url);
}
} else {
// Fallback to regular worker removal
// Regular mode removal
router.remove_worker(&worker_url);
}
} else {
@@ -462,11 +493,9 @@ async fn handle_pod_deletion(
#[cfg(test)]
mod tests {
use super::*;
use crate::router::Router;
use k8s_openapi::api::core::v1::{Pod, PodCondition, PodSpec, PodStatus};
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time;
use std::sync::RwLock;
// Helper function to create a Pod for testing PodInfo::from_pod
fn create_k8s_pod(
@@ -546,14 +575,14 @@ mod tests {
}
// Helper to create a Router instance for testing event handlers
fn create_test_router() -> Arc<Router> {
let workers = Arc::new(RwLock::new(Vec::new()));
Arc::new(Router::Random {
workers,
timeout_secs: 5,
interval_secs: 1,
_health_checker: None,
})
fn create_test_router() -> Arc<dyn RouterTrait> {
use crate::config::PolicyConfig;
use crate::policies::PolicyFactory;
use crate::routers::router::Router;
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
let router = Router::new(vec![], policy, 5, 1).unwrap();
Arc::new(router) as Arc<dyn RouterTrait>
}
// Helper to create a PD config for testing