PD Rust LB (PO2) (#6437)
Co-authored-by: fzyzcjy <5236035+fzyzcjy@users.noreply.github.com>
This commit is contained in:
171
sgl-pdlb/src/server.rs
Normal file
171
sgl-pdlb/src/server.rs
Normal file
@@ -0,0 +1,171 @@
|
||||
use crate::io_struct::{ChatReqInput, GenerateReqInput};
|
||||
use crate::lb_state::{LBConfig, LBState};
|
||||
use crate::strategy_lb::EngineType;
|
||||
use actix_web::{HttpRequest, HttpResponse, HttpServer, get, post, web};
|
||||
use reqwest::Method;
|
||||
use serde_json::json;
|
||||
use std::io::Write;
|
||||
|
||||
#[get("/health")]
|
||||
pub async fn health(_req: HttpRequest, _: web::Data<LBState>) -> HttpResponse {
|
||||
HttpResponse::Ok().body("Ok")
|
||||
}
|
||||
|
||||
#[get("/health_generate")]
|
||||
pub async fn health_generate(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let servers = app_state.strategy_lb.get_all_servers();
|
||||
app_state
|
||||
.route_collect(&servers, Method::GET, "/health_generate", None)
|
||||
.await?;
|
||||
// FIXME: log the response
|
||||
Ok(HttpResponse::Ok().body("Health check passed on all servers"))
|
||||
}
|
||||
|
||||
#[post("/flush_cache")]
|
||||
pub async fn flush_cache(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let servers = app_state.strategy_lb.get_all_servers();
|
||||
app_state
|
||||
.route_collect(&servers, Method::POST, "/flush_cache", None)
|
||||
.await?;
|
||||
Ok(HttpResponse::Ok().body("Cache flushed on all servers"))
|
||||
}
|
||||
|
||||
#[get("/get_model_info")]
|
||||
pub async fn get_model_info(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
// Return the first server's model info
|
||||
let engine = app_state.strategy_lb.get_one_server();
|
||||
app_state
|
||||
.route_one(&engine, Method::GET, "/get_model_info", None, false)
|
||||
.await?
|
||||
.into()
|
||||
}
|
||||
|
||||
#[post("/generate")]
|
||||
pub async fn generate(
|
||||
_req: HttpRequest,
|
||||
req: web::Json<GenerateReqInput>,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
app_state
|
||||
.generate("/generate", Box::new(req.into_inner()))
|
||||
.await
|
||||
}
|
||||
|
||||
#[post("/v1/chat/completions")]
|
||||
pub async fn chat_completions(
|
||||
_req: HttpRequest,
|
||||
req: web::Json<ChatReqInput>,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
app_state
|
||||
.generate("/v1/chat/completions", Box::new(req.into_inner()))
|
||||
.await
|
||||
}
|
||||
|
||||
#[get("/get_server_info")]
|
||||
pub async fn get_server_info(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let servers = app_state.strategy_lb.get_all_servers();
|
||||
let responses = app_state
|
||||
.route_collect(&servers, Method::GET, "/get_server_info", None)
|
||||
.await?;
|
||||
let mut prefill_infos = Vec::new();
|
||||
let mut decode_infos = Vec::new();
|
||||
for (i, resp) in responses.iter().enumerate() {
|
||||
let json = resp.to_json()?;
|
||||
match servers[i].engine_type {
|
||||
EngineType::Prefill => prefill_infos.push(json),
|
||||
EngineType::Decode => decode_infos.push(json),
|
||||
}
|
||||
}
|
||||
Ok(HttpResponse::Ok().json(json!({
|
||||
"prefill": prefill_infos,
|
||||
"decode": decode_infos,
|
||||
})))
|
||||
}
|
||||
|
||||
#[get("/get_loads")]
|
||||
pub async fn get_loads(
|
||||
_req: HttpRequest,
|
||||
app_state: web::Data<LBState>,
|
||||
) -> Result<HttpResponse, actix_web::Error> {
|
||||
let (prefill_loads, decode_loads) = app_state.get_engine_loads().await?;
|
||||
Ok(HttpResponse::Ok().json(json!({
|
||||
"prefill": prefill_loads.into_iter().map(|l| l.to_json()).collect::<Vec<_>>(),
|
||||
"decode": decode_loads.into_iter().map(|l| l.to_json()).collect::<Vec<_>>()
|
||||
})))
|
||||
}
|
||||
|
||||
pub async fn periodic_logging(lb_state: LBState) {
|
||||
// FIXME: currently we can just clone the lb_state to log as the lb is stateless
|
||||
loop {
|
||||
tokio::time::sleep(std::time::Duration::from_secs(lb_state.log_interval)).await;
|
||||
let (prefill_loads, decode_loads) = match lb_state.get_engine_loads().await {
|
||||
Ok((prefill_loads, decode_loads)) => (prefill_loads, decode_loads),
|
||||
Err(e) => {
|
||||
log::error!("Failed to get engine loads: {}", e);
|
||||
continue;
|
||||
}
|
||||
};
|
||||
let prefill_loads = prefill_loads
|
||||
.into_iter()
|
||||
.map(|l| l.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
let decode_loads = decode_loads
|
||||
.into_iter()
|
||||
.map(|l| l.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
log::info!("Prefill loads: {}", prefill_loads.join(", "));
|
||||
log::info!("Decode loads: {}", decode_loads.join(", "));
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn startup(lb_config: LBConfig, lb_state: LBState) -> std::io::Result<()> {
|
||||
let app_state = web::Data::new(lb_state);
|
||||
|
||||
println!("Starting server at {}:{}", lb_config.host, lb_config.port);
|
||||
|
||||
// default level is info
|
||||
env_logger::Builder::new()
|
||||
.format(|buf, record| {
|
||||
writeln!(
|
||||
buf,
|
||||
"{} - {} - {}",
|
||||
chrono::Local::now().format("%Y-%m-%d %H:%M:%S"),
|
||||
record.level(),
|
||||
record.args()
|
||||
)
|
||||
})
|
||||
.filter(None, log::LevelFilter::Info)
|
||||
.init();
|
||||
|
||||
HttpServer::new(move || {
|
||||
actix_web::App::new()
|
||||
.wrap(actix_web::middleware::Logger::default())
|
||||
.app_data(app_state.clone())
|
||||
.service(health)
|
||||
.service(health_generate)
|
||||
.service(flush_cache)
|
||||
.service(get_model_info)
|
||||
.service(get_server_info)
|
||||
.service(get_loads)
|
||||
.service(generate)
|
||||
.service(chat_completions)
|
||||
})
|
||||
.bind((lb_config.host, lb_config.port))?
|
||||
.run()
|
||||
.await?;
|
||||
|
||||
std::io::Result::Ok(())
|
||||
}
|
||||
Reference in New Issue
Block a user