From a59cbea92d9f2745f7e159e8f938cd4458435c5f Mon Sep 17 00:00:00 2001 From: Simo Lin Date: Fri, 8 Aug 2025 13:10:14 -0700 Subject: [PATCH] [router] harden retries + metrics; fix streaming load; header filtering (#8972) --- sgl-router/src/routers/router.rs | 46 ++++++++++++++++++++++---------- 1 file changed, 32 insertions(+), 14 deletions(-) diff --git a/sgl-router/src/routers/router.rs b/sgl-router/src/routers/router.rs index 894629b9b..aa5b3768f 100644 --- a/sgl-router/src/routers/router.rs +++ b/sgl-router/src/routers/router.rs @@ -7,7 +7,7 @@ use crate::routers::{RouterTrait, WorkerManagement}; use axum::{ body::Body, extract::Request, - http::{header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, + http::{header::CONTENT_LENGTH, header::CONTENT_TYPE, HeaderMap, HeaderValue, StatusCode}, response::{IntoResponse, Response}, Json, }; @@ -351,9 +351,8 @@ impl Router { Ok(worker_url) => { let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint)); for (name, value) in headers { - if name.to_lowercase() != "content-type" - && name.to_lowercase() != "content-length" - { + let name_lc = name.to_lowercase(); + if name_lc != "content-type" && name_lc != "content-length" { request_builder = request_builder.header(name, value); } } @@ -406,6 +405,14 @@ impl Router { // Select worker based on text let worker_url = self.select_generate_worker_from_text(&text); + if worker_url.is_empty() { + RouterMetrics::record_request_error(route, "no_healthy_workers"); + return ( + StatusCode::SERVICE_UNAVAILABLE, + "No healthy workers available", + ) + .into_response(); + } let mut request_retries = 0; // Try the same worker multiple times @@ -443,9 +450,15 @@ impl Router { if response.status().is_success() { let duration = start.elapsed(); + RouterMetrics::record_request(route); RouterMetrics::record_generate_duration(duration); return response; } else { + let status = response.status(); + if status.is_client_error() && status != StatusCode::TOO_MANY_REQUESTS { + RouterMetrics::record_request_error(route, "client_error"); + return response; + } // if the worker is healthy, it means the request is bad, so return the error response let health_response = self.send_health_check(&worker_url).await; if health_response.status().is_success() { @@ -473,6 +486,9 @@ impl Router { self.remove_worker(&worker_url); break; } + + let backoff_ms = (100u64 * (request_retries as u64)).min(1000); + tokio::time::sleep(Duration::from_millis(backoff_ms)).await; } } @@ -524,8 +540,6 @@ impl Router { is_stream: bool, load_incremented: bool, // Whether load was incremented for this request ) -> Response { - let start = Instant::now(); - let mut request_builder = if self.dp_aware { let (worker_url_prefix, dp_rank) = match Self::extract_dp_rank(worker_url) { Ok(tup) => tup, @@ -582,9 +596,7 @@ impl Router { if let Some(headers) = headers { for (name, value) in headers { // Skip Content-Type and Content-Length as .json() sets them - if name.to_string().to_lowercase() != "content-type" - && name.to_string().to_lowercase() != "content-length" - { + if *name != CONTENT_TYPE && *name != CONTENT_LENGTH { request_builder = request_builder.header(name, value); } } @@ -639,11 +651,6 @@ impl Router { } } - // Record metrics - let duration = start.elapsed(); - RouterMetrics::record_generate_duration(duration); - RouterMetrics::record_request(route); - response } else if load_incremented { // For streaming with load tracking, we need to manually decrement when done @@ -656,6 +663,7 @@ impl Router { // Spawn task to forward stream and detect completion tokio::spawn(async move { let mut stream = stream; + let mut decremented = false; while let Some(chunk) = stream.next().await { match chunk { Ok(bytes) => { @@ -674,6 +682,7 @@ impl Router { &worker_url, worker.load(), ); + decremented = true; } } } @@ -687,6 +696,15 @@ impl Router { } } } + if !decremented { + if let Ok(workers_guard) = workers.read() { + if let Some(worker) = workers_guard.iter().find(|w| w.url() == &worker_url) + { + worker.decrement_load(); + RouterMetrics::set_running_requests(&worker_url, worker.load()); + } + } + } }); let stream = UnboundedReceiverStream::new(rx);