Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)
This commit is contained in:
@@ -1,12 +1,13 @@
|
||||
use crate::logging::{self, LoggingConfig};
|
||||
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
|
||||
use crate::prometheus::{self, PrometheusConfig};
|
||||
use crate::request_adapter::ToPdRequest;
|
||||
use crate::router::PolicyConfig;
|
||||
use crate::router::Router;
|
||||
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
|
||||
use actix_web::{
|
||||
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
|
||||
};
|
||||
use bytes::Bytes;
|
||||
use futures_util::StreamExt;
|
||||
use reqwest::Client;
|
||||
use std::collections::HashMap;
|
||||
@@ -20,6 +21,7 @@ use tracing::{error, info, warn, Level};
|
||||
pub struct AppState {
|
||||
router: Arc<Router>,
|
||||
client: Client,
|
||||
is_pd_mode: bool, // Add flag to track PD mode
|
||||
}
|
||||
|
||||
impl AppState {
|
||||
@@ -28,9 +30,16 @@ impl AppState {
|
||||
client: Client,
|
||||
policy_config: PolicyConfig,
|
||||
) -> Result<Self, String> {
|
||||
// Check if this is PD mode from policy config
|
||||
let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. });
|
||||
|
||||
// Create router based on policy
|
||||
let router = Arc::new(Router::new(worker_urls, policy_config)?);
|
||||
Ok(Self { router, client })
|
||||
Ok(Self {
|
||||
router,
|
||||
client,
|
||||
is_pd_mode,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -46,8 +55,25 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht
|
||||
}
|
||||
|
||||
// Custom error handler for JSON payload errors.
|
||||
fn json_error_handler(_err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
|
||||
error::ErrorPayloadTooLarge("Payload too large")
|
||||
fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
|
||||
error!("JSON payload error: {:?}", err);
|
||||
match &err {
|
||||
error::JsonPayloadError::OverflowKnownLength { length, limit } => {
|
||||
error!(
|
||||
"Payload too large: {} bytes exceeds limit of {} bytes",
|
||||
length, limit
|
||||
);
|
||||
error::ErrorPayloadTooLarge(format!(
|
||||
"Payload too large: {} bytes exceeds limit of {} bytes",
|
||||
length, limit
|
||||
))
|
||||
}
|
||||
error::JsonPayloadError::Overflow { limit } => {
|
||||
error!("Payload overflow: exceeds limit of {} bytes", limit);
|
||||
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
|
||||
}
|
||||
_ => error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)),
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/health")]
|
||||
@@ -59,59 +85,134 @@ async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
|
||||
#[get("/health_generate")]
|
||||
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/health_generate", &req)
|
||||
.await
|
||||
// Check if we're in PD mode
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, check health on all servers
|
||||
data.router
|
||||
.route_pd_health_generate(&data.client, &req)
|
||||
.await
|
||||
} else {
|
||||
// Regular mode
|
||||
data.router
|
||||
.route_to_first(&data.client, "/health_generate", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/get_server_info")]
|
||||
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_server_info", &req)
|
||||
.await
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, aggregate info from both prefill and decode servers
|
||||
data.router.get_pd_server_info(&data.client, &req).await
|
||||
} else {
|
||||
// Regular mode - return first server's info
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_server_info", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/v1/models")]
|
||||
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/v1/models", &req)
|
||||
.await
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, return models from the first prefill server
|
||||
data.router.get_pd_models(&data.client, &req).await
|
||||
} else {
|
||||
// Regular mode
|
||||
data.router
|
||||
.route_to_first(&data.client, "/v1/models", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/get_model_info")]
|
||||
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_model_info", &req)
|
||||
.await
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, get model info from the first prefill server
|
||||
data.router.get_pd_model_info(&data.client, &req).await
|
||||
} else {
|
||||
data.router
|
||||
.route_to_first(&data.client, "/get_model_info", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/generate")]
|
||||
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
|
||||
data.router
|
||||
.route_generate_request(&data.client, &req, &body, "/generate")
|
||||
.await
|
||||
async fn generate(
|
||||
req: HttpRequest,
|
||||
body: web::Json<GenerateRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_generate_typed(&client, &req, pd_request, "/generate")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/generate")
|
||||
.await)
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/v1/chat/completions")]
|
||||
async fn v1_chat_completions(
|
||||
req: HttpRequest,
|
||||
body: Bytes,
|
||||
data: web::Data<AppState>,
|
||||
) -> impl Responder {
|
||||
data.router
|
||||
.route_generate_request(&data.client, &req, &body, "/v1/chat/completions")
|
||||
.await
|
||||
body: web::Json<ChatCompletionRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/v1/chat/completions")
|
||||
.await)
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/v1/completions")]
|
||||
async fn v1_completions(
|
||||
req: HttpRequest,
|
||||
body: Bytes,
|
||||
data: web::Data<AppState>,
|
||||
) -> impl Responder {
|
||||
data.router
|
||||
.route_generate_request(&data.client, &req, &body, "/v1/completions")
|
||||
.await
|
||||
body: web::Json<CompletionRequest>,
|
||||
state: web::Data<AppState>,
|
||||
) -> Result<HttpResponse, Error> {
|
||||
let client = &state.client;
|
||||
let router = &state.router;
|
||||
|
||||
// Use typed request directly for both PD and regular routing
|
||||
if state.is_pd_mode {
|
||||
// For PD mode, convert to PD request with bootstrap
|
||||
let pd_request = body.into_inner().to_pd_request();
|
||||
|
||||
Ok(router
|
||||
.route_pd_generate_typed(&client, &req, pd_request, "/v1/completions")
|
||||
.await)
|
||||
} else {
|
||||
// For regular mode, use typed request directly
|
||||
let request = body.into_inner();
|
||||
Ok(router
|
||||
.route_typed_request(&client, &req, &request, "/v1/completions")
|
||||
.await)
|
||||
}
|
||||
}
|
||||
|
||||
#[post("/add_worker")]
|
||||
@@ -153,6 +254,25 @@ async fn remove_worker(
|
||||
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
|
||||
}
|
||||
|
||||
#[post("/flush_cache")]
|
||||
async fn flush_cache(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
if data.is_pd_mode {
|
||||
// For PD mode, flush cache on both prefill and decode servers
|
||||
data.router.route_pd_flush_cache(&data.client).await
|
||||
} else {
|
||||
// Route to all workers for cache flushing
|
||||
data.router
|
||||
.route_to_all(&data.client, "/flush_cache", &req)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[get("/get_loads")]
|
||||
async fn get_loads(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
|
||||
// Get loads from all workers
|
||||
data.router.get_all_loads(&data.client, &req).await
|
||||
}
|
||||
|
||||
pub struct ServerConfig {
|
||||
pub host: String,
|
||||
pub port: u16,
|
||||
@@ -163,6 +283,7 @@ pub struct ServerConfig {
|
||||
pub log_dir: Option<String>,
|
||||
pub service_discovery_config: Option<ServiceDiscoveryConfig>,
|
||||
pub prometheus_config: Option<PrometheusConfig>,
|
||||
pub request_timeout_secs: u64,
|
||||
}
|
||||
|
||||
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
@@ -215,6 +336,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
|
||||
let client = Client::builder()
|
||||
.pool_idle_timeout(Some(Duration::from_secs(50)))
|
||||
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
|
||||
.build()
|
||||
.expect("Failed to create HTTP client");
|
||||
|
||||
@@ -276,7 +398,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
|
||||
.service(add_worker)
|
||||
.service(remove_worker)
|
||||
.service(list_workers)
|
||||
// Default handler for unmatched routes.
|
||||
.service(flush_cache)
|
||||
.service(get_loads)
|
||||
.default_service(web::route().to(sink_handler))
|
||||
})
|
||||
.bind_auto_h2c((config.host, config.port))?
|
||||
|
||||
Reference in New Issue
Block a user