[router] create worker removal step and clean up worker manager (#11921)
This commit is contained in:
@@ -85,6 +85,8 @@ def _popen_launch_router(
|
|||||||
str(prom_port),
|
str(prom_port),
|
||||||
"--router-prometheus-host",
|
"--router-prometheus-host",
|
||||||
"127.0.0.1",
|
"127.0.0.1",
|
||||||
|
"--router-log-level",
|
||||||
|
"warn",
|
||||||
]
|
]
|
||||||
|
|
||||||
proc = subprocess.Popen(cmd)
|
proc = subprocess.Popen(cmd)
|
||||||
|
|||||||
@@ -1,9 +1,31 @@
|
|||||||
|
import time
|
||||||
from types import SimpleNamespace
|
from types import SimpleNamespace
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_for_workers(
|
||||||
|
base_url: str, expected_count: int, timeout: float = 60.0, headers: dict = None
|
||||||
|
) -> None:
|
||||||
|
"""Poll /workers endpoint until expected number of workers are registered."""
|
||||||
|
start = time.perf_counter()
|
||||||
|
with requests.Session() as session:
|
||||||
|
while time.perf_counter() - start < timeout:
|
||||||
|
try:
|
||||||
|
r = session.get(f"{base_url}/workers", headers=headers, timeout=5)
|
||||||
|
if r.status_code == 200:
|
||||||
|
workers = r.json().get("workers", [])
|
||||||
|
if len(workers) >= expected_count:
|
||||||
|
return
|
||||||
|
except requests.RequestException:
|
||||||
|
pass
|
||||||
|
time.sleep(0.5)
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Expected {expected_count} workers at {base_url}, timed out after {timeout}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.e2e
|
@pytest.mark.e2e
|
||||||
def test_embeddings_basic(
|
def test_embeddings_basic(
|
||||||
e2e_router_only_rr, e2e_primary_embedding_worker, e2e_embedding_model
|
e2e_router_only_rr, e2e_primary_embedding_worker, e2e_embedding_model
|
||||||
@@ -12,8 +34,11 @@ def test_embeddings_basic(
|
|||||||
worker_url = e2e_primary_embedding_worker.url
|
worker_url = e2e_primary_embedding_worker.url
|
||||||
|
|
||||||
# Attach embedding worker to router-only instance
|
# Attach embedding worker to router-only instance
|
||||||
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180)
|
r = requests.post(f"{base}/workers", json={"url": worker_url}, timeout=180)
|
||||||
r.raise_for_status()
|
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
|
||||||
|
|
||||||
|
# Wait for worker to be registered
|
||||||
|
_wait_for_workers(base, expected_count=1, timeout=60.0)
|
||||||
|
|
||||||
# Simple embedding request with two inputs
|
# Simple embedding request with two inputs
|
||||||
payload = {
|
payload = {
|
||||||
|
|||||||
@@ -198,6 +198,8 @@ def pd_cluster(e2e_model: str):
|
|||||||
"--policy",
|
"--policy",
|
||||||
"round_robin",
|
"round_robin",
|
||||||
"--pd-disaggregation",
|
"--pd-disaggregation",
|
||||||
|
"--log-level",
|
||||||
|
"warn",
|
||||||
]
|
]
|
||||||
for url, bport in prefill:
|
for url, bport in prefill:
|
||||||
cmd += ["--prefill", url, str(bport)]
|
cmd += ["--prefill", url, str(bport)]
|
||||||
|
|||||||
@@ -8,13 +8,39 @@ import requests
|
|||||||
from sglang.test.run_eval import run_eval
|
from sglang.test.run_eval import run_eval
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_for_workers(
|
||||||
|
base_url: str, expected_count: int, timeout: float = 60.0, headers: dict = None
|
||||||
|
) -> None:
|
||||||
|
"""Poll /workers endpoint until expected number of workers are registered."""
|
||||||
|
start = time.perf_counter()
|
||||||
|
with requests.Session() as session:
|
||||||
|
while time.perf_counter() - start < timeout:
|
||||||
|
try:
|
||||||
|
r = session.get(f"{base_url}/workers", headers=headers, timeout=5)
|
||||||
|
if r.status_code == 200:
|
||||||
|
workers = r.json().get("workers", [])
|
||||||
|
if len(workers) >= expected_count:
|
||||||
|
return
|
||||||
|
except requests.RequestException:
|
||||||
|
pass
|
||||||
|
time.sleep(0.5)
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Expected {expected_count} workers at {base_url}, timed out after {timeout}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.e2e
|
@pytest.mark.e2e
|
||||||
def test_mmlu(e2e_router_only_rr, e2e_two_workers_dp2, e2e_model):
|
def test_mmlu(e2e_router_only_rr, e2e_two_workers_dp2, e2e_model):
|
||||||
# Attach two dp=2 workers (total 4 GPUs) to a fresh router-only instance
|
# Attach two dp=2 workers (total 4 GPUs) to a fresh router-only instance
|
||||||
base = e2e_router_only_rr.url
|
base = e2e_router_only_rr.url
|
||||||
for w in e2e_two_workers_dp2:
|
for w in e2e_two_workers_dp2:
|
||||||
r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180)
|
r = requests.post(f"{base}/workers", json={"url": w.url}, timeout=180)
|
||||||
r.raise_for_status()
|
assert (
|
||||||
|
r.status_code == 202
|
||||||
|
), f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
|
||||||
|
|
||||||
|
# Wait for workers to be registered
|
||||||
|
_wait_for_workers(base, expected_count=2, timeout=60.0)
|
||||||
|
|
||||||
args = SimpleNamespace(
|
args = SimpleNamespace(
|
||||||
base_url=base,
|
base_url=base,
|
||||||
@@ -35,8 +61,13 @@ def test_genai_bench(
|
|||||||
"""Attach a worker to the regular router and run a short genai-bench."""
|
"""Attach a worker to the regular router and run a short genai-bench."""
|
||||||
base = e2e_router_only_rr.url
|
base = e2e_router_only_rr.url
|
||||||
for w in e2e_two_workers_dp2:
|
for w in e2e_two_workers_dp2:
|
||||||
r = requests.post(f"{base}/add_worker", params={"url": w.url}, timeout=180)
|
r = requests.post(f"{base}/workers", json={"url": w.url}, timeout=180)
|
||||||
r.raise_for_status()
|
assert (
|
||||||
|
r.status_code == 202
|
||||||
|
), f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
|
||||||
|
|
||||||
|
# Wait for workers to be registered
|
||||||
|
_wait_for_workers(base, expected_count=2, timeout=60.0)
|
||||||
|
|
||||||
genai_bench_runner(
|
genai_bench_runner(
|
||||||
router_url=base,
|
router_url=base,
|
||||||
@@ -59,8 +90,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_
|
|||||||
base = e2e_router_only_rr.url
|
base = e2e_router_only_rr.url
|
||||||
worker_url = e2e_primary_worker.url
|
worker_url = e2e_primary_worker.url
|
||||||
|
|
||||||
r = requests.post(f"{base}/add_worker", params={"url": worker_url}, timeout=180)
|
r = requests.post(f"{base}/workers", json={"url": worker_url}, timeout=180)
|
||||||
r.raise_for_status()
|
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
|
||||||
|
|
||||||
|
# Wait for worker to be registered
|
||||||
|
_wait_for_workers(base, expected_count=1, timeout=60.0)
|
||||||
|
|
||||||
with requests.Session() as s:
|
with requests.Session() as s:
|
||||||
for i in range(8):
|
for i in range(8):
|
||||||
@@ -77,8 +111,11 @@ def test_add_and_remove_worker_live(e2e_router_only_rr, e2e_primary_worker, e2e_
|
|||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
|
|
||||||
# Remove the worker
|
# Remove the worker
|
||||||
r = requests.post(f"{base}/remove_worker", params={"url": worker_url}, timeout=60)
|
from urllib.parse import quote
|
||||||
r.raise_for_status()
|
|
||||||
|
encoded_url = quote(worker_url, safe="")
|
||||||
|
r = requests.delete(f"{base}/workers/{encoded_url}", timeout=60)
|
||||||
|
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.e2e
|
@pytest.mark.e2e
|
||||||
@@ -86,8 +123,11 @@ def test_lazy_fault_tolerance_live(e2e_router_only_rr, e2e_primary_worker, e2e_m
|
|||||||
base = e2e_router_only_rr.url
|
base = e2e_router_only_rr.url
|
||||||
worker = e2e_primary_worker
|
worker = e2e_primary_worker
|
||||||
|
|
||||||
r = requests.post(f"{base}/add_worker", params={"url": worker.url}, timeout=180)
|
r = requests.post(f"{base}/workers", json={"url": worker.url}, timeout=180)
|
||||||
r.raise_for_status()
|
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
|
||||||
|
|
||||||
|
# Wait for worker to be registered
|
||||||
|
_wait_for_workers(base, expected_count=1, timeout=60.0)
|
||||||
|
|
||||||
def killer():
|
def killer():
|
||||||
time.sleep(10)
|
time.sleep(10)
|
||||||
@@ -129,20 +169,30 @@ def test_dp_aware_worker_expansion_and_api_key(
|
|||||||
|
|
||||||
# Attach worker; router should expand to dp_size logical workers
|
# Attach worker; router should expand to dp_size logical workers
|
||||||
r = requests.post(
|
r = requests.post(
|
||||||
f"{router_url}/add_worker",
|
f"{router_url}/workers",
|
||||||
params={"url": worker_url, "api_key": api_key},
|
json={"url": worker_url, "api_key": api_key},
|
||||||
headers={"Authorization": f"Bearer {api_key}"},
|
headers={"Authorization": f"Bearer {api_key}"},
|
||||||
timeout=180,
|
timeout=180,
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
assert r.status_code == 202, f"Expected 202 ACCEPTED, got {r.status_code}: {r.text}"
|
||||||
|
|
||||||
|
# Wait for workers to be registered and expanded
|
||||||
|
_wait_for_workers(
|
||||||
|
router_url,
|
||||||
|
expected_count=2,
|
||||||
|
timeout=60.0,
|
||||||
|
headers={"Authorization": f"Bearer {api_key}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the expanded workers have correct URLs
|
||||||
r = requests.get(
|
r = requests.get(
|
||||||
f"{router_url}/list_workers",
|
f"{router_url}/workers",
|
||||||
headers={"Authorization": f"Bearer {api_key}"},
|
headers={"Authorization": f"Bearer {api_key}"},
|
||||||
timeout=30,
|
timeout=30,
|
||||||
)
|
)
|
||||||
r.raise_for_status()
|
r.raise_for_status()
|
||||||
urls = r.json().get("urls", [])
|
workers = r.json().get("workers", [])
|
||||||
|
urls = [w["url"] for w in workers]
|
||||||
assert len(urls) == 2
|
assert len(urls) == 2
|
||||||
assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"}
|
assert set(urls) == {f"{worker_url}@0", f"{worker_url}@1"}
|
||||||
|
|
||||||
|
|||||||
@@ -267,6 +267,8 @@ def popen_launch_workers_and_router(
|
|||||||
policy,
|
policy,
|
||||||
"--model-path",
|
"--model-path",
|
||||||
model,
|
model,
|
||||||
|
"--log-level",
|
||||||
|
"warn",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Add worker URLs
|
# Add worker URLs
|
||||||
|
|||||||
@@ -133,19 +133,90 @@ class RouterManager:
|
|||||||
time.sleep(0.2)
|
time.sleep(0.2)
|
||||||
raise TimeoutError(f"Router at {base_url} did not become healthy")
|
raise TimeoutError(f"Router at {base_url} did not become healthy")
|
||||||
|
|
||||||
def add_worker(self, base_url: str, worker_url: str) -> None:
|
def add_worker(self, base_url: str, worker_url: str, timeout: float = 30.0) -> None:
|
||||||
r = requests.post(f"{base_url}/add_worker", params={"url": worker_url})
|
r = requests.post(f"{base_url}/workers", json={"url": worker_url})
|
||||||
assert r.status_code == 200, f"add_worker failed: {r.status_code} {r.text}"
|
assert (
|
||||||
|
r.status_code == 202
|
||||||
|
), f"add_worker failed: {r.status_code} {r.text}" # ACCEPTED status
|
||||||
|
|
||||||
def remove_worker(self, base_url: str, worker_url: str) -> None:
|
# Poll until worker is actually added and healthy
|
||||||
r = requests.post(f"{base_url}/remove_worker", params={"url": worker_url})
|
from urllib.parse import quote
|
||||||
assert r.status_code == 200, f"remove_worker failed: {r.status_code} {r.text}"
|
|
||||||
|
encoded_url = quote(worker_url, safe="")
|
||||||
|
start = time.time()
|
||||||
|
with requests.Session() as s:
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
try:
|
||||||
|
r = s.get(f"{base_url}/workers/{encoded_url}", timeout=2)
|
||||||
|
if r.status_code == 200:
|
||||||
|
data = r.json()
|
||||||
|
# Check if registration job failed
|
||||||
|
job_status = data.get("job_status")
|
||||||
|
if job_status and job_status.get("state") == "failed":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Worker registration failed: {job_status.get('message', 'Unknown error')}"
|
||||||
|
)
|
||||||
|
# Check if worker is healthy and registered (not just in job queue)
|
||||||
|
if data.get("is_healthy", False):
|
||||||
|
return
|
||||||
|
# Worker not ready yet, continue polling
|
||||||
|
except requests.RequestException:
|
||||||
|
pass
|
||||||
|
time.sleep(0.1)
|
||||||
|
raise TimeoutError(
|
||||||
|
f"Worker {worker_url} was not added and healthy after {timeout}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
def remove_worker(
|
||||||
|
self, base_url: str, worker_url: str, timeout: float = 30.0
|
||||||
|
) -> None:
|
||||||
|
# URL encode the worker_url for path parameter
|
||||||
|
from urllib.parse import quote
|
||||||
|
|
||||||
|
encoded_url = quote(worker_url, safe="")
|
||||||
|
r = requests.delete(f"{base_url}/workers/{encoded_url}")
|
||||||
|
assert (
|
||||||
|
r.status_code == 202
|
||||||
|
), f"remove_worker failed: {r.status_code} {r.text}" # ACCEPTED status
|
||||||
|
|
||||||
|
# Poll until worker is actually removed (GET returns 404) or timeout
|
||||||
|
start = time.time()
|
||||||
|
last_status = None
|
||||||
|
with requests.Session() as s:
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
try:
|
||||||
|
r = s.get(f"{base_url}/workers/{encoded_url}", timeout=2)
|
||||||
|
if r.status_code == 404:
|
||||||
|
# Worker successfully removed
|
||||||
|
return
|
||||||
|
elif r.status_code == 200:
|
||||||
|
# Check if removal job failed
|
||||||
|
data = r.json()
|
||||||
|
job_status = data.get("job_status")
|
||||||
|
if job_status:
|
||||||
|
last_status = job_status
|
||||||
|
if job_status.get("state") == "failed":
|
||||||
|
raise RuntimeError(
|
||||||
|
f"Worker removal failed: {job_status.get('message', 'Unknown error')}"
|
||||||
|
)
|
||||||
|
# Worker still being processed, continue polling
|
||||||
|
except requests.RequestException:
|
||||||
|
pass
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
# Provide detailed timeout error with last known status
|
||||||
|
error_msg = f"Worker {worker_url} was not removed after {timeout}s"
|
||||||
|
if last_status:
|
||||||
|
error_msg += f". Last job status: {last_status}"
|
||||||
|
raise TimeoutError(error_msg)
|
||||||
|
|
||||||
def list_workers(self, base_url: str) -> list[str]:
|
def list_workers(self, base_url: str) -> list[str]:
|
||||||
r = requests.get(f"{base_url}/list_workers")
|
r = requests.get(f"{base_url}/workers")
|
||||||
assert r.status_code == 200, f"list_workers failed: {r.status_code} {r.text}"
|
assert r.status_code == 200, f"list_workers failed: {r.status_code} {r.text}"
|
||||||
data = r.json()
|
data = r.json()
|
||||||
return data.get("urls", [])
|
# Extract URLs from WorkerInfo objects
|
||||||
|
workers = data.get("workers", [])
|
||||||
|
return [w["url"] for w in workers]
|
||||||
|
|
||||||
def stop_all(self):
|
def stop_all(self):
|
||||||
for p in self._children:
|
for p in self._children:
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ import os
|
|||||||
import subprocess
|
import subprocess
|
||||||
import time
|
import time
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, Iterable, List, Tuple
|
from typing import Dict, Iterable, List, Optional, Tuple
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import requests
|
import requests
|
||||||
@@ -84,7 +84,7 @@ def mock_workers():
|
|||||||
|
|
||||||
procs: List[subprocess.Popen] = []
|
procs: List[subprocess.Popen] = []
|
||||||
|
|
||||||
def _start(n: int, args: List[str] | None = None):
|
def _start(n: int, args: Optional[List[str]] = None):
|
||||||
args = args or []
|
args = args or []
|
||||||
new_procs: List[subprocess.Popen] = []
|
new_procs: List[subprocess.Popen] = []
|
||||||
urls: List[str] = []
|
urls: List[str] = []
|
||||||
|
|||||||
@@ -15,11 +15,9 @@ use tracing::{debug, error, info, warn};
|
|||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
config::{RouterConfig, RoutingMode},
|
config::{RouterConfig, RoutingMode},
|
||||||
core::{
|
core::workflow::{
|
||||||
workflow::{
|
steps::WorkerRemovalRequest, WorkflowContext, WorkflowEngine, WorkflowId,
|
||||||
WorkflowContext, WorkflowEngine, WorkflowId, WorkflowInstanceId, WorkflowStatus,
|
WorkflowInstanceId, WorkflowStatus,
|
||||||
},
|
|
||||||
WorkerManager,
|
|
||||||
},
|
},
|
||||||
metrics::RouterMetrics,
|
metrics::RouterMetrics,
|
||||||
protocols::worker_spec::{JobStatus, WorkerConfigRequest},
|
protocols::worker_spec::{JobStatus, WorkerConfigRequest},
|
||||||
@@ -320,11 +318,29 @@ impl JobQueue {
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
Job::RemoveWorker { url } => {
|
Job::RemoveWorker { url } => {
|
||||||
let result = WorkerManager::remove_worker(url, context);
|
let engine = context
|
||||||
|
.workflow_engine
|
||||||
|
.get()
|
||||||
|
.ok_or_else(|| "Workflow engine not initialized".to_string())?;
|
||||||
|
|
||||||
|
let instance_id = Self::start_worker_removal_workflow(engine, url, context).await?;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Started worker removal workflow for {} (instance: {})",
|
||||||
|
url, instance_id
|
||||||
|
);
|
||||||
|
|
||||||
|
let timeout_duration = Duration::from_secs(30);
|
||||||
|
|
||||||
|
let result =
|
||||||
|
Self::wait_for_workflow_completion(engine, instance_id, url, timeout_duration)
|
||||||
|
.await;
|
||||||
|
|
||||||
// Clean up job status when removing worker
|
// Clean up job status when removing worker
|
||||||
if let Some(queue) = context.worker_job_queue.get() {
|
if let Some(queue) = context.worker_job_queue.get() {
|
||||||
queue.remove_status(url);
|
queue.remove_status(url);
|
||||||
}
|
}
|
||||||
|
|
||||||
result
|
result
|
||||||
}
|
}
|
||||||
Job::InitializeWorkersFromConfig { router_config } => {
|
Job::InitializeWorkersFromConfig { router_config } => {
|
||||||
@@ -424,6 +440,27 @@ impl JobQueue {
|
|||||||
.map_err(|e| format!("Failed to start worker registration workflow: {:?}", e))
|
.map_err(|e| format!("Failed to start worker registration workflow: {:?}", e))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Start worker removal workflow
|
||||||
|
async fn start_worker_removal_workflow(
|
||||||
|
engine: &Arc<WorkflowEngine>,
|
||||||
|
url: &str,
|
||||||
|
context: &Arc<AppContext>,
|
||||||
|
) -> Result<WorkflowInstanceId, String> {
|
||||||
|
let removal_request = WorkerRemovalRequest {
|
||||||
|
url: url.to_string(),
|
||||||
|
dp_aware: context.router_config.dp_aware,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut workflow_context = WorkflowContext::new(WorkflowInstanceId::new());
|
||||||
|
workflow_context.set("removal_request", removal_request);
|
||||||
|
workflow_context.set_arc("app_context", Arc::clone(context));
|
||||||
|
|
||||||
|
engine
|
||||||
|
.start_workflow(WorkflowId::new("worker_removal"), workflow_context)
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to start worker removal workflow: {:?}", e))
|
||||||
|
}
|
||||||
|
|
||||||
/// Wait for workflow completion with adaptive polling
|
/// Wait for workflow completion with adaptive polling
|
||||||
async fn wait_for_workflow_completion(
|
async fn wait_for_workflow_completion(
|
||||||
engine: &Arc<WorkflowEngine>,
|
engine: &Arc<WorkflowEngine>,
|
||||||
|
|||||||
@@ -29,5 +29,5 @@ pub use worker::{
|
|||||||
Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
|
Worker, WorkerFactory, WorkerLoadGuard, WorkerType,
|
||||||
};
|
};
|
||||||
pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder};
|
pub use worker_builder::{BasicWorkerBuilder, DPAwareWorkerBuilder};
|
||||||
pub use worker_manager::{DpInfo, LoadMonitor, ServerInfo, WorkerManager};
|
pub use worker_manager::{LoadMonitor, WorkerManager};
|
||||||
pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats};
|
pub use worker_registry::{WorkerId, WorkerRegistry, WorkerRegistryStats};
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -14,5 +14,5 @@ pub use engine::WorkflowEngine;
|
|||||||
pub use event::{EventBus, EventSubscriber, LoggingSubscriber, WorkflowEvent};
|
pub use event::{EventBus, EventSubscriber, LoggingSubscriber, WorkflowEvent};
|
||||||
pub use executor::{FunctionStep, StepExecutor};
|
pub use executor::{FunctionStep, StepExecutor};
|
||||||
pub use state::WorkflowStateStore;
|
pub use state::WorkflowStateStore;
|
||||||
pub use steps::create_worker_registration_workflow;
|
pub use steps::{create_worker_registration_workflow, create_worker_removal_workflow};
|
||||||
pub use types::*;
|
pub use types::*;
|
||||||
|
|||||||
@@ -2,11 +2,17 @@
|
|||||||
//!
|
//!
|
||||||
//! This module contains concrete step implementations for various workflows:
|
//! This module contains concrete step implementations for various workflows:
|
||||||
//! - Worker registration and activation
|
//! - Worker registration and activation
|
||||||
|
//! - Worker removal
|
||||||
//! - Future: Tokenizer fetching, LoRA updates, etc.
|
//! - Future: Tokenizer fetching, LoRA updates, etc.
|
||||||
|
|
||||||
pub mod worker_registration;
|
pub mod worker_registration;
|
||||||
|
pub mod worker_removal;
|
||||||
|
|
||||||
pub use worker_registration::{
|
pub use worker_registration::{
|
||||||
create_worker_registration_workflow, ActivateWorkerStep, CreateWorkerStep,
|
create_worker_registration_workflow, ActivateWorkerStep, CreateWorkerStep,
|
||||||
DetectConnectionModeStep, DiscoverMetadataStep, RegisterWorkerStep, UpdatePoliciesStep,
|
DetectConnectionModeStep, DiscoverMetadataStep, RegisterWorkerStep, UpdatePoliciesStep,
|
||||||
};
|
};
|
||||||
|
pub use worker_removal::{
|
||||||
|
create_worker_removal_workflow, FindWorkersToRemoveStep, RemoveFromPolicyRegistryStep,
|
||||||
|
RemoveFromWorkerRegistryStep, UpdateRemainingPoliciesStep, WorkerRemovalRequest,
|
||||||
|
};
|
||||||
|
|||||||
@@ -16,13 +16,14 @@ use std::{collections::HashMap, sync::Arc, time::Duration};
|
|||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use once_cell::sync::Lazy;
|
use once_cell::sync::Lazy;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::Value;
|
use serde_json::Value;
|
||||||
use tracing::{debug, info, warn};
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
core::{
|
core::{
|
||||||
workflow::*, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode,
|
workflow::*, BasicWorkerBuilder, CircuitBreakerConfig, ConnectionMode,
|
||||||
DPAwareWorkerBuilder, DpInfo, HealthConfig, Worker, WorkerManager, WorkerType,
|
DPAwareWorkerBuilder, HealthConfig, Worker, WorkerType,
|
||||||
},
|
},
|
||||||
grpc_client::SglangSchedulerClient,
|
grpc_client::SglangSchedulerClient,
|
||||||
protocols::worker_spec::WorkerConfigRequest,
|
protocols::worker_spec::WorkerConfigRequest,
|
||||||
@@ -37,6 +38,82 @@ static HTTP_CLIENT: Lazy<Client> = Lazy::new(|| {
|
|||||||
.expect("Failed to create HTTP client")
|
.expect("Failed to create HTTP client")
|
||||||
});
|
});
|
||||||
|
|
||||||
|
/// Server information returned from worker endpoints
|
||||||
|
#[derive(Debug, Clone, Deserialize, Serialize)]
|
||||||
|
struct ServerInfo {
|
||||||
|
#[serde(alias = "model")]
|
||||||
|
model_id: Option<String>,
|
||||||
|
model_path: Option<String>,
|
||||||
|
dp_size: Option<usize>,
|
||||||
|
version: Option<String>,
|
||||||
|
max_batch_size: Option<usize>,
|
||||||
|
max_total_tokens: Option<usize>,
|
||||||
|
max_prefill_tokens: Option<usize>,
|
||||||
|
max_running_requests: Option<usize>,
|
||||||
|
max_num_reqs: Option<usize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DpInfo {
|
||||||
|
pub dp_size: usize,
|
||||||
|
pub model_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse server info from JSON response using serde
|
||||||
|
fn parse_server_info(json: Value) -> Result<ServerInfo, String> {
|
||||||
|
serde_json::from_value(json).map_err(|e| format!("Failed to parse server info: {}", e))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get server info from /get_server_info endpoint
|
||||||
|
async fn get_server_info(url: &str, api_key: Option<&str>) -> Result<ServerInfo, String> {
|
||||||
|
let base_url = url.trim_end_matches('/');
|
||||||
|
let server_info_url = format!("{}/get_server_info", base_url);
|
||||||
|
|
||||||
|
let mut req = HTTP_CLIENT.get(&server_info_url);
|
||||||
|
if let Some(key) = api_key {
|
||||||
|
req = req.bearer_auth(key);
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = req
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to connect to {}: {}", server_info_url, e))?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(format!(
|
||||||
|
"Server returned status {} from {}",
|
||||||
|
response.status(),
|
||||||
|
server_info_url
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let json = response
|
||||||
|
.json::<Value>()
|
||||||
|
.await
|
||||||
|
.map_err(|e| format!("Failed to parse response from {}: {}", server_info_url, e))?;
|
||||||
|
|
||||||
|
parse_server_info(json)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get DP info for a worker URL
|
||||||
|
async fn get_dp_info(url: &str, api_key: Option<&str>) -> Result<DpInfo, String> {
|
||||||
|
let info = get_server_info(url, api_key).await?;
|
||||||
|
|
||||||
|
let dp_size = info
|
||||||
|
.dp_size
|
||||||
|
.ok_or_else(|| format!("No dp_size in response from {}", url))?;
|
||||||
|
|
||||||
|
let model_id = info
|
||||||
|
.model_id
|
||||||
|
.or_else(|| {
|
||||||
|
info.model_path
|
||||||
|
.and_then(|path| path.split('/').next_back().map(|s| s.to_string()))
|
||||||
|
})
|
||||||
|
.unwrap_or_else(|| "unknown".to_string());
|
||||||
|
|
||||||
|
Ok(DpInfo { dp_size, model_id })
|
||||||
|
}
|
||||||
|
|
||||||
/// Helper: Strip protocol prefix from URL
|
/// Helper: Strip protocol prefix from URL
|
||||||
fn strip_protocol(url: &str) -> String {
|
fn strip_protocol(url: &str) -> String {
|
||||||
url.trim_start_matches("http://")
|
url.trim_start_matches("http://")
|
||||||
@@ -83,49 +160,6 @@ async fn try_grpc_health_check(url: &str, timeout_secs: u64) -> Result<(), Strin
|
|||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Helper: Fetch HTTP metadata
|
|
||||||
async fn fetch_http_metadata(
|
|
||||||
url: &str,
|
|
||||||
api_key: Option<&str>,
|
|
||||||
) -> Result<HashMap<String, String>, String> {
|
|
||||||
let clean_url = strip_protocol(url);
|
|
||||||
let info_url = if clean_url.starts_with("http://") || clean_url.starts_with("https://") {
|
|
||||||
format!("{}/get_server_info", clean_url)
|
|
||||||
} else {
|
|
||||||
format!("http://{}/get_server_info", clean_url)
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut request = HTTP_CLIENT.get(&info_url);
|
|
||||||
if let Some(key) = api_key {
|
|
||||||
request = request.header("Authorization", format!("Bearer {}", key));
|
|
||||||
}
|
|
||||||
|
|
||||||
let response = request
|
|
||||||
.send()
|
|
||||||
.await
|
|
||||||
.map_err(|e| format!("Failed to fetch HTTP metadata: {}", e))?;
|
|
||||||
|
|
||||||
let server_info: Value = response
|
|
||||||
.json()
|
|
||||||
.await
|
|
||||||
.map_err(|e| format!("Failed to parse HTTP metadata: {}", e))?;
|
|
||||||
|
|
||||||
let mut labels = HashMap::new();
|
|
||||||
|
|
||||||
if let Some(model_path) = server_info.get("model_path").and_then(|v| v.as_str()) {
|
|
||||||
if !model_path.is_empty() {
|
|
||||||
labels.insert("model_path".to_string(), model_path.to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if let Some(tokenizer_path) = server_info.get("tokenizer_path").and_then(|v| v.as_str()) {
|
|
||||||
if !tokenizer_path.is_empty() {
|
|
||||||
labels.insert("tokenizer_path".to_string(), tokenizer_path.to_string());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
Ok(labels)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Helper: Fetch gRPC metadata
|
/// Helper: Fetch gRPC metadata
|
||||||
async fn fetch_grpc_metadata(url: &str) -> Result<HashMap<String, String>, String> {
|
async fn fetch_grpc_metadata(url: &str) -> Result<HashMap<String, String>, String> {
|
||||||
let grpc_url = if url.starts_with("grpc://") {
|
let grpc_url = if url.starts_with("grpc://") {
|
||||||
@@ -266,7 +300,18 @@ impl StepExecutor for DiscoverMetadataStep {
|
|||||||
|
|
||||||
let discovered_labels = match connection_mode.as_ref() {
|
let discovered_labels = match connection_mode.as_ref() {
|
||||||
ConnectionMode::Http => {
|
ConnectionMode::Http => {
|
||||||
fetch_http_metadata(&config.url, config.api_key.as_deref()).await
|
match get_server_info(&config.url, config.api_key.as_deref()).await {
|
||||||
|
Ok(server_info) => {
|
||||||
|
let mut labels = HashMap::new();
|
||||||
|
if let Some(model_path) = server_info.model_path {
|
||||||
|
if !model_path.is_empty() {
|
||||||
|
labels.insert("model_path".to_string(), model_path);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(labels)
|
||||||
|
}
|
||||||
|
Err(e) => Err(e),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
ConnectionMode::Grpc { .. } => fetch_grpc_metadata(&config.url).await,
|
ConnectionMode::Grpc { .. } => fetch_grpc_metadata(&config.url).await,
|
||||||
}
|
}
|
||||||
@@ -314,7 +359,7 @@ impl StepExecutor for DiscoverDPInfoStep {
|
|||||||
debug!("Discovering DP info for {} (DP-aware)", config.url);
|
debug!("Discovering DP info for {} (DP-aware)", config.url);
|
||||||
|
|
||||||
// Get DP info from worker
|
// Get DP info from worker
|
||||||
let dp_info = WorkerManager::get_dp_info(&config.url, config.api_key.as_deref())
|
let dp_info = get_dp_info(&config.url, config.api_key.as_deref())
|
||||||
.await
|
.await
|
||||||
.map_err(|e| WorkflowError::StepFailed {
|
.map_err(|e| WorkflowError::StepFailed {
|
||||||
step_id: StepId::new("discover_dp_info"),
|
step_id: StepId::new("discover_dp_info"),
|
||||||
@@ -327,7 +372,7 @@ impl StepExecutor for DiscoverDPInfoStep {
|
|||||||
);
|
);
|
||||||
|
|
||||||
// Store DP info in context
|
// Store DP info in context
|
||||||
context.set("dp_info", Arc::new(dp_info));
|
context.set("dp_info", dp_info);
|
||||||
|
|
||||||
Ok(StepResult::Success)
|
Ok(StepResult::Success)
|
||||||
}
|
}
|
||||||
@@ -522,7 +567,7 @@ impl StepExecutor for CreateWorkerStep {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Store workers (plural) and labels in context
|
// Store workers (plural) and labels in context
|
||||||
context.set("workers", Arc::new(workers));
|
context.set("workers", workers);
|
||||||
context.set("labels", final_labels);
|
context.set("labels", final_labels);
|
||||||
|
|
||||||
Ok(StepResult::Success)
|
Ok(StepResult::Success)
|
||||||
@@ -595,7 +640,7 @@ impl StepExecutor for RegisterWorkerStep {
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
context.set("worker_ids", Arc::new(worker_ids));
|
context.set("worker_ids", worker_ids);
|
||||||
Ok(StepResult::Success)
|
Ok(StepResult::Success)
|
||||||
} else {
|
} else {
|
||||||
// Non-DP-aware path: Register single worker
|
// Non-DP-aware path: Register single worker
|
||||||
|
|||||||
310
sgl-router/src/core/workflow/steps/worker_removal.rs
Normal file
310
sgl-router/src/core/workflow/steps/worker_removal.rs
Normal file
@@ -0,0 +1,310 @@
|
|||||||
|
//! Worker Removal Workflow Steps
|
||||||
|
//!
|
||||||
|
//! This module implements the workflow steps for removing workers from the router.
|
||||||
|
//! Handles both single worker removal and DP-aware worker removal with prefix matching.
|
||||||
|
//!
|
||||||
|
//! Steps:
|
||||||
|
//! 1. FindWorkersToRemove - Identify workers to remove based on URL (handles DP-aware prefix matching)
|
||||||
|
//! 2. RemoveFromPolicyRegistry - Remove workers from policy registry and cache-aware policies
|
||||||
|
//! 3. RemoveFromWorkerRegistry - Remove workers from worker registry
|
||||||
|
//! 4. UpdateRemainingPolicies - Update cache-aware policies for remaining workers
|
||||||
|
|
||||||
|
use std::{collections::HashSet, sync::Arc};
|
||||||
|
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use tracing::{debug, info};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
core::{workflow::*, Worker},
|
||||||
|
server::AppContext,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Request structure for worker removal
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct WorkerRemovalRequest {
|
||||||
|
pub url: String,
|
||||||
|
pub dp_aware: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Step 1: Find workers to remove based on URL
|
||||||
|
pub struct FindWorkersToRemoveStep;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl StepExecutor for FindWorkersToRemoveStep {
|
||||||
|
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||||
|
let request: Arc<WorkerRemovalRequest> = context
|
||||||
|
.get("removal_request")
|
||||||
|
.ok_or_else(|| WorkflowError::ContextValueNotFound("removal_request".to_string()))?;
|
||||||
|
let app_context: Arc<AppContext> = context
|
||||||
|
.get("app_context")
|
||||||
|
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Finding workers to remove for {} (dp_aware: {})",
|
||||||
|
request.url, request.dp_aware
|
||||||
|
);
|
||||||
|
|
||||||
|
let workers_to_remove: Vec<Arc<dyn Worker>> = if request.dp_aware {
|
||||||
|
// DP-aware: Find all workers with matching prefix
|
||||||
|
let worker_url_prefix = format!("{}@", request.url);
|
||||||
|
let all_workers = app_context.worker_registry.get_all();
|
||||||
|
|
||||||
|
all_workers
|
||||||
|
.iter()
|
||||||
|
.filter(|worker| worker.url().starts_with(&worker_url_prefix))
|
||||||
|
.cloned()
|
||||||
|
.collect()
|
||||||
|
} else {
|
||||||
|
// Non-DP-aware: Find single worker by exact URL
|
||||||
|
match app_context.worker_registry.get_by_url(&request.url) {
|
||||||
|
Some(worker) => vec![worker],
|
||||||
|
None => Vec::new(),
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if workers_to_remove.is_empty() {
|
||||||
|
let error_msg = if request.dp_aware {
|
||||||
|
format!("No workers found with prefix {}@", request.url)
|
||||||
|
} else {
|
||||||
|
format!("Worker {} not found", request.url)
|
||||||
|
};
|
||||||
|
return Err(WorkflowError::StepFailed {
|
||||||
|
step_id: StepId::new("find_workers_to_remove"),
|
||||||
|
message: error_msg,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Found {} worker(s) to remove for {}",
|
||||||
|
workers_to_remove.len(),
|
||||||
|
request.url
|
||||||
|
);
|
||||||
|
|
||||||
|
// Store workers and their model IDs for subsequent steps
|
||||||
|
let worker_urls: Vec<String> = workers_to_remove
|
||||||
|
.iter()
|
||||||
|
.map(|w| w.url().to_string())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let affected_models: HashSet<String> = workers_to_remove
|
||||||
|
.iter()
|
||||||
|
.map(|w| w.model_id().to_string())
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
context.set("workers_to_remove", workers_to_remove);
|
||||||
|
context.set("worker_urls", worker_urls);
|
||||||
|
context.set("affected_models", affected_models);
|
||||||
|
|
||||||
|
Ok(StepResult::Success)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_retryable(&self, _error: &WorkflowError) -> bool {
|
||||||
|
false // Worker not found is not retryable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Step 2: Remove workers from policy registry
|
||||||
|
pub struct RemoveFromPolicyRegistryStep;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl StepExecutor for RemoveFromPolicyRegistryStep {
|
||||||
|
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||||
|
let app_context: Arc<AppContext> = context
|
||||||
|
.get("app_context")
|
||||||
|
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
|
||||||
|
let workers_to_remove: Arc<Vec<Arc<dyn Worker>>> = context
|
||||||
|
.get("workers_to_remove")
|
||||||
|
.ok_or_else(|| WorkflowError::ContextValueNotFound("workers_to_remove".to_string()))?;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Removing {} worker(s) from policy registry",
|
||||||
|
workers_to_remove.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
for worker in workers_to_remove.iter() {
|
||||||
|
let model_id = worker.model_id().to_string();
|
||||||
|
let worker_url = worker.url();
|
||||||
|
|
||||||
|
// Remove from cache-aware policy
|
||||||
|
app_context
|
||||||
|
.policy_registry
|
||||||
|
.remove_worker_from_cache_aware(&model_id, worker_url);
|
||||||
|
|
||||||
|
// Notify policy registry
|
||||||
|
app_context.policy_registry.on_worker_removed(&model_id);
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Removed worker {} from policy registry (model: {})",
|
||||||
|
worker_url, model_id
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(StepResult::Success)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_retryable(&self, _error: &WorkflowError) -> bool {
|
||||||
|
false // Policy removal is not retryable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Step 3: Remove workers from worker registry
|
||||||
|
pub struct RemoveFromWorkerRegistryStep;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl StepExecutor for RemoveFromWorkerRegistryStep {
|
||||||
|
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||||
|
let app_context: Arc<AppContext> = context
|
||||||
|
.get("app_context")
|
||||||
|
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
|
||||||
|
let worker_urls: Arc<Vec<String>> = context
|
||||||
|
.get("worker_urls")
|
||||||
|
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker_urls".to_string()))?;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Removing {} worker(s) from worker registry",
|
||||||
|
worker_urls.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut removed_count = 0;
|
||||||
|
for worker_url in worker_urls.iter() {
|
||||||
|
if app_context
|
||||||
|
.worker_registry
|
||||||
|
.remove_by_url(worker_url)
|
||||||
|
.is_some()
|
||||||
|
{
|
||||||
|
removed_count += 1;
|
||||||
|
debug!("Removed worker {} from registry", worker_url);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if removed_count != worker_urls.len() {
|
||||||
|
return Err(WorkflowError::StepFailed {
|
||||||
|
step_id: StepId::new("remove_from_worker_registry"),
|
||||||
|
message: format!(
|
||||||
|
"Expected to remove {} workers but only removed {}",
|
||||||
|
worker_urls.len(),
|
||||||
|
removed_count
|
||||||
|
),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(StepResult::Success)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_retryable(&self, _error: &WorkflowError) -> bool {
|
||||||
|
false // Worker removal is not retryable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Step 4: Update cache-aware policies for remaining workers
|
||||||
|
pub struct UpdateRemainingPoliciesStep;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl StepExecutor for UpdateRemainingPoliciesStep {
|
||||||
|
async fn execute(&self, context: &mut WorkflowContext) -> WorkflowResult<StepResult> {
|
||||||
|
let app_context: Arc<AppContext> = context
|
||||||
|
.get("app_context")
|
||||||
|
.ok_or_else(|| WorkflowError::ContextValueNotFound("app_context".to_string()))?;
|
||||||
|
let affected_models: Arc<HashSet<String>> = context
|
||||||
|
.get("affected_models")
|
||||||
|
.ok_or_else(|| WorkflowError::ContextValueNotFound("affected_models".to_string()))?;
|
||||||
|
let worker_urls: Arc<Vec<String>> = context
|
||||||
|
.get("worker_urls")
|
||||||
|
.ok_or_else(|| WorkflowError::ContextValueNotFound("worker_urls".to_string()))?;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Updating cache-aware policies for {} affected model(s)",
|
||||||
|
affected_models.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
for model_id in affected_models.iter() {
|
||||||
|
let remaining_workers = app_context.worker_registry.get_by_model_fast(model_id);
|
||||||
|
|
||||||
|
if let Some(policy) = app_context.policy_registry.get_policy(model_id) {
|
||||||
|
if policy.name() == "cache_aware" && !remaining_workers.is_empty() {
|
||||||
|
app_context
|
||||||
|
.policy_registry
|
||||||
|
.init_cache_aware_policy(model_id, &remaining_workers);
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"Updated cache-aware policy for model {} ({} remaining workers)",
|
||||||
|
model_id,
|
||||||
|
remaining_workers.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Log final result at info level
|
||||||
|
if worker_urls.len() == 1 {
|
||||||
|
info!("Removed worker {}", worker_urls[0]);
|
||||||
|
} else {
|
||||||
|
info!(
|
||||||
|
"Removed {} DP-aware workers: {:?}",
|
||||||
|
worker_urls.len(),
|
||||||
|
worker_urls
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(StepResult::Success)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_retryable(&self, _error: &WorkflowError) -> bool {
|
||||||
|
false // Policy update is not retryable
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create a worker removal workflow definition
|
||||||
|
pub fn create_worker_removal_workflow() -> WorkflowDefinition {
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
WorkflowDefinition::new("worker_removal", "Remove worker from router")
|
||||||
|
.add_step(
|
||||||
|
StepDefinition::new(
|
||||||
|
"find_workers_to_remove",
|
||||||
|
"Find workers to remove",
|
||||||
|
Arc::new(FindWorkersToRemoveStep),
|
||||||
|
)
|
||||||
|
.with_timeout(Duration::from_secs(10))
|
||||||
|
.with_retry(RetryPolicy {
|
||||||
|
max_attempts: 1,
|
||||||
|
backoff: BackoffStrategy::Fixed(Duration::from_secs(0)),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.add_step(
|
||||||
|
StepDefinition::new(
|
||||||
|
"remove_from_policy_registry",
|
||||||
|
"Remove workers from policy registry",
|
||||||
|
Arc::new(RemoveFromPolicyRegistryStep),
|
||||||
|
)
|
||||||
|
.with_timeout(Duration::from_secs(10))
|
||||||
|
.with_retry(RetryPolicy {
|
||||||
|
max_attempts: 1,
|
||||||
|
backoff: BackoffStrategy::Fixed(Duration::from_secs(0)),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.add_step(
|
||||||
|
StepDefinition::new(
|
||||||
|
"remove_from_worker_registry",
|
||||||
|
"Remove workers from worker registry",
|
||||||
|
Arc::new(RemoveFromWorkerRegistryStep),
|
||||||
|
)
|
||||||
|
.with_timeout(Duration::from_secs(10))
|
||||||
|
.with_retry(RetryPolicy {
|
||||||
|
max_attempts: 1,
|
||||||
|
backoff: BackoffStrategy::Fixed(Duration::from_secs(0)),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
.add_step(
|
||||||
|
StepDefinition::new(
|
||||||
|
"update_remaining_policies",
|
||||||
|
"Update cache-aware policies for remaining workers",
|
||||||
|
Arc::new(UpdateRemainingPoliciesStep),
|
||||||
|
)
|
||||||
|
.with_timeout(Duration::from_secs(10))
|
||||||
|
.with_retry(RetryPolicy {
|
||||||
|
max_attempts: 1,
|
||||||
|
backoff: BackoffStrategy::Fixed(Duration::from_secs(0)),
|
||||||
|
}),
|
||||||
|
)
|
||||||
|
}
|
||||||
@@ -149,7 +149,11 @@ pub fn responses_to_chat(req: &ResponsesRequest) -> Result<ChatCompletionRequest
|
|||||||
|
|
||||||
Ok(ChatCompletionRequest {
|
Ok(ChatCompletionRequest {
|
||||||
messages,
|
messages,
|
||||||
model: req.model.clone().unwrap_or_else(|| "default".to_string()),
|
model: if req.model.is_empty() {
|
||||||
|
"default".to_string()
|
||||||
|
} else {
|
||||||
|
req.model.clone()
|
||||||
|
},
|
||||||
temperature: req.temperature,
|
temperature: req.temperature,
|
||||||
max_completion_tokens: req.max_output_tokens,
|
max_completion_tokens: req.max_output_tokens,
|
||||||
stream: is_streaming,
|
stream: is_streaming,
|
||||||
@@ -311,7 +315,7 @@ mod tests {
|
|||||||
let req = ResponsesRequest {
|
let req = ResponsesRequest {
|
||||||
input: ResponseInput::Text("Hello, world!".to_string()),
|
input: ResponseInput::Text("Hello, world!".to_string()),
|
||||||
instructions: Some("You are a helpful assistant.".to_string()),
|
instructions: Some("You are a helpful assistant.".to_string()),
|
||||||
model: Some("gpt-4".to_string()),
|
model: "gpt-4".to_string(),
|
||||||
temperature: Some(0.7),
|
temperature: Some(0.7),
|
||||||
..Default::default()
|
..Default::default()
|
||||||
};
|
};
|
||||||
|
|||||||
@@ -324,10 +324,11 @@ async fn route_responses_background(
|
|||||||
incomplete_details: None,
|
incomplete_details: None,
|
||||||
instructions: request.instructions.clone(),
|
instructions: request.instructions.clone(),
|
||||||
max_output_tokens: request.max_output_tokens,
|
max_output_tokens: request.max_output_tokens,
|
||||||
model: request
|
model: if request.model.is_empty() {
|
||||||
.model
|
"default".to_string()
|
||||||
.clone()
|
} else {
|
||||||
.unwrap_or_else(|| "default".to_string()),
|
request.model.clone()
|
||||||
|
},
|
||||||
output: Vec::new(),
|
output: Vec::new(),
|
||||||
parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true),
|
parallel_tool_calls: request.parallel_tool_calls.unwrap_or(true),
|
||||||
previous_response_id: request.previous_response_id.clone(),
|
previous_response_id: request.previous_response_id.clone(),
|
||||||
@@ -622,10 +623,11 @@ async fn process_and_transform_sse_stream(
|
|||||||
|
|
||||||
// Create event emitter for OpenAI-compatible streaming
|
// Create event emitter for OpenAI-compatible streaming
|
||||||
let response_id = format!("resp_{}", Uuid::new_v4());
|
let response_id = format!("resp_{}", Uuid::new_v4());
|
||||||
let model = original_request
|
let model = if original_request.model.is_empty() {
|
||||||
.model
|
"default".to_string()
|
||||||
.clone()
|
} else {
|
||||||
.unwrap_or_else(|| "default".to_string());
|
original_request.model.clone()
|
||||||
|
};
|
||||||
let created_at = chrono::Utc::now().timestamp() as u64;
|
let created_at = chrono::Utc::now().timestamp() as u64;
|
||||||
let mut event_emitter = ResponseStreamEventEmitter::new(response_id, model, created_at);
|
let mut event_emitter = ResponseStreamEventEmitter::new(response_id, model, created_at);
|
||||||
|
|
||||||
|
|||||||
@@ -608,10 +608,11 @@ async fn execute_tool_loop_streaming_internal(
|
|||||||
|
|
||||||
// Create response event emitter
|
// Create response event emitter
|
||||||
let response_id = format!("resp_{}", Uuid::new_v4());
|
let response_id = format!("resp_{}", Uuid::new_v4());
|
||||||
let model = current_request
|
let model = if current_request.model.is_empty() {
|
||||||
.model
|
"default".to_string()
|
||||||
.clone()
|
} else {
|
||||||
.unwrap_or_else(|| "default".to_string());
|
current_request.model.clone()
|
||||||
|
};
|
||||||
let created_at = SystemTime::now()
|
let created_at = SystemTime::now()
|
||||||
.duration_since(UNIX_EPOCH)
|
.duration_since(UNIX_EPOCH)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
|||||||
@@ -22,8 +22,12 @@ use tracing::{error, info, warn, Level};
|
|||||||
use crate::{
|
use crate::{
|
||||||
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
|
config::{ConnectionMode, HistoryBackend, RouterConfig, RoutingMode},
|
||||||
core::{
|
core::{
|
||||||
worker_to_info, workflow::WorkflowEngine, Job, JobQueue, JobQueueConfig, LoadMonitor,
|
worker_to_info,
|
||||||
WorkerManager, WorkerRegistry, WorkerType,
|
workflow::{
|
||||||
|
create_worker_registration_workflow, create_worker_removal_workflow, LoggingSubscriber,
|
||||||
|
WorkflowEngine,
|
||||||
|
},
|
||||||
|
Job, JobQueue, JobQueueConfig, LoadMonitor, WorkerManager, WorkerRegistry, WorkerType,
|
||||||
},
|
},
|
||||||
data_connector::{
|
data_connector::{
|
||||||
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
|
MemoryConversationItemStorage, MemoryConversationStorage, MemoryResponseStorage,
|
||||||
@@ -439,51 +443,6 @@ async fn v1_conversations_delete_item(
|
|||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
|
||||||
struct AddWorkerQuery {
|
|
||||||
url: String,
|
|
||||||
api_key: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn add_worker(
|
|
||||||
State(state): State<Arc<AppState>>,
|
|
||||||
Query(AddWorkerQuery { url, api_key }): Query<AddWorkerQuery>,
|
|
||||||
) -> Response {
|
|
||||||
// Warn if router has API key but worker is being added without one
|
|
||||||
if state.context.router_config.api_key.is_some() && api_key.is_none() {
|
|
||||||
warn!(
|
|
||||||
"Adding worker {} without API key while router has API key configured. \
|
|
||||||
Worker will be accessible without authentication. \
|
|
||||||
If the worker requires the same API key as the router, please specify it explicitly.",
|
|
||||||
url
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
let result = WorkerManager::add_worker(&url, &api_key, &state.context).await;
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(message) => (StatusCode::OK, message).into_response(),
|
|
||||||
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn list_workers(State(state): State<Arc<AppState>>) -> Response {
|
|
||||||
let worker_list = WorkerManager::get_worker_urls(&state.context.worker_registry);
|
|
||||||
Json(json!({ "urls": worker_list })).into_response()
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn remove_worker(
|
|
||||||
State(state): State<Arc<AppState>>,
|
|
||||||
Query(AddWorkerQuery { url, .. }): Query<AddWorkerQuery>,
|
|
||||||
) -> Response {
|
|
||||||
let result = WorkerManager::remove_worker(&url, &state.context);
|
|
||||||
|
|
||||||
match result {
|
|
||||||
Ok(message) => (StatusCode::OK, message).into_response(),
|
|
||||||
Err(error) => (StatusCode::BAD_REQUEST, error).into_response(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
||||||
match WorkerManager::flush_cache_all(&state.context.worker_registry, &state.context.client)
|
match WorkerManager::flush_cache_all(&state.context.worker_registry, &state.context.client)
|
||||||
.await
|
.await
|
||||||
@@ -566,6 +525,12 @@ async fn create_worker(
|
|||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Populate dp_aware from router's configuration
|
||||||
|
let config = WorkerConfigRequest {
|
||||||
|
dp_aware: state.context.router_config.dp_aware,
|
||||||
|
..config
|
||||||
|
};
|
||||||
|
|
||||||
// Submit job for async processing
|
// Submit job for async processing
|
||||||
let worker_url = config.url.clone();
|
let worker_url = config.url.clone();
|
||||||
let job = Job::AddWorker {
|
let job = Job::AddWorker {
|
||||||
@@ -761,9 +726,6 @@ pub fn build_app(
|
|||||||
.route("/get_server_info", get(get_server_info));
|
.route("/get_server_info", get(get_server_info));
|
||||||
|
|
||||||
let admin_routes = Router::new()
|
let admin_routes = Router::new()
|
||||||
.route("/add_worker", post(add_worker))
|
|
||||||
.route("/remove_worker", post(remove_worker))
|
|
||||||
.route("/list_workers", get(list_workers))
|
|
||||||
.route("/flush_cache", post(flush_cache))
|
.route("/flush_cache", post(flush_cache))
|
||||||
.route("/get_loads", get(get_loads))
|
.route("/get_loads", get(get_loads))
|
||||||
.route_layer(axum::middleware::from_fn_with_state(
|
.route_layer(axum::middleware::from_fn_with_state(
|
||||||
@@ -1018,15 +980,16 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
|
|
||||||
engine
|
engine
|
||||||
.event_bus()
|
.event_bus()
|
||||||
.subscribe(Arc::new(crate::core::workflow::LoggingSubscriber))
|
.subscribe(Arc::new(LoggingSubscriber))
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
engine.register_workflow(crate::core::workflow::create_worker_registration_workflow());
|
engine.register_workflow(create_worker_registration_workflow());
|
||||||
|
engine.register_workflow(create_worker_removal_workflow());
|
||||||
app_context
|
app_context
|
||||||
.workflow_engine
|
.workflow_engine
|
||||||
.set(engine)
|
.set(engine)
|
||||||
.expect("WorkflowEngine should only be initialized once");
|
.expect("WorkflowEngine should only be initialized once");
|
||||||
info!("Workflow engine initialized with worker registration workflow");
|
info!("Workflow engine initialized with worker registration and removal workflows");
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"Initializing workers for routing mode: {:?}",
|
"Initializing workers for routing mode: {:?}",
|
||||||
|
|||||||
@@ -18,11 +18,7 @@ use rustls;
|
|||||||
use tokio::{task, time};
|
use tokio::{task, time};
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
use crate::{
|
use crate::{core::Job, protocols::worker_spec::WorkerConfigRequest, server::AppContext};
|
||||||
core::{Job, WorkerManager},
|
|
||||||
protocols::worker_spec::WorkerConfigRequest,
|
|
||||||
server::AppContext,
|
|
||||||
};
|
|
||||||
|
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct ServiceDiscoveryConfig {
|
pub struct ServiceDiscoveryConfig {
|
||||||
@@ -386,7 +382,7 @@ async fn handle_pod_event(
|
|||||||
reasoning_parser: None,
|
reasoning_parser: None,
|
||||||
tool_parser: None,
|
tool_parser: None,
|
||||||
chat_template: None,
|
chat_template: None,
|
||||||
api_key: None,
|
api_key: app_context.router_config.api_key.clone(),
|
||||||
health_check_timeout_secs: app_context.router_config.health_check.timeout_secs,
|
health_check_timeout_secs: app_context.router_config.health_check.timeout_secs,
|
||||||
health_check_interval_secs: app_context
|
health_check_interval_secs: app_context
|
||||||
.router_config
|
.router_config
|
||||||
@@ -453,8 +449,24 @@ async fn handle_pod_deletion(
|
|||||||
pod_info.name, pod_info.pod_type, worker_url
|
pod_info.name, pod_info.pod_type, worker_url
|
||||||
);
|
);
|
||||||
|
|
||||||
if let Err(e) = WorkerManager::remove_worker(&worker_url, &app_context) {
|
let job = Job::RemoveWorker {
|
||||||
error!("Failed to remove worker {}: {}", worker_url, e);
|
url: worker_url.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(job_queue) = app_context.worker_job_queue.get() {
|
||||||
|
if let Err(e) = job_queue.submit(job).await {
|
||||||
|
error!(
|
||||||
|
"Failed to submit worker removal job for {}: {}",
|
||||||
|
worker_url, e
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
debug!("Submitted worker removal job for {}", worker_url);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
error!(
|
||||||
|
"JobQueue not initialized, cannot remove worker {}",
|
||||||
|
worker_url
|
||||||
|
);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
debug!(
|
debug!(
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ use sglang_router_rs::{
|
|||||||
config::{
|
config::{
|
||||||
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
CircuitBreakerConfig, ConnectionMode, PolicyConfig, RetryConfig, RouterConfig, RoutingMode,
|
||||||
},
|
},
|
||||||
core::WorkerManager,
|
core::Job,
|
||||||
routers::{RouterFactory, RouterTrait},
|
routers::{RouterFactory, RouterTrait},
|
||||||
server::AppContext,
|
server::AppContext,
|
||||||
};
|
};
|
||||||
@@ -112,22 +112,51 @@ impl TestContext {
|
|||||||
// Create app context
|
// Create app context
|
||||||
let app_context = common::create_test_context(config.clone());
|
let app_context = common::create_test_context(config.clone());
|
||||||
|
|
||||||
// Initialize workers in the registry before creating router
|
// Submit worker initialization job (same as real server does)
|
||||||
if !worker_urls.is_empty() {
|
if !worker_urls.is_empty() {
|
||||||
WorkerManager::initialize_workers(&config, &app_context.worker_registry, None)
|
let job_queue = app_context
|
||||||
|
.worker_job_queue
|
||||||
|
.get()
|
||||||
|
.expect("JobQueue should be initialized");
|
||||||
|
let job = Job::InitializeWorkersFromConfig {
|
||||||
|
router_config: Box::new(config.clone()),
|
||||||
|
};
|
||||||
|
job_queue
|
||||||
|
.submit(job)
|
||||||
.await
|
.await
|
||||||
.expect("Failed to initialize workers");
|
.expect("Failed to submit worker initialization job");
|
||||||
|
|
||||||
|
// Poll until all workers are healthy (up to 10 seconds)
|
||||||
|
let expected_count = worker_urls.len();
|
||||||
|
let start = tokio::time::Instant::now();
|
||||||
|
let timeout_duration = tokio::time::Duration::from_secs(10);
|
||||||
|
loop {
|
||||||
|
let healthy_workers = app_context
|
||||||
|
.worker_registry
|
||||||
|
.get_all()
|
||||||
|
.iter()
|
||||||
|
.filter(|w| w.is_healthy())
|
||||||
|
.count();
|
||||||
|
|
||||||
|
if healthy_workers >= expected_count {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if start.elapsed() > timeout_duration {
|
||||||
|
panic!(
|
||||||
|
"Timeout waiting for {} workers to become healthy (only {} ready)",
|
||||||
|
expected_count, healthy_workers
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create router
|
// Create router
|
||||||
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
||||||
let router = Arc::from(router);
|
let router = Arc::from(router);
|
||||||
|
|
||||||
// Wait for router to discover workers
|
|
||||||
if !workers.is_empty() {
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
workers,
|
workers,
|
||||||
router,
|
router,
|
||||||
@@ -711,221 +740,6 @@ mod model_info_tests {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
|
||||||
mod worker_management_tests {
|
|
||||||
use super::*;
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_add_new_worker() {
|
|
||||||
let ctx = TestContext::new(vec![]).await;
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
// Start a mock worker
|
|
||||||
let mut worker = MockWorker::new(MockWorkerConfig {
|
|
||||||
port: 18301,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
});
|
|
||||||
let url = worker.start().await.unwrap();
|
|
||||||
|
|
||||||
// Add the worker
|
|
||||||
let req = Request::builder()
|
|
||||||
.method("POST")
|
|
||||||
.uri(format!("/add_worker?url={}", url))
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let resp = app.clone().oneshot(req).await.unwrap();
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
// List workers to verify
|
|
||||||
let req = Request::builder()
|
|
||||||
.method("GET")
|
|
||||||
.uri("/list_workers")
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let resp = app.oneshot(req).await.unwrap();
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
|
||||||
let workers = body_json["urls"].as_array().unwrap();
|
|
||||||
assert!(workers.iter().any(|w| w.as_str().unwrap() == url));
|
|
||||||
|
|
||||||
worker.stop().await;
|
|
||||||
ctx.shutdown().await;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_remove_existing_worker() {
|
|
||||||
let ctx = TestContext::new(vec![MockWorkerConfig {
|
|
||||||
port: 18302,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
}])
|
|
||||||
.await;
|
|
||||||
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
// Get the worker URL
|
|
||||||
let req = Request::builder()
|
|
||||||
.method("GET")
|
|
||||||
.uri("/list_workers")
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
let resp = app.clone().oneshot(req).await.unwrap();
|
|
||||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
|
||||||
let workers = body_json["urls"].as_array().unwrap();
|
|
||||||
let worker_url = workers[0].as_str().unwrap();
|
|
||||||
|
|
||||||
// Remove the worker
|
|
||||||
let req = Request::builder()
|
|
||||||
.method("POST")
|
|
||||||
.uri(format!("/remove_worker?url={}", worker_url))
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let resp = app.clone().oneshot(req).await.unwrap();
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
let req = Request::builder()
|
|
||||||
.method("GET")
|
|
||||||
.uri("/list_workers")
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
let resp = app.oneshot(req).await.unwrap();
|
|
||||||
let body = axum::body::to_bytes(resp.into_body(), usize::MAX)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let body_json: serde_json::Value = serde_json::from_slice(&body).unwrap();
|
|
||||||
let workers = body_json["urls"].as_array().unwrap();
|
|
||||||
assert!(workers.is_empty());
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_add_worker_invalid_url() {
|
|
||||||
let ctx = TestContext::new(vec![]).await;
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
// Invalid URL format
|
|
||||||
let req = Request::builder()
|
|
||||||
.method("POST")
|
|
||||||
.uri("/add_worker?url=not-a-valid-url")
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let resp = app.clone().oneshot(req).await.unwrap();
|
|
||||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
|
||||||
|
|
||||||
// Missing URL parameter
|
|
||||||
let req = Request::builder()
|
|
||||||
.method("POST")
|
|
||||||
.uri("/add_worker")
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let resp = app.clone().oneshot(req).await.unwrap();
|
|
||||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
|
||||||
|
|
||||||
// Empty URL
|
|
||||||
let req = Request::builder()
|
|
||||||
.method("POST")
|
|
||||||
.uri("/add_worker?url=")
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
let resp = app.oneshot(req).await.unwrap();
|
|
||||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
|
||||||
|
|
||||||
ctx.shutdown().await;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_add_duplicate_worker() {
|
|
||||||
// Start a mock worker
|
|
||||||
let mut worker = MockWorker::new(MockWorkerConfig {
|
|
||||||
port: 18303,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Healthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
});
|
|
||||||
let url = worker.start().await.unwrap();
|
|
||||||
|
|
||||||
let ctx = TestContext::new(vec![]).await;
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
// Add worker first time
|
|
||||||
let req = Request::builder()
|
|
||||||
.method("POST")
|
|
||||||
.uri(format!("/add_worker?url={}", url))
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
let resp = app.clone().oneshot(req).await.unwrap();
|
|
||||||
assert_eq!(resp.status(), StatusCode::OK);
|
|
||||||
|
|
||||||
tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
|
|
||||||
|
|
||||||
// Try to add same worker again
|
|
||||||
let req = Request::builder()
|
|
||||||
.method("POST")
|
|
||||||
.uri(format!("/add_worker?url={}", url))
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
let resp = app.oneshot(req).await.unwrap();
|
|
||||||
// Should return error for duplicate
|
|
||||||
assert_eq!(resp.status(), StatusCode::BAD_REQUEST);
|
|
||||||
|
|
||||||
worker.stop().await;
|
|
||||||
ctx.shutdown().await;
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn test_add_unhealthy_worker() {
|
|
||||||
// Start unhealthy worker
|
|
||||||
let mut worker = MockWorker::new(MockWorkerConfig {
|
|
||||||
port: 18304,
|
|
||||||
worker_type: WorkerType::Regular,
|
|
||||||
health_status: HealthStatus::Unhealthy,
|
|
||||||
response_delay_ms: 0,
|
|
||||||
fail_rate: 0.0,
|
|
||||||
});
|
|
||||||
let url = worker.start().await.unwrap();
|
|
||||||
|
|
||||||
let ctx = TestContext::new(vec![]).await;
|
|
||||||
let app = ctx.create_app().await;
|
|
||||||
|
|
||||||
// Try to add unhealthy worker
|
|
||||||
let req = Request::builder()
|
|
||||||
.method("POST")
|
|
||||||
.uri(format!("/add_worker?url={}", url))
|
|
||||||
.body(Body::empty())
|
|
||||||
.unwrap();
|
|
||||||
let resp = app.oneshot(req).await.unwrap();
|
|
||||||
|
|
||||||
// Router should reject unhealthy workers
|
|
||||||
assert!(
|
|
||||||
resp.status() == StatusCode::BAD_REQUEST
|
|
||||||
|| resp.status() == StatusCode::SERVICE_UNAVAILABLE
|
|
||||||
);
|
|
||||||
|
|
||||||
worker.stop().await;
|
|
||||||
ctx.shutdown().await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod router_policy_tests {
|
mod router_policy_tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
|
|||||||
let worker_job_queue = Arc::new(OnceLock::new());
|
let worker_job_queue = Arc::new(OnceLock::new());
|
||||||
let workflow_engine = Arc::new(OnceLock::new());
|
let workflow_engine = Arc::new(OnceLock::new());
|
||||||
|
|
||||||
Arc::new(AppContext::new(
|
let app_context = Arc::new(AppContext::new(
|
||||||
config,
|
config,
|
||||||
client,
|
client,
|
||||||
rate_limiter,
|
rate_limiter,
|
||||||
@@ -81,7 +81,32 @@ pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
|
|||||||
load_monitor,
|
load_monitor,
|
||||||
worker_job_queue,
|
worker_job_queue,
|
||||||
workflow_engine,
|
workflow_engine,
|
||||||
))
|
));
|
||||||
|
|
||||||
|
// Initialize JobQueue after AppContext is created
|
||||||
|
let weak_context = Arc::downgrade(&app_context);
|
||||||
|
let job_queue = sglang_router_rs::core::JobQueue::new(
|
||||||
|
sglang_router_rs::core::JobQueueConfig::default(),
|
||||||
|
weak_context,
|
||||||
|
);
|
||||||
|
app_context
|
||||||
|
.worker_job_queue
|
||||||
|
.set(job_queue)
|
||||||
|
.expect("JobQueue should only be initialized once");
|
||||||
|
|
||||||
|
// Initialize WorkflowEngine and register workflows
|
||||||
|
use sglang_router_rs::core::workflow::{
|
||||||
|
create_worker_registration_workflow, create_worker_removal_workflow, WorkflowEngine,
|
||||||
|
};
|
||||||
|
let engine = Arc::new(WorkflowEngine::new());
|
||||||
|
engine.register_workflow(create_worker_registration_workflow());
|
||||||
|
engine.register_workflow(create_worker_removal_workflow());
|
||||||
|
app_context
|
||||||
|
.workflow_engine
|
||||||
|
.set(engine)
|
||||||
|
.expect("WorkflowEngine should only be initialized once");
|
||||||
|
|
||||||
|
app_context
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tokenizer download configuration
|
// Tokenizer download configuration
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ use reqwest::Client;
|
|||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::{
|
use sglang_router_rs::{
|
||||||
config::{RouterConfig, RoutingMode},
|
config::{RouterConfig, RoutingMode},
|
||||||
core::WorkerManager,
|
|
||||||
routers::{RouterFactory, RouterTrait},
|
routers::{RouterFactory, RouterTrait},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -51,13 +50,6 @@ impl TestContext {
|
|||||||
|
|
||||||
let app_context = common::create_test_context(config.clone());
|
let app_context = common::create_test_context(config.clone());
|
||||||
|
|
||||||
// Initialize workers in the registry before creating router
|
|
||||||
if !worker_urls.is_empty() {
|
|
||||||
WorkerManager::initialize_workers(&config, &app_context.worker_registry, None)
|
|
||||||
.await
|
|
||||||
.expect("Failed to initialize workers");
|
|
||||||
}
|
|
||||||
|
|
||||||
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
||||||
let router = Arc::from(router);
|
let router = Arc::from(router);
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ use reqwest::Client;
|
|||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use sglang_router_rs::{
|
use sglang_router_rs::{
|
||||||
config::{RouterConfig, RoutingMode},
|
config::{RouterConfig, RoutingMode},
|
||||||
core::WorkerManager,
|
|
||||||
routers::{RouterFactory, RouterTrait},
|
routers::{RouterFactory, RouterTrait},
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -52,13 +51,6 @@ impl TestContext {
|
|||||||
|
|
||||||
let app_context = common::create_test_context(config.clone());
|
let app_context = common::create_test_context(config.clone());
|
||||||
|
|
||||||
// Initialize workers in the registry before creating router
|
|
||||||
if !worker_urls.is_empty() {
|
|
||||||
WorkerManager::initialize_workers(&config, &app_context.worker_registry, None)
|
|
||||||
.await
|
|
||||||
.expect("Failed to initialize workers");
|
|
||||||
}
|
|
||||||
|
|
||||||
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
let router = RouterFactory::create_router(&app_context).await.unwrap();
|
||||||
let router = Arc::from(router);
|
let router = Arc::from(router);
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user