[feature] [sgl-router] Add a dp-aware routing strategy (#6869)
This commit is contained in:
@@ -141,6 +141,14 @@ Process:
|
|||||||
|
|
||||||
For unbalanced systems, this strategy tracks pending request counts per worker and routes new requests to the least busy worker. This helps maintain optimal load distribution across workers.
|
For unbalanced systems, this strategy tracks pending request counts per worker and routes new requests to the least busy worker. This helps maintain optimal load distribution across workers.
|
||||||
|
|
||||||
|
***Data-Parallelism Aware Routing***
|
||||||
|
|
||||||
|
An additional DP-aware routing strategy can be enabled on top of the sgl-router’s hybrid cache-aware load-balancing strategy by setting the `--dp-aware` flag when starting the router.
|
||||||
|
|
||||||
|
When this flag is enabled, the router attempts to contact the workers to retrieve the `dp_size` of each one and registers the new workers at the DP-rank level. In this mode, the router applies the cache-aware routing strategy in a more fine-grained manner, with assistance from the DP controller on the SRT side.
|
||||||
|
|
||||||
|
By default (when the flag is not set), the SRT’s DP controller distributes incoming requests across DP ranks in a round-robin fashion.
|
||||||
|
|
||||||
## Configuration Parameters
|
## Configuration Parameters
|
||||||
|
|
||||||
1. `cache_threshold`: (float, 0.0 to 1.0, default: 0.5)
|
1. `cache_threshold`: (float, 0.0 to 1.0, default: 0.5)
|
||||||
|
|||||||
@@ -50,6 +50,8 @@ class RouterArgs:
|
|||||||
eviction_interval: int = 60
|
eviction_interval: int = 60
|
||||||
max_tree_size: int = 2**24
|
max_tree_size: int = 2**24
|
||||||
max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches
|
max_payload_size: int = 256 * 1024 * 1024 # 256MB default for large batches
|
||||||
|
dp_aware: bool = False
|
||||||
|
api_key: Optional[str] = None
|
||||||
log_dir: Optional[str] = None
|
log_dir: Optional[str] = None
|
||||||
log_level: Optional[str] = None
|
log_level: Optional[str] = None
|
||||||
# Service discovery configuration
|
# Service discovery configuration
|
||||||
@@ -197,6 +199,17 @@ class RouterArgs:
|
|||||||
default=RouterArgs.max_payload_size,
|
default=RouterArgs.max_payload_size,
|
||||||
help="Maximum payload size in bytes",
|
help="Maximum payload size in bytes",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}dp-aware",
|
||||||
|
action="store_true",
|
||||||
|
help="Enable data parallelism aware schedule",
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
f"--{prefix}api-key",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="The api key used for the authorization with the worker. Useful when the dp aware scheduling strategy is enaled.",
|
||||||
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
f"--{prefix}log-dir",
|
f"--{prefix}log-dir",
|
||||||
type=str,
|
type=str,
|
||||||
@@ -304,6 +317,8 @@ class RouterArgs:
|
|||||||
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
|
eviction_interval=getattr(args, f"{prefix}eviction_interval"),
|
||||||
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
|
max_tree_size=getattr(args, f"{prefix}max_tree_size"),
|
||||||
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
|
max_payload_size=getattr(args, f"{prefix}max_payload_size"),
|
||||||
|
dp_aware=getattr(args, f"{prefix}dp_aware", False),
|
||||||
|
api_key=getattr(args, f"{prefix}api_key", None),
|
||||||
log_dir=getattr(args, f"{prefix}log_dir", None),
|
log_dir=getattr(args, f"{prefix}log_dir", None),
|
||||||
log_level=getattr(args, f"{prefix}log_level", None),
|
log_level=getattr(args, f"{prefix}log_level", None),
|
||||||
service_discovery=getattr(args, f"{prefix}service_discovery", False),
|
service_discovery=getattr(args, f"{prefix}service_discovery", False),
|
||||||
@@ -463,6 +478,8 @@ def launch_router(args: argparse.Namespace) -> Optional[Router]:
|
|||||||
eviction_interval_secs=router_args.eviction_interval,
|
eviction_interval_secs=router_args.eviction_interval,
|
||||||
max_tree_size=router_args.max_tree_size,
|
max_tree_size=router_args.max_tree_size,
|
||||||
max_payload_size=router_args.max_payload_size,
|
max_payload_size=router_args.max_payload_size,
|
||||||
|
dp_aware=router_args.dp_aware,
|
||||||
|
api_key=router_args.api_key,
|
||||||
log_dir=router_args.log_dir,
|
log_dir=router_args.log_dir,
|
||||||
log_level=router_args.log_level,
|
log_level=router_args.log_level,
|
||||||
service_discovery=router_args.service_discovery,
|
service_discovery=router_args.service_discovery,
|
||||||
|
|||||||
@@ -31,6 +31,10 @@ class Router:
|
|||||||
routing. Default: 60
|
routing. Default: 60
|
||||||
max_payload_size: Maximum payload size in bytes. Default: 256MB
|
max_payload_size: Maximum payload size in bytes. Default: 256MB
|
||||||
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
|
max_tree_size: Maximum size of the approximation tree for cache-aware routing. Default: 2^24
|
||||||
|
dp_aware: Enable data parallelism aware schedule. Default: False
|
||||||
|
api_key: The api key used for the authorization with the worker.
|
||||||
|
Useful when the dp aware scheduling strategy is enabled.
|
||||||
|
Default: None
|
||||||
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
|
log_dir: Directory to store log files. If None, logs are only output to console. Default: None
|
||||||
log_level: Logging level. Options: 'debug', 'info', 'warning', 'error', 'critical'.
|
log_level: Logging level. Options: 'debug', 'info', 'warning', 'error', 'critical'.
|
||||||
service_discovery: Enable Kubernetes service discovery. When enabled, the router will
|
service_discovery: Enable Kubernetes service discovery. When enabled, the router will
|
||||||
@@ -73,6 +77,8 @@ class Router:
|
|||||||
eviction_interval_secs: int = 60,
|
eviction_interval_secs: int = 60,
|
||||||
max_tree_size: int = 2**24,
|
max_tree_size: int = 2**24,
|
||||||
max_payload_size: int = 256 * 1024 * 1024, # 256MB
|
max_payload_size: int = 256 * 1024 * 1024, # 256MB
|
||||||
|
dp_aware: bool = False,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
log_dir: Optional[str] = None,
|
log_dir: Optional[str] = None,
|
||||||
log_level: Optional[str] = None,
|
log_level: Optional[str] = None,
|
||||||
service_discovery: bool = False,
|
service_discovery: bool = False,
|
||||||
@@ -110,6 +116,8 @@ class Router:
|
|||||||
eviction_interval_secs=eviction_interval_secs,
|
eviction_interval_secs=eviction_interval_secs,
|
||||||
max_tree_size=max_tree_size,
|
max_tree_size=max_tree_size,
|
||||||
max_payload_size=max_payload_size,
|
max_payload_size=max_payload_size,
|
||||||
|
dp_aware=dp_aware,
|
||||||
|
api_key=api_key,
|
||||||
log_dir=log_dir,
|
log_dir=log_dir,
|
||||||
log_level=log_level,
|
log_level=log_level,
|
||||||
service_discovery=service_discovery,
|
service_discovery=service_discovery,
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ if __name__ == "__main__":
|
|||||||
arg_parser.add_argument(
|
arg_parser.add_argument(
|
||||||
"--timeout-per-file",
|
"--timeout-per-file",
|
||||||
type=int,
|
type=int,
|
||||||
default=1000,
|
default=2000,
|
||||||
help="The time limit for running one file in seconds.",
|
help="The time limit for running one file in seconds.",
|
||||||
)
|
)
|
||||||
args = arg_parser.parse_args()
|
args = arg_parser.parse_args()
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
selector=None,
|
selector=None,
|
||||||
service_discovery_port=80,
|
service_discovery_port=80,
|
||||||
service_discovery_namespace=None,
|
service_discovery_namespace=None,
|
||||||
|
dp_aware=False,
|
||||||
prometheus_port=None,
|
prometheus_port=None,
|
||||||
prometheus_host=None,
|
prometheus_host=None,
|
||||||
# PD-specific attributes
|
# PD-specific attributes
|
||||||
@@ -111,6 +112,52 @@ class TestLaunchRouter(unittest.TestCase):
|
|||||||
)
|
)
|
||||||
self.run_router_process(args)
|
self.run_router_process(args)
|
||||||
|
|
||||||
|
def test_launch_router_common_with_dp_aware(self):
|
||||||
|
args = self.create_router_args(
|
||||||
|
worker_urls=["http://localhost:8000"],
|
||||||
|
dp_aware=True,
|
||||||
|
)
|
||||||
|
self.run_router_process(args)
|
||||||
|
|
||||||
|
def test_launch_router_with_empty_worker_urls_with_dp_aware(self):
|
||||||
|
args = self.create_router_args(
|
||||||
|
worker_urls=[],
|
||||||
|
dp_aware=True,
|
||||||
|
)
|
||||||
|
self.run_router_process(args)
|
||||||
|
|
||||||
|
def test_launch_router_common_with_dp_aware_service_discovery(self):
|
||||||
|
# Test launch router with bot srevice_discovery and dp_aware enabled
|
||||||
|
# Should fail since service_discovery and dp_aware is conflict
|
||||||
|
args = self.create_router_args(
|
||||||
|
worker_urls=["http://localhost:8000"],
|
||||||
|
dp_aware=True,
|
||||||
|
service_discovery=True,
|
||||||
|
selector=["app=test-worker"],
|
||||||
|
)
|
||||||
|
|
||||||
|
def run_router():
|
||||||
|
try:
|
||||||
|
from sglang_router.launch_router import launch_router
|
||||||
|
|
||||||
|
router = launch_router(args)
|
||||||
|
if router is None:
|
||||||
|
return 1
|
||||||
|
return 0
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
return 1
|
||||||
|
|
||||||
|
process = multiprocessing.Process(target=run_router)
|
||||||
|
try:
|
||||||
|
process.start()
|
||||||
|
# Wait 3 seconds
|
||||||
|
time.sleep(3)
|
||||||
|
# Should fail since service_discovery and dp_aware is conflict
|
||||||
|
self.assertFalse(process.is_alive())
|
||||||
|
finally:
|
||||||
|
terminate_process(process)
|
||||||
|
|
||||||
def test_launch_router_pd_mode_basic(self):
|
def test_launch_router_pd_mode_basic(self):
|
||||||
"""Test basic PD router functionality without actually starting servers."""
|
"""Test basic PD router functionality without actually starting servers."""
|
||||||
# This test just verifies the PD router can be created and configured
|
# This test just verifies the PD router can be created and configured
|
||||||
|
|||||||
@@ -30,6 +30,7 @@ def popen_launch_router(
|
|||||||
service_discovery_namespace: str = None,
|
service_discovery_namespace: str = None,
|
||||||
prometheus_port: int = None,
|
prometheus_port: int = None,
|
||||||
prometheus_host: str = None,
|
prometheus_host: str = None,
|
||||||
|
dp_aware: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Launch the router server process.
|
Launch the router server process.
|
||||||
@@ -49,6 +50,7 @@ def popen_launch_router(
|
|||||||
service_discovery_namespace: Kubernetes namespace to watch for pods. If None, watches all namespaces.
|
service_discovery_namespace: Kubernetes namespace to watch for pods. If None, watches all namespaces.
|
||||||
prometheus_port: Port to expose Prometheus metrics. If None, Prometheus metrics are disabled.
|
prometheus_port: Port to expose Prometheus metrics. If None, Prometheus metrics are disabled.
|
||||||
prometheus_host: Host address to bind the Prometheus metrics server.
|
prometheus_host: Host address to bind the Prometheus metrics server.
|
||||||
|
dp_aware: Enable data parallelism aware routing strategy.
|
||||||
"""
|
"""
|
||||||
_, host, port = base_url.split(":")
|
_, host, port = base_url.split(":")
|
||||||
host = host[2:]
|
host = host[2:]
|
||||||
@@ -69,10 +71,12 @@ def popen_launch_router(
|
|||||||
"5",
|
"5",
|
||||||
"--router-policy",
|
"--router-policy",
|
||||||
policy,
|
policy,
|
||||||
|
"--allow-auto-truncate",
|
||||||
]
|
]
|
||||||
|
|
||||||
if api_key is not None:
|
if api_key is not None:
|
||||||
command.extend(["--api-key", api_key])
|
command.extend(["--api-key", api_key])
|
||||||
|
command.extend(["--router-api-key", api_key])
|
||||||
|
|
||||||
if max_payload_size is not None:
|
if max_payload_size is not None:
|
||||||
command.extend(["--router-max-payload-size", str(max_payload_size)])
|
command.extend(["--router-max-payload-size", str(max_payload_size)])
|
||||||
@@ -100,6 +104,9 @@ def popen_launch_router(
|
|||||||
if log_dir is not None:
|
if log_dir is not None:
|
||||||
command.extend(["--log-dir", log_dir])
|
command.extend(["--log-dir", log_dir])
|
||||||
|
|
||||||
|
if dp_aware:
|
||||||
|
command.append("--router-dp-aware")
|
||||||
|
|
||||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||||
|
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
@@ -127,6 +134,7 @@ def popen_launch_server(
|
|||||||
model: str,
|
model: str,
|
||||||
base_url: str,
|
base_url: str,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
api_key: str = None,
|
||||||
):
|
):
|
||||||
_, host, port = base_url.split(":")
|
_, host, port = base_url.split(":")
|
||||||
host = host[2:]
|
host = host[2:]
|
||||||
@@ -145,6 +153,9 @@ def popen_launch_server(
|
|||||||
"1",
|
"1",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if api_key is not None:
|
||||||
|
command.extend(["--api-key", api_key])
|
||||||
|
|
||||||
process = subprocess.Popen(command, stdout=None, stderr=None)
|
process = subprocess.Popen(command, stdout=None, stderr=None)
|
||||||
|
|
||||||
# intentionally don't wait and defer the job to the router health check
|
# intentionally don't wait and defer the job to the router health check
|
||||||
@@ -426,6 +437,274 @@ class TestLaunchServer(unittest.TestCase):
|
|||||||
response.status_code, 200, "Request with correct api key should succeed"
|
response.status_code, 200, "Request with correct api key should succeed"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_6_mmlu_with_dp_aware(self):
|
||||||
|
print("Running test_6_mmlu_with_dp_aware...")
|
||||||
|
# DP size = 2
|
||||||
|
self.process = popen_launch_router(
|
||||||
|
self.model,
|
||||||
|
self.base_url,
|
||||||
|
dp_size=2,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
policy="cache_aware",
|
||||||
|
dp_aware=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
|
||||||
|
metrics = run_eval(args)
|
||||||
|
score = metrics["score"]
|
||||||
|
THRESHOLD = 0.65
|
||||||
|
passed = score >= THRESHOLD
|
||||||
|
msg = f"dp aware MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||||
|
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||||
|
|
||||||
|
def test_7_add_and_remove_worker_with_dp_aware(self):
|
||||||
|
print("Running test_7_add_and_remove_worker_with_dp_aware...")
|
||||||
|
|
||||||
|
# Set dp_size = 1
|
||||||
|
self.process = popen_launch_router(
|
||||||
|
self.model,
|
||||||
|
self.base_url,
|
||||||
|
dp_size=1,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
policy="round_robin", # make sure every worker processes requests
|
||||||
|
dp_aware=True, # dp aware strategy should work well with RR
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Start a worker
|
||||||
|
port = find_available_port()
|
||||||
|
worker_url = f"http://127.0.0.1:{port}"
|
||||||
|
worker_process = popen_launch_server(
|
||||||
|
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||||
|
)
|
||||||
|
self.other_process.append(worker_process)
|
||||||
|
|
||||||
|
# 2. Use the /add_worker API to add it to the router
|
||||||
|
# It will be used by router after it is healthy
|
||||||
|
with requests.Session() as session:
|
||||||
|
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||||
|
print(f"status code: {response.status_code}, response: {response.text}")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
# 3. Run mmlu
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=64,
|
||||||
|
num_threads=32,
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
score = metrics["score"]
|
||||||
|
THRESHOLD = 0.65
|
||||||
|
passed = score >= THRESHOLD
|
||||||
|
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||||
|
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||||
|
|
||||||
|
# 4. Use the /remove_worker API to remove it from the router
|
||||||
|
with requests.Session() as session:
|
||||||
|
response = session.post(f"{self.base_url}/remove_worker?url={worker_url}")
|
||||||
|
print(f"status code: {response.status_code}, response: {response.text}")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
# 5. Run mmlu again
|
||||||
|
metrics = run_eval(args)
|
||||||
|
score = metrics["score"]
|
||||||
|
THRESHOLD = 0.65
|
||||||
|
passed = score >= THRESHOLD
|
||||||
|
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||||
|
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||||
|
|
||||||
|
# 6. Start another worker with api_key set
|
||||||
|
terminate_and_wait(worker_process) # terminate the old worker process
|
||||||
|
port = find_available_port()
|
||||||
|
worker_url = f"http://127.0.0.1:{port}"
|
||||||
|
worker_process = popen_launch_server(
|
||||||
|
self.model,
|
||||||
|
worker_url,
|
||||||
|
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
api_key="correct_api_key",
|
||||||
|
)
|
||||||
|
self.other_process.append(worker_process)
|
||||||
|
|
||||||
|
# 7. Use the /add_worker API to add it to the router
|
||||||
|
# Should fail since the router would contact the worker's
|
||||||
|
# /get_server_info endpoint for the dp_size info, but it
|
||||||
|
# has no knowledge of the api key
|
||||||
|
with requests.Session() as session:
|
||||||
|
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||||
|
print(f"status code: {response.status_code}, response: {response.text}")
|
||||||
|
self.assertNotEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
def test_8_lazy_fault_tolerance_with_dp_aware(self):
|
||||||
|
print("Running test_8_lazy_fault_tolerance_with_dp_aware...")
|
||||||
|
|
||||||
|
# Set dp_size = 1
|
||||||
|
self.process = popen_launch_router(
|
||||||
|
self.model,
|
||||||
|
self.base_url,
|
||||||
|
dp_size=1,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
policy="round_robin",
|
||||||
|
dp_aware=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# 1. Start a worker
|
||||||
|
port = find_available_port()
|
||||||
|
worker_url = f"http://127.0.0.1:{port}"
|
||||||
|
worker_process = popen_launch_server(
|
||||||
|
self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH
|
||||||
|
)
|
||||||
|
self.other_process.append(worker_process)
|
||||||
|
|
||||||
|
# 2. Use the /add_worker API to add it to the router
|
||||||
|
# It will be used by router after it is healthy
|
||||||
|
with requests.Session() as session:
|
||||||
|
response = session.post(f"{self.base_url}/add_worker?url={worker_url}")
|
||||||
|
print(f"status code: {response.status_code}, response: {response.text}")
|
||||||
|
self.assertEqual(response.status_code, 200)
|
||||||
|
|
||||||
|
# Start a thread to kill the worker after 10 seconds to mimic
|
||||||
|
# abrupt worker failure
|
||||||
|
def kill_worker():
|
||||||
|
time.sleep(10)
|
||||||
|
kill_process_tree(worker_process.pid)
|
||||||
|
print("Worker process killed")
|
||||||
|
|
||||||
|
import threading
|
||||||
|
|
||||||
|
kill_thread = threading.Thread(target=kill_worker)
|
||||||
|
kill_thread.daemon = True
|
||||||
|
kill_thread.start()
|
||||||
|
|
||||||
|
# 3. Run mmlu
|
||||||
|
args = SimpleNamespace(
|
||||||
|
base_url=self.base_url,
|
||||||
|
model=self.model,
|
||||||
|
eval_name="mmlu",
|
||||||
|
num_examples=256,
|
||||||
|
num_threads=32,
|
||||||
|
temperature=0.1,
|
||||||
|
)
|
||||||
|
metrics = run_eval(args)
|
||||||
|
score = metrics["score"]
|
||||||
|
THRESHOLD = 0.65
|
||||||
|
passed = score >= THRESHOLD
|
||||||
|
msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})"
|
||||||
|
self.assertGreaterEqual(score, THRESHOLD, msg)
|
||||||
|
|
||||||
|
def test_9_payload_size_with_dp_aware(self):
|
||||||
|
print("Running test_9_payload_size_with_dp_aware...")
|
||||||
|
|
||||||
|
# Start the router with 1MB limit
|
||||||
|
self.process = popen_launch_router(
|
||||||
|
self.model,
|
||||||
|
self.base_url,
|
||||||
|
dp_size=1,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
policy="round_robin",
|
||||||
|
max_payload_size=1 * 1024 * 1024, # 1MB limit
|
||||||
|
dp_aware=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test case 1: Payload just under 1MB should succeed
|
||||||
|
payload_0_5_mb = {
|
||||||
|
"text": "x" * int(0.5 * 1024 * 1024), # 0.5MB of text
|
||||||
|
"temperature": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
with requests.Session() as session:
|
||||||
|
response = session.post(
|
||||||
|
f"{self.base_url}/generate",
|
||||||
|
json=payload_0_5_mb,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
response.status_code,
|
||||||
|
200,
|
||||||
|
f"0.5MB payload should succeed but got status {response.status_code}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test case 2: Payload over 1MB should fail
|
||||||
|
payload_1_plus_mb = {
|
||||||
|
"text": "x" * int((1.2 * 1024 * 1024)), # 1.2MB of text
|
||||||
|
"temperature": 0.0,
|
||||||
|
}
|
||||||
|
|
||||||
|
with requests.Session() as session:
|
||||||
|
response = session.post(
|
||||||
|
f"{self.base_url}/generate",
|
||||||
|
json=payload_1_plus_mb,
|
||||||
|
headers={"Content-Type": "application/json"},
|
||||||
|
)
|
||||||
|
self.assertEqual(
|
||||||
|
response.status_code,
|
||||||
|
413, # Payload Too Large
|
||||||
|
f"1.2MB payload should fail with 413 but got status {response.status_code}",
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_10_api_key_with_dp_aware(self):
|
||||||
|
print("Running test_10_api_key_with_dp_aware...")
|
||||||
|
|
||||||
|
self.process = popen_launch_router(
|
||||||
|
self.model,
|
||||||
|
self.base_url,
|
||||||
|
dp_size=1,
|
||||||
|
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
|
||||||
|
policy="round_robin",
|
||||||
|
api_key="correct_api_key",
|
||||||
|
dp_aware=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test case 1: request without api key should fail
|
||||||
|
with requests.Session() as session:
|
||||||
|
response = session.post(
|
||||||
|
f"{self.base_url}/generate",
|
||||||
|
json={"text": "Kanye west is, ", "temperature": 0},
|
||||||
|
)
|
||||||
|
print(f"status code: {response.status_code}, response: {response.text}")
|
||||||
|
self.assertEqual(
|
||||||
|
response.status_code,
|
||||||
|
401,
|
||||||
|
f"Request without api key should fail with 401 but got status {response.status_code}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test case 2: request with invalid api key should fail
|
||||||
|
with requests.Session() as session:
|
||||||
|
response = requests.post(
|
||||||
|
f"{self.base_url}/generate",
|
||||||
|
json={"text": "Kanye west is, ", "temperature": 0},
|
||||||
|
headers={"Authorization": "Bearer 123"},
|
||||||
|
)
|
||||||
|
print(f"status code: {response.status_code}, response: {response.text}")
|
||||||
|
self.assertEqual(
|
||||||
|
response.status_code,
|
||||||
|
401,
|
||||||
|
f"Request without api key should fail with 401 but got status {response.status_code}",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Test case 3: request with correct api key should succeed
|
||||||
|
with requests.Session() as session:
|
||||||
|
response = session.post(
|
||||||
|
f"{self.base_url}/generate",
|
||||||
|
json={"text": "Kanye west is ", "temperature": 0},
|
||||||
|
headers={"Authorization": "Bearer correct_api_key"},
|
||||||
|
)
|
||||||
|
print(f"status code: {response.status_code}, response: {response.text}")
|
||||||
|
self.assertEqual(
|
||||||
|
response.status_code,
|
||||||
|
200,
|
||||||
|
f"Request with correct api key should succeed but got status {response.status_code}",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|||||||
@@ -21,6 +21,10 @@ pub struct RouterConfig {
|
|||||||
pub worker_startup_timeout_secs: u64,
|
pub worker_startup_timeout_secs: u64,
|
||||||
/// Worker health check interval in seconds
|
/// Worker health check interval in seconds
|
||||||
pub worker_startup_check_interval_secs: u64,
|
pub worker_startup_check_interval_secs: u64,
|
||||||
|
/// Enable data parallelism aware schedule
|
||||||
|
pub dp_aware: bool,
|
||||||
|
/// The api key used for the authorization with the worker
|
||||||
|
pub api_key: Option<String>,
|
||||||
/// Service discovery configuration (optional)
|
/// Service discovery configuration (optional)
|
||||||
pub discovery: Option<DiscoveryConfig>,
|
pub discovery: Option<DiscoveryConfig>,
|
||||||
/// Metrics configuration (optional)
|
/// Metrics configuration (optional)
|
||||||
@@ -205,6 +209,8 @@ impl Default for RouterConfig {
|
|||||||
request_timeout_secs: 600,
|
request_timeout_secs: 600,
|
||||||
worker_startup_timeout_secs: 300,
|
worker_startup_timeout_secs: 300,
|
||||||
worker_startup_check_interval_secs: 10,
|
worker_startup_check_interval_secs: 10,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
discovery: None,
|
discovery: None,
|
||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
@@ -311,6 +317,8 @@ mod tests {
|
|||||||
request_timeout_secs: 30,
|
request_timeout_secs: 30,
|
||||||
worker_startup_timeout_secs: 60,
|
worker_startup_timeout_secs: 60,
|
||||||
worker_startup_check_interval_secs: 5,
|
worker_startup_check_interval_secs: 5,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
discovery: Some(DiscoveryConfig::default()),
|
discovery: Some(DiscoveryConfig::default()),
|
||||||
metrics: Some(MetricsConfig::default()),
|
metrics: Some(MetricsConfig::default()),
|
||||||
log_dir: Some("/var/log".to_string()),
|
log_dir: Some("/var/log".to_string()),
|
||||||
@@ -727,6 +735,8 @@ mod tests {
|
|||||||
request_timeout_secs: 120,
|
request_timeout_secs: 120,
|
||||||
worker_startup_timeout_secs: 60,
|
worker_startup_timeout_secs: 60,
|
||||||
worker_startup_check_interval_secs: 5,
|
worker_startup_check_interval_secs: 5,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
discovery: Some(DiscoveryConfig {
|
discovery: Some(DiscoveryConfig {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
namespace: Some("sglang".to_string()),
|
namespace: Some("sglang".to_string()),
|
||||||
@@ -774,6 +784,8 @@ mod tests {
|
|||||||
request_timeout_secs: 300,
|
request_timeout_secs: 300,
|
||||||
worker_startup_timeout_secs: 180,
|
worker_startup_timeout_secs: 180,
|
||||||
worker_startup_check_interval_secs: 15,
|
worker_startup_check_interval_secs: 15,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
discovery: Some(DiscoveryConfig {
|
discovery: Some(DiscoveryConfig {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
namespace: None,
|
namespace: None,
|
||||||
@@ -812,6 +824,8 @@ mod tests {
|
|||||||
request_timeout_secs: 900,
|
request_timeout_secs: 900,
|
||||||
worker_startup_timeout_secs: 600,
|
worker_startup_timeout_secs: 600,
|
||||||
worker_startup_check_interval_secs: 20,
|
worker_startup_check_interval_secs: 20,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
discovery: Some(DiscoveryConfig {
|
discovery: Some(DiscoveryConfig {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
namespace: Some("production".to_string()),
|
namespace: Some("production".to_string()),
|
||||||
|
|||||||
@@ -313,6 +313,14 @@ impl ConfigValidator {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Service discovery is conflict with dp_aware routing for now
|
||||||
|
// since it's not fully supported yet
|
||||||
|
if has_service_discovery && config.dp_aware {
|
||||||
|
return Err(ConfigError::IncompatibleConfig {
|
||||||
|
reason: "DP-aware routing is not compatible with service discovery".to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ pub enum WorkerError {
|
|||||||
NetworkError { url: String, error: String },
|
NetworkError { url: String, error: String },
|
||||||
/// Worker is at capacity
|
/// Worker is at capacity
|
||||||
WorkerAtCapacity { url: String },
|
WorkerAtCapacity { url: String },
|
||||||
|
/// Invalid URL format
|
||||||
|
InvalidUrl { url: String },
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for WorkerError {
|
impl fmt::Display for WorkerError {
|
||||||
@@ -37,6 +39,9 @@ impl fmt::Display for WorkerError {
|
|||||||
WorkerError::WorkerAtCapacity { url } => {
|
WorkerError::WorkerAtCapacity { url } => {
|
||||||
write!(f, "Worker at capacity: {}", url)
|
write!(f, "Worker at capacity: {}", url)
|
||||||
}
|
}
|
||||||
|
WorkerError::InvalidUrl { url } => {
|
||||||
|
write!(f, "Invalid URL format: {}", url)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -162,6 +162,27 @@ impl BasicWorker {
|
|||||||
self.metadata.health_config = config;
|
self.metadata.health_config = config;
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn normalised_url(&self) -> WorkerResult<&str> {
|
||||||
|
if self.url().contains("@") {
|
||||||
|
// Need to extract the URL from "http://host:port@dp_rank"
|
||||||
|
let parts: Vec<&str> = self.url().split('@').collect();
|
||||||
|
if parts.len() != 2 {
|
||||||
|
return Err(WorkerError::InvalidUrl {
|
||||||
|
url: self.url().to_string(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
// Ensure the second part (the dp_rank) can be parsed as an integer
|
||||||
|
match parts[1].parse::<usize>() {
|
||||||
|
Ok(_) => Ok(parts[0]),
|
||||||
|
Err(_) => Err(WorkerError::InvalidUrl {
|
||||||
|
url: self.url().to_string(),
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Ok(self.url())
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@@ -186,7 +207,8 @@ impl Worker for BasicWorker {
|
|||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
// Perform actual HTTP health check
|
// Perform actual HTTP health check
|
||||||
let health_url = format!("{}{}", self.url(), self.metadata.health_config.endpoint);
|
let url = self.normalised_url()?;
|
||||||
|
let health_url = format!("{}{}", url, self.metadata.health_config.endpoint);
|
||||||
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
|
let timeout = Duration::from_secs(self.metadata.health_config.timeout_secs);
|
||||||
|
|
||||||
// Use the shared client with a custom timeout for this request
|
// Use the shared client with a custom timeout for this request
|
||||||
@@ -203,7 +225,7 @@ impl Worker for BasicWorker {
|
|||||||
} else {
|
} else {
|
||||||
self.set_healthy(false);
|
self.set_healthy(false);
|
||||||
Err(WorkerError::HealthCheckFailed {
|
Err(WorkerError::HealthCheckFailed {
|
||||||
url: self.url().to_string(),
|
url: url.to_string(),
|
||||||
reason: format!("Health check returned status: {}", response.status()),
|
reason: format!("Health check returned status: {}", response.status()),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -211,7 +233,7 @@ impl Worker for BasicWorker {
|
|||||||
Err(e) => {
|
Err(e) => {
|
||||||
self.set_healthy(false);
|
self.set_healthy(false);
|
||||||
Err(WorkerError::HealthCheckFailed {
|
Err(WorkerError::HealthCheckFailed {
|
||||||
url: self.url().to_string(),
|
url: url.to_string(),
|
||||||
reason: format!("Health check request failed: {}", e),
|
reason: format!("Health check request failed: {}", e),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,6 +37,8 @@ struct Router {
|
|||||||
eviction_interval_secs: u64,
|
eviction_interval_secs: u64,
|
||||||
max_tree_size: usize,
|
max_tree_size: usize,
|
||||||
max_payload_size: usize,
|
max_payload_size: usize,
|
||||||
|
dp_aware: bool,
|
||||||
|
api_key: Option<String>,
|
||||||
log_dir: Option<String>,
|
log_dir: Option<String>,
|
||||||
log_level: Option<String>,
|
log_level: Option<String>,
|
||||||
service_discovery: bool,
|
service_discovery: bool,
|
||||||
@@ -136,6 +138,8 @@ impl Router {
|
|||||||
request_timeout_secs: self.request_timeout_secs,
|
request_timeout_secs: self.request_timeout_secs,
|
||||||
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
|
worker_startup_timeout_secs: self.worker_startup_timeout_secs,
|
||||||
worker_startup_check_interval_secs: self.worker_startup_check_interval,
|
worker_startup_check_interval_secs: self.worker_startup_check_interval,
|
||||||
|
dp_aware: self.dp_aware,
|
||||||
|
api_key: self.api_key.clone(),
|
||||||
discovery,
|
discovery,
|
||||||
metrics,
|
metrics,
|
||||||
log_dir: self.log_dir.clone(),
|
log_dir: self.log_dir.clone(),
|
||||||
@@ -161,6 +165,8 @@ impl Router {
|
|||||||
eviction_interval_secs = 60,
|
eviction_interval_secs = 60,
|
||||||
max_tree_size = 2usize.pow(24),
|
max_tree_size = 2usize.pow(24),
|
||||||
max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
|
max_payload_size = 256 * 1024 * 1024, // 256MB default for large batches
|
||||||
|
dp_aware = false,
|
||||||
|
api_key = None,
|
||||||
log_dir = None,
|
log_dir = None,
|
||||||
log_level = None,
|
log_level = None,
|
||||||
service_discovery = false,
|
service_discovery = false,
|
||||||
@@ -193,6 +199,8 @@ impl Router {
|
|||||||
eviction_interval_secs: u64,
|
eviction_interval_secs: u64,
|
||||||
max_tree_size: usize,
|
max_tree_size: usize,
|
||||||
max_payload_size: usize,
|
max_payload_size: usize,
|
||||||
|
dp_aware: bool,
|
||||||
|
api_key: Option<String>,
|
||||||
log_dir: Option<String>,
|
log_dir: Option<String>,
|
||||||
log_level: Option<String>,
|
log_level: Option<String>,
|
||||||
service_discovery: bool,
|
service_discovery: bool,
|
||||||
@@ -225,6 +233,8 @@ impl Router {
|
|||||||
eviction_interval_secs,
|
eviction_interval_secs,
|
||||||
max_tree_size,
|
max_tree_size,
|
||||||
max_payload_size,
|
max_payload_size,
|
||||||
|
dp_aware,
|
||||||
|
api_key,
|
||||||
log_dir,
|
log_dir,
|
||||||
log_level,
|
log_level,
|
||||||
service_discovery,
|
service_discovery,
|
||||||
|
|||||||
@@ -45,6 +45,8 @@ impl RouterFactory {
|
|||||||
policy,
|
policy,
|
||||||
router_config.worker_startup_timeout_secs,
|
router_config.worker_startup_timeout_secs,
|
||||||
router_config.worker_startup_check_interval_secs,
|
router_config.worker_startup_check_interval_secs,
|
||||||
|
router_config.dp_aware,
|
||||||
|
router_config.api_key.clone(),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Box::new(router))
|
Ok(Box::new(router))
|
||||||
|
|||||||
@@ -30,6 +30,8 @@ pub struct Router {
|
|||||||
policy: Arc<dyn LoadBalancingPolicy>,
|
policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
|
dp_aware: bool,
|
||||||
|
api_key: Option<String>,
|
||||||
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
_worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||||
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
_load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||||
_health_checker: Option<HealthChecker>,
|
_health_checker: Option<HealthChecker>,
|
||||||
@@ -42,6 +44,8 @@ impl Router {
|
|||||||
policy: Arc<dyn LoadBalancingPolicy>,
|
policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
|
dp_aware: bool,
|
||||||
|
api_key: Option<String>,
|
||||||
) -> Result<Self, String> {
|
) -> Result<Self, String> {
|
||||||
// Update active workers gauge
|
// Update active workers gauge
|
||||||
RouterMetrics::set_active_workers(worker_urls.len());
|
RouterMetrics::set_active_workers(worker_urls.len());
|
||||||
@@ -51,6 +55,14 @@ impl Router {
|
|||||||
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
|
Self::wait_for_healthy_workers(&worker_urls, timeout_secs, interval_secs)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let worker_urls = if dp_aware {
|
||||||
|
// worker address now in the format of "http://host:port@dp_rank"
|
||||||
|
Self::get_dp_aware_workers(&worker_urls, &api_key)
|
||||||
|
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?
|
||||||
|
} else {
|
||||||
|
worker_urls
|
||||||
|
};
|
||||||
|
|
||||||
// Create Worker trait objects from URLs
|
// Create Worker trait objects from URLs
|
||||||
let workers: Vec<Box<dyn Worker>> = worker_urls
|
let workers: Vec<Box<dyn Worker>> = worker_urls
|
||||||
.iter()
|
.iter()
|
||||||
@@ -89,6 +101,8 @@ impl Router {
|
|||||||
policy,
|
policy,
|
||||||
timeout_secs,
|
timeout_secs,
|
||||||
interval_secs,
|
interval_secs,
|
||||||
|
dp_aware,
|
||||||
|
api_key,
|
||||||
_worker_loads: worker_loads,
|
_worker_loads: worker_loads,
|
||||||
_load_monitor_handle: load_monitor_handle,
|
_load_monitor_handle: load_monitor_handle,
|
||||||
_health_checker: Some(health_checker),
|
_health_checker: Some(health_checker),
|
||||||
@@ -160,6 +174,62 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn get_worker_dp_size(worker_url: &str, api_key: &Option<String>) -> Result<usize, String> {
|
||||||
|
let sync_client = reqwest::blocking::Client::new();
|
||||||
|
let mut req_builder = sync_client.get(&format!("{}/get_server_info", worker_url));
|
||||||
|
if let Some(key) = api_key {
|
||||||
|
req_builder = req_builder.bearer_auth(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
match req_builder.send() {
|
||||||
|
Ok(res) => {
|
||||||
|
if res.status().is_success() {
|
||||||
|
let server_info = res
|
||||||
|
.text()
|
||||||
|
.map_err(|e| format!("failed to read text from response: {}", e))?;
|
||||||
|
|
||||||
|
let server_info: serde_json::Value = serde_json::from_str(&server_info)
|
||||||
|
.map_err(|e| format!("failed to decode JSON: {}", e))?;
|
||||||
|
|
||||||
|
let dp_size = server_info
|
||||||
|
.get("dp_size")
|
||||||
|
.and_then(|v| v.as_u64())
|
||||||
|
.ok_or_else(|| String::from("dp_size not found or not an u64"))?;
|
||||||
|
|
||||||
|
Ok(if dp_size > usize::MAX as u64 {
|
||||||
|
return Err(format!("dp_size is too large: {}", dp_size));
|
||||||
|
} else {
|
||||||
|
dp_size as usize
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Err(format!("unexpected status code: {}", res.status()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => Err(format!("error response: {}", e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Given a list of workers, return a list of workers with dp_rank as suffix
|
||||||
|
fn get_dp_aware_workers(
|
||||||
|
worker_urls: &[String],
|
||||||
|
api_key: &Option<String>,
|
||||||
|
) -> Result<Vec<String>, String> {
|
||||||
|
let mut dp_aware_workers: Vec<String> = Vec::new();
|
||||||
|
|
||||||
|
for url in worker_urls {
|
||||||
|
match Self::get_worker_dp_size(url, api_key) {
|
||||||
|
Ok(dp_size) => {
|
||||||
|
for i in 0..dp_size {
|
||||||
|
dp_aware_workers.push(format!("{}@{}", url, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => return Err(format!("Failed to get DP size for {}: {}", url, e)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(dp_aware_workers)
|
||||||
|
}
|
||||||
|
|
||||||
fn select_first_worker(&self) -> Result<String, String> {
|
fn select_first_worker(&self) -> Result<String, String> {
|
||||||
let workers_guard = self.workers.read().unwrap();
|
let workers_guard = self.workers.read().unwrap();
|
||||||
if workers_guard.is_empty() {
|
if workers_guard.is_empty() {
|
||||||
@@ -178,6 +248,21 @@ impl Router {
|
|||||||
) -> HttpResponse {
|
) -> HttpResponse {
|
||||||
let request_id = get_request_id(req);
|
let request_id = get_request_id(req);
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
|
let worker_url = if self.dp_aware {
|
||||||
|
// Need to extract the URL from "http://host:port@dp_rank"
|
||||||
|
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||||
|
Ok(tup) => tup,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to extract dp_rank: {}", e);
|
||||||
|
return HttpResponse::InternalServerError().finish();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
worker_url_prefix
|
||||||
|
} else {
|
||||||
|
worker_url
|
||||||
|
};
|
||||||
|
|
||||||
let mut request_builder = client.get(format!("{}{}", worker_url, route));
|
let mut request_builder = client.get(format!("{}{}", worker_url, route));
|
||||||
|
|
||||||
// Copy all headers from original request except for /health because it does not need authorization
|
// Copy all headers from original request except for /health because it does not need authorization
|
||||||
@@ -292,7 +377,7 @@ impl Router {
|
|||||||
worker_url = %worker_url,
|
worker_url = %worker_url,
|
||||||
"Removing failed worker"
|
"Removing failed worker"
|
||||||
);
|
);
|
||||||
self.remove_worker(&worker_url);
|
self.remove_failed_worker(&worker_url);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -392,7 +477,7 @@ impl Router {
|
|||||||
request_id = %request_id,
|
request_id = %request_id,
|
||||||
"Removing failed worker after typed request failures worker_url={}", worker_url
|
"Removing failed worker after typed request failures worker_url={}", worker_url
|
||||||
);
|
);
|
||||||
self.remove_worker(&worker_url);
|
self.remove_failed_worker(&worker_url);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -415,6 +500,23 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO (rui): Better accommodate to the Worker abstraction
|
||||||
|
fn extract_dp_rank(worker_url: &str) -> Result<(&str, usize), String> {
|
||||||
|
let parts: Vec<&str> = worker_url.split('@').collect();
|
||||||
|
if parts.len() != 2 {
|
||||||
|
return Err(format!("invalid worker_url format: {}", worker_url));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Parse the second part (dp_rank) into an integer
|
||||||
|
match parts[1].parse::<usize>() {
|
||||||
|
Ok(dp_rank) => Ok((parts[0], dp_rank)),
|
||||||
|
Err(_) => Err(format!(
|
||||||
|
"failed to parse dp_rank from worker_url: {}",
|
||||||
|
worker_url
|
||||||
|
)),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Send typed request directly without conversion
|
// Send typed request directly without conversion
|
||||||
async fn send_typed_request<T: serde::Serialize>(
|
async fn send_typed_request<T: serde::Serialize>(
|
||||||
&self,
|
&self,
|
||||||
@@ -429,9 +531,47 @@ impl Router {
|
|||||||
let request_id = get_request_id(req);
|
let request_id = get_request_id(req);
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
let mut request_builder = client
|
let mut request_builder = if self.dp_aware {
|
||||||
.post(format!("{}{}", worker_url, route))
|
let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||||
.json(typed_req); // Use json() directly with typed request
|
Ok(tup) => tup,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to extract dp_rank: {}", e);
|
||||||
|
return HttpResponse::InternalServerError().finish();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse the request body
|
||||||
|
let mut json_val = match serde_json::to_value(typed_req) {
|
||||||
|
Ok(j) => j,
|
||||||
|
Err(e) => {
|
||||||
|
return HttpResponse::BadRequest()
|
||||||
|
.body(format!("Convert into serde_json::Value failed: {}", e));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Insert the data_parallel_rank field
|
||||||
|
if let Some(map) = json_val.as_object_mut() {
|
||||||
|
map.insert(
|
||||||
|
String::from("data_parallel_rank"),
|
||||||
|
serde_json::json!(dp_rank),
|
||||||
|
);
|
||||||
|
debug!(
|
||||||
|
"Modified request body: {}",
|
||||||
|
serde_json::to_string(&json_val).unwrap_or(String::from("ERR"))
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
return HttpResponse::BadRequest()
|
||||||
|
.body("Failed to insert the data_parallel_rank field into the request body");
|
||||||
|
}
|
||||||
|
|
||||||
|
client
|
||||||
|
.post(format!("{}{}", worker_url_prefix, route))
|
||||||
|
.json(&json_val)
|
||||||
|
} else {
|
||||||
|
client
|
||||||
|
.post(format!("{}{}", worker_url, route))
|
||||||
|
.json(typed_req) // Use json() directly with typed request
|
||||||
|
};
|
||||||
|
|
||||||
// Copy all headers from original request
|
// Copy all headers from original request
|
||||||
for (name, value) in copy_request_headers(req) {
|
for (name, value) in copy_request_headers(req) {
|
||||||
@@ -560,12 +700,35 @@ impl Router {
|
|||||||
Ok(res) => {
|
Ok(res) => {
|
||||||
if res.status().is_success() {
|
if res.status().is_success() {
|
||||||
let mut workers_guard = self.workers.write().unwrap();
|
let mut workers_guard = self.workers.write().unwrap();
|
||||||
if workers_guard.iter().any(|w| w.url() == worker_url) {
|
if self.dp_aware {
|
||||||
return Err(format!("Worker {} already exists", worker_url));
|
// Need to contact the worker to extract the dp_size,
|
||||||
|
// and add them as multiple workers
|
||||||
|
let url_vec = vec![String::from(worker_url)];
|
||||||
|
let dp_url_vec = Self::get_dp_aware_workers(&url_vec, &self.api_key)
|
||||||
|
.map_err(|e| format!("Failed to get dp-aware workers: {}", e))?;
|
||||||
|
let mut worker_added: bool = false;
|
||||||
|
for dp_url in &dp_url_vec {
|
||||||
|
if workers_guard.iter().any(|w| w.url() == dp_url) {
|
||||||
|
warn!("Worker {} already exists", dp_url);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
info!("Added worker: {}", dp_url);
|
||||||
|
let new_worker = WorkerFactory::create_regular(dp_url.to_string());
|
||||||
|
workers_guard.push(new_worker);
|
||||||
|
worker_added = true;
|
||||||
|
}
|
||||||
|
if !worker_added {
|
||||||
|
return Err(format!("No worker added for {}", worker_url));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if workers_guard.iter().any(|w| w.url() == worker_url) {
|
||||||
|
return Err(format!("Worker {} already exists", worker_url));
|
||||||
|
}
|
||||||
|
info!("Added worker: {}", worker_url);
|
||||||
|
let new_worker = WorkerFactory::create_regular(worker_url.to_string());
|
||||||
|
workers_guard.push(new_worker);
|
||||||
}
|
}
|
||||||
info!("Added worker: {}", worker_url);
|
|
||||||
let new_worker = WorkerFactory::create_regular(worker_url.to_string());
|
|
||||||
workers_guard.push(new_worker);
|
|
||||||
RouterMetrics::set_active_workers(workers_guard.len());
|
RouterMetrics::set_active_workers(workers_guard.len());
|
||||||
|
|
||||||
// If cache aware policy, initialize the worker in the tree
|
// If cache aware policy, initialize the worker in the tree
|
||||||
@@ -612,11 +775,81 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Remove all the worker(s) that match the URL prefix
|
||||||
pub fn remove_worker(&self, worker_url: &str) {
|
pub fn remove_worker(&self, worker_url: &str) {
|
||||||
|
if self.dp_aware {
|
||||||
|
// remove dp-aware workers in a prefix-matching fashion
|
||||||
|
// without contacting the remote worker
|
||||||
|
let mut candidate_workers: Vec<String> = Vec::new();
|
||||||
|
let mut removed_workers: Vec<String> = Vec::new();
|
||||||
|
let worker_url_prefix = format!("{}@", worker_url);
|
||||||
|
|
||||||
|
{
|
||||||
|
// find the candidate workers to be removed
|
||||||
|
let workers_guard = self.workers.read().unwrap();
|
||||||
|
for w in workers_guard.iter() {
|
||||||
|
if w.url().starts_with(&worker_url_prefix) {
|
||||||
|
candidate_workers.push(w.url().to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
// do the removing on the worker_urls
|
||||||
|
let mut workers_guard = self.workers.write().unwrap();
|
||||||
|
for dp_url in candidate_workers.iter() {
|
||||||
|
if let Some(index) = workers_guard.iter().position(|w| w.url() == dp_url) {
|
||||||
|
workers_guard.remove(index);
|
||||||
|
info!("Removed worker: {}", dp_url);
|
||||||
|
removed_workers.push(dp_url.to_string());
|
||||||
|
} else {
|
||||||
|
warn!("Worker {} not found, skipping removal", dp_url);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
RouterMetrics::set_active_workers(workers_guard.len());
|
||||||
|
}
|
||||||
|
|
||||||
|
// If cache aware policy, remove the workers from the tree
|
||||||
|
if let Some(cache_aware) = self
|
||||||
|
.policy
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
|
{
|
||||||
|
for dp_url in removed_workers.iter() {
|
||||||
|
cache_aware.remove_worker(dp_url);
|
||||||
|
info!("Removed worker from tree: {}", dp_url);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
let mut workers_guard = self.workers.write().unwrap();
|
||||||
|
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
|
||||||
|
workers_guard.remove(index);
|
||||||
|
info!("Removed worker: {}", worker_url);
|
||||||
|
RouterMetrics::set_active_workers(workers_guard.len());
|
||||||
|
} else {
|
||||||
|
warn!("Worker {} not found, skipping removal", worker_url);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// If cache aware policy, remove the workers from the tree
|
||||||
|
if let Some(cache_aware) = self
|
||||||
|
.policy
|
||||||
|
.as_any()
|
||||||
|
.downcast_ref::<crate::policies::CacheAwarePolicy>()
|
||||||
|
{
|
||||||
|
cache_aware.remove_worker(worker_url);
|
||||||
|
info!("Removed worker from tree: {}", worker_url);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove a specific failed worker; for internal usage
|
||||||
|
fn remove_failed_worker(&self, worker_url: &str) {
|
||||||
let mut workers_guard = self.workers.write().unwrap();
|
let mut workers_guard = self.workers.write().unwrap();
|
||||||
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
|
if let Some(index) = workers_guard.iter().position(|w| w.url() == worker_url) {
|
||||||
workers_guard.remove(index);
|
workers_guard.remove(index);
|
||||||
info!("Removed worker: {}", worker_url);
|
info!("Removed failed worker: {}", worker_url);
|
||||||
RouterMetrics::set_active_workers(workers_guard.len());
|
RouterMetrics::set_active_workers(workers_guard.len());
|
||||||
} else {
|
} else {
|
||||||
warn!("Worker {} not found, skipping removal", worker_url);
|
warn!("Worker {} not found, skipping removal", worker_url);
|
||||||
@@ -634,6 +867,20 @@ impl Router {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||||
|
let worker_url = if self.dp_aware {
|
||||||
|
// Need to extract the URL from "http://host:port@dp_rank"
|
||||||
|
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||||
|
Ok(tup) => tup,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to extract dp_rank: {}", e);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
worker_url_prefix
|
||||||
|
} else {
|
||||||
|
worker_url
|
||||||
|
};
|
||||||
|
|
||||||
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
||||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||||
@@ -710,6 +957,20 @@ impl Router {
|
|||||||
|
|
||||||
// Static version of get_worker_load for use in monitoring task
|
// Static version of get_worker_load for use in monitoring task
|
||||||
async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
async fn get_worker_load_static(client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||||
|
let worker_url = if worker_url.contains("@") {
|
||||||
|
// Need to extract the URL from "http://host:port@dp_rank"
|
||||||
|
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||||
|
Ok(tup) => tup,
|
||||||
|
Err(e) => {
|
||||||
|
debug!("Failed to extract dp_rank: {}", e);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
worker_url_prefix
|
||||||
|
} else {
|
||||||
|
worker_url
|
||||||
|
};
|
||||||
|
|
||||||
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
||||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||||
@@ -862,6 +1123,19 @@ impl RouterTrait for Router {
|
|||||||
// Send requests to all workers concurrently without headers
|
// Send requests to all workers concurrently without headers
|
||||||
let mut tasks = Vec::new();
|
let mut tasks = Vec::new();
|
||||||
for worker_url in &worker_urls {
|
for worker_url in &worker_urls {
|
||||||
|
let worker_url = if self.dp_aware {
|
||||||
|
// Need to extract the URL from "http://host:port@dp_rank"
|
||||||
|
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||||
|
Ok(tup) => tup,
|
||||||
|
Err(e) => {
|
||||||
|
error!("Failed to extract dp_rank: {}", e);
|
||||||
|
return HttpResponse::InternalServerError().finish();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
worker_url_prefix
|
||||||
|
} else {
|
||||||
|
worker_url
|
||||||
|
};
|
||||||
let request_builder = client.post(format!("{}/flush_cache", worker_url));
|
let request_builder = client.post(format!("{}/flush_cache", worker_url));
|
||||||
tasks.push(request_builder.send());
|
tasks.push(request_builder.send());
|
||||||
}
|
}
|
||||||
@@ -948,6 +1222,8 @@ mod tests {
|
|||||||
policy: Arc::new(RandomPolicy::new()),
|
policy: Arc::new(RandomPolicy::new()),
|
||||||
timeout_secs: 5,
|
timeout_secs: 5,
|
||||||
interval_secs: 1,
|
interval_secs: 1,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
_worker_loads: Arc::new(rx),
|
_worker_loads: Arc::new(rx),
|
||||||
_load_monitor_handle: None,
|
_load_monitor_handle: None,
|
||||||
_health_checker: None,
|
_health_checker: None,
|
||||||
|
|||||||
@@ -581,7 +581,7 @@ mod tests {
|
|||||||
use crate::routers::router::Router;
|
use crate::routers::router::Router;
|
||||||
|
|
||||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
||||||
let router = Router::new(vec![], policy, 5, 1).unwrap();
|
let router = Router::new(vec![], policy, 5, 1, false, None).unwrap();
|
||||||
Arc::new(router) as Arc<dyn RouterTrait>
|
Arc::new(router) as Arc<dyn RouterTrait>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -31,6 +31,8 @@ impl TestContext {
|
|||||||
request_timeout_secs: 600,
|
request_timeout_secs: 600,
|
||||||
worker_startup_timeout_secs: 1,
|
worker_startup_timeout_secs: 1,
|
||||||
worker_startup_check_interval_secs: 1,
|
worker_startup_check_interval_secs: 1,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
discovery: None,
|
discovery: None,
|
||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
@@ -950,6 +952,8 @@ mod error_tests {
|
|||||||
request_timeout_secs: 600,
|
request_timeout_secs: 600,
|
||||||
worker_startup_timeout_secs: 1,
|
worker_startup_timeout_secs: 1,
|
||||||
worker_startup_check_interval_secs: 1,
|
worker_startup_check_interval_secs: 1,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
discovery: None,
|
discovery: None,
|
||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
|
|||||||
@@ -16,6 +16,8 @@ pub fn create_test_config(worker_urls: Vec<String>) -> RouterConfig {
|
|||||||
request_timeout_secs: 600,
|
request_timeout_secs: 600,
|
||||||
worker_startup_timeout_secs: 300,
|
worker_startup_timeout_secs: 300,
|
||||||
worker_startup_check_interval_secs: 10,
|
worker_startup_check_interval_secs: 10,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
discovery: None,
|
discovery: None,
|
||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
@@ -37,6 +39,8 @@ pub fn create_test_config_no_workers() -> RouterConfig {
|
|||||||
request_timeout_secs: 600,
|
request_timeout_secs: 600,
|
||||||
worker_startup_timeout_secs: 0, // No wait
|
worker_startup_timeout_secs: 0, // No wait
|
||||||
worker_startup_check_interval_secs: 10,
|
worker_startup_check_interval_secs: 10,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
discovery: None,
|
discovery: None,
|
||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
|
|||||||
@@ -42,6 +42,8 @@ impl RequestTestContext {
|
|||||||
request_timeout_secs: 600,
|
request_timeout_secs: 600,
|
||||||
worker_startup_timeout_secs: 1,
|
worker_startup_timeout_secs: 1,
|
||||||
worker_startup_check_interval_secs: 1,
|
worker_startup_check_interval_secs: 1,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
discovery: None,
|
discovery: None,
|
||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
|
|||||||
@@ -46,6 +46,8 @@ impl StreamingTestContext {
|
|||||||
request_timeout_secs: 600,
|
request_timeout_secs: 600,
|
||||||
worker_startup_timeout_secs: 1,
|
worker_startup_timeout_secs: 1,
|
||||||
worker_startup_check_interval_secs: 1,
|
worker_startup_check_interval_secs: 1,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
discovery: None,
|
discovery: None,
|
||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
|
|||||||
@@ -169,6 +169,8 @@ mod test_pd_routing {
|
|||||||
request_timeout_secs: 60,
|
request_timeout_secs: 60,
|
||||||
worker_startup_timeout_secs: 10,
|
worker_startup_timeout_secs: 10,
|
||||||
worker_startup_check_interval_secs: 1,
|
worker_startup_check_interval_secs: 1,
|
||||||
|
dp_aware: false,
|
||||||
|
api_key: None,
|
||||||
discovery: None,
|
discovery: None,
|
||||||
metrics: None,
|
metrics: None,
|
||||||
log_dir: None,
|
log_dir: None,
|
||||||
|
|||||||
Reference in New Issue
Block a user