[feature] [sgl-router] Add a dp-aware routing strategy (#6869)

This commit is contained in:
Rui Chen
2025-07-30 20:58:48 +08:00
committed by GitHub
parent 55ecdc0a8e
commit a730ce8162
19 changed files with 726 additions and 16 deletions

View File

@@ -21,6 +21,10 @@ pub struct RouterConfig {
pub worker_startup_timeout_secs: u64,
/// Worker health check interval in seconds
pub worker_startup_check_interval_secs: u64,
/// Enable data parallelism aware schedule
pub dp_aware: bool,
/// The api key used for the authorization with the worker
pub api_key: Option<String>,
/// Service discovery configuration (optional)
pub discovery: Option<DiscoveryConfig>,
/// Metrics configuration (optional)
@@ -205,6 +209,8 @@ impl Default for RouterConfig {
request_timeout_secs: 600,
worker_startup_timeout_secs: 300,
worker_startup_check_interval_secs: 10,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
@@ -311,6 +317,8 @@ mod tests {
request_timeout_secs: 30,
worker_startup_timeout_secs: 60,
worker_startup_check_interval_secs: 5,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig::default()),
metrics: Some(MetricsConfig::default()),
log_dir: Some("/var/log".to_string()),
@@ -727,6 +735,8 @@ mod tests {
request_timeout_secs: 120,
worker_startup_timeout_secs: 60,
worker_startup_check_interval_secs: 5,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: Some("sglang".to_string()),
@@ -774,6 +784,8 @@ mod tests {
request_timeout_secs: 300,
worker_startup_timeout_secs: 180,
worker_startup_check_interval_secs: 15,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: None,
@@ -812,6 +824,8 @@ mod tests {
request_timeout_secs: 900,
worker_startup_timeout_secs: 600,
worker_startup_check_interval_secs: 20,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: Some("production".to_string()),

View File

@@ -313,6 +313,14 @@ impl ConfigValidator {
}
}
// Service discovery is conflict with dp_aware routing for now
// since it's not fully supported yet
if has_service_discovery && config.dp_aware {
return Err(ConfigError::IncompatibleConfig {
reason: "DP-aware routing is not compatible with service discovery".to_string(),
});
}
Ok(())
}

View File

@@ -17,6 +17,8 @@ pub enum WorkerError {
NetworkError { url: String, error: String },
/// Worker is at capacity
WorkerAtCapacity { url: String },
/// Invalid URL format
InvalidUrl { url: String },
}
impl fmt::Display for WorkerError {
@@ -37,6 +39,9 @@ impl fmt::Display for WorkerError {
WorkerError::WorkerAtCapacity { url } => {
write!(f, "Worker at capacity: {}", url)
}
WorkerError::InvalidUrl { url } => {
write!(f, "Invalid URL format: {}", url)
}
}
}
}

View File

@@ -162,6 +162,27 @@ impl BasicWorker {
self.metadata.health_config = config;
self
}
pub fn normalised_url(&self) -> WorkerResult<&str> {
if self.url().contains("@") {
// Need to extract the URL from "http://host:port@dp_rank"
let parts: Vec<&str> = self.url().split('@').collect();
if parts.len() != 2 {
return Err(WorkerError::InvalidUrl {
url: self.url().to_string(),
});
}
// Ensure the second part (the dp_rank) can be parsed as an integer
match parts[1].parse::<usize>() {
Ok(_) => Ok(parts[0]),
Err(_) => Err(WorkerError::InvalidUrl {
url: self.url().to_string(),
}),
}
} else {
Ok(self.url())
}
}
}
#[async_trait]
@@ -186,7 +207,8 @@ impl Worker for BasicWorker {
use std::time::Duration;
// Perform actual HTTP health check
let health_url = format!("{}{}", self.url(), self.metadata.health_config.endpoint);
let url = self.normalised_url()?;
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
// Use the shared client with a custom timeout for this request
@@ -203,7 +225,7 @@ impl Worker for BasicWorker {
} else {
self.set_healthy(false);
Err(WorkerError::HealthCheckFailed {
url: self.url().to_string(),
url: url.to_string(),
reason: format!("Health check returned status: {}", response.status()),
})
}
@@ -211,7 +233,7 @@ impl Worker for BasicWorker {
Err(e) => {
self.set_healthy(false);
Err(WorkerError::HealthCheckFailed {
url: self.url().to_string(),
url: url.to_string(),
reason: format!("Health check request failed: {}", e),
})
}

View File

@@ -37,6 +37,8 @@ struct Router {
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
dp_aware: bool,
api_key: Option<String>,
log_dir: Option<String>,
log_level: Option<String>,
service_discovery: bool,
@@ -136,6 +138,8 @@ impl Router {
request_timeout_secs: self.request_timeout_secs,
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
worker_startup_check_interval_secs: self.worker_startup_check_interval,
dp_aware: self.dp_aware,
api_key: self.api_key.clone(),
discovery,
metrics,
log_dir: self.log_dir.clone(),
@@ -161,6 +165,8 @@ impl Router {
eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24),
max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
dp_aware = false,
api_key = None,
log_dir = None,
log_level = None,
service_discovery = false,
@@ -193,6 +199,8 @@ impl Router {
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
dp_aware: bool,
api_key: Option<String>,
log_dir: Option<String>,
log_level: Option<String>,
service_discovery: bool,
@@ -225,6 +233,8 @@ impl Router {
eviction_interval_secs,
max_tree_size,
max_payload_size,
dp_aware,
api_key,
log_dir,
log_level,
service_discovery,

View File

@@ -45,6 +45,8 @@ impl RouterFactory {
policy,
router_config.worker_startup_timeout_secs,
router_config.worker_startup_check_interval_secs,
router_config.dp_aware,
router_config.api_key.clone(),
)?;
Ok(Box::new(router))

View File

@@ -30,6 +30,8 @@ pub struct Router {
policy: Arc<dyn LoadBalancingPolicy>,
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
_health_checker: Option<HealthChecker>,
@@ -42,6 +44,8 @@ impl Router {
policy: Arc<dyn LoadBalancingPolicy>,
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
) -> Result<Self, String> {
// Update active workers gauge
RouterMetrics::set_active_workers(worker_urls.len());
@@ -51,6 +55,14 @@ impl Router {
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
}
let worker_urls = if dp_aware {
// worker address now in the format of "http://host:port@dp_rank"
Self::get_dp_aware_workers(&worker_urls, &api_key)
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?
} else {
worker_urls
};
// Create Worker trait objects from URLs
let workers: Vec<Box<dyn Worker>> = worker_urls
.iter()
@@ -89,6 +101,8 @@ impl Router {
policy,
timeout_secs,
interval_secs,
dp_aware,
api_key,
_worker_loads: worker_loads,
_load_monitor_handle: load_monitor_handle,
_health_checker: Some(health_checker),
@@ -160,6 +174,62 @@ impl Router {
}
}
fn get_worker_dp_size(worker_url: &str, api_key: &Option<String>) -> Result<usize, String> {
let sync_client = reqwest::blocking::Client::new();
let mut req_builder = sync_client.get(&format!("{}/get_server_info", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send() {
Ok(res) => {
if res.status().is_success() {
let server_info = res
.text()
.map_err(|e| format!("failed to read text from response: {}", e))?;
let server_info: serde_json::Value = serde_json::from_str(&server_info)
.map_err(|e| format!("failed to decode JSON: {}", e))?;
let dp_size = server_info
.get("dp_size")
.and_then(|v| v.as_u64())
.ok_or_else(|| String::from("dp_size not found or not an u64"))?;
Ok(if dp_size > usize::MAX as u64 {
return Err(format!("dp_size is too large: {}", dp_size));
} else {
dp_size as usize
})
} else {
Err(format!("unexpected status code: {}", res.status()))
}
}
Err(e) => Err(format!("error response: {}", e)),
}
}
// Given a list of workers, return a list of workers with dp_rank as suffix
fn get_dp_aware_workers(
worker_urls: &[String],
api_key: &Option<String>,
) -> Result<Vec<String>, String> {
let mut dp_aware_workers: Vec<String> = Vec::new();
for url in worker_urls {
match Self::get_worker_dp_size(url, api_key) {
Ok(dp_size) => {
for i in 0..dp_size {
dp_aware_workers.push(format!("{}@{}", url, i));
}
}
Err(e) => return Err(format!("Failed to get DP size for {}: {}", url, e)),
}
}
Ok(dp_aware_workers)
}
fn select_first_worker(&self) -> Result<String, String> {
let workers_guard = self.workers.read().unwrap();
if workers_guard.is_empty() {
@@ -178,6 +248,21 @@ impl Router {
) -> HttpResponse {
let request_id = get_request_id(req);
let start = Instant::now();
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return HttpResponse::InternalServerError().finish();
}
};
worker_url_prefix
} else {
worker_url
};
let mut request_builder = client.get(format!("{}{}", worker_url, route));
// Copy all headers from original request except for /health because it does not need authorization
@@ -292,7 +377,7 @@ impl Router {
worker_url = %worker_url,
"Removing failed worker"
);
self.remove_worker(&worker_url);
self.remove_failed_worker(&worker_url);
break;
}
}
@@ -392,7 +477,7 @@ impl Router {
request_id = %request_id,
"Removing failed worker after typed request failures worker_url={}", worker_url
);
self.remove_worker(&worker_url);
self.remove_failed_worker(&worker_url);
break;
}
}
@@ -415,6 +500,23 @@ impl Router {
}
}
// TODO (rui): Better accommodate to the Worker abstraction
fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
let parts: Vec<&str> = worker_url.split('@').collect();
if parts.len() != 2 {
return Err(format!("invalid worker_url format: {}", worker_url));
}
// Parse the second part (dp_rank) into an integer
match parts[1].parse::<usize>() {
Ok(dp_rank) => Ok((parts[0], dp_rank)),
Err(_) => Err(format!(
"failed to parse dp_rank from worker_url: {}",
worker_url
)),
}
}
// Send typed request directly without conversion
async fn send_typed_request<T: serde::Serialize>(
&self,
@@ -429,9 +531,47 @@ impl Router {
let request_id = get_request_id(req);
let start = Instant::now();
let mut request_builder = client
.post(format!("{}{}", worker_url, route))
.json(typed_req); // Use json() directly with typed request
let mut request_builder = if self.dp_aware {
let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return HttpResponse::InternalServerError().finish();
}
};
// Parse the request body
let mut json_val = match serde_json::to_value(typed_req) {
Ok(j) => j,
Err(e) => {
return HttpResponse::BadRequest()
.body(format!("Convert into serde_json::Value failed: {}", e));
}
};
// Insert the data_parallel_rank field
if let Some(map) = json_val.as_object_mut() {
map.insert(
String::from("data_parallel_rank"),
serde_json::json!(dp_rank),
);
debug!(
"Modified request body: {}",
serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
);
} else {
return HttpResponse::BadRequest()
.body("Failed to insert the data_parallel_rank field into the request body");
}
client
.post(format!("{}{}", worker_url_prefix, route))
.json(&json_val)
} else {
client
.post(format!("{}{}", worker_url, route))
.json(typed_req) // Use json() directly with typed request
};
// Copy all headers from original request
for (name, value) in copy_request_headers(req) {
@@ -560,12 +700,35 @@ impl Router {
Ok(res) => {
if res.status().is_success() {
let mut workers_guard = self.workers.write().unwrap();
if workers_guard.iter().any(|w| w.url() == worker_url) {
return Err(format!("Worker {} already exists", worker_url));
if self.dp_aware {
// Need to contact the worker to extract the dp_size,
// and add them as multiple workers
let url_vec = vec![String::from(worker_url)];
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, &self.api_key)
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
let mut worker_added: bool = false;
for dp_url in &dp_url_vec {
if workers_guard.iter().any(|w| w.url() == dp_url) {
warn!("Worker {} already exists", dp_url);
continue;
}
info!("Added worker: {}", dp_url);
let new_worker = WorkerFactory::create_regular(dp_url.to_string());
workers_guard.push(new_worker);
worker_added = true;
}
if !worker_added {
return Err(format!("No worker added for {}", worker_url));
}
} else {
if workers_guard.iter().any(|w| w.url() == worker_url) {
return Err(format!("Worker {} already exists", worker_url));
}
info!("Added worker: {}", worker_url);
let new_worker = WorkerFactory::create_regular(worker_url.to_string());
workers_guard.push(new_worker);
}
info!("Added worker: {}", worker_url);
let new_worker = WorkerFactory::create_regular(worker_url.to_string());
workers_guard.push(new_worker);
RouterMetrics::set_active_workers(workers_guard.len());
// If cache aware policy, initialize the worker in the tree
@@ -612,11 +775,81 @@ impl Router {
}
}
/// Remove all the worker(s) that match the URL prefix
pub fn remove_worker(&self, worker_url: &str) {
if self.dp_aware {
// remove dp-aware workers in a prefix-matching fashion
// without contacting the remote worker
let mut candidate_workers: Vec<String> = Vec::new();
let mut removed_workers: Vec<String> = Vec::new();
let worker_url_prefix = format!("{}@", worker_url);
{
// find the candidate workers to be removed
let workers_guard = self.workers.read().unwrap();
for w in workers_guard.iter() {
if w.url().starts_with(&worker_url_prefix) {
candidate_workers.push(w.url().to_string());
}
}
}
{
// do the removing on the worker_urls
let mut workers_guard = self.workers.write().unwrap();
for dp_url in candidate_workers.iter() {
if let Some(index) = workers_guard.iter().position(|w| w.url() == dp_url) {
workers_guard.remove(index);
info!("Removed worker: {}", dp_url);
removed_workers.push(dp_url.to_string());
} else {
warn!("Worker {} not found, skipping removal", dp_url);
continue;
}
}
RouterMetrics::set_active_workers(workers_guard.len());
}
// If cache aware policy, remove the workers from the tree
if let Some(cache_aware) = self
.policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
for dp_url in removed_workers.iter() {
cache_aware.remove_worker(dp_url);
info!("Removed worker from tree: {}", dp_url);
}
}
} else {
let mut workers_guard = self.workers.write().unwrap();
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
workers_guard.remove(index);
info!("Removed worker: {}", worker_url);
RouterMetrics::set_active_workers(workers_guard.len());
} else {
warn!("Worker {} not found, skipping removal", worker_url);
return;
}
// If cache aware policy, remove the workers from the tree
if let Some(cache_aware) = self
.policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker(worker_url);
info!("Removed worker from tree: {}", worker_url);
}
}
}
/// Remove a specific failed worker; for internal usage
fn remove_failed_worker(&self, worker_url: &str) {
let mut workers_guard = self.workers.write().unwrap();
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
workers_guard.remove(index);
info!("Removed worker: {}", worker_url);
info!("Removed failed worker: {}", worker_url);
RouterMetrics::set_active_workers(workers_guard.len());
} else {
warn!("Worker {} not found, skipping removal", worker_url);
@@ -634,6 +867,20 @@ impl Router {
}
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return None;
}
};
worker_url_prefix
} else {
worker_url
};
match client.get(&format!("{}/get_load", worker_url)).send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
@@ -710,6 +957,20 @@ impl Router {
// Static version of get_worker_load for use in monitoring task
async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
let worker_url = if worker_url.contains("@") {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
debug!("Failed to extract dp_rank: {}", e);
return None;
}
};
worker_url_prefix
} else {
worker_url
};
match client.get(&format!("{}/get_load", worker_url)).send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
@@ -862,6 +1123,19 @@ impl RouterTrait for Router {
// Send requests to all workers concurrently without headers
let mut tasks = Vec::new();
for worker_url in &worker_urls {
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return HttpResponse::InternalServerError().finish();
}
};
worker_url_prefix
} else {
worker_url
};
let request_builder = client.post(format!("{}/flush_cache", worker_url));
tasks.push(request_builder.send());
}
@@ -948,6 +1222,8 @@ mod tests {
policy: Arc::new(RandomPolicy::new()),
timeout_secs: 5,
interval_secs: 1,
dp_aware: false,
api_key: None,
_worker_loads: Arc::new(rx),
_load_monitor_handle: None,
_health_checker: None,

View File

@@ -581,7 +581,7 @@ mod tests {
use crate::routers::router::Router;
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
let router = Router::new(vec![], policy, 5, 1).unwrap();
let router = Router::new(vec![], policy, 5, 1, false, None).unwrap();
Arc::new(router) as Arc<dyn RouterTrait>
}