[feature] [sgl-router] Add a dp-aware routing strategy (#6869)
This commit is contained in:
@@ -21,6 +21,10 @@ pub struct RouterConfig {
|
||||
pub worker_startup_timeout_secs: u64,
|
||||
/// Worker health check interval in seconds
|
||||
pub worker_startup_check_interval_secs: u64,
|
||||
/// Enable data parallelism aware schedule
|
||||
pub dp_aware: bool,
|
||||
/// The api key used for the authorization with the worker
|
||||
pub api_key: Option<String>,
|
||||
/// Service discovery configuration (optional)
|
||||
pub discovery: Option<DiscoveryConfig>,
|
||||
/// Metrics configuration (optional)
|
||||
@@ -205,6 +209,8 @@ impl Default for RouterConfig {
|
||||
request_timeout_secs: 600,
|
||||
worker_startup_timeout_secs: 300,
|
||||
worker_startup_check_interval_secs: 10,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
@@ -311,6 +317,8 @@ mod tests {
|
||||
request_timeout_secs: 30,
|
||||
worker_startup_timeout_secs: 60,
|
||||
worker_startup_check_interval_secs: 5,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: Some(DiscoveryConfig::default()),
|
||||
metrics: Some(MetricsConfig::default()),
|
||||
log_dir: Some("/var/log".to_string()),
|
||||
@@ -727,6 +735,8 @@ mod tests {
|
||||
request_timeout_secs: 120,
|
||||
worker_startup_timeout_secs: 60,
|
||||
worker_startup_check_interval_secs: 5,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: Some(DiscoveryConfig {
|
||||
enabled: true,
|
||||
namespace: Some("sglang".to_string()),
|
||||
@@ -774,6 +784,8 @@ mod tests {
|
||||
request_timeout_secs: 300,
|
||||
worker_startup_timeout_secs: 180,
|
||||
worker_startup_check_interval_secs: 15,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: Some(DiscoveryConfig {
|
||||
enabled: true,
|
||||
namespace: None,
|
||||
@@ -812,6 +824,8 @@ mod tests {
|
||||
request_timeout_secs: 900,
|
||||
worker_startup_timeout_secs: 600,
|
||||
worker_startup_check_interval_secs: 20,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: Some(DiscoveryConfig {
|
||||
enabled: true,
|
||||
namespace: Some("production".to_string()),
|
||||
|
||||
@@ -313,6 +313,14 @@ impl ConfigValidator {
|
||||
}
|
||||
}
|
||||
|
||||
// Service discovery is conflict with dp_aware routing for now
|
||||
// since it's not fully supported yet
|
||||
if has_service_discovery && config.dp_aware {
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason: "DP-aware routing is not compatible with service discovery".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ pub enum WorkerError {
|
||||
NetworkError { url: String, error: String },
|
||||
/// Worker is at capacity
|
||||
WorkerAtCapacity { url: String },
|
||||
/// Invalid URL format
|
||||
InvalidUrl { url: String },
|
||||
}
|
||||
|
||||
impl fmt::Display for WorkerError {
|
||||
@@ -37,6 +39,9 @@ impl fmt::Display for WorkerError {
|
||||
WorkerError::WorkerAtCapacity { url } => {
|
||||
write!(f, "Worker at capacity: {}", url)
|
||||
}
|
||||
WorkerError::InvalidUrl { url } => {
|
||||
write!(f, "Invalid URL format: {}", url)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,6 +162,27 @@ impl BasicWorker {
|
||||
self.metadata.health_config = config;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn normalised_url(&self) -> WorkerResult<&str> {
|
||||
if self.url().contains("@") {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let parts: Vec<&str> = self.url().split('@').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(WorkerError::InvalidUrl {
|
||||
url: self.url().to_string(),
|
||||
});
|
||||
}
|
||||
// Ensure the second part (the dp_rank) can be parsed as an integer
|
||||
match parts[1].parse::<usize>() {
|
||||
Ok(_) => Ok(parts[0]),
|
||||
Err(_) => Err(WorkerError::InvalidUrl {
|
||||
url: self.url().to_string(),
|
||||
}),
|
||||
}
|
||||
} else {
|
||||
Ok(self.url())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -186,7 +207,8 @@ impl Worker for BasicWorker {
|
||||
use std::time::Duration;
|
||||
|
||||
// Perform actual HTTP health check
|
||||
let health_url = format!("{}{}", self.url(), self.metadata.health_config.endpoint);
|
||||
let url = self.normalised_url()?;
|
||||
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
|
||||
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
|
||||
|
||||
// Use the shared client with a custom timeout for this request
|
||||
@@ -203,7 +225,7 @@ impl Worker for BasicWorker {
|
||||
} else {
|
||||
self.set_healthy(false);
|
||||
Err(WorkerError::HealthCheckFailed {
|
||||
url: self.url().to_string(),
|
||||
url: url.to_string(),
|
||||
reason: format!("Health check returned status: {}", response.status()),
|
||||
})
|
||||
}
|
||||
@@ -211,7 +233,7 @@ impl Worker for BasicWorker {
|
||||
Err(e) => {
|
||||
self.set_healthy(false);
|
||||
Err(WorkerError::HealthCheckFailed {
|
||||
url: self.url().to_string(),
|
||||
url: url.to_string(),
|
||||
reason: format!("Health check request failed: {}", e),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -37,6 +37,8 @@ struct Router {
|
||||
eviction_interval_secs: u64,
|
||||
max_tree_size: usize,
|
||||
max_payload_size: usize,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
log_dir: Option<String>,
|
||||
log_level: Option<String>,
|
||||
service_discovery: bool,
|
||||
@@ -136,6 +138,8 @@ impl Router {
|
||||
request_timeout_secs: self.request_timeout_secs,
|
||||
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
|
||||
worker_startup_check_interval_secs: self.worker_startup_check_interval,
|
||||
dp_aware: self.dp_aware,
|
||||
api_key: self.api_key.clone(),
|
||||
discovery,
|
||||
metrics,
|
||||
log_dir: self.log_dir.clone(),
|
||||
@@ -161,6 +165,8 @@ impl Router {
|
||||
eviction_interval_secs = 60,
|
||||
max_tree_size = 2usize.pow(24),
|
||||
max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
|
||||
dp_aware = false,
|
||||
api_key = None,
|
||||
log_dir = None,
|
||||
log_level = None,
|
||||
service_discovery = false,
|
||||
@@ -193,6 +199,8 @@ impl Router {
|
||||
eviction_interval_secs: u64,
|
||||
max_tree_size: usize,
|
||||
max_payload_size: usize,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
log_dir: Option<String>,
|
||||
log_level: Option<String>,
|
||||
service_discovery: bool,
|
||||
@@ -225,6 +233,8 @@ impl Router {
|
||||
eviction_interval_secs,
|
||||
max_tree_size,
|
||||
max_payload_size,
|
||||
dp_aware,
|
||||
api_key,
|
||||
log_dir,
|
||||
log_level,
|
||||
service_discovery,
|
||||
|
||||
@@ -45,6 +45,8 @@ impl RouterFactory {
|
||||
policy,
|
||||
router_config.worker_startup_timeout_secs,
|
||||
router_config.worker_startup_check_interval_secs,
|
||||
router_config.dp_aware,
|
||||
router_config.api_key.clone(),
|
||||
)?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
|
||||
@@ -30,6 +30,8 @@ pub struct Router {
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||
_health_checker: Option<HealthChecker>,
|
||||
@@ -42,6 +44,8 @@ impl Router {
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
) -> Result<Self, String> {
|
||||
// Update active workers gauge
|
||||
RouterMetrics::set_active_workers(worker_urls.len());
|
||||
@@ -51,6 +55,14 @@ impl Router {
|
||||
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
|
||||
}
|
||||
|
||||
let worker_urls = if dp_aware {
|
||||
// worker address now in the format of "http://host:port@dp_rank"
|
||||
Self::get_dp_aware_workers(&worker_urls, &api_key)
|
||||
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?
|
||||
} else {
|
||||
worker_urls
|
||||
};
|
||||
|
||||
// Create Worker trait objects from URLs
|
||||
let workers: Vec<Box<dyn Worker>> = worker_urls
|
||||
.iter()
|
||||
@@ -89,6 +101,8 @@ impl Router {
|
||||
policy,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
dp_aware,
|
||||
api_key,
|
||||
_worker_loads: worker_loads,
|
||||
_load_monitor_handle: load_monitor_handle,
|
||||
_health_checker: Some(health_checker),
|
||||
@@ -160,6 +174,62 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_worker_dp_size(worker_url: &str, api_key: &Option<String>) -> Result<usize, String> {
|
||||
let sync_client = reqwest::blocking::Client::new();
|
||||
let mut req_builder = sync_client.get(&format!("{}/get_server_info", worker_url));
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
|
||||
match req_builder.send() {
|
||||
Ok(res) => {
|
||||
if res.status().is_success() {
|
||||
let server_info = res
|
||||
.text()
|
||||
.map_err(|e| format!("failed to read text from response: {}", e))?;
|
||||
|
||||
let server_info: serde_json::Value = serde_json::from_str(&server_info)
|
||||
.map_err(|e| format!("failed to decode JSON: {}", e))?;
|
||||
|
||||
let dp_size = server_info
|
||||
.get("dp_size")
|
||||
.and_then(|v| v.as_u64())
|
||||
.ok_or_else(|| String::from("dp_size not found or not an u64"))?;
|
||||
|
||||
Ok(if dp_size > usize::MAX as u64 {
|
||||
return Err(format!("dp_size is too large: {}", dp_size));
|
||||
} else {
|
||||
dp_size as usize
|
||||
})
|
||||
} else {
|
||||
Err(format!("unexpected status code: {}", res.status()))
|
||||
}
|
||||
}
|
||||
Err(e) => Err(format!("error response: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
// Given a list of workers, return a list of workers with dp_rank as suffix
|
||||
fn get_dp_aware_workers(
|
||||
worker_urls: &[String],
|
||||
api_key: &Option<String>,
|
||||
) -> Result<Vec<String>, String> {
|
||||
let mut dp_aware_workers: Vec<String> = Vec::new();
|
||||
|
||||
for url in worker_urls {
|
||||
match Self::get_worker_dp_size(url, api_key) {
|
||||
Ok(dp_size) => {
|
||||
for i in 0..dp_size {
|
||||
dp_aware_workers.push(format!("{}@{}", url, i));
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(format!("Failed to get DP size for {}: {}", url, e)),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(dp_aware_workers)
|
||||
}
|
||||
|
||||
fn select_first_worker(&self) -> Result<String, String> {
|
||||
let workers_guard = self.workers.read().unwrap();
|
||||
if workers_guard.is_empty() {
|
||||
@@ -178,6 +248,21 @@ impl Router {
|
||||
) -> HttpResponse {
|
||||
let request_id = get_request_id(req);
|
||||
let start = Instant::now();
|
||||
|
||||
let worker_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
error!("Failed to extract dp_rank: {}", e);
|
||||
return HttpResponse::InternalServerError().finish();
|
||||
}
|
||||
};
|
||||
worker_url_prefix
|
||||
} else {
|
||||
worker_url
|
||||
};
|
||||
|
||||
let mut request_builder = client.get(format!("{}{}", worker_url, route));
|
||||
|
||||
// Copy all headers from original request except for /health because it does not need authorization
|
||||
@@ -292,7 +377,7 @@ impl Router {
|
||||
worker_url = %worker_url,
|
||||
"Removing failed worker"
|
||||
);
|
||||
self.remove_worker(&worker_url);
|
||||
self.remove_failed_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -392,7 +477,7 @@ impl Router {
|
||||
request_id = %request_id,
|
||||
"Removing failed worker after typed request failures worker_url={}", worker_url
|
||||
);
|
||||
self.remove_worker(&worker_url);
|
||||
self.remove_failed_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -415,6 +500,23 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO (rui): Better accommodate to the Worker abstraction
|
||||
fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
|
||||
let parts: Vec<&str> = worker_url.split('@').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(format!("invalid worker_url format: {}", worker_url));
|
||||
}
|
||||
|
||||
// Parse the second part (dp_rank) into an integer
|
||||
match parts[1].parse::<usize>() {
|
||||
Ok(dp_rank) => Ok((parts[0], dp_rank)),
|
||||
Err(_) => Err(format!(
|
||||
"failed to parse dp_rank from worker_url: {}",
|
||||
worker_url
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// Send typed request directly without conversion
|
||||
async fn send_typed_request<T: serde::Serialize>(
|
||||
&self,
|
||||
@@ -429,9 +531,47 @@ impl Router {
|
||||
let request_id = get_request_id(req);
|
||||
let start = Instant::now();
|
||||
|
||||
let mut request_builder = client
|
||||
.post(format!("{}{}", worker_url, route))
|
||||
.json(typed_req); // Use json() directly with typed request
|
||||
let mut request_builder = if self.dp_aware {
|
||||
let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
error!("Failed to extract dp_rank: {}", e);
|
||||
return HttpResponse::InternalServerError().finish();
|
||||
}
|
||||
};
|
||||
|
||||
// Parse the request body
|
||||
let mut json_val = match serde_json::to_value(typed_req) {
|
||||
Ok(j) => j,
|
||||
Err(e) => {
|
||||
return HttpResponse::BadRequest()
|
||||
.body(format!("Convert into serde_json::Value failed: {}", e));
|
||||
}
|
||||
};
|
||||
|
||||
// Insert the data_parallel_rank field
|
||||
if let Some(map) = json_val.as_object_mut() {
|
||||
map.insert(
|
||||
String::from("data_parallel_rank"),
|
||||
serde_json::json!(dp_rank),
|
||||
);
|
||||
debug!(
|
||||
"Modified request body: {}",
|
||||
serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
|
||||
);
|
||||
} else {
|
||||
return HttpResponse::BadRequest()
|
||||
.body("Failed to insert the data_parallel_rank field into the request body");
|
||||
}
|
||||
|
||||
client
|
||||
.post(format!("{}{}", worker_url_prefix, route))
|
||||
.json(&json_val)
|
||||
} else {
|
||||
client
|
||||
.post(format!("{}{}", worker_url, route))
|
||||
.json(typed_req) // Use json() directly with typed request
|
||||
};
|
||||
|
||||
// Copy all headers from original request
|
||||
for (name, value) in copy_request_headers(req) {
|
||||
@@ -560,12 +700,35 @@ impl Router {
|
||||
Ok(res) => {
|
||||
if res.status().is_success() {
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
if workers_guard.iter().any(|w| w.url() == worker_url) {
|
||||
return Err(format!("Worker {} already exists", worker_url));
|
||||
if self.dp_aware {
|
||||
// Need to contact the worker to extract the dp_size,
|
||||
// and add them as multiple workers
|
||||
let url_vec = vec![String::from(worker_url)];
|
||||
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, &self.api_key)
|
||||
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
|
||||
let mut worker_added: bool = false;
|
||||
for dp_url in &dp_url_vec {
|
||||
if workers_guard.iter().any(|w| w.url() == dp_url) {
|
||||
warn!("Worker {} already exists", dp_url);
|
||||
continue;
|
||||
}
|
||||
info!("Added worker: {}", dp_url);
|
||||
let new_worker = WorkerFactory::create_regular(dp_url.to_string());
|
||||
workers_guard.push(new_worker);
|
||||
worker_added = true;
|
||||
}
|
||||
if !worker_added {
|
||||
return Err(format!("No worker added for {}", worker_url));
|
||||
}
|
||||
} else {
|
||||
if workers_guard.iter().any(|w| w.url() == worker_url) {
|
||||
return Err(format!("Worker {} already exists", worker_url));
|
||||
}
|
||||
info!("Added worker: {}", worker_url);
|
||||
let new_worker = WorkerFactory::create_regular(worker_url.to_string());
|
||||
workers_guard.push(new_worker);
|
||||
}
|
||||
info!("Added worker: {}", worker_url);
|
||||
let new_worker = WorkerFactory::create_regular(worker_url.to_string());
|
||||
workers_guard.push(new_worker);
|
||||
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
|
||||
// If cache aware policy, initialize the worker in the tree
|
||||
@@ -612,11 +775,81 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove all the worker(s) that match the URL prefix
|
||||
pub fn remove_worker(&self, worker_url: &str) {
|
||||
if self.dp_aware {
|
||||
// remove dp-aware workers in a prefix-matching fashion
|
||||
// without contacting the remote worker
|
||||
let mut candidate_workers: Vec<String> = Vec::new();
|
||||
let mut removed_workers: Vec<String> = Vec::new();
|
||||
let worker_url_prefix = format!("{}@", worker_url);
|
||||
|
||||
{
|
||||
// find the candidate workers to be removed
|
||||
let workers_guard = self.workers.read().unwrap();
|
||||
for w in workers_guard.iter() {
|
||||
if w.url().starts_with(&worker_url_prefix) {
|
||||
candidate_workers.push(w.url().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// do the removing on the worker_urls
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
for dp_url in candidate_workers.iter() {
|
||||
if let Some(index) = workers_guard.iter().position(|w| w.url() == dp_url) {
|
||||
workers_guard.remove(index);
|
||||
info!("Removed worker: {}", dp_url);
|
||||
removed_workers.push(dp_url.to_string());
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", dp_url);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
}
|
||||
|
||||
// If cache aware policy, remove the workers from the tree
|
||||
if let Some(cache_aware) = self
|
||||
.policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
for dp_url in removed_workers.iter() {
|
||||
cache_aware.remove_worker(dp_url);
|
||||
info!("Removed worker from tree: {}", dp_url);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
|
||||
workers_guard.remove(index);
|
||||
info!("Removed worker: {}", worker_url);
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", worker_url);
|
||||
return;
|
||||
}
|
||||
|
||||
// If cache aware policy, remove the workers from the tree
|
||||
if let Some(cache_aware) = self
|
||||
.policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker(worker_url);
|
||||
info!("Removed worker from tree: {}", worker_url);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a specific failed worker; for internal usage
|
||||
fn remove_failed_worker(&self, worker_url: &str) {
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
|
||||
workers_guard.remove(index);
|
||||
info!("Removed worker: {}", worker_url);
|
||||
info!("Removed failed worker: {}", worker_url);
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", worker_url);
|
||||
@@ -634,6 +867,20 @@ impl Router {
|
||||
}
|
||||
|
||||
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||
let worker_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
error!("Failed to extract dp_rank: {}", e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
worker_url_prefix
|
||||
} else {
|
||||
worker_url
|
||||
};
|
||||
|
||||
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
@@ -710,6 +957,20 @@ impl Router {
|
||||
|
||||
// Static version of get_worker_load for use in monitoring task
|
||||
async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||
let worker_url = if worker_url.contains("@") {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
debug!("Failed to extract dp_rank: {}", e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
worker_url_prefix
|
||||
} else {
|
||||
worker_url
|
||||
};
|
||||
|
||||
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
@@ -862,6 +1123,19 @@ impl RouterTrait for Router {
|
||||
// Send requests to all workers concurrently without headers
|
||||
let mut tasks = Vec::new();
|
||||
for worker_url in &worker_urls {
|
||||
let worker_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
error!("Failed to extract dp_rank: {}", e);
|
||||
return HttpResponse::InternalServerError().finish();
|
||||
}
|
||||
};
|
||||
worker_url_prefix
|
||||
} else {
|
||||
worker_url
|
||||
};
|
||||
let request_builder = client.post(format!("{}/flush_cache", worker_url));
|
||||
tasks.push(request_builder.send());
|
||||
}
|
||||
@@ -948,6 +1222,8 @@ mod tests {
|
||||
policy: Arc::new(RandomPolicy::new()),
|
||||
timeout_secs: 5,
|
||||
interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
_worker_loads: Arc::new(rx),
|
||||
_load_monitor_handle: None,
|
||||
_health_checker: None,
|
||||
|
||||
@@ -581,7 +581,7 @@ mod tests {
|
||||
use crate::routers::router::Router;
|
||||
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
||||
let router = Router::new(vec![], policy, 5, 1).unwrap();
|
||||
let router = Router::new(vec![], policy, 5, 1, false, None).unwrap();
|
||||
Arc::new(router) as Arc<dyn RouterTrait>
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user