[router] Implement HTTP Dependency Injection Pattern for Router System (#8714)
This commit is contained in:
@@ -35,7 +35,7 @@ pub struct PDRouter {
|
||||
pub interval_secs: u64,
|
||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||
pub http_client: Client,
|
||||
pub client: Client,
|
||||
_prefill_health_checker: Option<HealthChecker>,
|
||||
_decode_health_checker: Option<HealthChecker>,
|
||||
}
|
||||
@@ -177,6 +177,7 @@ impl PDRouter {
|
||||
decode_urls: Vec<String>,
|
||||
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||
client: Client,
|
||||
timeout_secs: u64,
|
||||
interval_secs: u64,
|
||||
) -> Result<Self, String> {
|
||||
@@ -215,17 +216,11 @@ impl PDRouter {
|
||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||
let worker_loads = Arc::new(rx);
|
||||
|
||||
// Create a shared HTTP client for all operations
|
||||
let http_client = Client::builder()
|
||||
.timeout(Duration::from_secs(timeout_secs))
|
||||
.build()
|
||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
||||
|
||||
let load_monitor_handle =
|
||||
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
|
||||
let monitor_urls = all_urls.clone();
|
||||
let monitor_interval = interval_secs;
|
||||
let monitor_client = http_client.clone();
|
||||
let monitor_client = client.clone();
|
||||
let prefill_policy_clone = Arc::clone(&prefill_policy);
|
||||
let decode_policy_clone = Arc::clone(&decode_policy);
|
||||
|
||||
@@ -264,7 +259,7 @@ impl PDRouter {
|
||||
interval_secs,
|
||||
worker_loads,
|
||||
load_monitor_handle,
|
||||
http_client,
|
||||
client,
|
||||
_prefill_health_checker: Some(prefill_health_checker),
|
||||
_decode_health_checker: Some(decode_health_checker),
|
||||
})
|
||||
@@ -302,7 +297,6 @@ impl PDRouter {
|
||||
// Route a typed generate request
|
||||
pub async fn route_generate(
|
||||
&self,
|
||||
client: &Client,
|
||||
headers: Option<&HeaderMap>,
|
||||
mut typed_req: GenerateReqInput,
|
||||
route: &str,
|
||||
@@ -371,7 +365,6 @@ impl PDRouter {
|
||||
|
||||
// Execute dual dispatch
|
||||
self.execute_dual_dispatch(
|
||||
client,
|
||||
headers,
|
||||
json_with_bootstrap,
|
||||
route,
|
||||
@@ -387,7 +380,6 @@ impl PDRouter {
|
||||
// Route a typed chat request
|
||||
pub async fn route_chat(
|
||||
&self,
|
||||
client: &Client,
|
||||
headers: Option<&HeaderMap>,
|
||||
mut typed_req: ChatReqInput,
|
||||
route: &str,
|
||||
@@ -459,7 +451,6 @@ impl PDRouter {
|
||||
|
||||
// Execute dual dispatch
|
||||
self.execute_dual_dispatch(
|
||||
client,
|
||||
headers,
|
||||
json_with_bootstrap,
|
||||
route,
|
||||
@@ -475,7 +466,6 @@ impl PDRouter {
|
||||
// Route a completion request while preserving OpenAI format
|
||||
pub async fn route_completion(
|
||||
&self,
|
||||
client: &Client,
|
||||
headers: Option<&HeaderMap>,
|
||||
mut typed_req: CompletionRequest,
|
||||
route: &str,
|
||||
@@ -540,7 +530,6 @@ impl PDRouter {
|
||||
|
||||
// Execute dual dispatch
|
||||
self.execute_dual_dispatch(
|
||||
client,
|
||||
headers,
|
||||
json_with_bootstrap,
|
||||
route,
|
||||
@@ -554,10 +543,8 @@ impl PDRouter {
|
||||
}
|
||||
|
||||
// Execute the dual dispatch to prefill and decode servers
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn execute_dual_dispatch(
|
||||
&self,
|
||||
client: &Client,
|
||||
headers: Option<&HeaderMap>,
|
||||
json_request: Value,
|
||||
route: &str,
|
||||
@@ -571,11 +558,13 @@ impl PDRouter {
|
||||
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
|
||||
|
||||
// Build requests using .json() method
|
||||
let mut prefill_request = client
|
||||
let mut prefill_request = self
|
||||
.client
|
||||
.post(api_path(prefill.url(), route))
|
||||
.json(&json_request);
|
||||
|
||||
let mut decode_request = client
|
||||
let mut decode_request = self
|
||||
.client
|
||||
.post(api_path(decode.url(), route))
|
||||
.json(&json_request);
|
||||
|
||||
@@ -987,7 +976,7 @@ async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option<i
|
||||
|
||||
// PD-specific endpoints
|
||||
impl PDRouter {
|
||||
pub async fn health_generate(&self, client: &reqwest::Client) -> Response {
|
||||
pub async fn health_generate(&self) -> Response {
|
||||
// Test model generation capability by selecting a random pair and testing them
|
||||
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
|
||||
|
||||
@@ -1005,11 +994,11 @@ impl PDRouter {
|
||||
|
||||
// Test prefill server's health_generate
|
||||
let prefill_url = format!("{}/health_generate", prefill.url());
|
||||
let prefill_result = client.get(&prefill_url).send().await;
|
||||
let prefill_result = self.client.get(&prefill_url).send().await;
|
||||
|
||||
// Test decode server's health_generate
|
||||
let decode_url = format!("{}/health_generate", decode.url());
|
||||
let decode_result = client.get(&decode_url).send().await;
|
||||
let decode_result = self.client.get(&decode_url).send().await;
|
||||
|
||||
// Check results
|
||||
let mut errors = Vec::new();
|
||||
@@ -1068,7 +1057,7 @@ impl PDRouter {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_server_info(&self, client: &reqwest::Client) -> Response {
|
||||
pub async fn get_server_info(&self) -> Response {
|
||||
// Get info from the first decode server to match sglang's server info format
|
||||
let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
|
||||
workers.first().map(|w| w.url().to_string())
|
||||
@@ -1081,7 +1070,8 @@ impl PDRouter {
|
||||
};
|
||||
|
||||
if let Some(worker_url) = first_decode_url {
|
||||
match client
|
||||
match self
|
||||
.client
|
||||
.get(format!("{}/get_server_info", worker_url))
|
||||
.send()
|
||||
.await
|
||||
@@ -1130,7 +1120,7 @@ impl PDRouter {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn get_models(&self, client: &reqwest::Client, req: Request<Body>) -> Response {
|
||||
pub async fn get_models(&self, req: Request<Body>) -> Response {
|
||||
// Extract headers first to avoid Send issues
|
||||
let headers = crate::routers::router::copy_request_headers(&req);
|
||||
|
||||
@@ -1147,7 +1137,7 @@ impl PDRouter {
|
||||
|
||||
if let Some(worker_url) = first_worker_url {
|
||||
// Send request directly without going through Router
|
||||
let mut request_builder = client.get(format!("{}/v1/models", worker_url));
|
||||
let mut request_builder = self.client.get(format!("{}/v1/models", worker_url));
|
||||
for (name, value) in headers {
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
||||
{
|
||||
@@ -1224,7 +1214,7 @@ impl PDRouter {
|
||||
.into_response()
|
||||
}
|
||||
|
||||
pub async fn get_model_info(&self, client: &reqwest::Client, req: Request<Body>) -> Response {
|
||||
pub async fn get_model_info(&self, req: Request<Body>) -> Response {
|
||||
// Extract headers first to avoid Send issues
|
||||
let headers = crate::routers::router::copy_request_headers(&req);
|
||||
|
||||
@@ -1241,7 +1231,7 @@ impl PDRouter {
|
||||
};
|
||||
|
||||
if let Some(worker_url) = first_worker_url {
|
||||
let mut request_builder = client.get(format!("{}/get_model_info", worker_url));
|
||||
let mut request_builder = self.client.get(format!("{}/get_model_info", worker_url));
|
||||
for (name, value) in headers {
|
||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
||||
{
|
||||
@@ -1384,7 +1374,7 @@ impl RouterTrait for PDRouter {
|
||||
self
|
||||
}
|
||||
|
||||
async fn health(&self, _client: &Client, _req: Request<Body>) -> Response {
|
||||
async fn health(&self, _req: Request<Body>) -> Response {
|
||||
// This is a server readiness check - checking if we have healthy workers
|
||||
// Workers handle their own health checks in the background
|
||||
let mut all_healthy = true;
|
||||
@@ -1417,68 +1407,65 @@ impl RouterTrait for PDRouter {
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_generate(&self, client: &Client, _req: Request<Body>) -> Response {
|
||||
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||
// Use the existing PDRouter health_generate method
|
||||
PDRouter::health_generate(self, client).await
|
||||
PDRouter::health_generate(self).await
|
||||
}
|
||||
|
||||
async fn get_server_info(&self, client: &Client, _req: Request<Body>) -> Response {
|
||||
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||
// Use the existing PDRouter get_server_info method
|
||||
PDRouter::get_server_info(self, client).await
|
||||
PDRouter::get_server_info(self).await
|
||||
}
|
||||
|
||||
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response {
|
||||
async fn get_models(&self, req: Request<Body>) -> Response {
|
||||
// Use the existing PDRouter get_models method
|
||||
PDRouter::get_models(self, client, req).await
|
||||
PDRouter::get_models(self, req).await
|
||||
}
|
||||
|
||||
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response {
|
||||
async fn get_model_info(&self, req: Request<Body>) -> Response {
|
||||
// Use the existing PDRouter get_model_info method
|
||||
PDRouter::get_model_info(self, client, req).await
|
||||
PDRouter::get_model_info(self, req).await
|
||||
}
|
||||
|
||||
async fn route_generate(
|
||||
&self,
|
||||
client: &Client,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &GenerateRequest,
|
||||
) -> Response {
|
||||
// Convert OpenAI format to PD format
|
||||
let pd_req = body.clone().to_pd_request();
|
||||
|
||||
PDRouter::route_generate(self, client, headers, pd_req, "/generate").await
|
||||
PDRouter::route_generate(self, headers, pd_req, "/generate").await
|
||||
}
|
||||
|
||||
async fn route_chat(
|
||||
&self,
|
||||
client: &Client,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &ChatCompletionRequest,
|
||||
) -> Response {
|
||||
// Convert OpenAI format to PD format
|
||||
let pd_req = body.clone().to_pd_request();
|
||||
|
||||
PDRouter::route_chat(self, client, headers, pd_req, "/v1/chat/completions").await
|
||||
PDRouter::route_chat(self, headers, pd_req, "/v1/chat/completions").await
|
||||
}
|
||||
|
||||
async fn route_completion(
|
||||
&self,
|
||||
client: &Client,
|
||||
headers: Option<&HeaderMap>,
|
||||
body: &CompletionRequest,
|
||||
) -> Response {
|
||||
// Use the new method that preserves OpenAI format
|
||||
PDRouter::route_completion(self, client, headers, body.clone(), "/v1/completions").await
|
||||
PDRouter::route_completion(self, headers, body.clone(), "/v1/completions").await
|
||||
}
|
||||
|
||||
async fn flush_cache(&self, client: &Client) -> Response {
|
||||
async fn flush_cache(&self) -> Response {
|
||||
// Use the existing PDRouter flush_cache method
|
||||
PDRouter::flush_cache(self, client).await
|
||||
PDRouter::flush_cache(self, &self.client).await
|
||||
}
|
||||
|
||||
async fn get_worker_loads(&self, client: &Client) -> Response {
|
||||
async fn get_worker_loads(&self) -> Response {
|
||||
// Use the existing PDRouter get_loads method
|
||||
PDRouter::get_loads(self, client).await
|
||||
PDRouter::get_loads(self, &self.client).await
|
||||
}
|
||||
|
||||
fn router_type(&self) -> &'static str {
|
||||
@@ -1570,7 +1557,7 @@ mod tests {
|
||||
interval_secs: 1,
|
||||
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
||||
load_monitor_handle: None,
|
||||
http_client: reqwest::Client::new(),
|
||||
client: Client::new(),
|
||||
_prefill_health_checker: None,
|
||||
_decode_health_checker: None,
|
||||
}
|
||||
@@ -1959,11 +1946,10 @@ mod tests {
|
||||
router.decode_workers.write().unwrap().push(decode_worker);
|
||||
|
||||
// Test health endpoint
|
||||
let client = reqwest::Client::new();
|
||||
let http_req = axum::http::Request::builder()
|
||||
.body(axum::body::Body::empty())
|
||||
.unwrap();
|
||||
let response = router.health(&client, http_req).await;
|
||||
let response = router.health(http_req).await;
|
||||
|
||||
assert_eq!(response.status(), 200);
|
||||
|
||||
|
||||
Reference in New Issue
Block a user