[router] Implement HTTP Dependency Injection Pattern for Router System (#8714)
This commit is contained in:
@@ -1,18 +1,20 @@
|
|||||||
//! Factory for creating router instances
|
//! Factory for creating router instances
|
||||||
|
|
||||||
use super::{pd_router::PDRouter, router::Router, RouterTrait};
|
use super::{pd_router::PDRouter, router::Router, RouterTrait};
|
||||||
use crate::config::{PolicyConfig, RouterConfig, RoutingMode};
|
use crate::config::{PolicyConfig, RoutingMode};
|
||||||
use crate::policies::PolicyFactory;
|
use crate::policies::PolicyFactory;
|
||||||
|
use crate::server::AppContext;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Factory for creating router instances based on configuration
|
/// Factory for creating router instances based on configuration
|
||||||
pub struct RouterFactory;
|
pub struct RouterFactory;
|
||||||
|
|
||||||
impl RouterFactory {
|
impl RouterFactory {
|
||||||
/// Create a router instance from configuration
|
/// Create a router instance from application context
|
||||||
pub fn create_router(config: &RouterConfig) -> Result<Box<dyn RouterTrait>, String> {
|
pub fn create_router(ctx: &Arc<AppContext>) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
match &config.mode {
|
match &ctx.router_config.mode {
|
||||||
RoutingMode::Regular { worker_urls } => {
|
RoutingMode::Regular { worker_urls } => {
|
||||||
Self::create_regular_router(worker_urls, &config.policy, config)
|
Self::create_regular_router(worker_urls, &ctx.router_config.policy, ctx)
|
||||||
}
|
}
|
||||||
RoutingMode::PrefillDecode {
|
RoutingMode::PrefillDecode {
|
||||||
prefill_urls,
|
prefill_urls,
|
||||||
@@ -24,8 +26,8 @@ impl RouterFactory {
|
|||||||
decode_urls,
|
decode_urls,
|
||||||
prefill_policy.as_ref(),
|
prefill_policy.as_ref(),
|
||||||
decode_policy.as_ref(),
|
decode_policy.as_ref(),
|
||||||
&config.policy,
|
&ctx.router_config.policy,
|
||||||
config,
|
ctx,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -34,19 +36,20 @@ impl RouterFactory {
|
|||||||
fn create_regular_router(
|
fn create_regular_router(
|
||||||
worker_urls: &[String],
|
worker_urls: &[String],
|
||||||
policy_config: &PolicyConfig,
|
policy_config: &PolicyConfig,
|
||||||
router_config: &RouterConfig,
|
ctx: &Arc<AppContext>,
|
||||||
) -> Result<Box<dyn RouterTrait>, String> {
|
) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
// Create policy
|
// Create policy
|
||||||
let policy = PolicyFactory::create_from_config(policy_config);
|
let policy = PolicyFactory::create_from_config(policy_config);
|
||||||
|
|
||||||
// Create regular router with injected policy
|
// Create regular router with injected policy and client
|
||||||
let router = Router::new(
|
let router = Router::new(
|
||||||
worker_urls.to_vec(),
|
worker_urls.to_vec(),
|
||||||
policy,
|
policy,
|
||||||
router_config.worker_startup_timeout_secs,
|
ctx.client.clone(),
|
||||||
router_config.worker_startup_check_interval_secs,
|
ctx.router_config.worker_startup_timeout_secs,
|
||||||
router_config.dp_aware,
|
ctx.router_config.worker_startup_check_interval_secs,
|
||||||
router_config.api_key.clone(),
|
ctx.router_config.dp_aware,
|
||||||
|
ctx.router_config.api_key.clone(),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Box::new(router))
|
Ok(Box::new(router))
|
||||||
@@ -59,7 +62,7 @@ impl RouterFactory {
|
|||||||
prefill_policy_config: Option<&PolicyConfig>,
|
prefill_policy_config: Option<&PolicyConfig>,
|
||||||
decode_policy_config: Option<&PolicyConfig>,
|
decode_policy_config: Option<&PolicyConfig>,
|
||||||
main_policy_config: &PolicyConfig,
|
main_policy_config: &PolicyConfig,
|
||||||
router_config: &RouterConfig,
|
ctx: &Arc<AppContext>,
|
||||||
) -> Result<Box<dyn RouterTrait>, String> {
|
) -> Result<Box<dyn RouterTrait>, String> {
|
||||||
// Create policies - use specific policies if provided, otherwise fall back to main policy
|
// Create policies - use specific policies if provided, otherwise fall back to main policy
|
||||||
let prefill_policy =
|
let prefill_policy =
|
||||||
@@ -67,14 +70,15 @@ impl RouterFactory {
|
|||||||
let decode_policy =
|
let decode_policy =
|
||||||
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
|
PolicyFactory::create_from_config(decode_policy_config.unwrap_or(main_policy_config));
|
||||||
|
|
||||||
// Create PD router with separate policies
|
// Create PD router with separate policies and client
|
||||||
let router = PDRouter::new(
|
let router = PDRouter::new(
|
||||||
prefill_urls.to_vec(),
|
prefill_urls.to_vec(),
|
||||||
decode_urls.to_vec(),
|
decode_urls.to_vec(),
|
||||||
prefill_policy,
|
prefill_policy,
|
||||||
decode_policy,
|
decode_policy,
|
||||||
router_config.worker_startup_timeout_secs,
|
ctx.client.clone(),
|
||||||
router_config.worker_startup_check_interval_secs,
|
ctx.router_config.worker_startup_timeout_secs,
|
||||||
|
ctx.router_config.worker_startup_check_interval_secs,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Box::new(router))
|
Ok(Box::new(router))
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ use axum::{
|
|||||||
http::{HeaderMap, StatusCode},
|
http::{HeaderMap, StatusCode},
|
||||||
response::{IntoResponse, Response},
|
response::{IntoResponse, Response},
|
||||||
};
|
};
|
||||||
use reqwest::Client;
|
|
||||||
use std::fmt::Debug;
|
use std::fmt::Debug;
|
||||||
|
|
||||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||||
@@ -46,32 +45,27 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
|||||||
fn as_any(&self) -> &dyn std::any::Any;
|
fn as_any(&self) -> &dyn std::any::Any;
|
||||||
|
|
||||||
/// Route a health check request
|
/// Route a health check request
|
||||||
async fn health(&self, client: &Client, req: Request<Body>) -> Response;
|
async fn health(&self, req: Request<Body>) -> Response;
|
||||||
|
|
||||||
/// Route a health generate request
|
/// Route a health generate request
|
||||||
async fn health_generate(&self, client: &Client, req: Request<Body>) -> Response;
|
async fn health_generate(&self, req: Request<Body>) -> Response;
|
||||||
|
|
||||||
/// Get server information
|
/// Get server information
|
||||||
async fn get_server_info(&self, client: &Client, req: Request<Body>) -> Response;
|
async fn get_server_info(&self, req: Request<Body>) -> Response;
|
||||||
|
|
||||||
/// Get available models
|
/// Get available models
|
||||||
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response;
|
async fn get_models(&self, req: Request<Body>) -> Response;
|
||||||
|
|
||||||
/// Get model information
|
/// Get model information
|
||||||
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response;
|
async fn get_model_info(&self, req: Request<Body>) -> Response;
|
||||||
|
|
||||||
/// Route a generate request
|
/// Route a generate request
|
||||||
async fn route_generate(
|
async fn route_generate(&self, headers: Option<&HeaderMap>, body: &GenerateRequest)
|
||||||
&self,
|
-> Response;
|
||||||
client: &Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
|
||||||
body: &GenerateRequest,
|
|
||||||
) -> Response;
|
|
||||||
|
|
||||||
/// Route a chat completion request
|
/// Route a chat completion request
|
||||||
async fn route_chat(
|
async fn route_chat(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
body: &ChatCompletionRequest,
|
body: &ChatCompletionRequest,
|
||||||
) -> Response;
|
) -> Response;
|
||||||
@@ -79,16 +73,15 @@ pub trait RouterTrait: Send + Sync + Debug + WorkerManagement {
|
|||||||
/// Route a completion request
|
/// Route a completion request
|
||||||
async fn route_completion(
|
async fn route_completion(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
body: &CompletionRequest,
|
body: &CompletionRequest,
|
||||||
) -> Response;
|
) -> Response;
|
||||||
|
|
||||||
/// Flush cache on all workers
|
/// Flush cache on all workers
|
||||||
async fn flush_cache(&self, client: &Client) -> Response;
|
async fn flush_cache(&self) -> Response;
|
||||||
|
|
||||||
/// Get worker loads (for monitoring)
|
/// Get worker loads (for monitoring)
|
||||||
async fn get_worker_loads(&self, client: &Client) -> Response;
|
async fn get_worker_loads(&self) -> Response;
|
||||||
|
|
||||||
/// Get router type name
|
/// Get router type name
|
||||||
fn router_type(&self) -> &'static str;
|
fn router_type(&self) -> &'static str;
|
||||||
|
|||||||
@@ -35,7 +35,7 @@ pub struct PDRouter {
|
|||||||
pub interval_secs: u64,
|
pub interval_secs: u64,
|
||||||
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
pub worker_loads: Arc<tokio::sync::watch::Receiver<HashMap<String, isize>>>,
|
||||||
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
pub load_monitor_handle: Option<Arc<tokio::task::JoinHandle<()>>>,
|
||||||
pub http_client: Client,
|
pub client: Client,
|
||||||
_prefill_health_checker: Option<HealthChecker>,
|
_prefill_health_checker: Option<HealthChecker>,
|
||||||
_decode_health_checker: Option<HealthChecker>,
|
_decode_health_checker: Option<HealthChecker>,
|
||||||
}
|
}
|
||||||
@@ -177,6 +177,7 @@ impl PDRouter {
|
|||||||
decode_urls: Vec<String>,
|
decode_urls: Vec<String>,
|
||||||
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
prefill_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
decode_policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
|
client: Client,
|
||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
) -> Result<Self, String> {
|
) -> Result<Self, String> {
|
||||||
@@ -215,17 +216,11 @@ impl PDRouter {
|
|||||||
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
let (tx, rx) = tokio::sync::watch::channel(HashMap::new());
|
||||||
let worker_loads = Arc::new(rx);
|
let worker_loads = Arc::new(rx);
|
||||||
|
|
||||||
// Create a shared HTTP client for all operations
|
|
||||||
let http_client = Client::builder()
|
|
||||||
.timeout(Duration::from_secs(timeout_secs))
|
|
||||||
.build()
|
|
||||||
.map_err(|e| format!("Failed to create HTTP client: {}", e))?;
|
|
||||||
|
|
||||||
let load_monitor_handle =
|
let load_monitor_handle =
|
||||||
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
|
if prefill_policy.name() == "power_of_two" || decode_policy.name() == "power_of_two" {
|
||||||
let monitor_urls = all_urls.clone();
|
let monitor_urls = all_urls.clone();
|
||||||
let monitor_interval = interval_secs;
|
let monitor_interval = interval_secs;
|
||||||
let monitor_client = http_client.clone();
|
let monitor_client = client.clone();
|
||||||
let prefill_policy_clone = Arc::clone(&prefill_policy);
|
let prefill_policy_clone = Arc::clone(&prefill_policy);
|
||||||
let decode_policy_clone = Arc::clone(&decode_policy);
|
let decode_policy_clone = Arc::clone(&decode_policy);
|
||||||
|
|
||||||
@@ -264,7 +259,7 @@ impl PDRouter {
|
|||||||
interval_secs,
|
interval_secs,
|
||||||
worker_loads,
|
worker_loads,
|
||||||
load_monitor_handle,
|
load_monitor_handle,
|
||||||
http_client,
|
client,
|
||||||
_prefill_health_checker: Some(prefill_health_checker),
|
_prefill_health_checker: Some(prefill_health_checker),
|
||||||
_decode_health_checker: Some(decode_health_checker),
|
_decode_health_checker: Some(decode_health_checker),
|
||||||
})
|
})
|
||||||
@@ -302,7 +297,6 @@ impl PDRouter {
|
|||||||
// Route a typed generate request
|
// Route a typed generate request
|
||||||
pub async fn route_generate(
|
pub async fn route_generate(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
mut typed_req: GenerateReqInput,
|
mut typed_req: GenerateReqInput,
|
||||||
route: &str,
|
route: &str,
|
||||||
@@ -371,7 +365,6 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Execute dual dispatch
|
// Execute dual dispatch
|
||||||
self.execute_dual_dispatch(
|
self.execute_dual_dispatch(
|
||||||
client,
|
|
||||||
headers,
|
headers,
|
||||||
json_with_bootstrap,
|
json_with_bootstrap,
|
||||||
route,
|
route,
|
||||||
@@ -387,7 +380,6 @@ impl PDRouter {
|
|||||||
// Route a typed chat request
|
// Route a typed chat request
|
||||||
pub async fn route_chat(
|
pub async fn route_chat(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
mut typed_req: ChatReqInput,
|
mut typed_req: ChatReqInput,
|
||||||
route: &str,
|
route: &str,
|
||||||
@@ -459,7 +451,6 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Execute dual dispatch
|
// Execute dual dispatch
|
||||||
self.execute_dual_dispatch(
|
self.execute_dual_dispatch(
|
||||||
client,
|
|
||||||
headers,
|
headers,
|
||||||
json_with_bootstrap,
|
json_with_bootstrap,
|
||||||
route,
|
route,
|
||||||
@@ -475,7 +466,6 @@ impl PDRouter {
|
|||||||
// Route a completion request while preserving OpenAI format
|
// Route a completion request while preserving OpenAI format
|
||||||
pub async fn route_completion(
|
pub async fn route_completion(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
mut typed_req: CompletionRequest,
|
mut typed_req: CompletionRequest,
|
||||||
route: &str,
|
route: &str,
|
||||||
@@ -540,7 +530,6 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Execute dual dispatch
|
// Execute dual dispatch
|
||||||
self.execute_dual_dispatch(
|
self.execute_dual_dispatch(
|
||||||
client,
|
|
||||||
headers,
|
headers,
|
||||||
json_with_bootstrap,
|
json_with_bootstrap,
|
||||||
route,
|
route,
|
||||||
@@ -554,10 +543,8 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Execute the dual dispatch to prefill and decode servers
|
// Execute the dual dispatch to prefill and decode servers
|
||||||
#[allow(clippy::too_many_arguments)]
|
|
||||||
async fn execute_dual_dispatch(
|
async fn execute_dual_dispatch(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
json_request: Value,
|
json_request: Value,
|
||||||
route: &str,
|
route: &str,
|
||||||
@@ -571,11 +558,13 @@ impl PDRouter {
|
|||||||
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
|
let _guard = WorkerLoadGuard::new_multi(vec![prefill, decode]);
|
||||||
|
|
||||||
// Build requests using .json() method
|
// Build requests using .json() method
|
||||||
let mut prefill_request = client
|
let mut prefill_request = self
|
||||||
|
.client
|
||||||
.post(api_path(prefill.url(), route))
|
.post(api_path(prefill.url(), route))
|
||||||
.json(&json_request);
|
.json(&json_request);
|
||||||
|
|
||||||
let mut decode_request = client
|
let mut decode_request = self
|
||||||
|
.client
|
||||||
.post(api_path(decode.url(), route))
|
.post(api_path(decode.url(), route))
|
||||||
.json(&json_request);
|
.json(&json_request);
|
||||||
|
|
||||||
@@ -987,7 +976,7 @@ async fn get_worker_load(client: &reqwest::Client, worker_url: &str) -> Option<i
|
|||||||
|
|
||||||
// PD-specific endpoints
|
// PD-specific endpoints
|
||||||
impl PDRouter {
|
impl PDRouter {
|
||||||
pub async fn health_generate(&self, client: &reqwest::Client) -> Response {
|
pub async fn health_generate(&self) -> Response {
|
||||||
// Test model generation capability by selecting a random pair and testing them
|
// Test model generation capability by selecting a random pair and testing them
|
||||||
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
|
// Note: This endpoint actually causes the model to generate tokens, so we only test one pair
|
||||||
|
|
||||||
@@ -1005,11 +994,11 @@ impl PDRouter {
|
|||||||
|
|
||||||
// Test prefill server's health_generate
|
// Test prefill server's health_generate
|
||||||
let prefill_url = format!("{}/health_generate", prefill.url());
|
let prefill_url = format!("{}/health_generate", prefill.url());
|
||||||
let prefill_result = client.get(&prefill_url).send().await;
|
let prefill_result = self.client.get(&prefill_url).send().await;
|
||||||
|
|
||||||
// Test decode server's health_generate
|
// Test decode server's health_generate
|
||||||
let decode_url = format!("{}/health_generate", decode.url());
|
let decode_url = format!("{}/health_generate", decode.url());
|
||||||
let decode_result = client.get(&decode_url).send().await;
|
let decode_result = self.client.get(&decode_url).send().await;
|
||||||
|
|
||||||
// Check results
|
// Check results
|
||||||
let mut errors = Vec::new();
|
let mut errors = Vec::new();
|
||||||
@@ -1068,7 +1057,7 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_server_info(&self, client: &reqwest::Client) -> Response {
|
pub async fn get_server_info(&self) -> Response {
|
||||||
// Get info from the first decode server to match sglang's server info format
|
// Get info from the first decode server to match sglang's server info format
|
||||||
let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
|
let first_decode_url = if let Ok(workers) = self.decode_workers.read() {
|
||||||
workers.first().map(|w| w.url().to_string())
|
workers.first().map(|w| w.url().to_string())
|
||||||
@@ -1081,7 +1070,8 @@ impl PDRouter {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if let Some(worker_url) = first_decode_url {
|
if let Some(worker_url) = first_decode_url {
|
||||||
match client
|
match self
|
||||||
|
.client
|
||||||
.get(format!("{}/get_server_info", worker_url))
|
.get(format!("{}/get_server_info", worker_url))
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
@@ -1130,7 +1120,7 @@ impl PDRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_models(&self, client: &reqwest::Client, req: Request<Body>) -> Response {
|
pub async fn get_models(&self, req: Request<Body>) -> Response {
|
||||||
// Extract headers first to avoid Send issues
|
// Extract headers first to avoid Send issues
|
||||||
let headers = crate::routers::router::copy_request_headers(&req);
|
let headers = crate::routers::router::copy_request_headers(&req);
|
||||||
|
|
||||||
@@ -1147,7 +1137,7 @@ impl PDRouter {
|
|||||||
|
|
||||||
if let Some(worker_url) = first_worker_url {
|
if let Some(worker_url) = first_worker_url {
|
||||||
// Send request directly without going through Router
|
// Send request directly without going through Router
|
||||||
let mut request_builder = client.get(format!("{}/v1/models", worker_url));
|
let mut request_builder = self.client.get(format!("{}/v1/models", worker_url));
|
||||||
for (name, value) in headers {
|
for (name, value) in headers {
|
||||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
||||||
{
|
{
|
||||||
@@ -1224,7 +1214,7 @@ impl PDRouter {
|
|||||||
.into_response()
|
.into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_model_info(&self, client: &reqwest::Client, req: Request<Body>) -> Response {
|
pub async fn get_model_info(&self, req: Request<Body>) -> Response {
|
||||||
// Extract headers first to avoid Send issues
|
// Extract headers first to avoid Send issues
|
||||||
let headers = crate::routers::router::copy_request_headers(&req);
|
let headers = crate::routers::router::copy_request_headers(&req);
|
||||||
|
|
||||||
@@ -1241,7 +1231,7 @@ impl PDRouter {
|
|||||||
};
|
};
|
||||||
|
|
||||||
if let Some(worker_url) = first_worker_url {
|
if let Some(worker_url) = first_worker_url {
|
||||||
let mut request_builder = client.get(format!("{}/get_model_info", worker_url));
|
let mut request_builder = self.client.get(format!("{}/get_model_info", worker_url));
|
||||||
for (name, value) in headers {
|
for (name, value) in headers {
|
||||||
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
if name.to_lowercase() != "content-type" && name.to_lowercase() != "content-length"
|
||||||
{
|
{
|
||||||
@@ -1384,7 +1374,7 @@ impl RouterTrait for PDRouter {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _client: &Client, _req: Request<Body>) -> Response {
|
async fn health(&self, _req: Request<Body>) -> Response {
|
||||||
// This is a server readiness check - checking if we have healthy workers
|
// This is a server readiness check - checking if we have healthy workers
|
||||||
// Workers handle their own health checks in the background
|
// Workers handle their own health checks in the background
|
||||||
let mut all_healthy = true;
|
let mut all_healthy = true;
|
||||||
@@ -1417,68 +1407,65 @@ impl RouterTrait for PDRouter {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health_generate(&self, client: &Client, _req: Request<Body>) -> Response {
|
async fn health_generate(&self, _req: Request<Body>) -> Response {
|
||||||
// Use the existing PDRouter health_generate method
|
// Use the existing PDRouter health_generate method
|
||||||
PDRouter::health_generate(self, client).await
|
PDRouter::health_generate(self).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_server_info(&self, client: &Client, _req: Request<Body>) -> Response {
|
async fn get_server_info(&self, _req: Request<Body>) -> Response {
|
||||||
// Use the existing PDRouter get_server_info method
|
// Use the existing PDRouter get_server_info method
|
||||||
PDRouter::get_server_info(self, client).await
|
PDRouter::get_server_info(self).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response {
|
async fn get_models(&self, req: Request<Body>) -> Response {
|
||||||
// Use the existing PDRouter get_models method
|
// Use the existing PDRouter get_models method
|
||||||
PDRouter::get_models(self, client, req).await
|
PDRouter::get_models(self, req).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response {
|
async fn get_model_info(&self, req: Request<Body>) -> Response {
|
||||||
// Use the existing PDRouter get_model_info method
|
// Use the existing PDRouter get_model_info method
|
||||||
PDRouter::get_model_info(self, client, req).await
|
PDRouter::get_model_info(self, req).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route_generate(
|
async fn route_generate(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
body: &GenerateRequest,
|
body: &GenerateRequest,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
// Convert OpenAI format to PD format
|
// Convert OpenAI format to PD format
|
||||||
let pd_req = body.clone().to_pd_request();
|
let pd_req = body.clone().to_pd_request();
|
||||||
|
|
||||||
PDRouter::route_generate(self, client, headers, pd_req, "/generate").await
|
PDRouter::route_generate(self, headers, pd_req, "/generate").await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route_chat(
|
async fn route_chat(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
body: &ChatCompletionRequest,
|
body: &ChatCompletionRequest,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
// Convert OpenAI format to PD format
|
// Convert OpenAI format to PD format
|
||||||
let pd_req = body.clone().to_pd_request();
|
let pd_req = body.clone().to_pd_request();
|
||||||
|
|
||||||
PDRouter::route_chat(self, client, headers, pd_req, "/v1/chat/completions").await
|
PDRouter::route_chat(self, headers, pd_req, "/v1/chat/completions").await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route_completion(
|
async fn route_completion(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
body: &CompletionRequest,
|
body: &CompletionRequest,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
// Use the new method that preserves OpenAI format
|
// Use the new method that preserves OpenAI format
|
||||||
PDRouter::route_completion(self, client, headers, body.clone(), "/v1/completions").await
|
PDRouter::route_completion(self, headers, body.clone(), "/v1/completions").await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self, client: &Client) -> Response {
|
async fn flush_cache(&self) -> Response {
|
||||||
// Use the existing PDRouter flush_cache method
|
// Use the existing PDRouter flush_cache method
|
||||||
PDRouter::flush_cache(self, client).await
|
PDRouter::flush_cache(self, &self.client).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_worker_loads(&self, client: &Client) -> Response {
|
async fn get_worker_loads(&self) -> Response {
|
||||||
// Use the existing PDRouter get_loads method
|
// Use the existing PDRouter get_loads method
|
||||||
PDRouter::get_loads(self, client).await
|
PDRouter::get_loads(self, &self.client).await
|
||||||
}
|
}
|
||||||
|
|
||||||
fn router_type(&self) -> &'static str {
|
fn router_type(&self) -> &'static str {
|
||||||
@@ -1570,7 +1557,7 @@ mod tests {
|
|||||||
interval_secs: 1,
|
interval_secs: 1,
|
||||||
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
worker_loads: Arc::new(tokio::sync::watch::channel(HashMap::new()).1),
|
||||||
load_monitor_handle: None,
|
load_monitor_handle: None,
|
||||||
http_client: reqwest::Client::new(),
|
client: Client::new(),
|
||||||
_prefill_health_checker: None,
|
_prefill_health_checker: None,
|
||||||
_decode_health_checker: None,
|
_decode_health_checker: None,
|
||||||
}
|
}
|
||||||
@@ -1959,11 +1946,10 @@ mod tests {
|
|||||||
router.decode_workers.write().unwrap().push(decode_worker);
|
router.decode_workers.write().unwrap().push(decode_worker);
|
||||||
|
|
||||||
// Test health endpoint
|
// Test health endpoint
|
||||||
let client = reqwest::Client::new();
|
|
||||||
let http_req = axum::http::Request::builder()
|
let http_req = axum::http::Request::builder()
|
||||||
.body(axum::body::Body::empty())
|
.body(axum::body::Body::empty())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let response = router.health(&client, http_req).await;
|
let response = router.health(http_req).await;
|
||||||
|
|
||||||
assert_eq!(response.status(), 200);
|
assert_eq!(response.status(), 200);
|
||||||
|
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ pub fn copy_request_headers(req: &Request<Body>) -> Vec<(String, String)> {
|
|||||||
pub struct Router {
|
pub struct Router {
|
||||||
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
workers: Arc<RwLock<Vec<Box<dyn Worker>>>>,
|
||||||
policy: Arc<dyn LoadBalancingPolicy>,
|
policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
|
client: Client,
|
||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
dp_aware: bool,
|
dp_aware: bool,
|
||||||
@@ -44,10 +45,11 @@ pub struct Router {
|
|||||||
}
|
}
|
||||||
|
|
||||||
impl Router {
|
impl Router {
|
||||||
/// Create a new router with injected policy
|
/// Create a new router with injected policy and client
|
||||||
pub fn new(
|
pub fn new(
|
||||||
worker_urls: Vec<String>,
|
worker_urls: Vec<String>,
|
||||||
policy: Arc<dyn LoadBalancingPolicy>,
|
policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
|
client: Client,
|
||||||
timeout_secs: u64,
|
timeout_secs: u64,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
dp_aware: bool,
|
dp_aware: bool,
|
||||||
@@ -94,9 +96,17 @@ impl Router {
|
|||||||
let monitor_urls = worker_urls.clone();
|
let monitor_urls = worker_urls.clone();
|
||||||
let monitor_interval = interval_secs;
|
let monitor_interval = interval_secs;
|
||||||
let policy_clone = Arc::clone(&policy);
|
let policy_clone = Arc::clone(&policy);
|
||||||
|
let client_clone = client.clone();
|
||||||
|
|
||||||
Some(Arc::new(tokio::spawn(async move {
|
Some(Arc::new(tokio::spawn(async move {
|
||||||
Self::monitor_worker_loads(monitor_urls, tx, monitor_interval, policy_clone).await;
|
Self::monitor_worker_loads(
|
||||||
|
monitor_urls,
|
||||||
|
tx,
|
||||||
|
monitor_interval,
|
||||||
|
policy_clone,
|
||||||
|
client_clone,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
})))
|
})))
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
@@ -105,6 +115,7 @@ impl Router {
|
|||||||
Ok(Router {
|
Ok(Router {
|
||||||
workers,
|
workers,
|
||||||
policy,
|
policy,
|
||||||
|
client,
|
||||||
timeout_secs,
|
timeout_secs,
|
||||||
interval_secs,
|
interval_secs,
|
||||||
dp_aware,
|
dp_aware,
|
||||||
@@ -245,7 +256,7 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn send_health_check(&self, client: &Client, worker_url: &str) -> Response {
|
pub async fn send_health_check(&self, worker_url: &str) -> Response {
|
||||||
let health_url = if self.dp_aware {
|
let health_url = if self.dp_aware {
|
||||||
// Need to extract the URL from "http://host:port@dp_rank"
|
// Need to extract the URL from "http://host:port@dp_rank"
|
||||||
match Self::extract_dp_rank(worker_url) {
|
match Self::extract_dp_rank(worker_url) {
|
||||||
@@ -263,7 +274,7 @@ impl Router {
|
|||||||
worker_url
|
worker_url
|
||||||
};
|
};
|
||||||
|
|
||||||
let request_builder = client.get(format!("{}/health", health_url));
|
let request_builder = self.client.get(format!("{}/health", health_url));
|
||||||
|
|
||||||
let response = match request_builder.send().await {
|
let response = match request_builder.send().await {
|
||||||
Ok(res) => {
|
Ok(res) => {
|
||||||
@@ -305,17 +316,12 @@ impl Router {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Helper method to proxy GET requests to the first available worker
|
// Helper method to proxy GET requests to the first available worker
|
||||||
async fn proxy_get_request(
|
async fn proxy_get_request(&self, req: Request<Body>, endpoint: &str) -> Response {
|
||||||
&self,
|
|
||||||
client: &Client,
|
|
||||||
req: Request<Body>,
|
|
||||||
endpoint: &str,
|
|
||||||
) -> Response {
|
|
||||||
let headers = copy_request_headers(&req);
|
let headers = copy_request_headers(&req);
|
||||||
|
|
||||||
match self.select_first_worker() {
|
match self.select_first_worker() {
|
||||||
Ok(worker_url) => {
|
Ok(worker_url) => {
|
||||||
let mut request_builder = client.get(format!("{}/{}", worker_url, endpoint));
|
let mut request_builder = self.client.get(format!("{}/{}", worker_url, endpoint));
|
||||||
for (name, value) in headers {
|
for (name, value) in headers {
|
||||||
if name.to_lowercase() != "content-type"
|
if name.to_lowercase() != "content-type"
|
||||||
&& name.to_lowercase() != "content-length"
|
&& name.to_lowercase() != "content-length"
|
||||||
@@ -353,7 +359,6 @@ impl Router {
|
|||||||
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
|
T: crate::openai_api_types::GenerationRequest + serde::Serialize + Clone,
|
||||||
>(
|
>(
|
||||||
&self,
|
&self,
|
||||||
client: &reqwest::Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
typed_req: &T,
|
typed_req: &T,
|
||||||
route: &str,
|
route: &str,
|
||||||
@@ -397,7 +402,6 @@ impl Router {
|
|||||||
// Send typed request directly
|
// Send typed request directly
|
||||||
let response = self
|
let response = self
|
||||||
.send_typed_request(
|
.send_typed_request(
|
||||||
client,
|
|
||||||
headers,
|
headers,
|
||||||
typed_req,
|
typed_req,
|
||||||
route,
|
route,
|
||||||
@@ -413,7 +417,7 @@ impl Router {
|
|||||||
return response;
|
return response;
|
||||||
} else {
|
} else {
|
||||||
// if the worker is healthy, it means the request is bad, so return the error response
|
// if the worker is healthy, it means the request is bad, so return the error response
|
||||||
let health_response = self.send_health_check(client, &worker_url).await;
|
let health_response = self.send_health_check(&worker_url).await;
|
||||||
if health_response.status().is_success() {
|
if health_response.status().is_success() {
|
||||||
RouterMetrics::record_request_error(route, "request_failed");
|
RouterMetrics::record_request_error(route, "request_failed");
|
||||||
return response;
|
return response;
|
||||||
@@ -483,7 +487,6 @@ impl Router {
|
|||||||
// Send typed request directly without conversion
|
// Send typed request directly without conversion
|
||||||
async fn send_typed_request<T: serde::Serialize>(
|
async fn send_typed_request<T: serde::Serialize>(
|
||||||
&self,
|
&self,
|
||||||
client: &reqwest::Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
typed_req: &T,
|
typed_req: &T,
|
||||||
route: &str,
|
route: &str,
|
||||||
@@ -536,11 +539,11 @@ impl Router {
|
|||||||
.into_response();
|
.into_response();
|
||||||
}
|
}
|
||||||
|
|
||||||
client
|
self.client
|
||||||
.post(format!("{}{}", worker_url_prefix, route))
|
.post(format!("{}{}", worker_url_prefix, route))
|
||||||
.json(&json_val)
|
.json(&json_val)
|
||||||
} else {
|
} else {
|
||||||
client
|
self.client
|
||||||
.post(format!("{}{}", worker_url, route))
|
.post(format!("{}{}", worker_url, route))
|
||||||
.json(typed_req) // Use json() directly with typed request
|
.json(typed_req) // Use json() directly with typed request
|
||||||
};
|
};
|
||||||
@@ -866,7 +869,7 @@ impl Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_worker_load(&self, client: &reqwest::Client, worker_url: &str) -> Option<isize> {
|
async fn get_worker_load(&self, worker_url: &str) -> Option<isize> {
|
||||||
let worker_url = if self.dp_aware {
|
let worker_url = if self.dp_aware {
|
||||||
// Need to extract the URL from "http://host:port@dp_rank"
|
// Need to extract the URL from "http://host:port@dp_rank"
|
||||||
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
let (worker_url_prefix, _dp_rank) = match Self::extract_dp_rank(worker_url) {
|
||||||
@@ -881,7 +884,12 @@ impl Router {
|
|||||||
worker_url
|
worker_url
|
||||||
};
|
};
|
||||||
|
|
||||||
match client.get(&format!("{}/get_load", worker_url)).send().await {
|
match self
|
||||||
|
.client
|
||||||
|
.get(&format!("{}/get_load", worker_url))
|
||||||
|
.send()
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(res) if res.status().is_success() => match res.bytes().await {
|
Ok(res) if res.status().is_success() => match res.bytes().await {
|
||||||
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
Ok(bytes) => match serde_json::from_slice::<serde_json::Value>(&bytes) {
|
||||||
Ok(data) => data
|
Ok(data) => data
|
||||||
@@ -919,18 +927,8 @@ impl Router {
|
|||||||
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
|
tx: tokio::sync::watch::Sender<HashMap<String, isize>>,
|
||||||
interval_secs: u64,
|
interval_secs: u64,
|
||||||
policy: Arc<dyn LoadBalancingPolicy>,
|
policy: Arc<dyn LoadBalancingPolicy>,
|
||||||
|
client: Client,
|
||||||
) {
|
) {
|
||||||
let client = match reqwest::Client::builder()
|
|
||||||
.timeout(Duration::from_secs(5))
|
|
||||||
.build()
|
|
||||||
{
|
|
||||||
Ok(c) => c,
|
|
||||||
Err(e) => {
|
|
||||||
error!("Failed to create HTTP client for load monitoring: {}", e);
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
|
let mut interval = tokio::time::interval(Duration::from_secs(interval_secs));
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
@@ -1028,7 +1026,7 @@ impl RouterTrait for Router {
|
|||||||
self
|
self
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health(&self, _client: &Client, _req: Request<Body>) -> Response {
|
async fn health(&self, _req: Request<Body>) -> Response {
|
||||||
let workers = self.workers.read().unwrap();
|
let workers = self.workers.read().unwrap();
|
||||||
let unhealthy_servers: Vec<_> = workers
|
let unhealthy_servers: Vec<_> = workers
|
||||||
.iter()
|
.iter()
|
||||||
@@ -1047,53 +1045,49 @@ impl RouterTrait for Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health_generate(&self, client: &Client, req: Request<Body>) -> Response {
|
async fn health_generate(&self, req: Request<Body>) -> Response {
|
||||||
self.proxy_get_request(client, req, "health_generate").await
|
self.proxy_get_request(req, "health_generate").await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_server_info(&self, client: &Client, req: Request<Body>) -> Response {
|
async fn get_server_info(&self, req: Request<Body>) -> Response {
|
||||||
self.proxy_get_request(client, req, "get_server_info").await
|
self.proxy_get_request(req, "get_server_info").await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_models(&self, client: &Client, req: Request<Body>) -> Response {
|
async fn get_models(&self, req: Request<Body>) -> Response {
|
||||||
self.proxy_get_request(client, req, "v1/models").await
|
self.proxy_get_request(req, "v1/models").await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_model_info(&self, client: &Client, req: Request<Body>) -> Response {
|
async fn get_model_info(&self, req: Request<Body>) -> Response {
|
||||||
self.proxy_get_request(client, req, "get_model_info").await
|
self.proxy_get_request(req, "get_model_info").await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route_generate(
|
async fn route_generate(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
body: &GenerateRequest,
|
body: &GenerateRequest,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
self.route_typed_request(client, headers, body, "/generate")
|
self.route_typed_request(headers, body, "/generate").await
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route_chat(
|
async fn route_chat(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
body: &ChatCompletionRequest,
|
body: &ChatCompletionRequest,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
self.route_typed_request(client, headers, body, "/v1/chat/completions")
|
self.route_typed_request(headers, body, "/v1/chat/completions")
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn route_completion(
|
async fn route_completion(
|
||||||
&self,
|
&self,
|
||||||
client: &Client,
|
|
||||||
headers: Option<&HeaderMap>,
|
headers: Option<&HeaderMap>,
|
||||||
body: &CompletionRequest,
|
body: &CompletionRequest,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
self.route_typed_request(client, headers, body, "/v1/completions")
|
self.route_typed_request(headers, body, "/v1/completions")
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn flush_cache(&self, client: &Client) -> Response {
|
async fn flush_cache(&self) -> Response {
|
||||||
// Get all worker URLs
|
// Get all worker URLs
|
||||||
let worker_urls = self.get_worker_urls();
|
let worker_urls = self.get_worker_urls();
|
||||||
|
|
||||||
@@ -1117,7 +1111,7 @@ impl RouterTrait for Router {
|
|||||||
} else {
|
} else {
|
||||||
worker_url
|
worker_url
|
||||||
};
|
};
|
||||||
let request_builder = client.post(format!("{}/flush_cache", worker_url));
|
let request_builder = self.client.post(format!("{}/flush_cache", worker_url));
|
||||||
tasks.push(request_builder.send());
|
tasks.push(request_builder.send());
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1142,13 +1136,13 @@ impl RouterTrait for Router {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_worker_loads(&self, client: &Client) -> Response {
|
async fn get_worker_loads(&self) -> Response {
|
||||||
let urls = self.get_worker_urls();
|
let urls = self.get_worker_urls();
|
||||||
let mut loads = Vec::new();
|
let mut loads = Vec::new();
|
||||||
|
|
||||||
// Get loads from all workers
|
// Get loads from all workers
|
||||||
for url in &urls {
|
for url in &urls {
|
||||||
let load = self.get_worker_load(client, url).await.unwrap_or(-1);
|
let load = self.get_worker_load(url).await.unwrap_or(-1);
|
||||||
loads.push(serde_json::json!({
|
loads.push(serde_json::json!({
|
||||||
"worker": url,
|
"worker": url,
|
||||||
"load": load
|
"load": load
|
||||||
@@ -1215,6 +1209,7 @@ mod tests {
|
|||||||
interval_secs: 1,
|
interval_secs: 1,
|
||||||
dp_aware: false,
|
dp_aware: false,
|
||||||
api_key: None,
|
api_key: None,
|
||||||
|
client: Client::new(),
|
||||||
_worker_loads: Arc::new(rx),
|
_worker_loads: Arc::new(rx),
|
||||||
_load_monitor_handle: None,
|
_load_monitor_handle: None,
|
||||||
_health_checker: None,
|
_health_checker: None,
|
||||||
|
|||||||
@@ -22,29 +22,34 @@ use tokio::spawn;
|
|||||||
use tracing::{error, info, warn, Level};
|
use tracing::{error, info, warn, Level};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppContext {
|
||||||
pub router: Arc<dyn RouterTrait>,
|
|
||||||
pub client: Client,
|
pub client: Client,
|
||||||
pub _concurrency_limiter: Arc<tokio::sync::Semaphore>,
|
pub router_config: RouterConfig,
|
||||||
|
pub concurrency_limiter: Arc<tokio::sync::Semaphore>,
|
||||||
|
// Future dependencies can be added here
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AppState {
|
impl AppContext {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
router_config: RouterConfig,
|
router_config: RouterConfig,
|
||||||
client: Client,
|
client: Client,
|
||||||
max_concurrent_requests: usize,
|
max_concurrent_requests: usize,
|
||||||
) -> Result<Self, String> {
|
) -> Self {
|
||||||
let router = RouterFactory::create_router(&router_config)?;
|
|
||||||
let router = Arc::from(router);
|
|
||||||
let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests));
|
let concurrency_limiter = Arc::new(tokio::sync::Semaphore::new(max_concurrent_requests));
|
||||||
Ok(Self {
|
Self {
|
||||||
router,
|
|
||||||
client,
|
client,
|
||||||
_concurrency_limiter: concurrency_limiter,
|
router_config,
|
||||||
})
|
concurrency_limiter,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AppState {
|
||||||
|
pub router: Arc<dyn RouterTrait>,
|
||||||
|
pub context: Arc<AppContext>,
|
||||||
|
}
|
||||||
|
|
||||||
// Fallback handler for unmatched routes
|
// Fallback handler for unmatched routes
|
||||||
async fn sink_handler() -> Response {
|
async fn sink_handler() -> Response {
|
||||||
StatusCode::NOT_FOUND.into_response()
|
StatusCode::NOT_FOUND.into_response()
|
||||||
@@ -60,23 +65,23 @@ async fn readiness(State(state): State<Arc<AppState>>) -> Response {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
async fn health(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||||
state.router.health(&state.client, req).await
|
state.router.health(req).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
async fn health_generate(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||||
state.router.health_generate(&state.client, req).await
|
state.router.health_generate(req).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
async fn get_server_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||||
state.router.get_server_info(&state.client, req).await
|
state.router.get_server_info(req).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
async fn v1_models(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||||
state.router.get_models(&state.client, req).await
|
state.router.get_models(req).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
async fn get_model_info(State(state): State<Arc<AppState>>, req: Request) -> Response {
|
||||||
state.router.get_model_info(&state.client, req).await
|
state.router.get_model_info(req).await
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generation endpoints
|
// Generation endpoints
|
||||||
@@ -86,10 +91,7 @@ async fn generate(
|
|||||||
headers: http::HeaderMap,
|
headers: http::HeaderMap,
|
||||||
Json(body): Json<GenerateRequest>,
|
Json(body): Json<GenerateRequest>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
state
|
state.router.route_generate(Some(&headers), &body).await
|
||||||
.router
|
|
||||||
.route_generate(&state.client, Some(&headers), &body)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn v1_chat_completions(
|
async fn v1_chat_completions(
|
||||||
@@ -97,10 +99,7 @@ async fn v1_chat_completions(
|
|||||||
headers: http::HeaderMap,
|
headers: http::HeaderMap,
|
||||||
Json(body): Json<ChatCompletionRequest>,
|
Json(body): Json<ChatCompletionRequest>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
state
|
state.router.route_chat(Some(&headers), &body).await
|
||||||
.router
|
|
||||||
.route_chat(&state.client, Some(&headers), &body)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn v1_completions(
|
async fn v1_completions(
|
||||||
@@ -108,10 +107,7 @@ async fn v1_completions(
|
|||||||
headers: http::HeaderMap,
|
headers: http::HeaderMap,
|
||||||
Json(body): Json<CompletionRequest>,
|
Json(body): Json<CompletionRequest>,
|
||||||
) -> Response {
|
) -> Response {
|
||||||
state
|
state.router.route_completion(Some(&headers), &body).await
|
||||||
.router
|
|
||||||
.route_completion(&state.client, Some(&headers), &body)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Worker management endpoints
|
// Worker management endpoints
|
||||||
@@ -159,11 +155,11 @@ async fn remove_worker(
|
|||||||
}
|
}
|
||||||
|
|
||||||
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
async fn flush_cache(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
||||||
state.router.flush_cache(&state.client).await
|
state.router.flush_cache().await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
async fn get_loads(State(state): State<Arc<AppState>>, _req: Request) -> Response {
|
||||||
state.router.get_worker_loads(&state.client).await
|
state.router.get_worker_loads().await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct ServerConfig {
|
pub struct ServerConfig {
|
||||||
@@ -281,11 +277,21 @@ pub async fn startup(config: ServerConfig) -> Result<(), Box<dyn std::error::Err
|
|||||||
.build()
|
.build()
|
||||||
.expect("Failed to create HTTP client");
|
.expect("Failed to create HTTP client");
|
||||||
|
|
||||||
let app_state = Arc::new(AppState::new(
|
// Create the application context with all dependencies
|
||||||
|
let app_context = Arc::new(AppContext::new(
|
||||||
config.router_config.clone(),
|
config.router_config.clone(),
|
||||||
client.clone(),
|
client.clone(),
|
||||||
config.router_config.max_concurrent_requests,
|
config.router_config.max_concurrent_requests,
|
||||||
)?);
|
));
|
||||||
|
|
||||||
|
// Create router with the context
|
||||||
|
let router = RouterFactory::create_router(&app_context)?;
|
||||||
|
|
||||||
|
// Create app state with router and context
|
||||||
|
let app_state = Arc::new(AppState {
|
||||||
|
router: Arc::from(router),
|
||||||
|
context: app_context.clone(),
|
||||||
|
});
|
||||||
let router_arc = Arc::clone(&app_state.router);
|
let router_arc = Arc::clone(&app_state.router);
|
||||||
|
|
||||||
// Start the service discovery if enabled
|
// Start the service discovery if enabled
|
||||||
|
|||||||
@@ -40,7 +40,6 @@ impl Default for ServiceDiscoveryConfig {
|
|||||||
check_interval: Duration::from_secs(60),
|
check_interval: Duration::from_secs(60),
|
||||||
port: 8000, // Standard port for modern services
|
port: 8000, // Standard port for modern services
|
||||||
namespace: None, // None means watch all namespaces
|
namespace: None, // None means watch all namespaces
|
||||||
// PD mode defaults
|
|
||||||
pd_mode: false,
|
pd_mode: false,
|
||||||
prefill_selector: HashMap::new(),
|
prefill_selector: HashMap::new(),
|
||||||
decode_selector: HashMap::new(),
|
decode_selector: HashMap::new(),
|
||||||
@@ -581,7 +580,8 @@ mod tests {
|
|||||||
use crate::routers::router::Router;
|
use crate::routers::router::Router;
|
||||||
|
|
||||||
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
let policy = PolicyFactory::create_from_config(&PolicyConfig::Random);
|
||||||
let router = Router::new(vec![], policy, 5, 1, false, None).unwrap();
|
let router =
|
||||||
|
Router::new(vec![], policy, reqwest::Client::new(), 5, 1, false, None).unwrap();
|
||||||
Arc::new(router) as Arc<dyn RouterTrait>
|
Arc::new(router) as Arc<dyn RouterTrait>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -83,12 +83,12 @@ impl TestContext {
|
|||||||
.build()
|
.build()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
// Clone config for the closure
|
// Create app context
|
||||||
let config_clone = config.clone();
|
let app_context = common::create_test_context(config.clone());
|
||||||
|
|
||||||
// Create router using sync factory in a blocking context
|
// Create router using sync factory in a blocking context
|
||||||
let router =
|
let router =
|
||||||
tokio::task::spawn_blocking(move || RouterFactory::create_router(&config_clone))
|
tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context))
|
||||||
.await
|
.await
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.unwrap();
|
.unwrap();
|
||||||
@@ -1433,9 +1433,12 @@ mod pd_mode_tests {
|
|||||||
cors_allowed_origins: vec![],
|
cors_allowed_origins: vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// Create app context
|
||||||
|
let app_context = common::create_test_context(config);
|
||||||
|
|
||||||
// Create router - this might fail due to health check issues
|
// Create router - this might fail due to health check issues
|
||||||
let router_result =
|
let router_result =
|
||||||
tokio::task::spawn_blocking(move || RouterFactory::create_router(&config))
|
tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context))
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|||||||
@@ -1,2 +1,15 @@
|
|||||||
pub mod mock_worker;
|
pub mod mock_worker;
|
||||||
pub mod test_app;
|
pub mod test_app;
|
||||||
|
|
||||||
|
use sglang_router_rs::config::RouterConfig;
|
||||||
|
use sglang_router_rs::server::AppContext;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
/// Helper function to create AppContext for tests
|
||||||
|
pub fn create_test_context(config: RouterConfig) -> Arc<AppContext> {
|
||||||
|
Arc::new(AppContext::new(
|
||||||
|
config.clone(),
|
||||||
|
reqwest::Client::new(),
|
||||||
|
config.max_concurrent_requests,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ use reqwest::Client;
|
|||||||
use sglang_router_rs::{
|
use sglang_router_rs::{
|
||||||
config::RouterConfig,
|
config::RouterConfig,
|
||||||
routers::RouterTrait,
|
routers::RouterTrait,
|
||||||
server::{build_app, AppState},
|
server::{build_app, AppContext, AppState},
|
||||||
};
|
};
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
@@ -13,13 +13,17 @@ pub fn create_test_app(
|
|||||||
client: Client,
|
client: Client,
|
||||||
router_config: &RouterConfig,
|
router_config: &RouterConfig,
|
||||||
) -> Router {
|
) -> Router {
|
||||||
// Create AppState with the test router
|
// Create AppContext
|
||||||
|
let app_context = Arc::new(AppContext::new(
|
||||||
|
router_config.clone(),
|
||||||
|
client,
|
||||||
|
router_config.max_concurrent_requests,
|
||||||
|
));
|
||||||
|
|
||||||
|
// Create AppState with the test router and context
|
||||||
let app_state = Arc::new(AppState {
|
let app_state = Arc::new(AppState {
|
||||||
router,
|
router,
|
||||||
client,
|
context: app_context,
|
||||||
_concurrency_limiter: Arc::new(tokio::sync::Semaphore::new(
|
|
||||||
router_config.max_concurrent_requests,
|
|
||||||
)),
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Configure request ID headers (use defaults if not specified)
|
// Configure request ID headers (use defaults if not specified)
|
||||||
|
|||||||
@@ -53,10 +53,12 @@ impl TestContext {
|
|||||||
|
|
||||||
config.mode = RoutingMode::Regular { worker_urls };
|
config.mode = RoutingMode::Regular { worker_urls };
|
||||||
|
|
||||||
let router = tokio::task::spawn_blocking(move || RouterFactory::create_router(&config))
|
let app_context = common::create_test_context(config);
|
||||||
.await
|
let router =
|
||||||
.unwrap()
|
tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context))
|
||||||
.unwrap();
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
let router = Arc::from(router);
|
let router = Arc::from(router);
|
||||||
|
|
||||||
if !workers.is_empty() {
|
if !workers.is_empty() {
|
||||||
|
|||||||
@@ -54,10 +54,12 @@ impl TestContext {
|
|||||||
|
|
||||||
config.mode = RoutingMode::Regular { worker_urls };
|
config.mode = RoutingMode::Regular { worker_urls };
|
||||||
|
|
||||||
let router = tokio::task::spawn_blocking(move || RouterFactory::create_router(&config))
|
let app_context = common::create_test_context(config);
|
||||||
.await
|
let router =
|
||||||
.unwrap()
|
tokio::task::spawn_blocking(move || RouterFactory::create_router(&app_context))
|
||||||
.unwrap();
|
.await
|
||||||
|
.unwrap()
|
||||||
|
.unwrap();
|
||||||
let router = Arc::from(router);
|
let router = Arc::from(router);
|
||||||
|
|
||||||
if !workers.is_empty() {
|
if !workers.is_empty() {
|
||||||
|
|||||||
@@ -181,7 +181,10 @@ mod test_pd_routing {
|
|||||||
};
|
};
|
||||||
|
|
||||||
// Router creation will fail due to health checks, but config should be valid
|
// Router creation will fail due to health checks, but config should be valid
|
||||||
let result = RouterFactory::create_router(&config);
|
let app_context =
|
||||||
|
sglang_router_rs::server::AppContext::new(config, reqwest::Client::new(), 64);
|
||||||
|
let app_context = std::sync::Arc::new(app_context);
|
||||||
|
let result = RouterFactory::create_router(&app_context);
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
let error_msg = result.unwrap_err();
|
let error_msg = result.unwrap_err();
|
||||||
// Error should be about health/timeout, not configuration
|
// Error should be about health/timeout, not configuration
|
||||||
|
|||||||
Reference in New Issue
Block a user