[router] minor code clean up and and refactoring (#8711)
This commit is contained in:
@@ -35,7 +35,7 @@ pub struct PDRouter {
|
||||
pub interval_secs: u64,
|
||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||
pub http_client: reqwest::Client,
|
||||
pub http_client: Client,
|
||||
_prefill_health_checker: Option<HealthChecker>,
|
||||
_decode_health_checker: Option<HealthChecker>,
|
||||
}
|
||||
@@ -206,51 +206,17 @@ impl PDRouter {
|
||||
}
|
||||
|
||||
// Initialize cache-aware components if needed for prefill policy
|
||||
let prefill_tree = if prefill_policy.name() == "cache_aware" {
|
||||
// Initialize the policy's internal tree with prefill workers
|
||||
if let Some(cache_policy) = prefill_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_policy.init_workers(&prefill_workers);
|
||||
}
|
||||
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
// Initialize tree with prefill workers
|
||||
for worker in &prefill_workers {
|
||||
tree.lock().unwrap().insert("", worker.url());
|
||||
}
|
||||
Some(tree)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let prefill_tree = Self::initialize_radix_tree(&prefill_policy, &prefill_workers)?;
|
||||
|
||||
// Initialize cache-aware components if needed for decode policy
|
||||
let decode_tree = if decode_policy.name() == "cache_aware" {
|
||||
// Initialize the policy's internal tree with decode workers
|
||||
if let Some(cache_policy) = decode_policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_policy.init_workers(&decode_workers);
|
||||
}
|
||||
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
// Initialize tree with decode workers
|
||||
for worker in &decode_workers {
|
||||
tree.lock().unwrap().insert("", worker.url());
|
||||
}
|
||||
Some(tree)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let decode_tree = Self::initialize_radix_tree(&decode_policy, &decode_workers)?;
|
||||
|
||||
// Set up background load monitoring for power-of-two selection
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
let worker_loads = Arc::new(rx);
|
||||
|
||||
// Create a shared HTTP client for all operations
|
||||
let http_client = reqwest::Client::builder()
|
||||
let http_client = Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
@@ -304,6 +270,35 @@ impl PDRouter {
|
||||
})
|
||||
}
|
||||
|
||||
// Helper function to initialize radix tree for cache-aware policies
|
||||
fn initialize_radix_tree(
|
||||
policy: &Arc<dyn LoadBalancingPolicy>,
|
||||
workers: &[Box<dyn Worker>],
|
||||
) -> Result<Option<Arc<Mutex<Tree>>>, String> {
|
||||
if let Some(cache_policy) = policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
// Initialize the policy's internal tree with workers
|
||||
cache_policy.init_workers(workers);
|
||||
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
|
||||
{
|
||||
let tree_guard = tree
|
||||
.lock()
|
||||
.map_err(|e| format!("Failed to lock tree: {}", e))?;
|
||||
for worker in workers {
|
||||
tree_guard.insert("", worker.url());
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Some(tree))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
// Route a typed generate request
|
||||
pub async fn route_generate(
|
||||
&self,
|
||||
@@ -329,7 +324,7 @@ impl PDRouter {
|
||||
});
|
||||
|
||||
// Select servers
|
||||
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
||||
let (prefill, decode) = match self.select_pd_pair(request_text).await {
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
error!("Failed to select PD pair error={}", e);
|
||||
@@ -417,7 +412,7 @@ impl PDRouter {
|
||||
.and_then(|content| content.as_str());
|
||||
|
||||
// Select servers
|
||||
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
||||
let (prefill, decode) = match self.select_pd_pair(request_text).await {
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
error!("Failed to select PD pair error={}", e);
|
||||
@@ -498,7 +493,7 @@ impl PDRouter {
|
||||
};
|
||||
|
||||
// Select servers
|
||||
let (prefill, decode) = match self.select_pd_pair(client, request_text).await {
|
||||
let (prefill, decode) = match self.select_pd_pair(request_text).await {
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
error!("Failed to select PD pair error={}", e);
|
||||
@@ -833,7 +828,6 @@ impl PDRouter {
|
||||
// Select a pair of prefill and decode servers
|
||||
async fn select_pd_pair(
|
||||
&self,
|
||||
_client: &Client,
|
||||
request_text: Option<&str>,
|
||||
) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
||||
// Get read locks for both worker lists
|
||||
@@ -998,7 +992,7 @@ impl PDRouter {
|
||||
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
|
||||
|
||||
// Select a random worker pair using the policy
|
||||
let (prefill, decode) = match self.select_pd_pair(client, None).await {
|
||||
let (prefill, decode) = match self.select_pd_pair(None).await {
|
||||
Ok(pair) => pair,
|
||||
Err(e) => {
|
||||
return (
|
||||
@@ -1921,8 +1915,7 @@ mod tests {
|
||||
router.prefill_workers.write().unwrap().push(healthy_worker);
|
||||
router.decode_workers.write().unwrap().push(decode_worker);
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let result = router.select_pd_pair(&client, None).await;
|
||||
let result = router.select_pd_pair(None).await;
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (prefill, _decode) = result.unwrap();
|
||||
@@ -1936,8 +1929,7 @@ mod tests {
|
||||
async fn test_empty_worker_lists() {
|
||||
let router = create_test_pd_router();
|
||||
|
||||
let client = reqwest::Client::new();
|
||||
let result = router.select_pd_pair(&client, None).await;
|
||||
let result = router.select_pd_pair(None).await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().contains("No prefill workers available"));
|
||||
|
||||
Reference in New Issue
Block a user