From a4cca7fc53da2b0c58495e208bb17e0199246e12 Mon Sep 17 00:00:00 2001 From: Byron Hsu Date: Wed, 11 Dec 2024 12:13:08 -0800 Subject: [PATCH] [router] Add retries based fault tolerance (#2452) --- rust/py_test/test_launch_server.py | 56 ++++++++++++++++- rust/src/router.rs | 98 +++++++++++++++++++++++++++--- 2 files changed, 144 insertions(+), 10 deletions(-) diff --git a/rust/py_test/test_launch_server.py b/rust/py_test/test_launch_server.py index 2591abb5c..daa0b821e 100644 --- a/rust/py_test/test_launch_server.py +++ b/rust/py_test/test_launch_server.py @@ -6,6 +6,7 @@ from types import SimpleNamespace import requests +from sglang.srt.utils import kill_process_tree from sglang.test.run_eval import run_eval from sglang.test.test_utils import ( DEFAULT_MODEL_NAME_FOR_TEST, @@ -182,7 +183,7 @@ class TestLaunchServer(unittest.TestCase): timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, policy="round_robin", # use round robin to make sure every worker processes requests ) - # 1. start a worker, and wait until it is healthy + # 1. start a worker port = find_available_port() worker_url = f"http://127.0.0.1:{port}" worker_process = popen_launch_server( @@ -226,6 +227,59 @@ class TestLaunchServer(unittest.TestCase): msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" self.assertGreaterEqual(score, THRESHOLD, msg) + def test_3_lazy_fault_tolerance(self): + print("Running test_3_lazy_fault_tolerance...") + # DP size = 1 + self.process = popen_launch_router( + self.model, + self.base_url, + dp_size=1, + timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH, + policy="round_robin", + ) + + # 1. start a worker + port = find_available_port() + worker_url = f"http://127.0.0.1:{port}" + worker_process = popen_launch_server( + self.model, worker_url, DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH + ) + self.other_process.append(worker_process) + + # 2. use /add_worker api to add it the the router. It will be used by router after it is healthy + with requests.Session() as session: + response = session.post(f"{self.base_url}/add_worker?url={worker_url}") + print(f"status code: {response.status_code}, response: {response.text}") + self.assertEqual(response.status_code, 200) + + # Start a thread to kill the worker after 10 seconds to mimic abrupt worker failure + def kill_worker(): + time.sleep(10) + kill_process_tree(worker_process.pid) + print("Worker process killed") + + import threading + + kill_thread = threading.Thread(target=kill_worker) + kill_thread.daemon = True + kill_thread.start() + + # 3. run mmlu + args = SimpleNamespace( + base_url=self.base_url, + model=self.model, + eval_name="mmlu", + num_examples=256, + num_threads=32, + temperature=0.1, + ) + metrics = run_eval(args) + score = metrics["score"] + THRESHOLD = 0.65 + passed = score >= THRESHOLD + msg = f"MMLU test {'passed' if passed else 'failed'} with score {score:.3f} (threshold: {THRESHOLD})" + self.assertGreaterEqual(score, THRESHOLD, msg) + if __name__ == "__main__": unittest.main() diff --git a/rust/src/router.rs b/rust/src/router.rs index d07949d49..08f6cdefa 100644 --- a/rust/src/router.rs +++ b/rust/src/router.rs @@ -274,10 +274,49 @@ impl Router { } pub async fn route_to_first(&self, client: &reqwest::Client, route: &str) -> HttpResponse { - match self.select_first_worker() { - Ok(worker_url) => self.send_request(client, &worker_url, route).await, - Err(e) => HttpResponse::InternalServerError().body(e), + const MAX_REQUEST_RETRIES: u32 = 3; + const MAX_TOTAL_RETRIES: u32 = 6; + let mut total_retries = 0; + + while total_retries < MAX_TOTAL_RETRIES { + match self.select_first_worker() { + Ok(worker_url) => { + let mut request_retries = 0; + + // Try the same worker multiple times + while request_retries < MAX_REQUEST_RETRIES { + if total_retries >= 1 { + info!("Retrying request after {} failed attempts", total_retries); + } + + let response = self.send_request(client, &worker_url, route).await; + + if response.status().is_success() { + return response; + } + + warn!( + "Request to {} failed (attempt {}/{})", + worker_url, + request_retries + 1, + MAX_REQUEST_RETRIES + ); + + request_retries += 1; + total_retries += 1; + + if request_retries == MAX_REQUEST_RETRIES { + warn!("Removing failed worker: {}", worker_url); + self.remove_worker(&worker_url); + break; + } + } + } + Err(e) => return HttpResponse::InternalServerError().body(e), + } } + + HttpResponse::InternalServerError().body("All retry attempts failed") } fn get_text_from_request(&self, body: &Bytes, route: &str) -> String { @@ -488,9 +527,46 @@ impl Router { body: &Bytes, route: &str, ) -> HttpResponse { - let worker_url = self.select_generate_worker(&body, route); - self.send_generate_request(client, req, body, route, &worker_url) - .await + const MAX_REQUEST_RETRIES: u32 = 3; + const MAX_TOTAL_RETRIES: u32 = 6; + let mut total_retries = 0; + + while total_retries < MAX_TOTAL_RETRIES { + let worker_url = self.select_generate_worker(body, route); + let mut request_retries = 0; + + // Try the same worker multiple times + while request_retries < MAX_REQUEST_RETRIES { + if total_retries >= 1 { + info!("Retrying request after {} failed attempts", total_retries); + } + let response = self + .send_generate_request(client, req, body, route, &worker_url) + .await; + + if response.status().is_success() { + return response; + } + + warn!( + "Generate request to {} failed (attempt {}/{})", + worker_url, + request_retries + 1, + MAX_REQUEST_RETRIES + ); + + request_retries += 1; + total_retries += 1; + + if request_retries == MAX_REQUEST_RETRIES { + warn!("Removing failed worker: {}", worker_url); + self.remove_worker(&worker_url); + break; + } + } + } + + HttpResponse::InternalServerError().body("All retry attempts failed") } pub async fn add_worker(&self, worker_url: &str) -> Result { @@ -590,9 +666,13 @@ impl Router { | Router::Random { worker_urls } | Router::CacheAware { worker_urls, .. } => { let mut urls = worker_urls.write().unwrap(); - let index = urls.iter().position(|url| url == &worker_url).unwrap(); - urls.remove(index); - info!("Removed worker: {}", worker_url); + if let Some(index) = urls.iter().position(|url| url == &worker_url) { + urls.remove(index); + info!("Removed worker: {}", worker_url); + } else { + warn!("Worker {} not found, skipping removal", worker_url); + return; + } } }