Merge PDLB (Prefill-Decode Load Balancer) into SGLang Router (#7096)

This commit is contained in:
Simo Lin
2025-06-18 11:28:15 -07:00
committed by GitHub
parent 712bf9ec9b
commit 09ae5b20f3
13 changed files with 4045 additions and 187 deletions

View File

@@ -1,12 +1,13 @@
use crate::logging::{self, LoggingConfig};
use crate::openai_api_types::{ChatCompletionRequest, CompletionRequest, GenerateRequest};
use crate::prometheus::{self, PrometheusConfig};
use crate::request_adapter::ToPdRequest;
use crate::router::PolicyConfig;
use crate::router::Router;
use crate::service_discovery::{start_service_discovery, ServiceDiscoveryConfig};
use actix_web::{
error, get, post, web, App, Error, HttpRequest, HttpResponse, HttpServer, Responder,
};
use bytes::Bytes;
use futures_util::StreamExt;
use reqwest::Client;
use std::collections::HashMap;
@@ -20,6 +21,7 @@ use tracing::{error, info, warn, Level};
pub struct AppState {
router: Arc<Router>,
client: Client,
is_pd_mode: bool, // Add flag to track PD mode
}
impl AppState {
@@ -28,9 +30,16 @@ impl AppState {
client: Client,
policy_config: PolicyConfig,
) -> Result<Self, String> {
// Check if this is PD mode from policy config
let is_pd_mode = matches!(policy_config, PolicyConfig::PrefillDecodeConfig { .. });
// Create router based on policy
let router = Arc::new(Router::new(worker_urls, policy_config)?);
Ok(Self { router, client })
Ok(Self {
router,
client,
is_pd_mode,
})
}
}
@@ -46,8 +55,25 @@ async fn sink_handler(_req: HttpRequest, mut payload: web::Payload) -> Result<Ht
}
// Custom error handler for JSON payload errors.
fn json_error_handler(_err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
error::ErrorPayloadTooLarge("Payload too large")
fn json_error_handler(err: error::JsonPayloadError, _req: &HttpRequest) -> Error {
error!("JSON payload error: {:?}", err);
match &err {
error::JsonPayloadError::OverflowKnownLength { length, limit } => {
error!(
"Payload too large: {} bytes exceeds limit of {} bytes",
length, limit
);
error::ErrorPayloadTooLarge(format!(
"Payload too large: {} bytes exceeds limit of {} bytes",
length, limit
))
}
error::JsonPayloadError::Overflow { limit } => {
error!("Payload overflow: exceeds limit of {} bytes", limit);
error::ErrorPayloadTooLarge(format!("Payload exceeds limit of {} bytes", limit))
}
_ => error::ErrorBadRequest(format!("Invalid JSON payload: {}", err)),
}
}
#[get("/health")]
@@ -59,59 +85,134 @@ async fn health(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
#[get("/health_generate")]
async fn health_generate(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router
.route_to_first(&data.client, "/health_generate", &req)
.await
// Check if we're in PD mode
if data.is_pd_mode {
// For PD mode, check health on all servers
data.router
.route_pd_health_generate(&data.client, &req)
.await
} else {
// Regular mode
data.router
.route_to_first(&data.client, "/health_generate", &req)
.await
}
}
#[get("/get_server_info")]
async fn get_server_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router
.route_to_first(&data.client, "/get_server_info", &req)
.await
if data.is_pd_mode {
// For PD mode, aggregate info from both prefill and decode servers
data.router.get_pd_server_info(&data.client, &req).await
} else {
// Regular mode - return first server's info
data.router
.route_to_first(&data.client, "/get_server_info", &req)
.await
}
}
#[get("/v1/models")]
async fn v1_models(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router
.route_to_first(&data.client, "/v1/models", &req)
.await
if data.is_pd_mode {
// For PD mode, return models from the first prefill server
data.router.get_pd_models(&data.client, &req).await
} else {
// Regular mode
data.router
.route_to_first(&data.client, "/v1/models", &req)
.await
}
}
#[get("/get_model_info")]
async fn get_model_info(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
data.router
.route_to_first(&data.client, "/get_model_info", &req)
.await
if data.is_pd_mode {
// For PD mode, get model info from the first prefill server
data.router.get_pd_model_info(&data.client, &req).await
} else {
data.router
.route_to_first(&data.client, "/get_model_info", &req)
.await
}
}
#[post("/generate")]
async fn generate(req: HttpRequest, body: Bytes, data: web::Data<AppState>) -> impl Responder {
data.router
.route_generate_request(&data.client, &req, &body, "/generate")
.await
async fn generate(
req: HttpRequest,
body: web::Json<GenerateRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let client = &state.client;
let router = &state.router;
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_generate_typed(&client, &req, pd_request, "/generate")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/generate")
.await)
}
}
#[post("/v1/chat/completions")]
async fn v1_chat_completions(
req: HttpRequest,
body: Bytes,
data: web::Data<AppState>,
) -> impl Responder {
data.router
.route_generate_request(&data.client, &req, &body, "/v1/chat/completions")
.await
body: web::Json<ChatCompletionRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let client = &state.client;
let router = &state.router;
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_chat_typed(&client, &req, pd_request, "/v1/chat/completions")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/v1/chat/completions")
.await)
}
}
#[post("/v1/completions")]
async fn v1_completions(
req: HttpRequest,
body: Bytes,
data: web::Data<AppState>,
) -> impl Responder {
data.router
.route_generate_request(&data.client, &req, &body, "/v1/completions")
.await
body: web::Json<CompletionRequest>,
state: web::Data<AppState>,
) -> Result<HttpResponse, Error> {
let client = &state.client;
let router = &state.router;
// Use typed request directly for both PD and regular routing
if state.is_pd_mode {
// For PD mode, convert to PD request with bootstrap
let pd_request = body.into_inner().to_pd_request();
Ok(router
.route_pd_generate_typed(&client, &req, pd_request, "/v1/completions")
.await)
} else {
// For regular mode, use typed request directly
let request = body.into_inner();
Ok(router
.route_typed_request(&client, &req, &request, "/v1/completions")
.await)
}
}
#[post("/add_worker")]
@@ -153,6 +254,25 @@ async fn remove_worker(
HttpResponse::Ok().body(format!("Successfully removed worker: {}", worker_url))
}
#[post("/flush_cache")]
async fn flush_cache(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
if data.is_pd_mode {
// For PD mode, flush cache on both prefill and decode servers
data.router.route_pd_flush_cache(&data.client).await
} else {
// Route to all workers for cache flushing
data.router
.route_to_all(&data.client, "/flush_cache", &req)
.await
}
}
#[get("/get_loads")]
async fn get_loads(req: HttpRequest, data: web::Data<AppState>) -> impl Responder {
// Get loads from all workers
data.router.get_all_loads(&data.client, &req).await
}
pub struct ServerConfig {
pub host: String,
pub port: u16,
@@ -163,6 +283,7 @@ pub struct ServerConfig {
pub log_dir: Option<String>,
pub service_discovery_config: Option<ServiceDiscoveryConfig>,
pub prometheus_config: Option<PrometheusConfig>,
pub request_timeout_secs: u64,
}
pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
@@ -215,6 +336,7 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
let client = Client::builder()
.pool_idle_timeout(Some(Duration::from_secs(50)))
.timeout(Duration::from_secs(config.request_timeout_secs)) // Use configurable timeout
.build()
.expect("Failed to create HTTP client");
@@ -276,7 +398,8 @@ pub async fn startup(config: ServerConfig) -> std::io::Result<()> {
.service(add_worker)
.service(remove_worker)
.service(list_workers)
// Default handler for unmatched routes.
.service(flush_cache)
.service(get_loads)
.default_service(web::route().to(sink_handler))
})
.bind_auto_h2c((config.host, config.port))?