[feature] [sgl-router] Add a dp-aware routing strategy (#6869)

This commit is contained in:
Rui Chen
2025-07-30 20:58:48 +08:00
committed by GitHub
parent 55ecdc0a8e
commit a730ce8162
19 changed files with 726 additions and 16 deletions

View File

@@ -50,6 +50,8 @@ class RouterArgs:
eviction_interval: int = 60
max_tree_size: int = 2**24
max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches
dp_aware: bool = False
api_key: Optional[str] = None
log_dir: Optional[str] = None
log_level: Optional[str] = None
# Service discovery configuration
@@ -197,6 +199,17 @@ class RouterArgs:
default=RouterArgs.max_payload_size,
help="Maximum payload size in bytes",
)
parser.add_argument(
f"--{prefix}dp-aware",
action="store_true",
help="Enable data parallelism aware schedule",
)
parser.add_argument(
f"--{prefix}api-key",
type=str,
default=None,
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
)
parser.add_argument(
f"--{prefix}log-dir",
type=str,
@@ -304,6 +317,8 @@ class RouterArgs:
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
dp_aware=getattr(args, f"{prefix}dp_aware", False),
api_key=getattr(args, f"{prefix}api_key", None),
log_dir=getattr(args, f"{prefix}log_dir", None),
log_level=getattr(args, f"{prefix}log_level", None),
service_discovery=getattr(args, f"{prefix}service_discovery", False),
@@ -463,6 +478,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
eviction_interval_secs=router_args.eviction_interval,
max_tree_size=router_args.max_tree_size,
max_payload_size=router_args.max_payload_size,
dp_aware=router_args.dp_aware,
api_key=router_args.api_key,
log_dir=router_args.log_dir,
log_level=router_args.log_level,
service_discovery=router_args.service_discovery,

View File

@@ -31,6 +31,10 @@ class Router:
routing. Default: 60
max_payload_size: Maximum payload size in bytes. Default: 256MB
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
dp_aware: Enable data parallelism aware schedule. Default: False
api_key: The api key used for the authorization with the worker.
Useful when the dp aware scheduling strategy is enabled.
Default: None
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
log_level: Logging level. Options: 'debug', 'info', 'warning', 'error', 'critical'.
service_discovery: Enable Kubernetes service discovery. When enabled, the router will
@@ -73,6 +77,8 @@ class Router:
eviction_interval_secs: int = 60,
max_tree_size: int = 2**24,
max_payload_size: int = 256 * 1024 * 1024, # 256MB
dp_aware: bool = False,
api_key: Optional[str] = None,
log_dir: Optional[str] = None,
log_level: Optional[str] = None,
service_discovery: bool = False,
@@ -110,6 +116,8 @@ class Router:
eviction_interval_secs=eviction_interval_secs,
max_tree_size=max_tree_size,
max_payload_size=max_payload_size,
dp_aware=dp_aware,
api_key=api_key,
log_dir=log_dir,
log_level=log_level,
service_discovery=service_discovery,

View File

@@ -8,7 +8,7 @@ if __name__ == "__main__":
arg_parser.add_argument(
"--timeout-per-file",
type=int,
default=1000,
default=2000,
help="The time limit for running one file in seconds.",
)
args = arg_parser.parse_args()

View File

@@ -43,6 +43,7 @@ class TestLaunchRouter(unittest.TestCase):
selector=None,
service_discovery_port=80,
service_discovery_namespace=None,
dp_aware=False,
prometheus_port=None,
prometheus_host=None,
# PD-specific attributes
@@ -111,6 +112,52 @@ class TestLaunchRouter(unittest.TestCase):
)
self.run_router_process(args)
def test_launch_router_common_with_dp_aware(self):
args = self.create_router_args(
worker_urls=["http://localhost:8000"],
dp_aware=True,
)
self.run_router_process(args)
def test_launch_router_with_empty_worker_urls_with_dp_aware(self):
args = self.create_router_args(
worker_urls=[],
dp_aware=True,
)
self.run_router_process(args)
def test_launch_router_common_with_dp_aware_service_discovery(self):
# Test launch router with bot srevice_discovery and dp_aware enabled
# Should fail since service_discovery and dp_aware is conflict
args = self.create_router_args(
worker_urls=["http://localhost:8000"],
dp_aware=True,
service_discovery=True,
selector=["app=test-worker"],
)
def run_router():
try:
from sglang_router.launch_router import launch_router
router = launch_router(args)
if router is None:
return 1
return 0
except Exception as e:
print(e)
return 1
process = multiprocessing.Process(target=run_router)
try:
process.start()
# Wait 3 seconds
time.sleep(3)
# Should fail since service_discovery and dp_aware is conflict
self.assertFalse(process.is_alive())
finally:
terminate_process(process)
def test_launch_router_pd_mode_basic(self):
"""Test basic PD router functionality without actually starting servers."""
# This test just verifies the PD router can be created and configured

View File

@@ -30,6 +30,7 @@ def popen_launch_router(
service_discovery_namespace: str = None,
prometheus_port: int = None,
prometheus_host: str = None,
dp_aware: bool = False,
):
"""
Launch the router server process.
@@ -49,6 +50,7 @@ def popen_launch_router(
service_discovery_namespace: Kubernetes namespace to watch for pods. If None, watches all namespaces.
prometheus_port: Port to expose Prometheus metrics. If None, Prometheus metrics are disabled.
prometheus_host: Host address to bind the Prometheus metrics server.
dp_aware: Enable data parallelism aware routing strategy.
"""
_, host, port = base_url.split(":")
host = host[2:]
@@ -69,10 +71,12 @@ def popen_launch_router(
"5",
"--router-policy",
policy,
"--allow-auto-truncate",
]
if api_key is not None:
command.extend(["--api-key", api_key])
command.extend(["--router-api-key", api_key])
if max_payload_size is not None:
command.extend(["--router-max-payload-size", str(max_payload_size)])
@@ -100,6 +104,9 @@ def popen_launch_router(
if log_dir is not None:
command.extend(["--log-dir", log_dir])
if dp_aware:
command.append("--router-dp-aware")
process = subprocess.Popen(command, stdout=None, stderr=None)
start_time = time.perf_counter()
@@ -127,6 +134,7 @@ def popen_launch_server(
model: str,
base_url: str,
timeout: float,
api_key: str = None,
):
_, host, port = base_url.split(":")
host = host[2:]
@@ -145,6 +153,9 @@ def popen_launch_server(
"1",
]
if api_key is not None:
command.extend(["--api-key", api_key])
process = subprocess.Popen(command, stdout=None, stderr=None)
# intentionally don't wait and defer the job to the router health check
@@ -426,6 +437,274 @@ class TestLaunchServer(unittest.TestCase):
response.status_code, 200, "Request with correct api key should succeed"
)
def test_6_mmlu_with_dp_aware(self):
print("Running test_6_mmlu_with_dp_aware...")
# DP size = 2
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=2,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="cache_aware",
dp_aware=True,
)
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"dp aware MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
def test_7_add_and_remove_worker_with_dp_aware(self):
print("Running test_7_add_and_remove_worker_with_dp_aware...")
# Set dp_size = 1
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin", # make sure every worker processes requests
dp_aware=True, # dp aware strategy should work well with RR
)
# 1. Start a worker
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self.other_process.append(worker_process)
# 2. Use the /add_worker API to add it to the router
# It will be used by router after it is healthy
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# 3. Run mmlu
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=64,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
# 4. Use the /remove_worker API to remove it from the router
with requests.Session() as session:
response = session.post(f"{self.base_url}/remove_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# 5. Run mmlu again
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
# 6. Start another worker with api_key set
terminate_and_wait(worker_process) # terminate the old worker process
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model,
worker_url,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
api_key="correct_api_key",
)
self.other_process.append(worker_process)
# 7. Use the /add_worker API to add it to the router
# Should fail since the router would contact the worker's
# /get_server_info endpoint for the dp_size info, but it
# has no knowledge of the api key
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertNotEqual(response.status_code, 200)
def test_8_lazy_fault_tolerance_with_dp_aware(self):
print("Running test_8_lazy_fault_tolerance_with_dp_aware...")
# Set dp_size = 1
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
dp_aware=True,
)
# 1. Start a worker
port = find_available_port()
worker_url = f"http://127.0.0.1:{port}"
worker_process = popen_launch_server(
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
)
self.other_process.append(worker_process)
# 2. Use the /add_worker API to add it to the router
# It will be used by router after it is healthy
with requests.Session() as session:
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(response.status_code, 200)
# Start a thread to kill the worker after 10 seconds to mimic
# abrupt worker failure
def kill_worker():
time.sleep(10)
kill_process_tree(worker_process.pid)
print("Worker process killed")
import threading
kill_thread = threading.Thread(target=kill_worker)
kill_thread.daemon = True
kill_thread.start()
# 3. Run mmlu
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mmlu",
num_examples=256,
num_threads=32,
temperature=0.1,
)
metrics = run_eval(args)
score = metrics["score"]
THRESHOLD = 0.65
passed = score >= THRESHOLD
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
self.assertGreaterEqual(score, THRESHOLD, msg)
def test_9_payload_size_with_dp_aware(self):
print("Running test_9_payload_size_with_dp_aware...")
# Start the router with 1MB limit
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
max_payload_size=1 * 1024 * 1024, # 1MB limit
dp_aware=True,
)
# Test case 1: Payload just under 1MB should succeed
payload_0_5_mb = {
"text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text
"temperature": 0.0,
}
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json=payload_0_5_mb,
headers={"Content-Type": "application/json"},
)
self.assertEqual(
response.status_code,
200,
f"0.5MB payload should succeed but got status {response.status_code}",
)
# Test case 2: Payload over 1MB should fail
payload_1_plus_mb = {
"text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text
"temperature": 0.0,
}
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json=payload_1_plus_mb,
headers={"Content-Type": "application/json"},
)
self.assertEqual(
response.status_code,
413, # Payload Too Large
f"1.2MB payload should fail with 413 but got status {response.status_code}",
)
def test_10_api_key_with_dp_aware(self):
print("Running test_10_api_key_with_dp_aware...")
self.process = popen_launch_router(
self.model,
self.base_url,
dp_size=1,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
policy="round_robin",
api_key="correct_api_key",
dp_aware=True,
)
# Test case 1: request without api key should fail
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is, ", "temperature": 0},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code,
401,
f"Request without api key should fail with 401 but got status {response.status_code}",
)
# Test case 2: request with invalid api key should fail
with requests.Session() as session:
response = requests.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is, ", "temperature": 0},
headers={"Authorization": "Bearer 123"},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code,
401,
f"Request without api key should fail with 401 but got status {response.status_code}",
)
# Test case 3: request with correct api key should succeed
with requests.Session() as session:
response = session.post(
f"{self.base_url}/generate",
json={"text": "Kanye west is ", "temperature": 0},
headers={"Authorization": "Bearer correct_api_key"},
)
print(f"status code: {response.status_code}, response: {response.text}")
self.assertEqual(
response.status_code,
200,
f"Request with correct api key should succeed but got status {response.status_code}",
)
if __name__ == "__main__":
unittest.main()

View File

@@ -21,6 +21,10 @@ pub struct RouterConfig {
pub worker_startup_timeout_secs: u64,
/// Worker health check interval in seconds
pub worker_startup_check_interval_secs: u64,
/// Enable data parallelism aware schedule
pub dp_aware: bool,
/// The api key used for the authorization with the worker
pub api_key: Option<String>,
/// Service discovery configuration (optional)
pub discovery: Option<DiscoveryConfig>,
/// Metrics configuration (optional)
@@ -205,6 +209,8 @@ impl Default for RouterConfig {
request_timeout_secs: 600,
worker_startup_timeout_secs: 300,
worker_startup_check_interval_secs: 10,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
@@ -311,6 +317,8 @@ mod tests {
request_timeout_secs: 30,
worker_startup_timeout_secs: 60,
worker_startup_check_interval_secs: 5,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig::default()),
metrics: Some(MetricsConfig::default()),
log_dir: Some("/var/log".to_string()),
@@ -727,6 +735,8 @@ mod tests {
request_timeout_secs: 120,
worker_startup_timeout_secs: 60,
worker_startup_check_interval_secs: 5,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: Some("sglang".to_string()),
@@ -774,6 +784,8 @@ mod tests {
request_timeout_secs: 300,
worker_startup_timeout_secs: 180,
worker_startup_check_interval_secs: 15,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: None,
@@ -812,6 +824,8 @@ mod tests {
request_timeout_secs: 900,
worker_startup_timeout_secs: 600,
worker_startup_check_interval_secs: 20,
dp_aware: false,
api_key: None,
discovery: Some(DiscoveryConfig {
enabled: true,
namespace: Some("production".to_string()),

View File

@@ -313,6 +313,14 @@ impl ConfigValidator {
}
}
// Service discovery is conflict with dp_aware routing for now
// since it's not fully supported yet
if has_service_discovery && config.dp_aware {
return Err(ConfigError::IncompatibleConfig {
reason: "DP-aware routing is not compatible with service discovery".to_string(),
});
}
Ok(())
}

View File

@@ -17,6 +17,8 @@ pub enum WorkerError {
NetworkError { url: String, error: String },
/// Worker is at capacity
WorkerAtCapacity { url: String },
/// Invalid URL format
InvalidUrl { url: String },
}
impl fmt::Display for WorkerError {
@@ -37,6 +39,9 @@ impl fmt::Display for WorkerError {
WorkerError::WorkerAtCapacity { url } => {
write!(f, "Worker at capacity: {}", url)
}
WorkerError::InvalidUrl { url } => {
write!(f, "Invalid URL format: {}", url)
}
}
}
}

View File

@@ -162,6 +162,27 @@ impl BasicWorker {
self.metadata.health_config = config;
self
}
pub fn normalised_url(&self) -> WorkerResult<&str> {
if self.url().contains("@") {
// Need to extract the URL from "http://host:port@dp_rank"
let parts: Vec<&str> = self.url().split('@').collect();
if parts.len() != 2 {
return Err(WorkerError::InvalidUrl {
url: self.url().to_string(),
});
}
// Ensure the second part (the dp_rank) can be parsed as an integer
match parts[1].parse::<usize>() {
Ok(_) => Ok(parts[0]),
Err(_) => Err(WorkerError::InvalidUrl {
url: self.url().to_string(),
}),
}
} else {
Ok(self.url())
}
}
}
#[async_trait]
@@ -186,7 +207,8 @@ impl Worker for BasicWorker {
use std::time::Duration;
// Perform actual HTTP health check
let health_url = format!("{}{}", self.url(), self.metadata.health_config.endpoint);
let url = self.normalised_url()?;
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
// Use the shared client with a custom timeout for this request
@@ -203,7 +225,7 @@ impl Worker for BasicWorker {
} else {
self.set_healthy(false);
Err(WorkerError::HealthCheckFailed {
url: self.url().to_string(),
url: url.to_string(),
reason: format!("Health check returned status: {}", response.status()),
})
}
@@ -211,7 +233,7 @@ impl Worker for BasicWorker {
Err(e) => {
self.set_healthy(false);
Err(WorkerError::HealthCheckFailed {
url: self.url().to_string(),
url: url.to_string(),
reason: format!("Health check request failed: {}", e),
})
}

View File

@@ -37,6 +37,8 @@ struct Router {
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
dp_aware: bool,
api_key: Option<String>,
log_dir: Option<String>,
log_level: Option<String>,
service_discovery: bool,
@@ -136,6 +138,8 @@ impl Router {
request_timeout_secs: self.request_timeout_secs,
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
worker_startup_check_interval_secs: self.worker_startup_check_interval,
dp_aware: self.dp_aware,
api_key: self.api_key.clone(),
discovery,
metrics,
log_dir: self.log_dir.clone(),
@@ -161,6 +165,8 @@ impl Router {
eviction_interval_secs = 60,
max_tree_size = 2usize.pow(24),
max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
dp_aware = false,
api_key = None,
log_dir = None,
log_level = None,
service_discovery = false,
@@ -193,6 +199,8 @@ impl Router {
eviction_interval_secs: u64,
max_tree_size: usize,
max_payload_size: usize,
dp_aware: bool,
api_key: Option<String>,
log_dir: Option<String>,
log_level: Option<String>,
service_discovery: bool,
@@ -225,6 +233,8 @@ impl Router {
eviction_interval_secs,
max_tree_size,
max_payload_size,
dp_aware,
api_key,
log_dir,
log_level,
service_discovery,

View File

@@ -45,6 +45,8 @@ impl RouterFactory {
policy,
router_config.worker_startup_timeout_secs,
router_config.worker_startup_check_interval_secs,
router_config.dp_aware,
router_config.api_key.clone(),
)?;
Ok(Box::new(router))

View File

@@ -30,6 +30,8 @@ pub struct Router {
policy: Arc<dyn LoadBalancingPolicy>,
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
_health_checker: Option<HealthChecker>,
@@ -42,6 +44,8 @@ impl Router {
policy: Arc<dyn LoadBalancingPolicy>,
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
api_key: Option<String>,
) -> Result<Self, String> {
// Update active workers gauge
RouterMetrics::set_active_workers(worker_urls.len());
@@ -51,6 +55,14 @@ impl Router {
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
}
let worker_urls = if dp_aware {
// worker address now in the format of "http://host:port@dp_rank"
Self::get_dp_aware_workers(&worker_urls, &api_key)
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?
} else {
worker_urls
};
// Create Worker trait objects from URLs
let workers: Vec<Box<dyn Worker>> = worker_urls
.iter()
@@ -89,6 +101,8 @@ impl Router {
policy,
timeout_secs,
interval_secs,
dp_aware,
api_key,
_worker_loads: worker_loads,
_load_monitor_handle: load_monitor_handle,
_health_checker: Some(health_checker),
@@ -160,6 +174,62 @@ impl Router {
}
}
fn get_worker_dp_size(worker_url: &str, api_key: &Option<String>) -> Result<usize, String> {
let sync_client = reqwest::blocking::Client::new();
let mut req_builder = sync_client.get(&format!("{}/get_server_info", worker_url));
if let Some(key) = api_key {
req_builder = req_builder.bearer_auth(key);
}
match req_builder.send() {
Ok(res) => {
if res.status().is_success() {
let server_info = res
.text()
.map_err(|e| format!("failed to read text from response: {}", e))?;
let server_info: serde_json::Value = serde_json::from_str(&server_info)
.map_err(|e| format!("failed to decode JSON: {}", e))?;
let dp_size = server_info
.get("dp_size")
.and_then(|v| v.as_u64())
.ok_or_else(|| String::from("dp_size not found or not an u64"))?;
Ok(if dp_size > usize::MAX as u64 {
return Err(format!("dp_size is too large: {}", dp_size));
} else {
dp_size as usize
})
} else {
Err(format!("unexpected status code: {}", res.status()))
}
}
Err(e) => Err(format!("error response: {}", e)),
}
}
// Given a list of workers, return a list of workers with dp_rank as suffix
fn get_dp_aware_workers(
worker_urls: &[String],
api_key: &Option<String>,
) -> Result<Vec<String>, String> {
let mut dp_aware_workers: Vec<String> = Vec::new();
for url in worker_urls {
match Self::get_worker_dp_size(url, api_key) {
Ok(dp_size) => {
for i in 0..dp_size {
dp_aware_workers.push(format!("{}@{}", url, i));
}
}
Err(e) => return Err(format!("Failed to get DP size for {}: {}", url, e)),
}
}
Ok(dp_aware_workers)
}
fn select_first_worker(&self) -> Result<String, String> {
let workers_guard = self.workers.read().unwrap();
if workers_guard.is_empty() {
@@ -178,6 +248,21 @@ impl Router {
) -> HttpResponse {
let request_id = get_request_id(req);
let start = Instant::now();
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return HttpResponse::InternalServerError().finish();
}
};
worker_url_prefix
} else {
worker_url
};
let mut request_builder = client.get(format!("{}{}", worker_url, route));
// Copy all headers from original request except for /health because it does not need authorization
@@ -292,7 +377,7 @@ impl Router {
worker_url = %worker_url,
"Removing failed worker"
);
self.remove_worker(&worker_url);
self.remove_failed_worker(&worker_url);
break;
}
}
@@ -392,7 +477,7 @@ impl Router {
request_id = %request_id,
"Removing failed worker after typed request failures worker_url={}", worker_url
);
self.remove_worker(&worker_url);
self.remove_failed_worker(&worker_url);
break;
}
}
@@ -415,6 +500,23 @@ impl Router {
}
}
// TODO (rui): Better accommodate to the Worker abstraction
fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
let parts: Vec<&str> = worker_url.split('@').collect();
if parts.len() != 2 {
return Err(format!("invalid worker_url format: {}", worker_url));
}
// Parse the second part (dp_rank) into an integer
match parts[1].parse::<usize>() {
Ok(dp_rank) => Ok((parts[0], dp_rank)),
Err(_) => Err(format!(
"failed to parse dp_rank from worker_url: {}",
worker_url
)),
}
}
// Send typed request directly without conversion
async fn send_typed_request<T: serde::Serialize>(
&self,
@@ -429,9 +531,47 @@ impl Router {
let request_id = get_request_id(req);
let start = Instant::now();
let mut request_builder = client
.post(format!("{}{}", worker_url, route))
.json(typed_req); // Use json() directly with typed request
let mut request_builder = if self.dp_aware {
let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return HttpResponse::InternalServerError().finish();
}
};
// Parse the request body
let mut json_val = match serde_json::to_value(typed_req) {
Ok(j) => j,
Err(e) => {
return HttpResponse::BadRequest()
.body(format!("Convert into serde_json::Value failed: {}", e));
}
};
// Insert the data_parallel_rank field
if let Some(map) = json_val.as_object_mut() {
map.insert(
String::from("data_parallel_rank"),
serde_json::json!(dp_rank),
);
debug!(
"Modified request body: {}",
serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
);
} else {
return HttpResponse::BadRequest()
.body("Failed to insert the data_parallel_rank field into the request body");
}
client
.post(format!("{}{}", worker_url_prefix, route))
.json(&json_val)
} else {
client
.post(format!("{}{}", worker_url, route))
.json(typed_req) // Use json() directly with typed request
};
// Copy all headers from original request
for (name, value) in copy_request_headers(req) {
@@ -560,12 +700,35 @@ impl Router {
Ok(res) => {
if res.status().is_success() {
let mut workers_guard = self.workers.write().unwrap();
if workers_guard.iter().any(|w| w.url() == worker_url) {
return Err(format!("Worker {} already exists", worker_url));
if self.dp_aware {
// Need to contact the worker to extract the dp_size,
// and add them as multiple workers
let url_vec = vec![String::from(worker_url)];
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, &self.api_key)
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
let mut worker_added: bool = false;
for dp_url in &dp_url_vec {
if workers_guard.iter().any(|w| w.url() == dp_url) {
warn!("Worker {} already exists", dp_url);
continue;
}
info!("Added worker: {}", dp_url);
let new_worker = WorkerFactory::create_regular(dp_url.to_string());
workers_guard.push(new_worker);
worker_added = true;
}
if !worker_added {
return Err(format!("No worker added for {}", worker_url));
}
} else {
if workers_guard.iter().any(|w| w.url() == worker_url) {
return Err(format!("Worker {} already exists", worker_url));
}
info!("Added worker: {}", worker_url);
let new_worker = WorkerFactory::create_regular(worker_url.to_string());
workers_guard.push(new_worker);
}
info!("Added worker: {}", worker_url);
let new_worker = WorkerFactory::create_regular(worker_url.to_string());
workers_guard.push(new_worker);
RouterMetrics::set_active_workers(workers_guard.len());
// If cache aware policy, initialize the worker in the tree
@@ -612,11 +775,81 @@ impl Router {
}
}
/// Remove all the worker(s) that match the URL prefix
pub fn remove_worker(&self, worker_url: &str) {
if self.dp_aware {
// remove dp-aware workers in a prefix-matching fashion
// without contacting the remote worker
let mut candidate_workers: Vec<String> = Vec::new();
let mut removed_workers: Vec<String> = Vec::new();
let worker_url_prefix = format!("{}@", worker_url);
{
// find the candidate workers to be removed
let workers_guard = self.workers.read().unwrap();
for w in workers_guard.iter() {
if w.url().starts_with(&worker_url_prefix) {
candidate_workers.push(w.url().to_string());
}
}
}
{
// do the removing on the worker_urls
let mut workers_guard = self.workers.write().unwrap();
for dp_url in candidate_workers.iter() {
if let Some(index) = workers_guard.iter().position(|w| w.url() == dp_url) {
workers_guard.remove(index);
info!("Removed worker: {}", dp_url);
removed_workers.push(dp_url.to_string());
} else {
warn!("Worker {} not found, skipping removal", dp_url);
continue;
}
}
RouterMetrics::set_active_workers(workers_guard.len());
}
// If cache aware policy, remove the workers from the tree
if let Some(cache_aware) = self
.policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
for dp_url in removed_workers.iter() {
cache_aware.remove_worker(dp_url);
info!("Removed worker from tree: {}", dp_url);
}
}
} else {
let mut workers_guard = self.workers.write().unwrap();
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
workers_guard.remove(index);
info!("Removed worker: {}", worker_url);
RouterMetrics::set_active_workers(workers_guard.len());
} else {
warn!("Worker {} not found, skipping removal", worker_url);
return;
}
// If cache aware policy, remove the workers from the tree
if let Some(cache_aware) = self
.policy
.as_any()
.downcast_ref::<crate::policies::CacheAwarePolicy>()
{
cache_aware.remove_worker(worker_url);
info!("Removed worker from tree: {}", worker_url);
}
}
}
/// Remove a specific failed worker; for internal usage
fn remove_failed_worker(&self, worker_url: &str) {
let mut workers_guard = self.workers.write().unwrap();
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
workers_guard.remove(index);
info!("Removed worker: {}", worker_url);
info!("Removed failed worker: {}", worker_url);
RouterMetrics::set_active_workers(workers_guard.len());
} else {
warn!("Worker {} not found, skipping removal", worker_url);
@@ -634,6 +867,20 @@ impl Router {
}
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return None;
}
};
worker_url_prefix
} else {
worker_url
};
match client.get(&format!("{}/get_load", worker_url)).send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
@@ -710,6 +957,20 @@ impl Router {
// Static version of get_worker_load for use in monitoring task
async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
let worker_url = if worker_url.contains("@") {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
debug!("Failed to extract dp_rank: {}", e);
return None;
}
};
worker_url_prefix
} else {
worker_url
};
match client.get(&format!("{}/get_load", worker_url)).send().await {
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
@@ -862,6 +1123,19 @@ impl RouterTrait for Router {
// Send requests to all workers concurrently without headers
let mut tasks = Vec::new();
for worker_url in &worker_urls {
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
Ok(tup) => tup,
Err(e) => {
error!("Failed to extract dp_rank: {}", e);
return HttpResponse::InternalServerError().finish();
}
};
worker_url_prefix
} else {
worker_url
};
let request_builder = client.post(format!("{}/flush_cache", worker_url));
tasks.push(request_builder.send());
}
@@ -948,6 +1222,8 @@ mod tests {
policy: Arc::new(RandomPolicy::new()),
timeout_secs: 5,
interval_secs: 1,
dp_aware: false,
api_key: None,
_worker_loads: Arc::new(rx),
_load_monitor_handle: None,
_health_checker: None,

View File

@@ -581,7 +581,7 @@ mod tests {
use crate::routers::router::Router;
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
let router = Router::new(vec![], policy, 5, 1).unwrap();
let router = Router::new(vec![], policy, 5, 1, false, None).unwrap();
Arc::new(router) as Arc<dyn RouterTrait>
}

View File

@@ -31,6 +31,8 @@ impl TestContext {
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
@@ -950,6 +952,8 @@ mod error_tests {
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,

View File

@@ -16,6 +16,8 @@ pub fn create_test_config(worker_urls: Vec<String>) -> RouterConfig {
request_timeout_secs: 600,
worker_startup_timeout_secs: 300,
worker_startup_check_interval_secs: 10,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,
@@ -37,6 +39,8 @@ pub fn create_test_config_no_workers() -> RouterConfig {
request_timeout_secs: 600,
worker_startup_timeout_secs: 0, // No wait
worker_startup_check_interval_secs: 10,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,

View File

@@ -42,6 +42,8 @@ impl RequestTestContext {
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,

View File

@@ -46,6 +46,8 @@ impl StreamingTestContext {
request_timeout_secs: 600,
worker_startup_timeout_secs: 1,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,

View File

@@ -169,6 +169,8 @@ mod test_pd_routing {
request_timeout_secs: 60,
worker_startup_timeout_secs: 10,
worker_startup_check_interval_secs: 1,
dp_aware: false,
api_key: None,
discovery: None,
metrics: None,
log_dir: None,