[router] add worker abstraction (#7960)
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
// PD (Prefill-Decode) Router Implementation
|
||||
// This module handles routing for disaggregated prefill-decode systems
|
||||
|
||||
use crate::core::{HealthChecker, Worker, WorkerFactory, WorkerLoadGuard};
|
||||
use crate::pd_types::{
|
||||
Bootstrap, ChatReqInput, EngineInfo, GenerateReqInput, PDRouterError, PDSelectionPolicy,
|
||||
api_path, Bootstrap, ChatReqInput, GenerateReqInput, PDRouterError, PDSelectionPolicy,
|
||||
};
|
||||
use crate::tree::Tree;
|
||||
use actix_web::http::header::{HeaderValue, CONTENT_TYPE};
|
||||
@@ -11,7 +12,6 @@ use futures_util::{StreamExt, TryStreamExt};
|
||||
use metrics::{counter, histogram};
|
||||
use serde_json::Value;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::{Arc, Mutex, RwLock};
|
||||
use std::time::{Duration, Instant};
|
||||
use tracing::{debug, error, info, warn};
|
||||
@@ -21,49 +21,17 @@ use uuid::Uuid;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct PDRouter {
|
||||
pub prefill_workers: Arc<RwLock<Vec<EngineInfo>>>,
|
||||
pub decode_workers: Arc<RwLock<Vec<EngineInfo>>>,
|
||||
pub prefill_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
pub decode_workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
pub selection_policy: PDSelectionPolicy,
|
||||
pub load_tracking: Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
|
||||
pub prefill_tree: Option<Arc<Mutex<Tree>>>,
|
||||
pub timeout_secs: u64,
|
||||
pub interval_secs: u64,
|
||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||
pub http_client: reqwest::Client,
|
||||
}
|
||||
|
||||
// RAII guard for load tracking to ensure cleanup even on panic
|
||||
struct LoadGuard<'a> {
|
||||
tracking: &'a Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
|
||||
urls: Vec<String>,
|
||||
}
|
||||
|
||||
impl<'a> LoadGuard<'a> {
|
||||
fn new(
|
||||
tracking: &'a Arc<dashmap::DashMap<String, Arc<AtomicUsize>>>,
|
||||
urls: Vec<String>,
|
||||
) -> Self {
|
||||
// Increment counters
|
||||
for url in &urls {
|
||||
let counter = tracking
|
||||
.entry(url.clone())
|
||||
.or_insert_with(|| Arc::new(AtomicUsize::new(0)));
|
||||
counter.fetch_add(1, Ordering::Relaxed);
|
||||
}
|
||||
LoadGuard { tracking, urls }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for LoadGuard<'_> {
|
||||
fn drop(&mut self) {
|
||||
// Guaranteed cleanup even on panic
|
||||
for url in &self.urls {
|
||||
if let Some(counter) = self.tracking.get(url) {
|
||||
counter.fetch_sub(1, Ordering::Relaxed);
|
||||
}
|
||||
}
|
||||
}
|
||||
_prefill_health_checker: Option<HealthChecker>,
|
||||
_decode_health_checker: Option<HealthChecker>,
|
||||
}
|
||||
|
||||
impl PDRouter {
|
||||
@@ -73,9 +41,6 @@ impl PDRouter {
|
||||
url: String,
|
||||
bootstrap_port: Option<u16>,
|
||||
) -> Result<String, PDRouterError> {
|
||||
// Create EngineInfo for the new prefill server
|
||||
let engine_info = EngineInfo::new_prefill(url.clone(), bootstrap_port);
|
||||
|
||||
// Wait for the new server to be healthy
|
||||
crate::router::Router::wait_for_healthy_workers(
|
||||
&[url.clone()],
|
||||
@@ -84,6 +49,9 @@ impl PDRouter {
|
||||
)
|
||||
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
|
||||
|
||||
// Create Worker for the new prefill server
|
||||
let worker = WorkerFactory::create_prefill(url.clone(), bootstrap_port);
|
||||
|
||||
// Add to prefill workers list
|
||||
let mut workers = self
|
||||
.prefill_workers
|
||||
@@ -93,15 +61,11 @@ impl PDRouter {
|
||||
})?;
|
||||
|
||||
// Check if already exists
|
||||
if workers.iter().any(|w| w.url == url) {
|
||||
if workers.iter().any(|w| w.url() == &url) {
|
||||
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
|
||||
}
|
||||
|
||||
workers.push(engine_info);
|
||||
|
||||
// Initialize load tracking
|
||||
self.load_tracking
|
||||
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
|
||||
workers.push(worker);
|
||||
|
||||
// Add to cache tree if using cache-aware policy
|
||||
if let Some(ref tree) = self.prefill_tree {
|
||||
@@ -113,9 +77,6 @@ impl PDRouter {
|
||||
}
|
||||
|
||||
pub async fn add_decode_server(&self, url: String) -> Result<String, PDRouterError> {
|
||||
// Create EngineInfo for the new decode server
|
||||
let engine_info = EngineInfo::new_decode(url.clone());
|
||||
|
||||
// Wait for the new server to be healthy
|
||||
crate::router::Router::wait_for_healthy_workers(
|
||||
&[url.clone()],
|
||||
@@ -124,6 +85,9 @@ impl PDRouter {
|
||||
)
|
||||
.map_err(|_| PDRouterError::HealthCheckFailed { url: url.clone() })?;
|
||||
|
||||
// Create Worker for the new decode server
|
||||
let worker = WorkerFactory::create_decode(url.clone());
|
||||
|
||||
// Add to decode workers list
|
||||
let mut workers = self
|
||||
.decode_workers
|
||||
@@ -133,15 +97,14 @@ impl PDRouter {
|
||||
})?;
|
||||
|
||||
// Check if already exists
|
||||
if workers.iter().any(|w| w.url == url) {
|
||||
if workers.iter().any(|w| w.url() == &url) {
|
||||
return Err(PDRouterError::WorkerAlreadyExists { url: url.clone() });
|
||||
}
|
||||
|
||||
workers.push(engine_info);
|
||||
workers.push(worker);
|
||||
|
||||
// Initialize load tracking
|
||||
self.load_tracking
|
||||
.insert(url.clone(), Arc::new(AtomicUsize::new(0)));
|
||||
// Worker tracks its own load internally
|
||||
|
||||
info!("Added decode server: {}", url);
|
||||
Ok(format!("Successfully added decode server: {}", url))
|
||||
@@ -157,7 +120,7 @@ impl PDRouter {
|
||||
|
||||
// Find and remove the server
|
||||
let initial_len = workers.len();
|
||||
workers.retain(|w| w.url != url);
|
||||
workers.retain(|w| w.url() != url);
|
||||
|
||||
if workers.len() == initial_len {
|
||||
return Err(PDRouterError::WorkerNotFound {
|
||||
@@ -166,7 +129,7 @@ impl PDRouter {
|
||||
}
|
||||
|
||||
// Remove from load tracking
|
||||
self.load_tracking.remove(url);
|
||||
// Worker load tracking is internal
|
||||
|
||||
// Remove from cache tree if using cache-aware policy
|
||||
if let Some(ref tree) = self.prefill_tree {
|
||||
@@ -174,7 +137,7 @@ impl PDRouter {
|
||||
let mut tree_guard = tree.lock().unwrap();
|
||||
*tree_guard = Tree::new();
|
||||
for worker in workers.iter() {
|
||||
tree_guard.insert("", &worker.url);
|
||||
tree_guard.insert("", worker.url());
|
||||
}
|
||||
}
|
||||
|
||||
@@ -192,7 +155,7 @@ impl PDRouter {
|
||||
|
||||
// Find and remove the server
|
||||
let initial_len = workers.len();
|
||||
workers.retain(|w| w.url != url);
|
||||
workers.retain(|w| w.url() != url);
|
||||
|
||||
if workers.len() == initial_len {
|
||||
return Err(PDRouterError::WorkerNotFound {
|
||||
@@ -200,9 +163,6 @@ impl PDRouter {
|
||||
});
|
||||
}
|
||||
|
||||
// Remove from load tracking
|
||||
self.load_tracking.remove(url);
|
||||
|
||||
info!("Removed decode server: {}", url);
|
||||
Ok(format!("Successfully removed decode server: {}", url))
|
||||
}
|
||||
@@ -214,41 +174,32 @@ impl PDRouter {
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
) -> Result<Self, String> {
|
||||
// Convert URLs to EngineInfo
|
||||
let prefill_workers: Vec<EngineInfo> = prefill_urls
|
||||
// Convert URLs to Worker trait objects
|
||||
let prefill_workers: Vec<Box<dyn Worker>> = prefill_urls
|
||||
.into_iter()
|
||||
.map(|(url, port)| EngineInfo::new_prefill(url, port))
|
||||
.map(|(url, port)| WorkerFactory::create_prefill(url, port))
|
||||
.collect();
|
||||
|
||||
let decode_workers: Vec<EngineInfo> = decode_urls
|
||||
let decode_workers: Vec<Box<dyn Worker>> = decode_urls
|
||||
.into_iter()
|
||||
.map(EngineInfo::new_decode)
|
||||
.map(WorkerFactory::create_decode)
|
||||
.collect();
|
||||
|
||||
// Wait for PD workers to be healthy
|
||||
let all_urls: Vec<String> = prefill_workers
|
||||
.iter()
|
||||
.chain(decode_workers.iter())
|
||||
.map(|engine| engine.url.clone())
|
||||
.map(|worker| worker.url().to_string())
|
||||
.collect();
|
||||
crate::router::Router::wait_for_healthy_workers(&all_urls, timeout_secs, interval_secs)?;
|
||||
|
||||
// Initialize load tracking with atomic counters
|
||||
let load_tracking = Arc::new(dashmap::DashMap::new());
|
||||
for engine in &prefill_workers {
|
||||
load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0)));
|
||||
}
|
||||
for engine in &decode_workers {
|
||||
load_tracking.insert(engine.url.clone(), Arc::new(AtomicUsize::new(0)));
|
||||
}
|
||||
|
||||
// Initialize cache-aware components if needed
|
||||
let prefill_tree = match &selection_policy {
|
||||
PDSelectionPolicy::CacheAware { .. } => {
|
||||
let tree = Arc::new(Mutex::new(Tree::new()));
|
||||
// Initialize tree with prefill workers
|
||||
for engine in &prefill_workers {
|
||||
tree.lock().unwrap().insert("", &engine.url);
|
||||
for worker in &prefill_workers {
|
||||
tree.lock().unwrap().insert("", worker.url());
|
||||
}
|
||||
Some(tree)
|
||||
}
|
||||
@@ -283,17 +234,27 @@ impl PDRouter {
|
||||
None
|
||||
};
|
||||
|
||||
let prefill_workers = Arc::new(RwLock::new(prefill_workers));
|
||||
let decode_workers = Arc::new(RwLock::new(decode_workers));
|
||||
|
||||
// Start health checkers for both worker pools
|
||||
let prefill_health_checker =
|
||||
crate::core::start_health_checker(Arc::clone(&prefill_workers), interval_secs);
|
||||
let decode_health_checker =
|
||||
crate::core::start_health_checker(Arc::clone(&decode_workers), interval_secs);
|
||||
|
||||
Ok(PDRouter {
|
||||
prefill_workers: Arc::new(RwLock::new(prefill_workers)),
|
||||
decode_workers: Arc::new(RwLock::new(decode_workers)),
|
||||
prefill_workers,
|
||||
decode_workers,
|
||||
selection_policy,
|
||||
load_tracking,
|
||||
prefill_tree,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
worker_loads,
|
||||
load_monitor_handle,
|
||||
http_client,
|
||||
_prefill_health_checker: Some(prefill_health_checker),
|
||||
_decode_health_checker: Some(decode_health_checker),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -330,11 +291,13 @@ impl PDRouter {
|
||||
// Log routing decision
|
||||
info!(
|
||||
"PD routing: {} -> prefill={}, decode={}",
|
||||
route, prefill.url, decode.url
|
||||
route,
|
||||
prefill.url(),
|
||||
decode.url()
|
||||
);
|
||||
|
||||
// Add bootstrap info using the trait method
|
||||
if let Err(e) = typed_req.add_bootstrap_info(&prefill) {
|
||||
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||
error!("Failed to add bootstrap info: {}", e);
|
||||
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
|
||||
return HttpResponse::InternalServerError()
|
||||
@@ -356,8 +319,8 @@ impl PDRouter {
|
||||
req,
|
||||
json_with_bootstrap,
|
||||
route,
|
||||
&prefill,
|
||||
&decode,
|
||||
prefill.as_ref(),
|
||||
decode.as_ref(),
|
||||
is_stream,
|
||||
return_logprob,
|
||||
start,
|
||||
@@ -397,11 +360,13 @@ impl PDRouter {
|
||||
// Log routing decision
|
||||
info!(
|
||||
"PD routing: {} -> prefill={}, decode={}",
|
||||
route, prefill.url, decode.url
|
||||
route,
|
||||
prefill.url(),
|
||||
decode.url()
|
||||
);
|
||||
|
||||
// Add bootstrap info using the trait method
|
||||
if let Err(e) = typed_req.add_bootstrap_info(&prefill) {
|
||||
if let Err(e) = typed_req.add_bootstrap_info(prefill.as_ref()) {
|
||||
error!("Failed to add bootstrap info: {}", e);
|
||||
counter!("sgl_router_pd_errors_total", "error" => "bootstrap_injection").increment(1);
|
||||
return HttpResponse::InternalServerError()
|
||||
@@ -423,8 +388,8 @@ impl PDRouter {
|
||||
req,
|
||||
json_with_bootstrap,
|
||||
route,
|
||||
&prefill,
|
||||
&decode,
|
||||
prefill.as_ref(),
|
||||
decode.as_ref(),
|
||||
is_stream,
|
||||
return_logprob,
|
||||
start,
|
||||
@@ -440,22 +405,23 @@ impl PDRouter {
|
||||
req: &HttpRequest,
|
||||
json_request: serde_json::Value,
|
||||
route: &str,
|
||||
prefill: &EngineInfo,
|
||||
decode: &EngineInfo,
|
||||
prefill: &dyn Worker,
|
||||
decode: &dyn Worker,
|
||||
is_stream: bool,
|
||||
return_logprob: bool,
|
||||
start_time: Instant,
|
||||
) -> HttpResponse {
|
||||
// Update load tracking for both workers
|
||||
let _guard = LoadGuard::new(
|
||||
&self.load_tracking,
|
||||
vec![prefill.url.clone(), decode.url.clone()],
|
||||
);
|
||||
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
|
||||
|
||||
// Build requests using .json() method
|
||||
let mut prefill_request = client.post(prefill.api_path(route)).json(&json_request);
|
||||
let mut prefill_request = client
|
||||
.post(api_path(prefill.url(), route))
|
||||
.json(&json_request);
|
||||
|
||||
let mut decode_request = client.post(decode.api_path(route)).json(&json_request);
|
||||
let mut decode_request = client
|
||||
.post(api_path(decode.url(), route))
|
||||
.json(&json_request);
|
||||
|
||||
// Copy headers from original request
|
||||
for (name, value) in crate::router::copy_request_headers(req) {
|
||||
@@ -474,9 +440,9 @@ impl PDRouter {
|
||||
histogram!("sgl_router_pd_request_duration_seconds", "route" => route.to_string())
|
||||
.record(duration.as_secs_f64());
|
||||
counter!("sgl_router_pd_requests_total", "route" => route.to_string()).increment(1);
|
||||
counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url.to_string())
|
||||
counter!("sgl_router_pd_prefill_requests_total", "worker" => prefill.url().to_string())
|
||||
.increment(1);
|
||||
counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url.to_string())
|
||||
counter!("sgl_router_pd_decode_requests_total", "worker" => decode.url().to_string())
|
||||
.increment(1);
|
||||
|
||||
// Process decode response
|
||||
@@ -486,10 +452,11 @@ impl PDRouter {
|
||||
.unwrap_or(actix_web::http::StatusCode::INTERNAL_SERVER_ERROR);
|
||||
|
||||
if !status.is_success() {
|
||||
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string()).increment(1);
|
||||
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string()).increment(1);
|
||||
error!(
|
||||
"Decode server {} returned error status: {}",
|
||||
decode.url, status
|
||||
decode.url(),
|
||||
status
|
||||
);
|
||||
|
||||
// Return the error response from decode server
|
||||
@@ -508,9 +475,10 @@ impl PDRouter {
|
||||
if let Err(e) = &prefill_result {
|
||||
error!(
|
||||
"Prefill server {} failed (non-critical): {}",
|
||||
prefill.url, e
|
||||
prefill.url(),
|
||||
e
|
||||
);
|
||||
counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url.to_string()).increment(1);
|
||||
counter!("sgl_router_pd_prefill_errors_total", "worker" => prefill.url().to_string()).increment(1);
|
||||
}
|
||||
|
||||
if is_stream {
|
||||
@@ -559,7 +527,7 @@ impl PDRouter {
|
||||
HttpResponse::build(status)
|
||||
.insert_header((CONTENT_TYPE, HeaderValue::from_static("text/event-stream")))
|
||||
.streaming({
|
||||
let decode_url = decode.url.clone();
|
||||
let decode_url = decode.url().to_string();
|
||||
res.bytes_stream().map_err(move |e| {
|
||||
error!("Stream error from decode server {}: {}", decode_url, e);
|
||||
counter!("sgl_router_pd_stream_errors_total", "worker" => decode_url.to_string()).increment(1);
|
||||
@@ -587,7 +555,7 @@ impl PDRouter {
|
||||
}
|
||||
Err(e) => {
|
||||
error!("Decode request failed: {}", e);
|
||||
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url.to_string())
|
||||
counter!("sgl_router_pd_decode_errors_total", "worker" => decode.url().to_string())
|
||||
.increment(1);
|
||||
HttpResponse::BadGateway().body(format!("Decode server error: {}", e))
|
||||
}
|
||||
@@ -652,7 +620,7 @@ impl PDRouter {
|
||||
async fn select_pd_pair(
|
||||
&self,
|
||||
_client: &reqwest::Client,
|
||||
) -> Result<(EngineInfo, EngineInfo), String> {
|
||||
) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
||||
// Check we have workers
|
||||
if self
|
||||
.prefill_workers
|
||||
@@ -681,17 +649,17 @@ impl PDRouter {
|
||||
}
|
||||
}
|
||||
|
||||
fn select_random(&self) -> Result<(EngineInfo, EngineInfo), String> {
|
||||
fn select_random(&self) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
||||
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
|
||||
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
|
||||
|
||||
let prefill = prefill_list[rand::random::<usize>() % prefill_list.len()].clone();
|
||||
let decode = decode_list[rand::random::<usize>() % decode_list.len()].clone();
|
||||
let prefill = prefill_list[rand::random::<usize>() % prefill_list.len()].clone_worker();
|
||||
let decode = decode_list[rand::random::<usize>() % decode_list.len()].clone_worker();
|
||||
|
||||
Ok((prefill, decode))
|
||||
}
|
||||
|
||||
async fn select_power_of_two(&self) -> Result<(EngineInfo, EngineInfo), String> {
|
||||
async fn select_power_of_two(&self) -> Result<(Box<dyn Worker>, Box<dyn Worker>), String> {
|
||||
let prefill_list = self.prefill_workers.read().map_err(|_| "Lock error")?;
|
||||
let decode_list = self.decode_workers.read().map_err(|_| "Lock error")?;
|
||||
|
||||
@@ -700,33 +668,45 @@ impl PDRouter {
|
||||
|
||||
let loads = self.worker_loads.borrow();
|
||||
|
||||
let p1_load = loads.get(&prefill_list[p1_idx].url).copied().unwrap_or(0);
|
||||
let p2_load = loads.get(&prefill_list[p2_idx].url).copied().unwrap_or(0);
|
||||
let d1_load = loads.get(&decode_list[d1_idx].url).copied().unwrap_or(0);
|
||||
let d2_load = loads.get(&decode_list[d2_idx].url).copied().unwrap_or(0);
|
||||
let p1_load = loads
|
||||
.get(prefill_list[p1_idx].url())
|
||||
.copied()
|
||||
.unwrap_or(isize::MAX);
|
||||
let p2_load = loads
|
||||
.get(prefill_list[p2_idx].url())
|
||||
.copied()
|
||||
.unwrap_or(isize::MAX);
|
||||
let d1_load = loads
|
||||
.get(decode_list[d1_idx].url())
|
||||
.copied()
|
||||
.unwrap_or(isize::MAX);
|
||||
let d2_load = loads
|
||||
.get(decode_list[d2_idx].url())
|
||||
.copied()
|
||||
.unwrap_or(isize::MAX);
|
||||
|
||||
info!(
|
||||
"Power-of-two selection - Prefill: {}={} vs {}={} | Decode: {}={} vs {}={}",
|
||||
prefill_list[p1_idx].url,
|
||||
prefill_list[p1_idx].url(),
|
||||
p1_load,
|
||||
prefill_list[p2_idx].url,
|
||||
prefill_list[p2_idx].url(),
|
||||
p2_load,
|
||||
decode_list[d1_idx].url,
|
||||
decode_list[d1_idx].url(),
|
||||
d1_load,
|
||||
decode_list[d2_idx].url,
|
||||
decode_list[d2_idx].url(),
|
||||
d2_load
|
||||
);
|
||||
|
||||
let selected_prefill = if p1_load <= p2_load {
|
||||
prefill_list[p1_idx].clone()
|
||||
prefill_list[p1_idx].clone_worker()
|
||||
} else {
|
||||
prefill_list[p2_idx].clone()
|
||||
prefill_list[p2_idx].clone_worker()
|
||||
};
|
||||
|
||||
let selected_decode = if d1_load <= d2_load {
|
||||
decode_list[d1_idx].clone()
|
||||
decode_list[d1_idx].clone_worker()
|
||||
} else {
|
||||
decode_list[d2_idx].clone()
|
||||
decode_list[d2_idx].clone_worker()
|
||||
};
|
||||
|
||||
Ok((selected_prefill, selected_decode))
|
||||
@@ -868,11 +848,11 @@ impl PDRouter {
|
||||
let mut worker_infos = Vec::new();
|
||||
|
||||
for worker in self.prefill_workers.read().unwrap().iter() {
|
||||
worker_infos.push((worker.url.clone(), "prefill"));
|
||||
worker_infos.push((worker.url().to_string(), "prefill"));
|
||||
}
|
||||
|
||||
for worker in self.decode_workers.read().unwrap().iter() {
|
||||
worker_infos.push((worker.url.clone(), "decode"));
|
||||
worker_infos.push((worker.url().to_string(), "decode"));
|
||||
}
|
||||
|
||||
// Create tasks with URL tracking
|
||||
@@ -922,7 +902,7 @@ impl PDRouter {
|
||||
pub async fn get_server_info(&self, client: &reqwest::Client) -> HttpResponse {
|
||||
// Get info from the first decode server to match sglang's server info format
|
||||
let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
|
||||
workers.first().map(|w| w.url.clone())
|
||||
workers.first().map(|w| w.url().to_string())
|
||||
} else {
|
||||
return HttpResponse::InternalServerError().body("Failed to access decode workers");
|
||||
};
|
||||
@@ -967,7 +947,7 @@ impl PDRouter {
|
||||
pub async fn get_models(&self, client: &reqwest::Client, req: &HttpRequest) -> HttpResponse {
|
||||
// Get first prefill worker URL to avoid holding lock across await
|
||||
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
|
||||
workers.first().map(|w| w.url.clone())
|
||||
workers.first().map(|w| w.url().to_string())
|
||||
} else {
|
||||
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
|
||||
};
|
||||
@@ -1005,14 +985,14 @@ impl PDRouter {
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|w| w.url.clone())
|
||||
.map(|w| w.url().to_string())
|
||||
.collect();
|
||||
let d_urls: Vec<_> = self
|
||||
.decode_workers
|
||||
.read()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|w| w.url.clone())
|
||||
.map(|w| w.url().to_string())
|
||||
.collect();
|
||||
|
||||
let mut prefill_loads = Vec::new();
|
||||
@@ -1048,7 +1028,7 @@ impl PDRouter {
|
||||
// Get model info from the first prefill server (matches original Rust PDLB behavior)
|
||||
// Get first prefill worker URL to avoid holding lock across await
|
||||
let first_worker_url = if let Ok(workers) = self.prefill_workers.read() {
|
||||
workers.first().map(|w| w.url.clone())
|
||||
workers.first().map(|w| w.url().to_string())
|
||||
} else {
|
||||
return HttpResponse::InternalServerError().body("Failed to access prefill workers");
|
||||
};
|
||||
@@ -1084,13 +1064,13 @@ impl PDRouter {
|
||||
|
||||
// Flush cache on all prefill servers
|
||||
for worker in self.prefill_workers.read().unwrap().iter() {
|
||||
let url = format!("{}/flush_cache", worker.url);
|
||||
let url = format!("{}/flush_cache", worker.url());
|
||||
tasks.push(client.post(&url).send());
|
||||
}
|
||||
|
||||
// Flush cache on all decode servers
|
||||
for worker in self.decode_workers.read().unwrap().iter() {
|
||||
let url = format!("{}/flush_cache", worker.url);
|
||||
let url = format!("{}/flush_cache", worker.url());
|
||||
tasks.push(client.post(&url).send());
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user