[router] Implement HTTP Dependency Injection Pattern for Router System (#8714)

This commit is contained in:
Simo Lin
2025-08-02 19:16:47 -07:00
committed by GitHub
parent 8ada1ab6c7
commit 828a4fe944
12 changed files with 197 additions and 186 deletions

View File

@@ -35,7 +35,7 @@ pub struct PDRouter {
pub interval_secs: u64,
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
pub http_client: Client,
pub client: Client,
_prefill_health_checker: Option<HealthChecker>,
_decode_health_checker: Option<HealthChecker>,
}
@@ -177,6 +177,7 @@ impl PDRouter {
decode_urls: Vec<String>,
prefill_policy: Arc<dyn LoadBalancingPolicy>,
decode_policy: Arc<dyn LoadBalancingPolicy>,
client: Client,
timeout_secs: u64,
interval_secs: u64,
) -> Result<Self, String> {
@@ -215,17 +216,11 @@ impl PDRouter {
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
let worker_loads = Arc::new(rx);
// Create a shared HTTP client for all operations
let http_client = Client::builder()
.timeout(Duration::from_secs(timeout_secs))
.build()
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
let load_monitor_handle =
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
let monitor_urls = all_urls.clone();
let monitor_interval = interval_secs;
let monitor_client = http_client.clone();
let monitor_client = client.clone();
let prefill_policy_clone = Arc::clone(&prefill_policy);
let decode_policy_clone = Arc::clone(&decode_policy);
@@ -264,7 +259,7 @@ impl PDRouter {
interval_secs,
worker_loads,
load_monitor_handle,
http_client,
client,
_prefill_health_checker: Some(prefill_health_checker),
_decode_health_checker: Some(decode_health_checker),
})
@@ -302,7 +297,6 @@ impl PDRouter {
// Route a typed generate request
pub async fn route_generate(
&self,
client: &Client,
headers: Option<&HeaderMap>,
mut typed_req: GenerateReqInput,
route: &str,
@@ -371,7 +365,6 @@ impl PDRouter {
// Execute dual dispatch
self.execute_dual_dispatch(
client,
headers,
json_with_bootstrap,
route,
@@ -387,7 +380,6 @@ impl PDRouter {
// Route a typed chat request
pub async fn route_chat(
&self,
client: &Client,
headers: Option<&HeaderMap>,
mut typed_req: ChatReqInput,
route: &str,
@@ -459,7 +451,6 @@ impl PDRouter {
// Execute dual dispatch
self.execute_dual_dispatch(
client,
headers,
json_with_bootstrap,
route,
@@ -475,7 +466,6 @@ impl PDRouter {
// Route a completion request while preserving OpenAI format
pub async fn route_completion(
&self,
client: &Client,
headers: Option<&HeaderMap>,
mut typed_req: CompletionRequest,
route: &str,
@@ -540,7 +530,6 @@ impl PDRouter {
// Execute dual dispatch
self.execute_dual_dispatch(
client,
headers,
json_with_bootstrap,
route,
@@ -554,10 +543,8 @@ impl PDRouter {
}
// Execute the dual dispatch to prefill and decode servers
#[allow(clippy::too_many_arguments)]
async fn execute_dual_dispatch(
&self,
client: &Client,
headers: Option<&HeaderMap>,
json_request: Value,
route: &str,
@@ -571,11 +558,13 @@ impl PDRouter {
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
// Build requests using .json() method
let mut prefill_request = client
let mut prefill_request = self
.client
.post(api_path(prefill.url(), route))
.json(&json_request);
let mut decode_request = client
let mut decode_request = self
.client
.post(api_path(decode.url(), route))
.json(&json_request);
@@ -987,7 +976,7 @@ async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option<i
// PD-specific endpoints
impl PDRouter {
pub async fn health_generate(&self, client: &reqwest::Client) -> Response {
pub async fn health_generate(&self) -> Response {
// Test model generation capability by selecting a random pair and testing them
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
@@ -1005,11 +994,11 @@ impl PDRouter {
// Test prefill server's health_generate
let prefill_url = format!("{}/health_generate", prefill.url());
let prefill_result = client.get(&prefill_url).send().await;
let prefill_result = self.client.get(&prefill_url).send().await;
// Test decode server's health_generate
let decode_url = format!("{}/health_generate", decode.url());
let decode_result = client.get(&decode_url).send().await;
let decode_result = self.client.get(&decode_url).send().await;
// Check results
let mut errors = Vec::new();
@@ -1068,7 +1057,7 @@ impl PDRouter {
}
}
pub async fn get_server_info(&self, client: &reqwest::Client) -> Response {
pub async fn get_server_info(&self) -> Response {
// Get info from the first decode server to match sglang's server info format
let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
workers.first().map(|w| w.url().to_string())
@@ -1081,7 +1070,8 @@ impl PDRouter {
};
if let Some(worker_url) = first_decode_url {
match client
match self
.client
.get(format!("{}/get_server_info", worker_url))
.send()
.await
@@ -1130,7 +1120,7 @@ impl PDRouter {
}
}
pub async fn get_models(&self, client: &reqwest::Client, req: Request<Body>) -> Response {
pub async fn get_models(&self, req: Request<Body>) -> Response {
// Extract headers first to avoid Send issues
let headers = crate::routers::router::copy_request_headers(&req);
@@ -1147,7 +1137,7 @@ impl PDRouter {
if let Some(worker_url) = first_worker_url {
// Send request directly without going through Router
let mut request_builder = client.get(format!("{}/v1/models", worker_url));
let mut request_builder = self.client.get(format!("{}/v1/models", worker_url));
for (name, value) in headers {
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
{
@@ -1224,7 +1214,7 @@ impl PDRouter {
.into_response()
}
pub async fn get_model_info(&self, client: &reqwest::Client, req: Request<Body>) -> Response {
pub async fn get_model_info(&self, req: Request<Body>) -> Response {
// Extract headers first to avoid Send issues
let headers = crate::routers::router::copy_request_headers(&req);
@@ -1241,7 +1231,7 @@ impl PDRouter {
};
if let Some(worker_url) = first_worker_url {
let mut request_builder = client.get(format!("{}/get_model_info", worker_url));
let mut request_builder = self.client.get(format!("{}/get_model_info", worker_url));
for (name, value) in headers {
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
{
@@ -1384,7 +1374,7 @@ impl RouterTrait for PDRouter {
self
}
async fn health(&self, _client: &Client, _req: Request<Body>) -> Response {
async fn health(&self, _req: Request<Body>) -> Response {
// This is a server readiness check - checking if we have healthy workers
// Workers handle their own health checks in the background
let mut all_healthy = true;
@@ -1417,68 +1407,65 @@ impl RouterTrait for PDRouter {
}
}
async fn health_generate(&self, client: &Client, _req: Request<Body>) -> Response {
async fn health_generate(&self, _req: Request<Body>) -> Response {
// Use the existing PDRouter health_generate method
PDRouter::health_generate(self, client).await
PDRouter::health_generate(self).await
}
async fn get_server_info(&self, client: &Client, _req: Request<Body>) -> Response {
async fn get_server_info(&self, _req: Request<Body>) -> Response {
// Use the existing PDRouter get_server_info method
PDRouter::get_server_info(self, client).await
PDRouter::get_server_info(self).await
}
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response {
async fn get_models(&self, req: Request<Body>) -> Response {
// Use the existing PDRouter get_models method
PDRouter::get_models(self, client, req).await
PDRouter::get_models(self, req).await
}
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response {
async fn get_model_info(&self, req: Request<Body>) -> Response {
// Use the existing PDRouter get_model_info method
PDRouter::get_model_info(self, client, req).await
PDRouter::get_model_info(self, req).await
}
async fn route_generate(
&self,
client: &Client,
headers: Option<&HeaderMap>,
body: &GenerateRequest,
) -> Response {
// Convert OpenAI format to PD format
let pd_req = body.clone().to_pd_request();
PDRouter::route_generate(self, client, headers, pd_req, "/generate").await
PDRouter::route_generate(self, headers, pd_req, "/generate").await
}
async fn route_chat(
&self,
client: &Client,
headers: Option<&HeaderMap>,
body: &ChatCompletionRequest,
) -> Response {
// Convert OpenAI format to PD format
let pd_req = body.clone().to_pd_request();
PDRouter::route_chat(self, client, headers, pd_req, "/v1/chat/completions").await
PDRouter::route_chat(self, headers, pd_req, "/v1/chat/completions").await
}
async fn route_completion(
&self,
client: &Client,
headers: Option<&HeaderMap>,
body: &CompletionRequest,
) -> Response {
// Use the new method that preserves OpenAI format
PDRouter::route_completion(self, client, headers, body.clone(), "/v1/completions").await
PDRouter::route_completion(self, headers, body.clone(), "/v1/completions").await
}
async fn flush_cache(&self, client: &Client) -> Response {
async fn flush_cache(&self) -> Response {
// Use the existing PDRouter flush_cache method
PDRouter::flush_cache(self, client).await
PDRouter::flush_cache(self, &self.client).await
}
async fn get_worker_loads(&self, client: &Client) -> Response {
async fn get_worker_loads(&self) -> Response {
// Use the existing PDRouter get_loads method
PDRouter::get_loads(self, client).await
PDRouter::get_loads(self, &self.client).await
}
fn router_type(&self) -> &'static str {
@@ -1570,7 +1557,7 @@ mod tests {
interval_secs: 1,
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
load_monitor_handle: None,
http_client: reqwest::Client::new(),
client: Client::new(),
_prefill_health_checker: None,
_decode_health_checker: None,
}
@@ -1959,11 +1946,10 @@ mod tests {
router.decode_workers.write().unwrap().push(decode_worker);
// Test health endpoint
let client = reqwest::Client::new();
let http_req = axum::http::Request::builder()
.body(axum::body::Body::empty())
.unwrap();
let response = router.health(&client, http_req).await;
let response = router.health(http_req).await;
assert_eq!(response.status(), 200);