[router] Refactor router and policy traits with dependency injection (#7987)
Co-authored-by: Jin Pan <jpan236@wisc.edu> Co-authored-by: Keru Yang <rukeyang@gmail.com> Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com> Co-authored-by: Philip Zhu <phlipzhux@gmail.com>
This commit is contained in:
@@ -1,9 +1,8 @@
|
||||
use crate::config::RouterConfig;
|
||||
use crate::logging::{self, LoggingConfig};
|
||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use crate::prometheus::{self, PrometheusConfig};
|
||||
use crate::request_adapter::ToPdRequest;
|
||||
use crate::router::PolicyConfig;
|
||||
use crate::router::Router;
|
||||
use crate::routers::{RouterFactory, RouterTrait};
|
||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||
use actix_web::{
|
||||
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
|
||||
@@ -19,27 +18,19 @@ use tracing::{error, info, warn, Level};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct AppState {
|
||||
router: Arc<Router>,
|
||||
router: Arc<dyn RouterTrait>,
|
||||
client: Client,
|
||||
is_pd_mode: bool, // Add flag to track PD mode
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
pub fn new(
|
||||
worker_urls: Vec<String>,
|
||||
client: Client,
|
||||
policy_config: PolicyConfig,
|
||||
) -> Result<Self, String> {
|
||||
// Check if this is PD mode from policy config
|
||||
let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. });
|
||||
pub fn new(router_config: RouterConfig, client: Client) -> Result<Self, String> {
|
||||
// Use RouterFactory to create the appropriate router type
|
||||
let router = RouterFactory::create_router(&router_config)?;
|
||||
|
||||
// Create router based on policy
|
||||
let router = Arc::new(Router::new(worker_urls, policy_config)?);
|
||||
Ok(Self {
|
||||
router,
|
||||
client,
|
||||
is_pd_mode,
|
||||
})
|
||||
// Convert Box<dyn RouterTrait> to Arc<dyn RouterTrait>
|
||||
let router = Arc::from(router);
|
||||
|
||||
Ok(Self { router, client })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -76,65 +67,39 @@ fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/liveness")]
|
||||
async fn liveness(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router.liveness()
|
||||
}
|
||||
|
||||
#[get("/readiness")]
|
||||
async fn readiness(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router.readiness()
|
||||
}
|
||||
|
||||
#[get("/health")]
|
||||
async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/health", &req)
|
||||
.await
|
||||
data.router.health(&data.client, &req).await
|
||||
}
|
||||
|
||||
#[get("/health_generate")]
|
||||
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
// Check if we're in PD mode
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, check health on all servers
|
||||
data.router
|
||||
.route_pd_health_generate(&data.client, &req)
|
||||
.await
|
||||
} else {
|
||||
// Regular mode
|
||||
data.router
|
||||
.route_to_first(&data.client, "/health_generate", &req)
|
||||
.await
|
||||
}
|
||||
data.router.health_generate(&data.client, &req).await
|
||||
}
|
||||
|
||||
#[get("/get_server_info")]
|
||||
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, aggregate info from both prefill and decode servers
|
||||
data.router.get_pd_server_info(&data.client, &req).await
|
||||
} else {
|
||||
// Regular mode - return first server's info
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_server_info", &req)
|
||||
.await
|
||||
}
|
||||
data.router.get_server_info(&data.client, &req).await
|
||||
}
|
||||
|
||||
#[get("/v1/models")]
|
||||
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, return models from the first prefill server
|
||||
data.router.get_pd_models(&data.client, &req).await
|
||||
} else {
|
||||
// Regular mode
|
||||
data.router
|
||||
.route_to_first(&data.client, "/v1/models", &req)
|
||||
.await
|
||||
}
|
||||
data.router.get_models(&data.client, &req).await
|
||||
}
|
||||
|
||||
#[get("/get_model_info")]
|
||||
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, get model info from the first prefill server
|
||||
data.router.get_pd_model_info(&data.client, &req).await
|
||||
} else {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_model_info", &req)
|
||||
.await
|
||||
}
|
||||
data.router.get_model_info(&data.client, &req).await
|
||||
}
|
||||
|
||||
#[post("/generate")]
|
||||
@@ -143,24 +108,12 @@ async fn generate(
|
||||
body: web::Json<GenerateRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_generate_typed(&client, &req, pd_request, "/generate")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/generate")
|
||||
.await)
|
||||
}
|
||||
let json_body = serde_json::to_value(body.into_inner())
|
||||
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?;
|
||||
Ok(state
|
||||
.router
|
||||
.route_generate(&state.client, &req, json_body)
|
||||
.await)
|
||||
}
|
||||
|
||||
#[post("/v1/chat/completions")]
|
||||
@@ -169,24 +122,12 @@ async fn v1_chat_completions(
|
||||
body: web::Json<ChatCompletionRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/v1/chat/completions")
|
||||
.await)
|
||||
}
|
||||
let json_body = serde_json::to_value(body.into_inner())
|
||||
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?;
|
||||
Ok(state
|
||||
.router
|
||||
.route_chat(&state.client, &req, json_body)
|
||||
.await)
|
||||
}
|
||||
|
||||
#[post("/v1/completions")]
|
||||
@@ -195,24 +136,12 @@ async fn v1_completions(
|
||||
body: web::Json<CompletionRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_generate_typed(&client, &req, pd_request, "/v1/completions")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/v1/completions")
|
||||
.await)
|
||||
}
|
||||
let json_body = serde_json::to_value(body.into_inner())
|
||||
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?;
|
||||
Ok(state
|
||||
.router
|
||||
.route_completion(&state.client, &req, json_body)
|
||||
.await)
|
||||
}
|
||||
|
||||
#[post("/add_worker")]
|
||||
@@ -254,29 +183,19 @@ async fn remove_worker(
|
||||
}
|
||||
|
||||
#[post("/flush_cache")]
|
||||
async fn flush_cache(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, flush cache on both prefill and decode servers
|
||||
data.router.route_pd_flush_cache(&data.client).await
|
||||
} else {
|
||||
// Route to all workers for cache flushing
|
||||
data.router
|
||||
.route_to_all(&data.client, "/flush_cache", &req)
|
||||
.await
|
||||
}
|
||||
async fn flush_cache(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router.flush_cache(&data.client).await
|
||||
}
|
||||
|
||||
#[get("/get_loads")]
|
||||
async fn get_loads(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
// Get loads from all workers
|
||||
data.router.get_all_loads(&data.client, &req).await
|
||||
async fn get_loads(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router.get_worker_loads(&data.client).await
|
||||
}
|
||||
|
||||
pub struct ServerConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
pub worker_urls: Vec<String>,
|
||||
pub policy_config: PolicyConfig,
|
||||
pub router_config: RouterConfig,
|
||||
pub max_payload_size: usize,
|
||||
pub log_dir: Option<String>,
|
||||
pub log_level: Option<String>,
|
||||
@@ -324,8 +243,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
}
|
||||
|
||||
info!("🚧 Initializing router on {}:{}", config.host, config.port);
|
||||
info!("🚧 Initializing workers on {:?}", config.worker_urls);
|
||||
info!("🚧 Policy Config: {:?}", config.policy_config);
|
||||
info!("🚧 Router mode: {:?}", config.router_config.mode);
|
||||
info!("🚧 Policy: {:?}", config.router_config.policy);
|
||||
info!(
|
||||
"🚧 Max payload size: {} MB",
|
||||
config.max_payload_size / (1024 * 1024)
|
||||
@@ -345,12 +264,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
let app_state_init = AppState::new(
|
||||
config.worker_urls.clone(),
|
||||
client.clone(),
|
||||
config.policy_config.clone(),
|
||||
)
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
||||
let app_state_init = AppState::new(config.router_config.clone(), client.clone())
|
||||
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
|
||||
let router_arc = Arc::clone(&app_state_init.router);
|
||||
let app_state = web::Data::new(app_state_init);
|
||||
|
||||
@@ -397,6 +312,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
.service(v1_completions)
|
||||
.service(v1_models)
|
||||
.service(get_model_info)
|
||||
.service(liveness)
|
||||
.service(readiness)
|
||||
.service(health)
|
||||
.service(health_generate)
|
||||
.service(get_server_info)
|
||||
|
||||
Reference in New Issue
Block a user