[router] minor code clean up and and refactoring (#8711)
This commit is contained in:
@@ -35,7 +35,7 @@ pub struct PDRouter {
|
|||||||
pub interval_secs: u64,
|
pub interval_secs: u64,
|
||||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||||
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||||
pub http_client: reqwest::Client,
|
pub http_client: Client,
|
||||||
_prefill_health_checker: Option<HealthChecker>,
|
_prefill_health_checker: Option<HealthChecker>,
|
||||||
_decode_health_checker: Option<HealthChecker>,
|
_decode_health_checker: Option<HealthChecker>,
|
||||||
}
|
}
|
||||||
@@ -206,51 +206,17 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Initialize cache-aware components if needed for prefill policy
|
// Initialize cache-aware components if needed for prefill policy
|
||||||
let prefill_tree = if prefill_policy.name() == "cache_aware" {
|
let prefill_tree = Self::initialize_radix_tree(&prefill_policy, &prefill_workers)?;
|
||||||
// Initialize the policy's internal tree with prefill workers
|
|
||||||
if let Some(cache_policy) = prefill_policy
|
|
||||||
.as_any()
|
|
||||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
|
||||||
{
|
|
||||||
cache_policy.init_workers(&prefill_workers);
|
|
||||||
}
|
|
||||||
|
|
||||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
|
||||||
// Initialize tree with prefill workers
|
|
||||||
for worker in &prefill_workers {
|
|
||||||
tree.lock().unwrap().insert("", worker.url());
|
|
||||||
}
|
|
||||||
Some(tree)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Initialize cache-aware components if needed for decode policy
|
// Initialize cache-aware components if needed for decode policy
|
||||||
let decode_tree = if decode_policy.name() == "cache_aware" {
|
let decode_tree = Self::initialize_radix_tree(&decode_policy, &decode_workers)?;
|
||||||
// Initialize the policy's internal tree with decode workers
|
|
||||||
if let Some(cache_policy) = decode_policy
|
|
||||||
.as_any()
|
|
||||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
|
||||||
{
|
|
||||||
cache_policy.init_workers(&decode_workers);
|
|
||||||
}
|
|
||||||
|
|
||||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
|
||||||
// Initialize tree with decode workers
|
|
||||||
for worker in &decode_workers {
|
|
||||||
tree.lock().unwrap().insert("", worker.url());
|
|
||||||
}
|
|
||||||
Some(tree)
|
|
||||||
} else {
|
|
||||||
None
|
|
||||||
};
|
|
||||||
|
|
||||||
// Set up background load monitoring for power-of-two selection
|
// Set up background load monitoring for power-of-two selection
|
||||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||||
let worker_loads = Arc::new(rx);
|
let worker_loads = Arc::new(rx);
|
||||||
|
|
||||||
// Create a shared HTTP client for all operations
|
// Create a shared HTTP client for all operations
|
||||||
let http_client = reqwest::Client::builder()
|
let http_client = Client::builder()
|
||||||
.timeout(Duration::from_secs(timeout_secs))
|
.timeout(Duration::from_secs(timeout_secs))
|
||||||
.build()
|
.build()
|
||||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||||
@@ -304,6 +270,35 @@ impl PDRouter {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Helper function to initialize radix tree for cache-aware policies
|
||||||
|
fn initialize_radix_tree(
|
||||||
|
policy: &Arc<dyn LoadBalancingPolicy>,
|
||||||
|
workers: &[Box<dyn Worker>],
|
||||||
|
) -> Result<Option<Arc<Mutex<Tree>>>, String> {
|
||||||
|
if let Some(cache_policy) = policy
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
|
{
|
||||||
|
// Initialize the policy's internal tree with workers
|
||||||
|
cache_policy.init_workers(workers);
|
||||||
|
|
||||||
|
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||||
|
|
||||||
|
{
|
||||||
|
let tree_guard = tree
|
||||||
|
.lock()
|
||||||
|
.map_err(|e| format!("Failed to lock tree: {}", e))?;
|
||||||
|
for worker in workers {
|
||||||
|
tree_guard.insert("", worker.url());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Some(tree))
|
||||||
|
} else {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Route a typed generate request
|
// Route a typed generate request
|
||||||
pub async fn route_generate(
|
pub async fn route_generate(
|
||||||
&self,
|
&self,
|
||||||
@@ -329,7 +324,7 @@ impl PDRouter {
|
|||||||
});
|
});
|
||||||
|
|
||||||
// Select servers
|
// Select servers
|
||||||
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
let (prefill, decode) = match self.select_pd_pair(request_text).await {
|
||||||
Ok(pair) => pair,
|
Ok(pair) => pair,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to select PD pair error={}", e);
|
error!("Failed to select PD pair error={}", e);
|
||||||
@@ -417,7 +412,7 @@ impl PDRouter {
|
|||||||
.and_then(|content| content.as_str());
|
.and_then(|content| content.as_str());
|
||||||
|
|
||||||
// Select servers
|
// Select servers
|
||||||
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
let (prefill, decode) = match self.select_pd_pair(request_text).await {
|
||||||
Ok(pair) => pair,
|
Ok(pair) => pair,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to select PD pair error={}", e);
|
error!("Failed to select PD pair error={}", e);
|
||||||
@@ -498,7 +493,7 @@ impl PDRouter {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Select servers
|
// Select servers
|
||||||
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
let (prefill, decode) = match self.select_pd_pair(request_text).await {
|
||||||
Ok(pair) => pair,
|
Ok(pair) => pair,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to select PD pair error={}", e);
|
error!("Failed to select PD pair error={}", e);
|
||||||
@@ -833,7 +828,6 @@ impl PDRouter {
|
|||||||
// Select a pair of prefill and decode servers
|
// Select a pair of prefill and decode servers
|
||||||
async fn select_pd_pair(
|
async fn select_pd_pair(
|
||||||
&self,
|
&self,
|
||||||
_client: &Client,
|
|
||||||
request_text: Option<&str>,
|
request_text: Option<&str>,
|
||||||
) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
||||||
// Get read locks for both worker lists
|
// Get read locks for both worker lists
|
||||||
@@ -998,7 +992,7 @@ impl PDRouter {
|
|||||||
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
|
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
|
||||||
|
|
||||||
// Select a random worker pair using the policy
|
// Select a random worker pair using the policy
|
||||||
let (prefill, decode) = match self.select_pd_pair(client, None).await {
|
let (prefill, decode) = match self.select_pd_pair(None).await {
|
||||||
Ok(pair) => pair,
|
Ok(pair) => pair,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return (
|
return (
|
||||||
@@ -1921,8 +1915,7 @@ mod tests {
|
|||||||
router.prefill_workers.write().unwrap().push(healthy_worker);
|
router.prefill_workers.write().unwrap().push(healthy_worker);
|
||||||
router.decode_workers.write().unwrap().push(decode_worker);
|
router.decode_workers.write().unwrap().push(decode_worker);
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let result = router.select_pd_pair(None).await;
|
||||||
let result = router.select_pd_pair(&client, None).await;
|
|
||||||
|
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
let (prefill, _decode) = result.unwrap();
|
let (prefill, _decode) = result.unwrap();
|
||||||
@@ -1936,8 +1929,7 @@ mod tests {
|
|||||||
async fn test_empty_worker_lists() {
|
async fn test_empty_worker_lists() {
|
||||||
let router = create_test_pd_router();
|
let router = create_test_pd_router();
|
||||||
|
|
||||||
let client = reqwest::Client::new();
|
let result = router.select_pd_pair(None).await;
|
||||||
let result = router.select_pd_pair(&client, None).await;
|
|
||||||
|
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
assert!(result.unwrap_err().contains("No prefill workers available"));
|
assert!(result.unwrap_err().contains("No prefill workers available"));
|
||||||
|
|||||||
Reference in New Issue
Block a user