[feature] [sgl-router] Add a dp-aware routing strategy (#6869)
This commit is contained in:
@@ -141,6 +141,14 @@ Process:
|
||||
|
||||
For unbalanced systems, this strategy tracks pending request counts per worker and routes new requests to the least busy worker. This helps maintain optimal load distribution across workers.
|
||||
|
||||
***Data-Parallelism Aware Routing***
|
||||
|
||||
An additional DP-aware routing strategy can be enabled on top of the sgl-router’s hybrid cache-aware load-balancing strategy by setting the `--dp-aware` flag when starting the router.
|
||||
|
||||
When this flag is enabled, the router attempts to contact the workers to retrieve the `dp_size` of each one and registers the new workers at the DP-rank level. In this mode, the router applies the cache-aware routing strategy in a more fine-grained manner, with assistance from the DP controller on the SRT side.
|
||||
|
||||
By default (when the flag is not set), the SRT’s DP controller distributes incoming requests across DP ranks in a round-robin fashion.
|
||||
|
||||
## Configuration Parameters
|
||||
|
||||
1. `cache_threshold`: (float, 0.0 to 1.0, default: 0.5)
|
||||
|
||||
@@ -50,6 +50,8 @@ class RouterArgs:
|
||||
eviction_interval: int = 60
|
||||
max_tree_size: int = 2**24
|
||||
max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches
|
||||
dp_aware: bool = False
|
||||
api_key: Optional[str] = None
|
||||
log_dir: Optional[str] = None
|
||||
log_level: Optional[str] = None
|
||||
# Service discovery configuration
|
||||
@@ -197,6 +199,17 @@ class RouterArgs:
|
||||
default=RouterArgs.max_payload_size,
|
||||
help="Maximum payload size in bytes",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}dp-aware",
|
||||
action="store_true",
|
||||
help="Enable data parallelism aware schedule",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}api-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
|
||||
)
|
||||
parser.add_argument(
|
||||
f"--{prefix}log-dir",
|
||||
type=str,
|
||||
@@ -304,6 +317,8 @@ class RouterArgs:
|
||||
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
|
||||
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
|
||||
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
|
||||
dp_aware=getattr(args, f"{prefix}dp_aware", False),
|
||||
api_key=getattr(args, f"{prefix}api_key", None),
|
||||
log_dir=getattr(args, f"{prefix}log_dir", None),
|
||||
log_level=getattr(args, f"{prefix}log_level", None),
|
||||
service_discovery=getattr(args, f"{prefix}service_discovery", False),
|
||||
@@ -463,6 +478,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
||||
eviction_interval_secs=router_args.eviction_interval,
|
||||
max_tree_size=router_args.max_tree_size,
|
||||
max_payload_size=router_args.max_payload_size,
|
||||
dp_aware=router_args.dp_aware,
|
||||
api_key=router_args.api_key,
|
||||
log_dir=router_args.log_dir,
|
||||
log_level=router_args.log_level,
|
||||
service_discovery=router_args.service_discovery,
|
||||
|
||||
@@ -31,6 +31,10 @@ class Router:
|
||||
routing. Default: 60
|
||||
max_payload_size: Maximum payload size in bytes. Default: 256MB
|
||||
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
|
||||
dp_aware: Enable data parallelism aware schedule. Default: False
|
||||
api_key: The api key used for the authorization with the worker.
|
||||
Useful when the dp aware scheduling strategy is enabled.
|
||||
Default: None
|
||||
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
|
||||
log_level: Logging level. Options: 'debug', 'info', 'warning', 'error', 'critical'.
|
||||
service_discovery: Enable Kubernetes service discovery. When enabled, the router will
|
||||
@@ -73,6 +77,8 @@ class Router:
|
||||
eviction_interval_secs: int = 60,
|
||||
max_tree_size: int = 2**24,
|
||||
max_payload_size: int = 256 * 1024 * 1024, # 256MB
|
||||
dp_aware: bool = False,
|
||||
api_key: Optional[str] = None,
|
||||
log_dir: Optional[str] = None,
|
||||
log_level: Optional[str] = None,
|
||||
service_discovery: bool = False,
|
||||
@@ -110,6 +116,8 @@ class Router:
|
||||
eviction_interval_secs=eviction_interval_secs,
|
||||
max_tree_size=max_tree_size,
|
||||
max_payload_size=max_payload_size,
|
||||
dp_aware=dp_aware,
|
||||
api_key=api_key,
|
||||
log_dir=log_dir,
|
||||
log_level=log_level,
|
||||
service_discovery=service_discovery,
|
||||
|
||||
@@ -8,7 +8,7 @@ if __name__ == "__main__":
|
||||
arg_parser.add_argument(
|
||||
"--timeout-per-file",
|
||||
type=int,
|
||||
default=1000,
|
||||
default=2000,
|
||||
help="The time limit for running one file in seconds.",
|
||||
)
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
@@ -43,6 +43,7 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
selector=None,
|
||||
service_discovery_port=80,
|
||||
service_discovery_namespace=None,
|
||||
dp_aware=False,
|
||||
prometheus_port=None,
|
||||
prometheus_host=None,
|
||||
# PD-specific attributes
|
||||
@@ -111,6 +112,52 @@ class TestLaunchRouter(unittest.TestCase):
|
||||
)
|
||||
self.run_router_process(args)
|
||||
|
||||
def test_launch_router_common_with_dp_aware(self):
|
||||
args = self.create_router_args(
|
||||
worker_urls=["http://localhost:8000"],
|
||||
dp_aware=True,
|
||||
)
|
||||
self.run_router_process(args)
|
||||
|
||||
def test_launch_router_with_empty_worker_urls_with_dp_aware(self):
|
||||
args = self.create_router_args(
|
||||
worker_urls=[],
|
||||
dp_aware=True,
|
||||
)
|
||||
self.run_router_process(args)
|
||||
|
||||
def test_launch_router_common_with_dp_aware_service_discovery(self):
|
||||
# Test launch router with bot srevice_discovery and dp_aware enabled
|
||||
# Should fail since service_discovery and dp_aware is conflict
|
||||
args = self.create_router_args(
|
||||
worker_urls=["http://localhost:8000"],
|
||||
dp_aware=True,
|
||||
service_discovery=True,
|
||||
selector=["app=test-worker"],
|
||||
)
|
||||
|
||||
def run_router():
|
||||
try:
|
||||
from sglang_router.launch_router import launch_router
|
||||
|
||||
router = launch_router(args)
|
||||
if router is None:
|
||||
return 1
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return 1
|
||||
|
||||
process = multiprocessing.Process(target=run_router)
|
||||
try:
|
||||
process.start()
|
||||
# Wait 3 seconds
|
||||
time.sleep(3)
|
||||
# Should fail since service_discovery and dp_aware is conflict
|
||||
self.assertFalse(process.is_alive())
|
||||
finally:
|
||||
terminate_process(process)
|
||||
|
||||
def test_launch_router_pd_mode_basic(self):
|
||||
"""Test basic PD router functionality without actually starting servers."""
|
||||
# This test just verifies the PD router can be created and configured
|
||||
|
||||
@@ -30,6 +30,7 @@ def popen_launch_router(
|
||||
service_discovery_namespace: str = None,
|
||||
prometheus_port: int = None,
|
||||
prometheus_host: str = None,
|
||||
dp_aware: bool = False,
|
||||
):
|
||||
"""
|
||||
Launch the router server process.
|
||||
@@ -49,6 +50,7 @@ def popen_launch_router(
|
||||
service_discovery_namespace: Kubernetes namespace to watch for pods. If None, watches all namespaces.
|
||||
prometheus_port: Port to expose Prometheus metrics. If None, Prometheus metrics are disabled.
|
||||
prometheus_host: Host address to bind the Prometheus metrics server.
|
||||
dp_aware: Enable data parallelism aware routing strategy.
|
||||
"""
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
@@ -69,10 +71,12 @@ def popen_launch_router(
|
||||
"5",
|
||||
"--router-policy",
|
||||
policy,
|
||||
"--allow-auto-truncate",
|
||||
]
|
||||
|
||||
if api_key is not None:
|
||||
command.extend(["--api-key", api_key])
|
||||
command.extend(["--router-api-key", api_key])
|
||||
|
||||
if max_payload_size is not None:
|
||||
command.extend(["--router-max-payload-size", str(max_payload_size)])
|
||||
@@ -100,6 +104,9 @@ def popen_launch_router(
|
||||
if log_dir is not None:
|
||||
command.extend(["--log-dir", log_dir])
|
||||
|
||||
if dp_aware:
|
||||
command.append("--router-dp-aware")
|
||||
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
|
||||
start_time = time.perf_counter()
|
||||
@@ -127,6 +134,7 @@ def popen_launch_server(
|
||||
model: str,
|
||||
base_url: str,
|
||||
timeout: float,
|
||||
api_key: str = None,
|
||||
):
|
||||
_, host, port = base_url.split(":")
|
||||
host = host[2:]
|
||||
@@ -145,6 +153,9 @@ def popen_launch_server(
|
||||
"1",
|
||||
]
|
||||
|
||||
if api_key is not None:
|
||||
command.extend(["--api-key", api_key])
|
||||
|
||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||
|
||||
# intentionally don't wait and defer the job to the router health check
|
||||
@@ -426,6 +437,274 @@ class TestLaunchServer(unittest.TestCase):
|
||||
response.status_code, 200, "Request with correct api key should succeed"
|
||||
)
|
||||
|
||||
def test_6_mmlu_with_dp_aware(self):
|
||||
print("Running test_6_mmlu_with_dp_aware...")
|
||||
# DP size = 2
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=2,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="cache_aware",
|
||||
dp_aware=True,
|
||||
)
|
||||
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.65
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"dp aware MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
def test_7_add_and_remove_worker_with_dp_aware(self):
|
||||
print("Running test_7_add_and_remove_worker_with_dp_aware...")
|
||||
|
||||
# Set dp_size = 1
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin", # make sure every worker processes requests
|
||||
dp_aware=True, # dp aware strategy should work well with RR
|
||||
)
|
||||
|
||||
# 1. Start a worker
|
||||
port = find_available_port()
|
||||
worker_url = f"http://127.0.0.1:{port}"
|
||||
worker_process = popen_launch_server(
|
||||
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
self.other_process.append(worker_process)
|
||||
|
||||
# 2. Use the /add_worker API to add it to the router
|
||||
# It will be used by router after it is healthy
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# 3. Run mmlu
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=64,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.65
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
# 4. Use the /remove_worker API to remove it from the router
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/remove_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# 5. Run mmlu again
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.65
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
# 6. Start another worker with api_key set
|
||||
terminate_and_wait(worker_process) # terminate the old worker process
|
||||
port = find_available_port()
|
||||
worker_url = f"http://127.0.0.1:{port}"
|
||||
worker_process = popen_launch_server(
|
||||
self.model,
|
||||
worker_url,
|
||||
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
api_key="correct_api_key",
|
||||
)
|
||||
self.other_process.append(worker_process)
|
||||
|
||||
# 7. Use the /add_worker API to add it to the router
|
||||
# Should fail since the router would contact the worker's
|
||||
# /get_server_info endpoint for the dp_size info, but it
|
||||
# has no knowledge of the api key
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertNotEqual(response.status_code, 200)
|
||||
|
||||
def test_8_lazy_fault_tolerance_with_dp_aware(self):
|
||||
print("Running test_8_lazy_fault_tolerance_with_dp_aware...")
|
||||
|
||||
# Set dp_size = 1
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
dp_aware=True,
|
||||
)
|
||||
|
||||
# 1. Start a worker
|
||||
port = find_available_port()
|
||||
worker_url = f"http://127.0.0.1:{port}"
|
||||
worker_process = popen_launch_server(
|
||||
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||
)
|
||||
self.other_process.append(worker_process)
|
||||
|
||||
# 2. Use the /add_worker API to add it to the router
|
||||
# It will be used by router after it is healthy
|
||||
with requests.Session() as session:
|
||||
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(response.status_code, 200)
|
||||
|
||||
# Start a thread to kill the worker after 10 seconds to mimic
|
||||
# abrupt worker failure
|
||||
def kill_worker():
|
||||
time.sleep(10)
|
||||
kill_process_tree(worker_process.pid)
|
||||
print("Worker process killed")
|
||||
|
||||
import threading
|
||||
|
||||
kill_thread = threading.Thread(target=kill_worker)
|
||||
kill_thread.daemon = True
|
||||
kill_thread.start()
|
||||
|
||||
# 3. Run mmlu
|
||||
args = SimpleNamespace(
|
||||
base_url=self.base_url,
|
||||
model=self.model,
|
||||
eval_name="mmlu",
|
||||
num_examples=256,
|
||||
num_threads=32,
|
||||
temperature=0.1,
|
||||
)
|
||||
metrics = run_eval(args)
|
||||
score = metrics["score"]
|
||||
THRESHOLD = 0.65
|
||||
passed = score >= THRESHOLD
|
||||
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||
|
||||
def test_9_payload_size_with_dp_aware(self):
|
||||
print("Running test_9_payload_size_with_dp_aware...")
|
||||
|
||||
# Start the router with 1MB limit
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
max_payload_size=1 * 1024 * 1024, # 1MB limit
|
||||
dp_aware=True,
|
||||
)
|
||||
|
||||
# Test case 1: Payload just under 1MB should succeed
|
||||
payload_0_5_mb = {
|
||||
"text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json=payload_0_5_mb,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
200,
|
||||
f"0.5MB payload should succeed but got status {response.status_code}",
|
||||
)
|
||||
|
||||
# Test case 2: Payload over 1MB should fail
|
||||
payload_1_plus_mb = {
|
||||
"text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json=payload_1_plus_mb,
|
||||
headers={"Content-Type": "application/json"},
|
||||
)
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
413, # Payload Too Large
|
||||
f"1.2MB payload should fail with 413 but got status {response.status_code}",
|
||||
)
|
||||
|
||||
def test_10_api_key_with_dp_aware(self):
|
||||
print("Running test_10_api_key_with_dp_aware...")
|
||||
|
||||
self.process = popen_launch_router(
|
||||
self.model,
|
||||
self.base_url,
|
||||
dp_size=1,
|
||||
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||
policy="round_robin",
|
||||
api_key="correct_api_key",
|
||||
dp_aware=True,
|
||||
)
|
||||
|
||||
# Test case 1: request without api key should fail
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is, ", "temperature": 0},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
401,
|
||||
f"Request without api key should fail with 401 but got status {response.status_code}",
|
||||
)
|
||||
|
||||
# Test case 2: request with invalid api key should fail
|
||||
with requests.Session() as session:
|
||||
response = requests.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is, ", "temperature": 0},
|
||||
headers={"Authorization": "Bearer 123"},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
401,
|
||||
f"Request without api key should fail with 401 but got status {response.status_code}",
|
||||
)
|
||||
|
||||
# Test case 3: request with correct api key should succeed
|
||||
with requests.Session() as session:
|
||||
response = session.post(
|
||||
f"{self.base_url}/generate",
|
||||
json={"text": "Kanye west is ", "temperature": 0},
|
||||
headers={"Authorization": "Bearer correct_api_key"},
|
||||
)
|
||||
print(f"status code: {response.status_code}, response: {response.text}")
|
||||
self.assertEqual(
|
||||
response.status_code,
|
||||
200,
|
||||
f"Request with correct api key should succeed but got status {response.status_code}",
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
|
||||
@@ -21,6 +21,10 @@ pub struct RouterConfig {
|
||||
pub worker_startup_timeout_secs: u64,
|
||||
/// Worker health check interval in seconds
|
||||
pub worker_startup_check_interval_secs: u64,
|
||||
/// Enable data parallelism aware schedule
|
||||
pub dp_aware: bool,
|
||||
/// The api key used for the authorization with the worker
|
||||
pub api_key: Option<String>,
|
||||
/// Service discovery configuration (optional)
|
||||
pub discovery: Option<DiscoveryConfig>,
|
||||
/// Metrics configuration (optional)
|
||||
@@ -205,6 +209,8 @@ impl Default for RouterConfig {
|
||||
request_timeout_secs: 600,
|
||||
worker_startup_timeout_secs: 300,
|
||||
worker_startup_check_interval_secs: 10,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
@@ -311,6 +317,8 @@ mod tests {
|
||||
request_timeout_secs: 30,
|
||||
worker_startup_timeout_secs: 60,
|
||||
worker_startup_check_interval_secs: 5,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: Some(DiscoveryConfig::default()),
|
||||
metrics: Some(MetricsConfig::default()),
|
||||
log_dir: Some("/var/log".to_string()),
|
||||
@@ -727,6 +735,8 @@ mod tests {
|
||||
request_timeout_secs: 120,
|
||||
worker_startup_timeout_secs: 60,
|
||||
worker_startup_check_interval_secs: 5,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: Some(DiscoveryConfig {
|
||||
enabled: true,
|
||||
namespace: Some("sglang".to_string()),
|
||||
@@ -774,6 +784,8 @@ mod tests {
|
||||
request_timeout_secs: 300,
|
||||
worker_startup_timeout_secs: 180,
|
||||
worker_startup_check_interval_secs: 15,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: Some(DiscoveryConfig {
|
||||
enabled: true,
|
||||
namespace: None,
|
||||
@@ -812,6 +824,8 @@ mod tests {
|
||||
request_timeout_secs: 900,
|
||||
worker_startup_timeout_secs: 600,
|
||||
worker_startup_check_interval_secs: 20,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: Some(DiscoveryConfig {
|
||||
enabled: true,
|
||||
namespace: Some("production".to_string()),
|
||||
|
||||
@@ -313,6 +313,14 @@ impl ConfigValidator {
|
||||
}
|
||||
}
|
||||
|
||||
// Service discovery is conflict with dp_aware routing for now
|
||||
// since it's not fully supported yet
|
||||
if has_service_discovery && config.dp_aware {
|
||||
return Err(ConfigError::IncompatibleConfig {
|
||||
reason: "DP-aware routing is not compatible with service discovery".to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
||||
@@ -17,6 +17,8 @@ pub enum WorkerError {
|
||||
NetworkError { url: String, error: String },
|
||||
/// Worker is at capacity
|
||||
WorkerAtCapacity { url: String },
|
||||
/// Invalid URL format
|
||||
InvalidUrl { url: String },
|
||||
}
|
||||
|
||||
impl fmt::Display for WorkerError {
|
||||
@@ -37,6 +39,9 @@ impl fmt::Display for WorkerError {
|
||||
WorkerError::WorkerAtCapacity { url } => {
|
||||
write!(f, "Worker at capacity: {}", url)
|
||||
}
|
||||
WorkerError::InvalidUrl { url } => {
|
||||
write!(f, "Invalid URL format: {}", url)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -162,6 +162,27 @@ impl BasicWorker {
|
||||
self.metadata.health_config = config;
|
||||
self
|
||||
}
|
||||
|
||||
pub fn normalised_url(&self) -> WorkerResult<&str> {
|
||||
if self.url().contains("@") {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let parts: Vec<&str> = self.url().split('@').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(WorkerError::InvalidUrl {
|
||||
url: self.url().to_string(),
|
||||
});
|
||||
}
|
||||
// Ensure the second part (the dp_rank) can be parsed as an integer
|
||||
match parts[1].parse::<usize>() {
|
||||
Ok(_) => Ok(parts[0]),
|
||||
Err(_) => Err(WorkerError::InvalidUrl {
|
||||
url: self.url().to_string(),
|
||||
}),
|
||||
}
|
||||
} else {
|
||||
Ok(self.url())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -186,7 +207,8 @@ impl Worker for BasicWorker {
|
||||
use std::time::Duration;
|
||||
|
||||
// Perform actual HTTP health check
|
||||
let health_url = format!("{}{}", self.url(), self.metadata.health_config.endpoint);
|
||||
let url = self.normalised_url()?;
|
||||
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
|
||||
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
|
||||
|
||||
// Use the shared client with a custom timeout for this request
|
||||
@@ -203,7 +225,7 @@ impl Worker for BasicWorker {
|
||||
} else {
|
||||
self.set_healthy(false);
|
||||
Err(WorkerError::HealthCheckFailed {
|
||||
url: self.url().to_string(),
|
||||
url: url.to_string(),
|
||||
reason: format!("Health check returned status: {}", response.status()),
|
||||
})
|
||||
}
|
||||
@@ -211,7 +233,7 @@ impl Worker for BasicWorker {
|
||||
Err(e) => {
|
||||
self.set_healthy(false);
|
||||
Err(WorkerError::HealthCheckFailed {
|
||||
url: self.url().to_string(),
|
||||
url: url.to_string(),
|
||||
reason: format!("Health check request failed: {}", e),
|
||||
})
|
||||
}
|
||||
|
||||
@@ -37,6 +37,8 @@ struct Router {
|
||||
eviction_interval_secs: u64,
|
||||
max_tree_size: usize,
|
||||
max_payload_size: usize,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
log_dir: Option<String>,
|
||||
log_level: Option<String>,
|
||||
service_discovery: bool,
|
||||
@@ -136,6 +138,8 @@ impl Router {
|
||||
request_timeout_secs: self.request_timeout_secs,
|
||||
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
|
||||
worker_startup_check_interval_secs: self.worker_startup_check_interval,
|
||||
dp_aware: self.dp_aware,
|
||||
api_key: self.api_key.clone(),
|
||||
discovery,
|
||||
metrics,
|
||||
log_dir: self.log_dir.clone(),
|
||||
@@ -161,6 +165,8 @@ impl Router {
|
||||
eviction_interval_secs = 60,
|
||||
max_tree_size = 2usize.pow(24),
|
||||
max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
|
||||
dp_aware = false,
|
||||
api_key = None,
|
||||
log_dir = None,
|
||||
log_level = None,
|
||||
service_discovery = false,
|
||||
@@ -193,6 +199,8 @@ impl Router {
|
||||
eviction_interval_secs: u64,
|
||||
max_tree_size: usize,
|
||||
max_payload_size: usize,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
log_dir: Option<String>,
|
||||
log_level: Option<String>,
|
||||
service_discovery: bool,
|
||||
@@ -225,6 +233,8 @@ impl Router {
|
||||
eviction_interval_secs,
|
||||
max_tree_size,
|
||||
max_payload_size,
|
||||
dp_aware,
|
||||
api_key,
|
||||
log_dir,
|
||||
log_level,
|
||||
service_discovery,
|
||||
|
||||
@@ -45,6 +45,8 @@ impl RouterFactory {
|
||||
policy,
|
||||
router_config.worker_startup_timeout_secs,
|
||||
router_config.worker_startup_check_interval_secs,
|
||||
router_config.dp_aware,
|
||||
router_config.api_key.clone(),
|
||||
)?;
|
||||
|
||||
Ok(Box::new(router))
|
||||
|
||||
@@ -30,6 +30,8 @@ pub struct Router {
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||
_health_checker: Option<HealthChecker>,
|
||||
@@ -42,6 +44,8 @@ impl Router {
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
api_key: Option<String>,
|
||||
) -> Result<Self, String> {
|
||||
// Update active workers gauge
|
||||
RouterMetrics::set_active_workers(worker_urls.len());
|
||||
@@ -51,6 +55,14 @@ impl Router {
|
||||
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
|
||||
}
|
||||
|
||||
let worker_urls = if dp_aware {
|
||||
// worker address now in the format of "http://host:port@dp_rank"
|
||||
Self::get_dp_aware_workers(&worker_urls, &api_key)
|
||||
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?
|
||||
} else {
|
||||
worker_urls
|
||||
};
|
||||
|
||||
// Create Worker trait objects from URLs
|
||||
let workers: Vec<Box<dyn Worker>> = worker_urls
|
||||
.iter()
|
||||
@@ -89,6 +101,8 @@ impl Router {
|
||||
policy,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
dp_aware,
|
||||
api_key,
|
||||
_worker_loads: worker_loads,
|
||||
_load_monitor_handle: load_monitor_handle,
|
||||
_health_checker: Some(health_checker),
|
||||
@@ -160,6 +174,62 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
fn get_worker_dp_size(worker_url: &str, api_key: &Option<String>) -> Result<usize, String> {
|
||||
let sync_client = reqwest::blocking::Client::new();
|
||||
let mut req_builder = sync_client.get(&format!("{}/get_server_info", worker_url));
|
||||
if let Some(key) = api_key {
|
||||
req_builder = req_builder.bearer_auth(key);
|
||||
}
|
||||
|
||||
match req_builder.send() {
|
||||
Ok(res) => {
|
||||
if res.status().is_success() {
|
||||
let server_info = res
|
||||
.text()
|
||||
.map_err(|e| format!("failed to read text from response: {}", e))?;
|
||||
|
||||
let server_info: serde_json::Value = serde_json::from_str(&server_info)
|
||||
.map_err(|e| format!("failed to decode JSON: {}", e))?;
|
||||
|
||||
let dp_size = server_info
|
||||
.get("dp_size")
|
||||
.and_then(|v| v.as_u64())
|
||||
.ok_or_else(|| String::from("dp_size not found or not an u64"))?;
|
||||
|
||||
Ok(if dp_size > usize::MAX as u64 {
|
||||
return Err(format!("dp_size is too large: {}", dp_size));
|
||||
} else {
|
||||
dp_size as usize
|
||||
})
|
||||
} else {
|
||||
Err(format!("unexpected status code: {}", res.status()))
|
||||
}
|
||||
}
|
||||
Err(e) => Err(format!("error response: {}", e)),
|
||||
}
|
||||
}
|
||||
|
||||
// Given a list of workers, return a list of workers with dp_rank as suffix
|
||||
fn get_dp_aware_workers(
|
||||
worker_urls: &[String],
|
||||
api_key: &Option<String>,
|
||||
) -> Result<Vec<String>, String> {
|
||||
let mut dp_aware_workers: Vec<String> = Vec::new();
|
||||
|
||||
for url in worker_urls {
|
||||
match Self::get_worker_dp_size(url, api_key) {
|
||||
Ok(dp_size) => {
|
||||
for i in 0..dp_size {
|
||||
dp_aware_workers.push(format!("{}@{}", url, i));
|
||||
}
|
||||
}
|
||||
Err(e) => return Err(format!("Failed to get DP size for {}: {}", url, e)),
|
||||
}
|
||||
}
|
||||
|
||||
Ok(dp_aware_workers)
|
||||
}
|
||||
|
||||
fn select_first_worker(&self) -> Result<String, String> {
|
||||
let workers_guard = self.workers.read().unwrap();
|
||||
if workers_guard.is_empty() {
|
||||
@@ -178,6 +248,21 @@ impl Router {
|
||||
) -> HttpResponse {
|
||||
let request_id = get_request_id(req);
|
||||
let start = Instant::now();
|
||||
|
||||
let worker_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
error!("Failed to extract dp_rank: {}", e);
|
||||
return HttpResponse::InternalServerError().finish();
|
||||
}
|
||||
};
|
||||
worker_url_prefix
|
||||
} else {
|
||||
worker_url
|
||||
};
|
||||
|
||||
let mut request_builder = client.get(format!("{}{}", worker_url, route));
|
||||
|
||||
// Copy all headers from original request except for /health because it does not need authorization
|
||||
@@ -292,7 +377,7 @@ impl Router {
|
||||
worker_url = %worker_url,
|
||||
"Removing failed worker"
|
||||
);
|
||||
self.remove_worker(&worker_url);
|
||||
self.remove_failed_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -392,7 +477,7 @@ impl Router {
|
||||
request_id = %request_id,
|
||||
"Removing failed worker after typed request failures worker_url={}", worker_url
|
||||
);
|
||||
self.remove_worker(&worker_url);
|
||||
self.remove_failed_worker(&worker_url);
|
||||
break;
|
||||
}
|
||||
}
|
||||
@@ -415,6 +500,23 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
// TODO (rui): Better accommodate to the Worker abstraction
|
||||
fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
|
||||
let parts: Vec<&str> = worker_url.split('@').collect();
|
||||
if parts.len() != 2 {
|
||||
return Err(format!("invalid worker_url format: {}", worker_url));
|
||||
}
|
||||
|
||||
// Parse the second part (dp_rank) into an integer
|
||||
match parts[1].parse::<usize>() {
|
||||
Ok(dp_rank) => Ok((parts[0], dp_rank)),
|
||||
Err(_) => Err(format!(
|
||||
"failed to parse dp_rank from worker_url: {}",
|
||||
worker_url
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
// Send typed request directly without conversion
|
||||
async fn send_typed_request<T: serde::Serialize>(
|
||||
&self,
|
||||
@@ -429,9 +531,47 @@ impl Router {
|
||||
let request_id = get_request_id(req);
|
||||
let start = Instant::now();
|
||||
|
||||
let mut request_builder = client
|
||||
.post(format!("{}{}", worker_url, route))
|
||||
.json(typed_req); // Use json() directly with typed request
|
||||
let mut request_builder = if self.dp_aware {
|
||||
let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
error!("Failed to extract dp_rank: {}", e);
|
||||
return HttpResponse::InternalServerError().finish();
|
||||
}
|
||||
};
|
||||
|
||||
// Parse the request body
|
||||
let mut json_val = match serde_json::to_value(typed_req) {
|
||||
Ok(j) => j,
|
||||
Err(e) => {
|
||||
return HttpResponse::BadRequest()
|
||||
.body(format!("Convert into serde_json::Value failed: {}", e));
|
||||
}
|
||||
};
|
||||
|
||||
// Insert the data_parallel_rank field
|
||||
if let Some(map) = json_val.as_object_mut() {
|
||||
map.insert(
|
||||
String::from("data_parallel_rank"),
|
||||
serde_json::json!(dp_rank),
|
||||
);
|
||||
debug!(
|
||||
"Modified request body: {}",
|
||||
serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
|
||||
);
|
||||
} else {
|
||||
return HttpResponse::BadRequest()
|
||||
.body("Failed to insert the data_parallel_rank field into the request body");
|
||||
}
|
||||
|
||||
client
|
||||
.post(format!("{}{}", worker_url_prefix, route))
|
||||
.json(&json_val)
|
||||
} else {
|
||||
client
|
||||
.post(format!("{}{}", worker_url, route))
|
||||
.json(typed_req) // Use json() directly with typed request
|
||||
};
|
||||
|
||||
// Copy all headers from original request
|
||||
for (name, value) in copy_request_headers(req) {
|
||||
@@ -560,12 +700,35 @@ impl Router {
|
||||
Ok(res) => {
|
||||
if res.status().is_success() {
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
if workers_guard.iter().any(|w| w.url() == worker_url) {
|
||||
return Err(format!("Worker {} already exists", worker_url));
|
||||
if self.dp_aware {
|
||||
// Need to contact the worker to extract the dp_size,
|
||||
// and add them as multiple workers
|
||||
let url_vec = vec![String::from(worker_url)];
|
||||
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, &self.api_key)
|
||||
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
|
||||
let mut worker_added: bool = false;
|
||||
for dp_url in &dp_url_vec {
|
||||
if workers_guard.iter().any(|w| w.url() == dp_url) {
|
||||
warn!("Worker {} already exists", dp_url);
|
||||
continue;
|
||||
}
|
||||
info!("Added worker: {}", dp_url);
|
||||
let new_worker = WorkerFactory::create_regular(dp_url.to_string());
|
||||
workers_guard.push(new_worker);
|
||||
worker_added = true;
|
||||
}
|
||||
if !worker_added {
|
||||
return Err(format!("No worker added for {}", worker_url));
|
||||
}
|
||||
} else {
|
||||
if workers_guard.iter().any(|w| w.url() == worker_url) {
|
||||
return Err(format!("Worker {} already exists", worker_url));
|
||||
}
|
||||
info!("Added worker: {}", worker_url);
|
||||
let new_worker = WorkerFactory::create_regular(worker_url.to_string());
|
||||
workers_guard.push(new_worker);
|
||||
}
|
||||
info!("Added worker: {}", worker_url);
|
||||
let new_worker = WorkerFactory::create_regular(worker_url.to_string());
|
||||
workers_guard.push(new_worker);
|
||||
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
|
||||
// If cache aware policy, initialize the worker in the tree
|
||||
@@ -612,11 +775,81 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove all the worker(s) that match the URL prefix
|
||||
pub fn remove_worker(&self, worker_url: &str) {
|
||||
if self.dp_aware {
|
||||
// remove dp-aware workers in a prefix-matching fashion
|
||||
// without contacting the remote worker
|
||||
let mut candidate_workers: Vec<String> = Vec::new();
|
||||
let mut removed_workers: Vec<String> = Vec::new();
|
||||
let worker_url_prefix = format!("{}@", worker_url);
|
||||
|
||||
{
|
||||
// find the candidate workers to be removed
|
||||
let workers_guard = self.workers.read().unwrap();
|
||||
for w in workers_guard.iter() {
|
||||
if w.url().starts_with(&worker_url_prefix) {
|
||||
candidate_workers.push(w.url().to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// do the removing on the worker_urls
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
for dp_url in candidate_workers.iter() {
|
||||
if let Some(index) = workers_guard.iter().position(|w| w.url() == dp_url) {
|
||||
workers_guard.remove(index);
|
||||
info!("Removed worker: {}", dp_url);
|
||||
removed_workers.push(dp_url.to_string());
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", dp_url);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
}
|
||||
|
||||
// If cache aware policy, remove the workers from the tree
|
||||
if let Some(cache_aware) = self
|
||||
.policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
for dp_url in removed_workers.iter() {
|
||||
cache_aware.remove_worker(dp_url);
|
||||
info!("Removed worker from tree: {}", dp_url);
|
||||
}
|
||||
}
|
||||
} else {
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
|
||||
workers_guard.remove(index);
|
||||
info!("Removed worker: {}", worker_url);
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", worker_url);
|
||||
return;
|
||||
}
|
||||
|
||||
// If cache aware policy, remove the workers from the tree
|
||||
if let Some(cache_aware) = self
|
||||
.policy
|
||||
.as_any()
|
||||
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||
{
|
||||
cache_aware.remove_worker(worker_url);
|
||||
info!("Removed worker from tree: {}", worker_url);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove a specific failed worker; for internal usage
|
||||
fn remove_failed_worker(&self, worker_url: &str) {
|
||||
let mut workers_guard = self.workers.write().unwrap();
|
||||
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
|
||||
workers_guard.remove(index);
|
||||
info!("Removed worker: {}", worker_url);
|
||||
info!("Removed failed worker: {}", worker_url);
|
||||
RouterMetrics::set_active_workers(workers_guard.len());
|
||||
} else {
|
||||
warn!("Worker {} not found, skipping removal", worker_url);
|
||||
@@ -634,6 +867,20 @@ impl Router {
|
||||
}
|
||||
|
||||
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||
let worker_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
error!("Failed to extract dp_rank: {}", e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
worker_url_prefix
|
||||
} else {
|
||||
worker_url
|
||||
};
|
||||
|
||||
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
@@ -710,6 +957,20 @@ impl Router {
|
||||
|
||||
// Static version of get_worker_load for use in monitoring task
|
||||
async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||
let worker_url = if worker_url.contains("@") {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
debug!("Failed to extract dp_rank: {}", e);
|
||||
return None;
|
||||
}
|
||||
};
|
||||
worker_url_prefix
|
||||
} else {
|
||||
worker_url
|
||||
};
|
||||
|
||||
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
@@ -862,6 +1123,19 @@ impl RouterTrait for Router {
|
||||
// Send requests to all workers concurrently without headers
|
||||
let mut tasks = Vec::new();
|
||||
for worker_url in &worker_urls {
|
||||
let worker_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
Ok(tup) => tup,
|
||||
Err(e) => {
|
||||
error!("Failed to extract dp_rank: {}", e);
|
||||
return HttpResponse::InternalServerError().finish();
|
||||
}
|
||||
};
|
||||
worker_url_prefix
|
||||
} else {
|
||||
worker_url
|
||||
};
|
||||
let request_builder = client.post(format!("{}/flush_cache", worker_url));
|
||||
tasks.push(request_builder.send());
|
||||
}
|
||||
@@ -948,6 +1222,8 @@ mod tests {
|
||||
policy: Arc::new(RandomPolicy::new()),
|
||||
timeout_secs: 5,
|
||||
interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
_worker_loads: Arc::new(rx),
|
||||
_load_monitor_handle: None,
|
||||
_health_checker: None,
|
||||
|
||||
@@ -581,7 +581,7 @@ mod tests {
|
||||
use crate::routers::router::Router;
|
||||
|
||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
||||
let router = Router::new(vec![], policy, 5, 1).unwrap();
|
||||
let router = Router::new(vec![], policy, 5, 1, false, None).unwrap();
|
||||
Arc::new(router) as Arc<dyn RouterTrait>
|
||||
}
|
||||
|
||||
|
||||
@@ -31,6 +31,8 @@ impl TestContext {
|
||||
request_timeout_secs: 600,
|
||||
worker_startup_timeout_secs: 1,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
@@ -950,6 +952,8 @@ mod error_tests {
|
||||
request_timeout_secs: 600,
|
||||
worker_startup_timeout_secs: 1,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
|
||||
@@ -16,6 +16,8 @@ pub fn create_test_config(worker_urls: Vec<String>) -> RouterConfig {
|
||||
request_timeout_secs: 600,
|
||||
worker_startup_timeout_secs: 300,
|
||||
worker_startup_check_interval_secs: 10,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
@@ -37,6 +39,8 @@ pub fn create_test_config_no_workers() -> RouterConfig {
|
||||
request_timeout_secs: 600,
|
||||
worker_startup_timeout_secs: 0, // No wait
|
||||
worker_startup_check_interval_secs: 10,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
|
||||
@@ -42,6 +42,8 @@ impl RequestTestContext {
|
||||
request_timeout_secs: 600,
|
||||
worker_startup_timeout_secs: 1,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
|
||||
@@ -46,6 +46,8 @@ impl StreamingTestContext {
|
||||
request_timeout_secs: 600,
|
||||
worker_startup_timeout_secs: 1,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
|
||||
@@ -169,6 +169,8 @@ mod test_pd_routing {
|
||||
request_timeout_secs: 60,
|
||||
worker_startup_timeout_secs: 10,
|
||||
worker_startup_check_interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
discovery: None,
|
||||
metrics: None,
|
||||
log_dir: None,
|
||||
|
||||
Reference in New Issue
Block a user