[router] Refactor router and policy traits with dependency injection (#7987)

Co-authored-by: Jin Pan <jpan236@wisc.edu>
Co-authored-by: Keru Yang <rukeyang@gmail.com>
Co-authored-by: Yingyi Huang <yingyihuang2000@outlook.com>
Co-authored-by: Philip Zhu <phlipzhux@gmail.com>
This commit is contained in:
Simo Lin
2025-07-18 14:24:24 -07:00
committed by GitHub
parent 1f76fc8747
commit c8f31042a8
24 changed files with 3190 additions and 1944 deletions

View File

@@ -1,9 +1,8 @@
use crate::config::RouterConfig;
use crate::logging::{self, LoggingConfig};
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::prometheus::{self, PrometheusConfig};
use crate::request_adapter::ToPdRequest;
use crate::router::PolicyConfig;
use crate::router::Router;
use crate::routers::{RouterFactory, RouterTrait};
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use actix_web::{
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
@@ -19,27 +18,19 @@ use tracing::{error, info, warn, Level};
#[derive(Debug)]
pub struct AppState {
router: Arc<Router>,
router: Arc<dyn RouterTrait>,
client: Client,
is_pd_mode: bool, // Add flag to track PD mode
}
impl AppState {
pub fn new(
worker_urls: Vec<String>,
client: Client,
policy_config: PolicyConfig,
) -> Result<Self, String> {
// Check if this is PD mode from policy config
let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. });
pub fn new(router_config: RouterConfig, client: Client) -> Result<Self, String> {
// Use RouterFactory to create the appropriate router type
let router = RouterFactory::create_router(&router_config)?;
// Create router based on policy
let router = Arc::new(Router::new(worker_urls, policy_config)?);
Ok(Self {
router,
client,
is_pd_mode,
})
// Convert Box<dyn RouterTrait> to Arc<dyn RouterTrait>
let router = Arc::from(router);
Ok(Self { router, client })
}
}
@@ -76,65 +67,39 @@ fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error
}
}
#[get("/liveness")]
async fn liveness(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.liveness()
}
#[get("/readiness")]
async fn readiness(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.readiness()
}
#[get("/health")]
async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router
.route_to_first(&data.client, "/health", &req)
.await
data.router.health(&data.client, &req).await
}
#[get("/health_generate")]
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
// Check if we're in PD mode
if data.is_pd_mode {
// For PD mode, check health on all servers
data.router
.route_pd_health_generate(&data.client, &req)
.await
} else {
// Regular mode
data.router
.route_to_first(&data.client, "/health_generate", &req)
.await
}
data.router.health_generate(&data.client, &req).await
}
#[get("/get_server_info")]
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
if data.is_pd_mode {
// For PD mode, aggregate info from both prefill and decode servers
data.router.get_pd_server_info(&data.client, &req).await
} else {
// Regular mode - return first server's info
data.router
.route_to_first(&data.client, "/get_server_info", &req)
.await
}
data.router.get_server_info(&data.client, &req).await
}
#[get("/v1/models")]
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
if data.is_pd_mode {
// For PD mode, return models from the first prefill server
data.router.get_pd_models(&data.client, &req).await
} else {
// Regular mode
data.router
.route_to_first(&data.client, "/v1/models", &req)
.await
}
data.router.get_models(&data.client, &req).await
}
#[get("/get_model_info")]
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
if data.is_pd_mode {
// For PD mode, get model info from the first prefill server
data.router.get_pd_model_info(&data.client, &req).await
} else {
data.router
.route_to_first(&data.client, "/get_model_info", &req)
.await
}
data.router.get_model_info(&data.client, &req).await
}
#[post("/generate")]
@@ -143,24 +108,12 @@ async fn generate(
body: web::Json<GenerateRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let client = &state.client;
let router = &state.router;
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_generate_typed(&client, &req, pd_request, "/generate")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/generate")
.await)
}
let json_body = serde_json::to_value(body.into_inner())
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?;
Ok(state
.router
.route_generate(&state.client, &req, json_body)
.await)
}
#[post("/v1/chat/completions")]
@@ -169,24 +122,12 @@ async fn v1_chat_completions(
body: web::Json<ChatCompletionRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let client = &state.client;
let router = &state.router;
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/v1/chat/completions")
.await)
}
let json_body = serde_json::to_value(body.into_inner())
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?;
Ok(state
.router
.route_chat(&state.client, &req, json_body)
.await)
}
#[post("/v1/completions")]
@@ -195,24 +136,12 @@ async fn v1_completions(
body: web::Json<CompletionRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let client = &state.client;
let router = &state.router;
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_generate_typed(&client, &req, pd_request, "/v1/completions")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/v1/completions")
.await)
}
let json_body = serde_json::to_value(body.into_inner())
.map_err(|e| error::ErrorBadRequest(format!("Invalid JSON: {}", e)))?;
Ok(state
.router
.route_completion(&state.client, &req, json_body)
.await)
}
#[post("/add_worker")]
@@ -254,29 +183,19 @@ async fn remove_worker(
}
#[post("/flush_cache")]
async fn flush_cache(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
if data.is_pd_mode {
// For PD mode, flush cache on both prefill and decode servers
data.router.route_pd_flush_cache(&data.client).await
} else {
// Route to all workers for cache flushing
data.router
.route_to_all(&data.client, "/flush_cache", &req)
.await
}
async fn flush_cache(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.flush_cache(&data.client).await
}
#[get("/get_loads")]
async fn get_loads(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
// Get loads from all workers
data.router.get_all_loads(&data.client, &req).await
async fn get_loads(_req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router.get_worker_loads(&data.client).await
}
pub struct ServerConfig {
pub host: String,
pub port: u16,
pub worker_urls: Vec<String>,
pub policy_config: PolicyConfig,
pub router_config: RouterConfig,
pub max_payload_size: usize,
pub log_dir: Option<String>,
pub log_level: Option<String>,
@@ -324,8 +243,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
}
info!("🚧 Initializing router on {}:{}", config.host, config.port);
info!("🚧 Initializing workers on {:?}", config.worker_urls);
info!("🚧 Policy Config: {:?}", config.policy_config);
info!("🚧 Router mode: {:?}", config.router_config.mode);
info!("🚧 Policy: {:?}", config.router_config.policy);
info!(
"🚧 Max payload size: {} MB",
config.max_payload_size / (1024 * 1024)
@@ -345,12 +264,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.build()
.expect("Failed to create HTTP client");
let app_state_init = AppState::new(
config.worker_urls.clone(),
client.clone(),
config.policy_config.clone(),
)
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let app_state_init = AppState::new(config.router_config.clone(), client.clone())
.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
let router_arc = Arc::clone(&app_state_init.router);
let app_state = web::Data::new(app_state_init);
@@ -397,6 +312,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.service(v1_completions)
.service(v1_models)
.service(get_model_info)
.service(liveness)
.service(readiness)
.service(health)
.service(health_generate)
.service(get_server_info)