[router] Implement HTTP Dependency Injection Pattern for Router System (#8714)

This commit is contained in:
Simo Lin
2025-08-02 19:16:47 -07:00
committed by GitHub
parent 8ada1ab6c7
commit 828a4fe944
12 changed files with 197 additions and 186 deletions

View File

@@ -34,6 +34,7 @@ pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
pub struct Router {
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
policy: Arc<dyn LoadBalancingPolicy>,
client: Client,
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
@@ -44,10 +45,11 @@ pub struct Router {
}
impl Router {
/// Create a new router with injected policy
/// Create a new router with injected policy and client
pub fn new(
worker_urls: Vec<String>,
policy: Arc<dyn LoadBalancingPolicy>,
client: Client,
timeout_secs: u64,
interval_secs: u64,
dp_aware: bool,
@@ -94,9 +96,17 @@ impl Router {
let monitor_urls = worker_urls.clone();
let monitor_interval = interval_secs;
let policy_clone = Arc::clone(&policy);
let client_clone = client.clone();
Some(Arc::new(tokio::spawn(async move {
Self::monitor_worker_loads(monitor_urls, tx, monitor_interval, policy_clone).await;
Self::monitor_worker_loads(
monitor_urls,
tx,
monitor_interval,
policy_clone,
client_clone,
)
.await;
})))
} else {
None
@@ -105,6 +115,7 @@ impl Router {
Ok(Router {
workers,
policy,
client,
timeout_secs,
interval_secs,
dp_aware,
@@ -245,7 +256,7 @@ impl Router {
}
}
pub async fn send_health_check(&self, client: &Client, worker_url: &str) -> Response {
pub async fn send_health_check(&self, worker_url: &str) -> Response {
let health_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
match Self::extract_dp_rank(worker_url) {
@@ -263,7 +274,7 @@ impl Router {
worker_url
};
let request_builder = client.get(format!("{}/health", health_url));
let request_builder = self.client.get(format!("{}/health", health_url));
let response = match request_builder.send().await {
Ok(res) => {
@@ -305,17 +316,12 @@ impl Router {
}
// Helper method to proxy GET requests to the first available worker
async fn proxy_get_request(
&self,
client: &Client,
req: Request<Body>,
endpoint: &str,
) -> Response {
async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
let headers = copy_request_headers(&req);
match self.select_first_worker() {
Ok(worker_url) => {
let mut request_builder = client.get(format!("{}/{}", worker_url, endpoint));
let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint));
for (name, value) in headers {
if name.to_lowercase() != "content-type"
&& name.to_lowercase() != "content-length"
@@ -353,7 +359,6 @@ impl Router {
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
>(
&self,
client: &reqwest::Client,
headers: Option<&HeaderMap>,
typed_req: &T,
route: &str,
@@ -397,7 +402,6 @@ impl Router {
// Send typed request directly
let response = self
.send_typed_request(
client,
headers,
typed_req,
route,
@@ -413,7 +417,7 @@ impl Router {
return response;
} else {
// if the worker is healthy, it means the request is bad, so return the error response
let health_response = self.send_health_check(client, &worker_url).await;
let health_response = self.send_health_check(&worker_url).await;
if health_response.status().is_success() {
RouterMetrics::record_request_error(route, "request_failed");
return response;
@@ -483,7 +487,6 @@ impl Router {
// Send typed request directly without conversion
async fn send_typed_request<T: serde::Serialize>(
&self,
client: &reqwest::Client,
headers: Option<&HeaderMap>,
typed_req: &T,
route: &str,
@@ -536,11 +539,11 @@ impl Router {
.into_response();
}
client
self.client
.post(format!("{}{}", worker_url_prefix, route))
.json(&json_val)
} else {
client
self.client
.post(format!("{}{}", worker_url, route))
.json(typed_req) // Use json() directly with typed request
};
@@ -866,7 +869,7 @@ impl Router {
}
}
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
async fn get_worker_load(&self, worker_url: &str) -> Option<isize> {
let worker_url = if self.dp_aware {
// Need to extract the URL from "http://host:port@dp_rank"
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
@@ -881,7 +884,12 @@ impl Router {
worker_url
};
match client.get(&format!("{}/get_load", worker_url)).send().await {
match self
.client
.get(&format!("{}/get_load", worker_url))
.send()
.await
{
Ok(res) if res.status().is_success() => match res.bytes().await {
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
Ok(data) => data
@@ -919,18 +927,8 @@ impl Router {
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
interval_secs: u64,
policy: Arc<dyn LoadBalancingPolicy>,
client: Client,
) {
let client = match reqwest::Client::builder()
.timeout(Duration::from_secs(5))
.build()
{
Ok(c) => c,
Err(e) => {
error!("Failed to create HTTP client for load monitoring: {}", e);
return;
}
};
let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
loop {
@@ -1028,7 +1026,7 @@ impl RouterTrait for Router {
self
}
async fn health(&self, _client: &Client, _req: Request<Body>) -> Response {
async fn health(&self, _req: Request<Body>) -> Response {
let workers = self.workers.read().unwrap();
let unhealthy_servers: Vec<_> = workers
.iter()
@@ -1047,53 +1045,49 @@ impl RouterTrait for Router {
}
}
async fn health_generate(&self, client: &Client, req: Request<Body>) -> Response {
self.proxy_get_request(client, req, "health_generate").await
async fn health_generate(&self, req: Request<Body>) -> Response {
self.proxy_get_request(req, "health_generate").await
}
async fn get_server_info(&self, client: &Client, req: Request<Body>) -> Response {
self.proxy_get_request(client, req, "get_server_info").await
async fn get_server_info(&self, req: Request<Body>) -> Response {
self.proxy_get_request(req, "get_server_info").await
}
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response {
self.proxy_get_request(client, req, "v1/models").await
async fn get_models(&self, req: Request<Body>) -> Response {
self.proxy_get_request(req, "v1/models").await
}
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response {
self.proxy_get_request(client, req, "get_model_info").await
async fn get_model_info(&self, req: Request<Body>) -> Response {
self.proxy_get_request(req, "get_model_info").await
}
async fn route_generate(
&self,
client: &Client,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
) -> Response {
self.route_typed_request(client, headers, body, "/generate")
.await
self.route_typed_request(headers, body, "/generate").await
}
async fn route_chat(
&self,
client: &Client,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
) -> Response {
self.route_typed_request(client, headers, body, "/v1/chat/completions")
self.route_typed_request(headers, body, "/v1/chat/completions")
.await
}
async fn route_completion(
&self,
client: &Client,
headers: Option<&HeaderMap>,
body: &CompletionRequest,
) -> Response {
self.route_typed_request(client, headers, body, "/v1/completions")
self.route_typed_request(headers, body, "/v1/completions")
.await
}
async fn flush_cache(&self, client: &Client) -> Response {
async fn flush_cache(&self) -> Response {
// Get all worker URLs
let worker_urls = self.get_worker_urls();
@@ -1117,7 +1111,7 @@ impl RouterTrait for Router {
} else {
worker_url
};
let request_builder = client.post(format!("{}/flush_cache", worker_url));
let request_builder = self.client.post(format!("{}/flush_cache", worker_url));
tasks.push(request_builder.send());
}
@@ -1142,13 +1136,13 @@ impl RouterTrait for Router {
}
}
async fn get_worker_loads(&self, client: &Client) -> Response {
async fn get_worker_loads(&self) -> Response {
let urls = self.get_worker_urls();
let mut loads = Vec::new();
// Get loads from all workers
for url in &urls {
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
let load = self.get_worker_load(url).await.unwrap_or(-1);
loads.push(serde_json::json!({
"worker": url,
"load": load
@@ -1215,6 +1209,7 @@ mod tests {
interval_secs: 1,
dp_aware: false,
api_key: None,
client: Client::new(),
_worker_loads: Arc::new(rx),
_load_monitor_handle: None,
_health_checker: None,