[router] Refactor router and policy traits with dependency injection (#7987)
Co-authored-by: Jin Pan <jpan236@wisc.edu> Co-authored-by: Keru Yang <rukeyang@gmail.com> Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com> Co-authored-by: Philip Zhu <phlipzhux@gmail.com>
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
use crate::router::Router;
|
||||
use crate::routers::RouterTrait;
|
||||
|
||||
use futures::{StreamExt, TryStreamExt};
|
||||
use k8s_openapi::api::core::v1::Pod;
|
||||
@@ -176,7 +176,7 @@ impl PodInfo {
|
||||
|
||||
pub async fn start_service_discovery(
|
||||
config: ServiceDiscoveryConfig,
|
||||
router: Arc<Router>,
|
||||
router: Arc<dyn RouterTrait>,
|
||||
) -> Result<task::JoinHandle<()>, kube::Error> {
|
||||
// Don't initialize anything if service discovery is disabled
|
||||
if !config.enabled {
|
||||
@@ -346,7 +346,7 @@ pub async fn start_service_discovery(
|
||||
async fn handle_pod_event(
|
||||
pod_info: &PodInfo,
|
||||
tracked_pods: Arc<Mutex<HashSet<PodInfo>>>,
|
||||
router: Arc<Router>,
|
||||
router: Arc<dyn RouterTrait>,
|
||||
port: u16,
|
||||
pd_mode: bool,
|
||||
) {
|
||||
@@ -379,17 +379,32 @@ async fn handle_pod_event(
|
||||
pod_info.name, pod_info.pod_type, worker_url
|
||||
);
|
||||
|
||||
// Handle PD mode with specific pod types
|
||||
let result = if pd_mode && pod_info.pod_type.is_some() {
|
||||
// Use PD-aware worker management
|
||||
if let Some(pod_type) = &pod_info.pod_type {
|
||||
router
|
||||
.add_pd_worker(&worker_url, pod_type.clone(), pod_info.bootstrap_port)
|
||||
.await
|
||||
// Need to import PDRouter type
|
||||
use crate::routers::pd_router::PDRouter;
|
||||
|
||||
// Try to downcast to PDRouter
|
||||
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
|
||||
match &pod_info.pod_type {
|
||||
Some(PodType::Prefill) => pd_router
|
||||
.add_prefill_server(worker_url.clone(), pod_info.bootstrap_port)
|
||||
.await
|
||||
.map_err(|e| e.to_string()),
|
||||
Some(PodType::Decode) => pd_router
|
||||
.add_decode_server(worker_url.clone())
|
||||
.await
|
||||
.map_err(|e| e.to_string()),
|
||||
Some(PodType::Regular) | None => {
|
||||
// Fall back to regular add_worker for regular pods
|
||||
router.add_worker(&worker_url).await
|
||||
}
|
||||
}
|
||||
} else {
|
||||
Err("Pod type is None in PD mode".to_string())
|
||||
Err("PD mode enabled but router is not a PDRouter".to_string())
|
||||
}
|
||||
} else {
|
||||
// Fallback to regular worker management
|
||||
// Regular mode or no pod type specified
|
||||
router.add_worker(&worker_url).await
|
||||
};
|
||||
|
||||
@@ -412,7 +427,7 @@ async fn handle_pod_event(
|
||||
async fn handle_pod_deletion(
|
||||
pod_info: &PodInfo,
|
||||
tracked_pods: Arc<Mutex<HashSet<PodInfo>>>,
|
||||
router: Arc<Router>,
|
||||
router: Arc<dyn RouterTrait>,
|
||||
port: u16,
|
||||
pd_mode: bool,
|
||||
) {
|
||||
@@ -435,18 +450,34 @@ async fn handle_pod_deletion(
|
||||
pod_info.name, pod_info.pod_type, worker_url
|
||||
);
|
||||
|
||||
// Handle PD mode removal
|
||||
if pd_mode && pod_info.pod_type.is_some() {
|
||||
// Use PD-aware worker removal
|
||||
if let Some(pod_type) = &pod_info.pod_type {
|
||||
if let Err(e) = router.remove_pd_worker(&worker_url, pod_type.clone()).await {
|
||||
error!(
|
||||
"Failed to remove PD worker {} from router: {}",
|
||||
worker_url, e
|
||||
);
|
||||
use crate::routers::pd_router::PDRouter;
|
||||
|
||||
// Try to downcast to PDRouter for PD-specific removal
|
||||
if let Some(pd_router) = router.as_any().downcast_ref::<PDRouter>() {
|
||||
match &pod_info.pod_type {
|
||||
Some(PodType::Prefill) => {
|
||||
if let Err(e) = pd_router.remove_prefill_server(&worker_url).await {
|
||||
error!("Failed to remove prefill server {}: {}", worker_url, e);
|
||||
}
|
||||
}
|
||||
Some(PodType::Decode) => {
|
||||
if let Err(e) = pd_router.remove_decode_server(&worker_url).await {
|
||||
error!("Failed to remove decode server {}: {}", worker_url, e);
|
||||
}
|
||||
}
|
||||
Some(PodType::Regular) | None => {
|
||||
// Fall back to regular remove_worker
|
||||
router.remove_worker(&worker_url);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// PD mode but not a PDRouter, use generic removal
|
||||
router.remove_worker(&worker_url);
|
||||
}
|
||||
} else {
|
||||
// Fallback to regular worker removal
|
||||
// Regular mode removal
|
||||
router.remove_worker(&worker_url);
|
||||
}
|
||||
} else {
|
||||
@@ -462,11 +493,9 @@ async fn handle_pod_deletion(
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::router::Router;
|
||||
use k8s_openapi::api::core::v1::{Pod, PodCondition, PodSpec, PodStatus};
|
||||
use k8s_openapi::apimachinery::pkg::apis::meta::v1::ObjectMeta;
|
||||
use k8s_openapi::apimachinery::pkg::apis::meta::v1::Time;
|
||||
use std::sync::RwLock;
|
||||
|
||||
// Helper function to create a Pod for testing PodInfo::from_pod
|
||||
fn create_k8s_pod(
|
||||
@@ -546,14 +575,14 @@ mod tests {
|
||||
}
|
||||
|
||||
// Helper to create a Router instance for testing event handlers
|
||||
fn create_test_router() -> Arc<Router> {
|
||||
let workers = Arc::new(RwLock::new(Vec::new()));
|
||||
Arc::new(Router::Random {
|
||||
workers,
|
||||
timeout_secs: 5,
|
||||
interval_secs: 1,
|
||||
_health_checker: None,
|
||||
})
|
||||
fn create_test_router() -> Arc<dyn RouterTrait> {
|
||||
use crate::config::PolicyConfig;
|
||||
use crate::policies::PolicyFactory;
|
||||
use crate::routers::router::Router;
|
||||
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
||||
let router = Router::new(vec![], policy, 5, 1).unwrap();
|
||||
Arc::new(router) as Arc<dyn RouterTrait>
|
||||
}
|
||||
|
||||
// Helper to create a PD config for testing
|
||||
|
||||
Reference in New Issue
Block a user