[router] Implement HTTP Dependency Injection Pattern for Router System (#8714)
This commit is contained in:
@@ -34,6 +34,7 @@ pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
|
||||
pub struct Router {
|
||||
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
client: Client,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
@@ -44,10 +45,11 @@ pub struct Router {
|
||||
}
|
||||
|
||||
impl Router {
|
||||
/// Create a new router with injected policy
|
||||
/// Create a new router with injected policy and client
|
||||
pub fn new(
|
||||
worker_urls: Vec<String>,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
client: Client,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
dp_aware: bool,
|
||||
@@ -94,9 +96,17 @@ impl Router {
|
||||
let monitor_urls = worker_urls.clone();
|
||||
let monitor_interval = interval_secs;
|
||||
let policy_clone = Arc::clone(&policy);
|
||||
let client_clone = client.clone();
|
||||
|
||||
Some(Arc::new(tokio::spawn(async move {
|
||||
Self::monitor_worker_loads(monitor_urls, tx, monitor_interval, policy_clone).await;
|
||||
Self::monitor_worker_loads(
|
||||
monitor_urls,
|
||||
tx,
|
||||
monitor_interval,
|
||||
policy_clone,
|
||||
client_clone,
|
||||
)
|
||||
.await;
|
||||
})))
|
||||
} else {
|
||||
None
|
||||
@@ -105,6 +115,7 @@ impl Router {
|
||||
Ok(Router {
|
||||
workers,
|
||||
policy,
|
||||
client,
|
||||
timeout_secs,
|
||||
interval_secs,
|
||||
dp_aware,
|
||||
@@ -245,7 +256,7 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn send_health_check(&self, client: &Client, worker_url: &str) -> Response {
|
||||
pub async fn send_health_check(&self, worker_url: &str) -> Response {
|
||||
let health_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
match Self::extract_dp_rank(worker_url) {
|
||||
@@ -263,7 +274,7 @@ impl Router {
|
||||
worker_url
|
||||
};
|
||||
|
||||
let request_builder = client.get(format!("{}/health", health_url));
|
||||
let request_builder = self.client.get(format!("{}/health", health_url));
|
||||
|
||||
let response = match request_builder.send().await {
|
||||
Ok(res) => {
|
||||
@@ -305,17 +316,12 @@ impl Router {
|
||||
}
|
||||
|
||||
// Helper method to proxy GET requests to the first available worker
|
||||
async fn proxy_get_request(
|
||||
&self,
|
||||
client: &Client,
|
||||
req: Request<Body>,
|
||||
endpoint: &str,
|
||||
) -> Response {
|
||||
async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
|
||||
let headers = copy_request_headers(&req);
|
||||
|
||||
match self.select_first_worker() {
|
||||
Ok(worker_url) => {
|
||||
let mut request_builder = client.get(format!("{}/{}", worker_url, endpoint));
|
||||
let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint));
|
||||
for (name, value) in headers {
|
||||
if name.to_lowercase() != "content-type"
|
||||
&& name.to_lowercase() != "content-length"
|
||||
@@ -353,7 +359,6 @@ impl Router {
|
||||
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
|
||||
>(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
headers: Option<&HeaderMap>,
|
||||
typed_req: &T,
|
||||
route: &str,
|
||||
@@ -397,7 +402,6 @@ impl Router {
|
||||
// Send typed request directly
|
||||
let response = self
|
||||
.send_typed_request(
|
||||
client,
|
||||
headers,
|
||||
typed_req,
|
||||
route,
|
||||
@@ -413,7 +417,7 @@ impl Router {
|
||||
return response;
|
||||
} else {
|
||||
// if the worker is healthy, it means the request is bad, so return the error response
|
||||
let health_response = self.send_health_check(client, &worker_url).await;
|
||||
let health_response = self.send_health_check(&worker_url).await;
|
||||
if health_response.status().is_success() {
|
||||
RouterMetrics::record_request_error(route, "request_failed");
|
||||
return response;
|
||||
@@ -483,7 +487,6 @@ impl Router {
|
||||
// Send typed request directly without conversion
|
||||
async fn send_typed_request<T: serde::Serialize>(
|
||||
&self,
|
||||
client: &reqwest::Client,
|
||||
headers: Option<&HeaderMap>,
|
||||
typed_req: &T,
|
||||
route: &str,
|
||||
@@ -536,11 +539,11 @@ impl Router {
|
||||
.into_response();
|
||||
}
|
||||
|
||||
client
|
||||
self.client
|
||||
.post(format!("{}{}", worker_url_prefix, route))
|
||||
.json(&json_val)
|
||||
} else {
|
||||
client
|
||||
self.client
|
||||
.post(format!("{}{}", worker_url, route))
|
||||
.json(typed_req) // Use json() directly with typed request
|
||||
};
|
||||
@@ -866,7 +869,7 @@ impl Router {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
||||
async fn get_worker_load(&self, worker_url: &str) -> Option<isize> {
|
||||
let worker_url = if self.dp_aware {
|
||||
// Need to extract the URL from "http://host:port@dp_rank"
|
||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||
@@ -881,7 +884,12 @@ impl Router {
|
||||
worker_url
|
||||
};
|
||||
|
||||
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
||||
match self
|
||||
.client
|
||||
.get(&format!("{}/get_load", worker_url))
|
||||
.send()
|
||||
.await
|
||||
{
|
||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||
Ok(data) => data
|
||||
@@ -919,18 +927,8 @@ impl Router {
|
||||
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
|
||||
interval_secs: u64,
|
||||
policy: Arc<dyn LoadBalancingPolicy>,
|
||||
client: Client,
|
||||
) {
|
||||
let client = match reqwest::Client::builder()
|
||||
.timeout(Duration::from_secs(5))
|
||||
.build()
|
||||
{
|
||||
Ok(c) => c,
|
||||
Err(e) => {
|
||||
error!("Failed to create HTTP client for load monitoring: {}", e);
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
|
||||
|
||||
loop {
|
||||
@@ -1028,7 +1026,7 @@ impl RouterTrait for Router {
|
||||
self
|
||||
}
|
||||
|
||||
async fn health(&self, _client: &Client, _req: Request<Body>) -> Response {
|
||||
async fn health(&self, _req: Request<Body>) -> Response {
|
||||
let workers = self.workers.read().unwrap();
|
||||
let unhealthy_servers: Vec<_> = workers
|
||||
.iter()
|
||||
@@ -1047,53 +1045,49 @@ impl RouterTrait for Router {
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_generate(&self, client: &Client, req: Request<Body>) -> Response {
|
||||
self.proxy_get_request(client, req, "health_generate").await
|
||||
async fn health_generate(&self, req: Request<Body>) -> Response {
|
||||
self.proxy_get_request(req, "health_generate").await
|
||||
}
|
||||
|
||||
async fn get_server_info(&self, client: &Client, req: Request<Body>) -> Response {
|
||||
self.proxy_get_request(client, req, "get_server_info").await
|
||||
async fn get_server_info(&self, req: Request<Body>) -> Response {
|
||||
self.proxy_get_request(req, "get_server_info").await
|
||||
}
|
||||
|
||||
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response {
|
||||
self.proxy_get_request(client, req, "v1/models").await
|
||||
async fn get_models(&self, req: Request<Body>) -> Response {
|
||||
self.proxy_get_request(req, "v1/models").await
|
||||
}
|
||||
|
||||
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response {
|
||||
self.proxy_get_request(client, req, "get_model_info").await
|
||||
async fn get_model_info(&self, req: Request<Body>) -> Response {
|
||||
self.proxy_get_request(req, "get_model_info").await
|
||||
}
|
||||
|
||||
async fn route_generate(
|
||||
&self,
|
||||
client: &Client,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &GenerateRequest,
|
||||
) -> Response {
|
||||
self.route_typed_request(client, headers, body, "/generate")
|
||||
.await
|
||||
self.route_typed_request(headers, body, "/generate").await
|
||||
}
|
||||
|
||||
async fn route_chat(
|
||||
&self,
|
||||
client: &Client,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ChatCompletionRequest,
|
||||
) -> Response {
|
||||
self.route_typed_request(client, headers, body, "/v1/chat/completions")
|
||||
self.route_typed_request(headers, body, "/v1/chat/completions")
|
||||
.await
|
||||
}
|
||||
|
||||
async fn route_completion(
|
||||
&self,
|
||||
client: &Client,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &CompletionRequest,
|
||||
) -> Response {
|
||||
self.route_typed_request(client, headers, body, "/v1/completions")
|
||||
self.route_typed_request(headers, body, "/v1/completions")
|
||||
.await
|
||||
}
|
||||
|
||||
async fn flush_cache(&self, client: &Client) -> Response {
|
||||
async fn flush_cache(&self) -> Response {
|
||||
// Get all worker URLs
|
||||
let worker_urls = self.get_worker_urls();
|
||||
|
||||
@@ -1117,7 +1111,7 @@ impl RouterTrait for Router {
|
||||
} else {
|
||||
worker_url
|
||||
};
|
||||
let request_builder = client.post(format!("{}/flush_cache", worker_url));
|
||||
let request_builder = self.client.post(format!("{}/flush_cache", worker_url));
|
||||
tasks.push(request_builder.send());
|
||||
}
|
||||
|
||||
@@ -1142,13 +1136,13 @@ impl RouterTrait for Router {
|
||||
}
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self, client: &Client) -> Response {
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
let urls = self.get_worker_urls();
|
||||
let mut loads = Vec::new();
|
||||
|
||||
// Get loads from all workers
|
||||
for url in &urls {
|
||||
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
|
||||
let load = self.get_worker_load(url).await.unwrap_or(-1);
|
||||
loads.push(serde_json::json!({
|
||||
"worker": url,
|
||||
"load": load
|
||||
@@ -1215,6 +1209,7 @@ mod tests {
|
||||
interval_secs: 1,
|
||||
dp_aware: false,
|
||||
api_key: None,
|
||||
client: Client::new(),
|
||||
_worker_loads: Arc::new(rx),
|
||||
_load_monitor_handle: None,
|
||||
_health_checker: None,
|
||||
|
||||
Reference in New Issue
Block a user